Merge tag 'i3c/for-5.11' of git://git.kernel.org/pub/scm/linux/kernel/git/i3c/linux
[linux-2.6-microblaze.git] / drivers / hv / channel.c
index 3ebda77..6fb0c76 100644 (file)
 
 #include "hyperv_vmbus.h"
 
-#define NUM_PAGES_SPANNED(addr, len) \
-((PAGE_ALIGN(addr + len) >> PAGE_SHIFT) - (addr >> PAGE_SHIFT))
+/*
+ * hv_gpadl_size - Return the real size of a gpadl, the size that Hyper-V uses
+ *
+ * For BUFFER gpadl, Hyper-V uses the exact same size as the guest does.
+ *
+ * For RING gpadl, in each ring, the guest uses one PAGE_SIZE as the header
+ * (because of the alignment requirement), however, the hypervisor only
+ * uses the first HV_HYP_PAGE_SIZE as the header, therefore leaving a
+ * (PAGE_SIZE - HV_HYP_PAGE_SIZE) gap. And since there are two rings in a
+ * ringbuffer, the total size for a RING gpadl that Hyper-V uses is the
+ * total size that the guest uses minus twice of the gap size.
+ */
+static inline u32 hv_gpadl_size(enum hv_gpadl_type type, u32 size)
+{
+       switch (type) {
+       case HV_GPADL_BUFFER:
+               return size;
+       case HV_GPADL_RING:
+               /* The size of a ringbuffer must be page-aligned */
+               BUG_ON(size % PAGE_SIZE);
+               /*
+                * Two things to notice here:
+                * 1) We're processing two ring buffers as a unit
+                * 2) We're skipping any space larger than HV_HYP_PAGE_SIZE in
+                * the first guest-size page of each of the two ring buffers.
+                * So we effectively subtract out two guest-size pages, and add
+                * back two Hyper-V size pages.
+                */
+               return size - 2 * (PAGE_SIZE - HV_HYP_PAGE_SIZE);
+       }
+       BUG();
+       return 0;
+}
 
-static unsigned long virt_to_hvpfn(void *addr)
+/*
+ * hv_ring_gpadl_send_hvpgoffset - Calculate the send offset (in unit of
+ *                                 HV_HYP_PAGE) in a ring gpadl based on the
+ *                                 offset in the guest
+ *
+ * @offset: the offset (in bytes) where the send ringbuffer starts in the
+ *               virtual address space of the guest
+ */
+static inline u32 hv_ring_gpadl_send_hvpgoffset(u32 offset)
 {
-       phys_addr_t paddr;
 
-       if (is_vmalloc_addr(addr))
-               paddr = page_to_phys(vmalloc_to_page(addr)) +
-                                        offset_in_page(addr);
-       else
-               paddr = __pa(addr);
+       /*
+        * For RING gpadl, in each ring, the guest uses one PAGE_SIZE as the
+        * header (because of the alignment requirement), however, the
+        * hypervisor only uses the first HV_HYP_PAGE_SIZE as the header,
+        * therefore leaving a (PAGE_SIZE - HV_HYP_PAGE_SIZE) gap.
+        *
+        * And to calculate the effective send offset in gpadl, we need to
+        * substract this gap.
+        */
+       return (offset - (PAGE_SIZE - HV_HYP_PAGE_SIZE)) >> HV_HYP_PAGE_SHIFT;
+}
+
+/*
+ * hv_gpadl_hvpfn - Return the Hyper-V page PFN of the @i th Hyper-V page in
+ *                  the gpadl
+ *
+ * @type: the type of the gpadl
+ * @kbuffer: the pointer to the gpadl in the guest
+ * @size: the total size (in bytes) of the gpadl
+ * @send_offset: the offset (in bytes) where the send ringbuffer starts in the
+ *               virtual address space of the guest
+ * @i: the index
+ */
+static inline u64 hv_gpadl_hvpfn(enum hv_gpadl_type type, void *kbuffer,
+                                u32 size, u32 send_offset, int i)
+{
+       int send_idx = hv_ring_gpadl_send_hvpgoffset(send_offset);
+       unsigned long delta = 0UL;
+
+       switch (type) {
+       case HV_GPADL_BUFFER:
+               break;
+       case HV_GPADL_RING:
+               if (i == 0)
+                       delta = 0;
+               else if (i <= send_idx)
+                       delta = PAGE_SIZE - HV_HYP_PAGE_SIZE;
+               else
+                       delta = 2 * (PAGE_SIZE - HV_HYP_PAGE_SIZE);
+               break;
+       default:
+               BUG();
+               break;
+       }
 
-       return  paddr >> PAGE_SHIFT;
+       return virt_to_hvpfn(kbuffer + delta + (HV_HYP_PAGE_SIZE * i));
 }
 
 /*
@@ -112,160 +189,6 @@ int vmbus_alloc_ring(struct vmbus_channel *newchannel,
 }
 EXPORT_SYMBOL_GPL(vmbus_alloc_ring);
 
-static int __vmbus_open(struct vmbus_channel *newchannel,
-                      void *userdata, u32 userdatalen,
-                      void (*onchannelcallback)(void *context), void *context)
-{
-       struct vmbus_channel_open_channel *open_msg;
-       struct vmbus_channel_msginfo *open_info = NULL;
-       struct page *page = newchannel->ringbuffer_page;
-       u32 send_pages, recv_pages;
-       unsigned long flags;
-       int err;
-
-       if (userdatalen > MAX_USER_DEFINED_BYTES)
-               return -EINVAL;
-
-       send_pages = newchannel->ringbuffer_send_offset;
-       recv_pages = newchannel->ringbuffer_pagecount - send_pages;
-
-       if (newchannel->state != CHANNEL_OPEN_STATE)
-               return -EINVAL;
-
-       newchannel->state = CHANNEL_OPENING_STATE;
-       newchannel->onchannel_callback = onchannelcallback;
-       newchannel->channel_callback_context = context;
-
-       err = hv_ringbuffer_init(&newchannel->outbound, page, send_pages);
-       if (err)
-               goto error_clean_ring;
-
-       err = hv_ringbuffer_init(&newchannel->inbound,
-                                &page[send_pages], recv_pages);
-       if (err)
-               goto error_clean_ring;
-
-       /* Establish the gpadl for the ring buffer */
-       newchannel->ringbuffer_gpadlhandle = 0;
-
-       err = vmbus_establish_gpadl(newchannel,
-                                   page_address(newchannel->ringbuffer_page),
-                                   (send_pages + recv_pages) << PAGE_SHIFT,
-                                   &newchannel->ringbuffer_gpadlhandle);
-       if (err)
-               goto error_clean_ring;
-
-       /* Create and init the channel open message */
-       open_info = kmalloc(sizeof(*open_info) +
-                          sizeof(struct vmbus_channel_open_channel),
-                          GFP_KERNEL);
-       if (!open_info) {
-               err = -ENOMEM;
-               goto error_free_gpadl;
-       }
-
-       init_completion(&open_info->waitevent);
-       open_info->waiting_channel = newchannel;
-
-       open_msg = (struct vmbus_channel_open_channel *)open_info->msg;
-       open_msg->header.msgtype = CHANNELMSG_OPENCHANNEL;
-       open_msg->openid = newchannel->offermsg.child_relid;
-       open_msg->child_relid = newchannel->offermsg.child_relid;
-       open_msg->ringbuffer_gpadlhandle = newchannel->ringbuffer_gpadlhandle;
-       open_msg->downstream_ringbuffer_pageoffset = newchannel->ringbuffer_send_offset;
-       open_msg->target_vp = hv_cpu_number_to_vp_number(newchannel->target_cpu);
-
-       if (userdatalen)
-               memcpy(open_msg->userdata, userdata, userdatalen);
-
-       spin_lock_irqsave(&vmbus_connection.channelmsg_lock, flags);
-       list_add_tail(&open_info->msglistentry,
-                     &vmbus_connection.chn_msg_list);
-       spin_unlock_irqrestore(&vmbus_connection.channelmsg_lock, flags);
-
-       if (newchannel->rescind) {
-               err = -ENODEV;
-               goto error_free_info;
-       }
-
-       err = vmbus_post_msg(open_msg,
-                            sizeof(struct vmbus_channel_open_channel), true);
-
-       trace_vmbus_open(open_msg, err);
-
-       if (err != 0)
-               goto error_clean_msglist;
-
-       wait_for_completion(&open_info->waitevent);
-
-       spin_lock_irqsave(&vmbus_connection.channelmsg_lock, flags);
-       list_del(&open_info->msglistentry);
-       spin_unlock_irqrestore(&vmbus_connection.channelmsg_lock, flags);
-
-       if (newchannel->rescind) {
-               err = -ENODEV;
-               goto error_free_info;
-       }
-
-       if (open_info->response.open_result.status) {
-               err = -EAGAIN;
-               goto error_free_info;
-       }
-
-       newchannel->state = CHANNEL_OPENED_STATE;
-       kfree(open_info);
-       return 0;
-
-error_clean_msglist:
-       spin_lock_irqsave(&vmbus_connection.channelmsg_lock, flags);
-       list_del(&open_info->msglistentry);
-       spin_unlock_irqrestore(&vmbus_connection.channelmsg_lock, flags);
-error_free_info:
-       kfree(open_info);
-error_free_gpadl:
-       vmbus_teardown_gpadl(newchannel, newchannel->ringbuffer_gpadlhandle);
-       newchannel->ringbuffer_gpadlhandle = 0;
-error_clean_ring:
-       hv_ringbuffer_cleanup(&newchannel->outbound);
-       hv_ringbuffer_cleanup(&newchannel->inbound);
-       newchannel->state = CHANNEL_OPEN_STATE;
-       return err;
-}
-
-/*
- * vmbus_connect_ring - Open the channel but reuse ring buffer
- */
-int vmbus_connect_ring(struct vmbus_channel *newchannel,
-                      void (*onchannelcallback)(void *context), void *context)
-{
-       return  __vmbus_open(newchannel, NULL, 0, onchannelcallback, context);
-}
-EXPORT_SYMBOL_GPL(vmbus_connect_ring);
-
-/*
- * vmbus_open - Open the specified channel.
- */
-int vmbus_open(struct vmbus_channel *newchannel,
-              u32 send_ringbuffer_size, u32 recv_ringbuffer_size,
-              void *userdata, u32 userdatalen,
-              void (*onchannelcallback)(void *context), void *context)
-{
-       int err;
-
-       err = vmbus_alloc_ring(newchannel, send_ringbuffer_size,
-                              recv_ringbuffer_size);
-       if (err)
-               return err;
-
-       err = __vmbus_open(newchannel, userdata, userdatalen,
-                          onchannelcallback, context);
-       if (err)
-               vmbus_free_ring(newchannel);
-
-       return err;
-}
-EXPORT_SYMBOL_GPL(vmbus_open);
-
 /* Used for Hyper-V Socket: a guest client's connect() to the host */
 int vmbus_send_tl_connect_request(const guid_t *shv_guest_servie_id,
                                  const guid_t *shv_host_servie_id)
@@ -317,7 +240,8 @@ EXPORT_SYMBOL_GPL(vmbus_send_modifychannel);
 /*
  * create_gpadl_header - Creates a gpadl for the specified buffer
  */
-static int create_gpadl_header(void *kbuffer, u32 size,
+static int create_gpadl_header(enum hv_gpadl_type type, void *kbuffer,
+                              u32 size, u32 send_offset,
                               struct vmbus_channel_msginfo **msginfo)
 {
        int i;
@@ -330,7 +254,7 @@ static int create_gpadl_header(void *kbuffer, u32 size,
 
        int pfnsum, pfncount, pfnleft, pfncurr, pfnsize;
 
-       pagecount = size >> PAGE_SHIFT;
+       pagecount = hv_gpadl_size(type, size) >> HV_HYP_PAGE_SHIFT;
 
        /* do we need a gpadl body msg */
        pfnsize = MAX_SIZE_CHANNEL_MESSAGE -
@@ -357,10 +281,10 @@ static int create_gpadl_header(void *kbuffer, u32 size,
                gpadl_header->range_buflen = sizeof(struct gpa_range) +
                                         pagecount * sizeof(u64);
                gpadl_header->range[0].byte_offset = 0;
-               gpadl_header->range[0].byte_count = size;
+               gpadl_header->range[0].byte_count = hv_gpadl_size(type, size);
                for (i = 0; i < pfncount; i++)
-                       gpadl_header->range[0].pfn_array[i] = virt_to_hvpfn(
-                               kbuffer + PAGE_SIZE * i);
+                       gpadl_header->range[0].pfn_array[i] = hv_gpadl_hvpfn(
+                               type, kbuffer, size, send_offset, i);
                *msginfo = msgheader;
 
                pfnsum = pfncount;
@@ -411,8 +335,8 @@ static int create_gpadl_header(void *kbuffer, u32 size,
                         * so the hypervisor guarantees that this is ok.
                         */
                        for (i = 0; i < pfncurr; i++)
-                               gpadl_body->pfn[i] = virt_to_hvpfn(
-                                       kbuffer + PAGE_SIZE * (pfnsum + i));
+                               gpadl_body->pfn[i] = hv_gpadl_hvpfn(type,
+                                       kbuffer, size, send_offset, pfnsum + i);
 
                        /* add to msg header */
                        list_add_tail(&msgbody->msglistentry,
@@ -438,10 +362,10 @@ static int create_gpadl_header(void *kbuffer, u32 size,
                gpadl_header->range_buflen = sizeof(struct gpa_range) +
                                         pagecount * sizeof(u64);
                gpadl_header->range[0].byte_offset = 0;
-               gpadl_header->range[0].byte_count = size;
+               gpadl_header->range[0].byte_count = hv_gpadl_size(type, size);
                for (i = 0; i < pagecount; i++)
-                       gpadl_header->range[0].pfn_array[i] = virt_to_hvpfn(
-                               kbuffer + PAGE_SIZE * i);
+                       gpadl_header->range[0].pfn_array[i] = hv_gpadl_hvpfn(
+                               type, kbuffer, size, send_offset, i);
 
                *msginfo = msgheader;
        }
@@ -454,15 +378,20 @@ nomem:
 }
 
 /*
- * vmbus_establish_gpadl - Establish a GPADL for the specified buffer
+ * __vmbus_establish_gpadl - Establish a GPADL for a buffer or ringbuffer
  *
  * @channel: a channel
+ * @type: the type of the corresponding GPADL, only meaningful for the guest.
  * @kbuffer: from kmalloc or vmalloc
  * @size: page-size multiple
+ * @send_offset: the offset (in bytes) where the send ring buffer starts,
+ *              should be 0 for BUFFER type gpadl
  * @gpadl_handle: some funky thing
  */
-int vmbus_establish_gpadl(struct vmbus_channel *channel, void *kbuffer,
-                              u32 size, u32 *gpadl_handle)
+static int __vmbus_establish_gpadl(struct vmbus_channel *channel,
+                                  enum hv_gpadl_type type, void *kbuffer,
+                                  u32 size, u32 send_offset,
+                                  u32 *gpadl_handle)
 {
        struct vmbus_channel_gpadl_header *gpadlmsg;
        struct vmbus_channel_gpadl_body *gpadl_body;
@@ -476,7 +405,7 @@ int vmbus_establish_gpadl(struct vmbus_channel *channel, void *kbuffer,
        next_gpadl_handle =
                (atomic_inc_return(&vmbus_connection.next_gpadl_handle) - 1);
 
-       ret = create_gpadl_header(kbuffer, size, &msginfo);
+       ret = create_gpadl_header(type, kbuffer, size, send_offset, &msginfo);
        if (ret)
                return ret;
 
@@ -557,8 +486,255 @@ cleanup:
        kfree(msginfo);
        return ret;
 }
+
+/*
+ * vmbus_establish_gpadl - Establish a GPADL for the specified buffer
+ *
+ * @channel: a channel
+ * @kbuffer: from kmalloc or vmalloc
+ * @size: page-size multiple
+ * @gpadl_handle: some funky thing
+ */
+int vmbus_establish_gpadl(struct vmbus_channel *channel, void *kbuffer,
+                         u32 size, u32 *gpadl_handle)
+{
+       return __vmbus_establish_gpadl(channel, HV_GPADL_BUFFER, kbuffer, size,
+                                      0U, gpadl_handle);
+}
 EXPORT_SYMBOL_GPL(vmbus_establish_gpadl);
 
+/**
+ * request_arr_init - Allocates memory for the requestor array. Each slot
+ * keeps track of the next available slot in the array. Initially, each
+ * slot points to the next one (as in a Linked List). The last slot
+ * does not point to anything, so its value is U64_MAX by default.
+ * @size The size of the array
+ */
+static u64 *request_arr_init(u32 size)
+{
+       int i;
+       u64 *req_arr;
+
+       req_arr = kcalloc(size, sizeof(u64), GFP_KERNEL);
+       if (!req_arr)
+               return NULL;
+
+       for (i = 0; i < size - 1; i++)
+               req_arr[i] = i + 1;
+
+       /* Last slot (no more available slots) */
+       req_arr[i] = U64_MAX;
+
+       return req_arr;
+}
+
+/*
+ * vmbus_alloc_requestor - Initializes @rqstor's fields.
+ * Index 0 is the first free slot
+ * @size: Size of the requestor array
+ */
+static int vmbus_alloc_requestor(struct vmbus_requestor *rqstor, u32 size)
+{
+       u64 *rqst_arr;
+       unsigned long *bitmap;
+
+       rqst_arr = request_arr_init(size);
+       if (!rqst_arr)
+               return -ENOMEM;
+
+       bitmap = bitmap_zalloc(size, GFP_KERNEL);
+       if (!bitmap) {
+               kfree(rqst_arr);
+               return -ENOMEM;
+       }
+
+       rqstor->req_arr = rqst_arr;
+       rqstor->req_bitmap = bitmap;
+       rqstor->size = size;
+       rqstor->next_request_id = 0;
+       spin_lock_init(&rqstor->req_lock);
+
+       return 0;
+}
+
+/*
+ * vmbus_free_requestor - Frees memory allocated for @rqstor
+ * @rqstor: Pointer to the requestor struct
+ */
+static void vmbus_free_requestor(struct vmbus_requestor *rqstor)
+{
+       kfree(rqstor->req_arr);
+       bitmap_free(rqstor->req_bitmap);
+}
+
+static int __vmbus_open(struct vmbus_channel *newchannel,
+                      void *userdata, u32 userdatalen,
+                      void (*onchannelcallback)(void *context), void *context)
+{
+       struct vmbus_channel_open_channel *open_msg;
+       struct vmbus_channel_msginfo *open_info = NULL;
+       struct page *page = newchannel->ringbuffer_page;
+       u32 send_pages, recv_pages;
+       unsigned long flags;
+       int err;
+
+       if (userdatalen > MAX_USER_DEFINED_BYTES)
+               return -EINVAL;
+
+       send_pages = newchannel->ringbuffer_send_offset;
+       recv_pages = newchannel->ringbuffer_pagecount - send_pages;
+
+       if (newchannel->state != CHANNEL_OPEN_STATE)
+               return -EINVAL;
+
+       /* Create and init requestor */
+       if (newchannel->rqstor_size) {
+               if (vmbus_alloc_requestor(&newchannel->requestor, newchannel->rqstor_size))
+                       return -ENOMEM;
+       }
+
+       newchannel->state = CHANNEL_OPENING_STATE;
+       newchannel->onchannel_callback = onchannelcallback;
+       newchannel->channel_callback_context = context;
+
+       err = hv_ringbuffer_init(&newchannel->outbound, page, send_pages);
+       if (err)
+               goto error_clean_ring;
+
+       err = hv_ringbuffer_init(&newchannel->inbound,
+                                &page[send_pages], recv_pages);
+       if (err)
+               goto error_clean_ring;
+
+       /* Establish the gpadl for the ring buffer */
+       newchannel->ringbuffer_gpadlhandle = 0;
+
+       err = __vmbus_establish_gpadl(newchannel, HV_GPADL_RING,
+                                     page_address(newchannel->ringbuffer_page),
+                                     (send_pages + recv_pages) << PAGE_SHIFT,
+                                     newchannel->ringbuffer_send_offset << PAGE_SHIFT,
+                                     &newchannel->ringbuffer_gpadlhandle);
+       if (err)
+               goto error_clean_ring;
+
+       /* Create and init the channel open message */
+       open_info = kmalloc(sizeof(*open_info) +
+                          sizeof(struct vmbus_channel_open_channel),
+                          GFP_KERNEL);
+       if (!open_info) {
+               err = -ENOMEM;
+               goto error_free_gpadl;
+       }
+
+       init_completion(&open_info->waitevent);
+       open_info->waiting_channel = newchannel;
+
+       open_msg = (struct vmbus_channel_open_channel *)open_info->msg;
+       open_msg->header.msgtype = CHANNELMSG_OPENCHANNEL;
+       open_msg->openid = newchannel->offermsg.child_relid;
+       open_msg->child_relid = newchannel->offermsg.child_relid;
+       open_msg->ringbuffer_gpadlhandle = newchannel->ringbuffer_gpadlhandle;
+       /*
+        * The unit of ->downstream_ringbuffer_pageoffset is HV_HYP_PAGE and
+        * the unit of ->ringbuffer_send_offset (i.e. send_pages) is PAGE, so
+        * here we calculate it into HV_HYP_PAGE.
+        */
+       open_msg->downstream_ringbuffer_pageoffset =
+               hv_ring_gpadl_send_hvpgoffset(send_pages << PAGE_SHIFT);
+       open_msg->target_vp = hv_cpu_number_to_vp_number(newchannel->target_cpu);
+
+       if (userdatalen)
+               memcpy(open_msg->userdata, userdata, userdatalen);
+
+       spin_lock_irqsave(&vmbus_connection.channelmsg_lock, flags);
+       list_add_tail(&open_info->msglistentry,
+                     &vmbus_connection.chn_msg_list);
+       spin_unlock_irqrestore(&vmbus_connection.channelmsg_lock, flags);
+
+       if (newchannel->rescind) {
+               err = -ENODEV;
+               goto error_free_info;
+       }
+
+       err = vmbus_post_msg(open_msg,
+                            sizeof(struct vmbus_channel_open_channel), true);
+
+       trace_vmbus_open(open_msg, err);
+
+       if (err != 0)
+               goto error_clean_msglist;
+
+       wait_for_completion(&open_info->waitevent);
+
+       spin_lock_irqsave(&vmbus_connection.channelmsg_lock, flags);
+       list_del(&open_info->msglistentry);
+       spin_unlock_irqrestore(&vmbus_connection.channelmsg_lock, flags);
+
+       if (newchannel->rescind) {
+               err = -ENODEV;
+               goto error_free_info;
+       }
+
+       if (open_info->response.open_result.status) {
+               err = -EAGAIN;
+               goto error_free_info;
+       }
+
+       newchannel->state = CHANNEL_OPENED_STATE;
+       kfree(open_info);
+       return 0;
+
+error_clean_msglist:
+       spin_lock_irqsave(&vmbus_connection.channelmsg_lock, flags);
+       list_del(&open_info->msglistentry);
+       spin_unlock_irqrestore(&vmbus_connection.channelmsg_lock, flags);
+error_free_info:
+       kfree(open_info);
+error_free_gpadl:
+       vmbus_teardown_gpadl(newchannel, newchannel->ringbuffer_gpadlhandle);
+       newchannel->ringbuffer_gpadlhandle = 0;
+error_clean_ring:
+       hv_ringbuffer_cleanup(&newchannel->outbound);
+       hv_ringbuffer_cleanup(&newchannel->inbound);
+       vmbus_free_requestor(&newchannel->requestor);
+       newchannel->state = CHANNEL_OPEN_STATE;
+       return err;
+}
+
+/*
+ * vmbus_connect_ring - Open the channel but reuse ring buffer
+ */
+int vmbus_connect_ring(struct vmbus_channel *newchannel,
+                      void (*onchannelcallback)(void *context), void *context)
+{
+       return  __vmbus_open(newchannel, NULL, 0, onchannelcallback, context);
+}
+EXPORT_SYMBOL_GPL(vmbus_connect_ring);
+
+/*
+ * vmbus_open - Open the specified channel.
+ */
+int vmbus_open(struct vmbus_channel *newchannel,
+              u32 send_ringbuffer_size, u32 recv_ringbuffer_size,
+              void *userdata, u32 userdatalen,
+              void (*onchannelcallback)(void *context), void *context)
+{
+       int err;
+
+       err = vmbus_alloc_ring(newchannel, send_ringbuffer_size,
+                              recv_ringbuffer_size);
+       if (err)
+               return err;
+
+       err = __vmbus_open(newchannel, userdata, userdatalen,
+                          onchannelcallback, context);
+       if (err)
+               vmbus_free_ring(newchannel);
+
+       return err;
+}
+EXPORT_SYMBOL_GPL(vmbus_open);
+
 /*
  * vmbus_teardown_gpadl -Teardown the specified GPADL handle
  */
@@ -703,6 +879,9 @@ static int vmbus_close_internal(struct vmbus_channel *channel)
                channel->ringbuffer_gpadlhandle = 0;
        }
 
+       if (!ret)
+               vmbus_free_requestor(&channel->requestor);
+
        return ret;
 }
 
@@ -783,7 +962,7 @@ int vmbus_sendpacket(struct vmbus_channel *channel, void *buffer,
        /* in 8-bytes granularity */
        desc.offset8 = sizeof(struct vmpacket_descriptor) >> 3;
        desc.len8 = (u16)(packetlen_aligned >> 3);
-       desc.trans_id = requestid;
+       desc.trans_id = VMBUS_RQST_ERROR; /* will be updated in hv_ringbuffer_write() */
 
        bufferlist[0].iov_base = &desc;
        bufferlist[0].iov_len = sizeof(struct vmpacket_descriptor);
@@ -792,7 +971,7 @@ int vmbus_sendpacket(struct vmbus_channel *channel, void *buffer,
        bufferlist[2].iov_base = &aligned_data;
        bufferlist[2].iov_len = (packetlen_aligned - packetlen);
 
-       return hv_ringbuffer_write(channel, bufferlist, num_vecs);
+       return hv_ringbuffer_write(channel, bufferlist, num_vecs, requestid);
 }
 EXPORT_SYMBOL(vmbus_sendpacket);
 
@@ -834,7 +1013,7 @@ int vmbus_sendpacket_pagebuffer(struct vmbus_channel *channel,
        desc.flags = VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED;
        desc.dataoffset8 = descsize >> 3; /* in 8-bytes granularity */
        desc.length8 = (u16)(packetlen_aligned >> 3);
-       desc.transactionid = requestid;
+       desc.transactionid = VMBUS_RQST_ERROR; /* will be updated in hv_ringbuffer_write() */
        desc.reserved = 0;
        desc.rangecount = pagecount;
 
@@ -851,7 +1030,7 @@ int vmbus_sendpacket_pagebuffer(struct vmbus_channel *channel,
        bufferlist[2].iov_base = &aligned_data;
        bufferlist[2].iov_len = (packetlen_aligned - packetlen);
 
-       return hv_ringbuffer_write(channel, bufferlist, 3);
+       return hv_ringbuffer_write(channel, bufferlist, 3, requestid);
 }
 EXPORT_SYMBOL_GPL(vmbus_sendpacket_pagebuffer);
 
@@ -878,7 +1057,7 @@ int vmbus_sendpacket_mpb_desc(struct vmbus_channel *channel,
        desc->flags = VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED;
        desc->dataoffset8 = desc_size >> 3; /* in 8-bytes granularity */
        desc->length8 = (u16)(packetlen_aligned >> 3);
-       desc->transactionid = requestid;
+       desc->transactionid = VMBUS_RQST_ERROR; /* will be updated in hv_ringbuffer_write() */
        desc->reserved = 0;
        desc->rangecount = 1;
 
@@ -889,7 +1068,7 @@ int vmbus_sendpacket_mpb_desc(struct vmbus_channel *channel,
        bufferlist[2].iov_base = &aligned_data;
        bufferlist[2].iov_len = (packetlen_aligned - packetlen);
 
-       return hv_ringbuffer_write(channel, bufferlist, 3);
+       return hv_ringbuffer_write(channel, bufferlist, 3, requestid);
 }
 EXPORT_SYMBOL_GPL(vmbus_sendpacket_mpb_desc);
 
@@ -937,3 +1116,91 @@ int vmbus_recvpacket_raw(struct vmbus_channel *channel, void *buffer,
                                  buffer_actual_len, requestid, true);
 }
 EXPORT_SYMBOL_GPL(vmbus_recvpacket_raw);
+
+/*
+ * vmbus_next_request_id - Returns a new request id. It is also
+ * the index at which the guest memory address is stored.
+ * Uses a spin lock to avoid race conditions.
+ * @rqstor: Pointer to the requestor struct
+ * @rqst_add: Guest memory address to be stored in the array
+ */
+u64 vmbus_next_request_id(struct vmbus_requestor *rqstor, u64 rqst_addr)
+{
+       unsigned long flags;
+       u64 current_id;
+       const struct vmbus_channel *channel =
+               container_of(rqstor, const struct vmbus_channel, requestor);
+
+       /* Check rqstor has been initialized */
+       if (!channel->rqstor_size)
+               return VMBUS_NO_RQSTOR;
+
+       spin_lock_irqsave(&rqstor->req_lock, flags);
+       current_id = rqstor->next_request_id;
+
+       /* Requestor array is full */
+       if (current_id >= rqstor->size) {
+               spin_unlock_irqrestore(&rqstor->req_lock, flags);
+               return VMBUS_RQST_ERROR;
+       }
+
+       rqstor->next_request_id = rqstor->req_arr[current_id];
+       rqstor->req_arr[current_id] = rqst_addr;
+
+       /* The already held spin lock provides atomicity */
+       bitmap_set(rqstor->req_bitmap, current_id, 1);
+
+       spin_unlock_irqrestore(&rqstor->req_lock, flags);
+
+       /*
+        * Cannot return an ID of 0, which is reserved for an unsolicited
+        * message from Hyper-V.
+        */
+       return current_id + 1;
+}
+EXPORT_SYMBOL_GPL(vmbus_next_request_id);
+
+/*
+ * vmbus_request_addr - Returns the memory address stored at @trans_id
+ * in @rqstor. Uses a spin lock to avoid race conditions.
+ * @rqstor: Pointer to the requestor struct
+ * @trans_id: Request id sent back from Hyper-V. Becomes the requestor's
+ * next request id.
+ */
+u64 vmbus_request_addr(struct vmbus_requestor *rqstor, u64 trans_id)
+{
+       unsigned long flags;
+       u64 req_addr;
+       const struct vmbus_channel *channel =
+               container_of(rqstor, const struct vmbus_channel, requestor);
+
+       /* Check rqstor has been initialized */
+       if (!channel->rqstor_size)
+               return VMBUS_NO_RQSTOR;
+
+       /* Hyper-V can send an unsolicited message with ID of 0 */
+       if (!trans_id)
+               return trans_id;
+
+       spin_lock_irqsave(&rqstor->req_lock, flags);
+
+       /* Data corresponding to trans_id is stored at trans_id - 1 */
+       trans_id--;
+
+       /* Invalid trans_id */
+       if (trans_id >= rqstor->size || !test_bit(trans_id, rqstor->req_bitmap)) {
+               spin_unlock_irqrestore(&rqstor->req_lock, flags);
+               return VMBUS_RQST_ERROR;
+       }
+
+       req_addr = rqstor->req_arr[trans_id];
+       rqstor->req_arr[trans_id] = rqstor->next_request_id;
+       rqstor->next_request_id = trans_id;
+
+       /* The already held spin lock provides atomicity */
+       bitmap_clear(rqstor->req_bitmap, trans_id, 1);
+
+       spin_unlock_irqrestore(&rqstor->req_lock, flags);
+       return req_addr;
+}
+EXPORT_SYMBOL_GPL(vmbus_request_addr);