vhost-vdpa: fix page pinning leakage in error path (rework)
[linux-2.6-microblaze.git] / drivers / vhost / vdpa.c
index 2754f30..f2db990 100644 (file)
@@ -577,6 +577,8 @@ static int vhost_vdpa_map(struct vhost_vdpa *v,
 
        if (r)
                vhost_iotlb_del_range(dev->iotlb, iova, iova + size - 1);
+       else
+               atomic64_add(size >> PAGE_SHIFT, &dev->mm->pinned_vm);
 
        return r;
 }
@@ -608,8 +610,9 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
        unsigned long list_size = PAGE_SIZE / sizeof(struct page *);
        unsigned int gup_flags = FOLL_LONGTERM;
        unsigned long npages, cur_base, map_pfn, last_pfn = 0;
-       unsigned long locked, lock_limit, pinned, i;
+       unsigned long lock_limit, sz2pin, nchunks, i;
        u64 iova = msg->iova;
+       long pinned;
        int ret = 0;
 
        if (msg->iova < v->range.first ||
@@ -620,6 +623,7 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
                                    msg->iova + msg->size - 1))
                return -EEXIST;
 
+       /* Limit the use of memory for bookkeeping */
        page_list = (struct page **) __get_free_page(GFP_KERNEL);
        if (!page_list)
                return -ENOMEM;
@@ -628,52 +632,75 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
                gup_flags |= FOLL_WRITE;
 
        npages = PAGE_ALIGN(msg->size + (iova & ~PAGE_MASK)) >> PAGE_SHIFT;
-       if (!npages)
-               return -EINVAL;
+       if (!npages) {
+               ret = -EINVAL;
+               goto free;
+       }
 
        mmap_read_lock(dev->mm);
 
-       locked = atomic64_add_return(npages, &dev->mm->pinned_vm);
        lock_limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
-
-       if (locked > lock_limit) {
+       if (npages + atomic64_read(&dev->mm->pinned_vm) > lock_limit) {
                ret = -ENOMEM;
-               goto out;
+               goto unlock;
        }
 
        cur_base = msg->uaddr & PAGE_MASK;
        iova &= PAGE_MASK;
+       nchunks = 0;
 
        while (npages) {
-               pinned = min_t(unsigned long, npages, list_size);
-               ret = pin_user_pages(cur_base, pinned,
-                                    gup_flags, page_list, NULL);
-               if (ret != pinned)
+               sz2pin = min_t(unsigned long, npages, list_size);
+               pinned = pin_user_pages(cur_base, sz2pin,
+                                       gup_flags, page_list, NULL);
+               if (sz2pin != pinned) {
+                       if (pinned < 0) {
+                               ret = pinned;
+                       } else {
+                               unpin_user_pages(page_list, pinned);
+                               ret = -ENOMEM;
+                       }
                        goto out;
+               }
+               nchunks++;
 
                if (!last_pfn)
                        map_pfn = page_to_pfn(page_list[0]);
 
-               for (i = 0; i < ret; i++) {
+               for (i = 0; i < pinned; i++) {
                        unsigned long this_pfn = page_to_pfn(page_list[i]);
                        u64 csize;
 
                        if (last_pfn && (this_pfn != last_pfn + 1)) {
                                /* Pin a contiguous chunk of memory */
                                csize = (last_pfn - map_pfn + 1) << PAGE_SHIFT;
-                               if (vhost_vdpa_map(v, iova, csize,
-                                                  map_pfn << PAGE_SHIFT,
-                                                  msg->perm))
+                               ret = vhost_vdpa_map(v, iova, csize,
+                                                    map_pfn << PAGE_SHIFT,
+                                                    msg->perm);
+                               if (ret) {
+                                       /*
+                                        * Unpin the pages that are left unmapped
+                                        * from this point on in the current
+                                        * page_list. The remaining outstanding
+                                        * ones which may stride across several
+                                        * chunks will be covered in the common
+                                        * error path subsequently.
+                                        */
+                                       unpin_user_pages(&page_list[i],
+                                                        pinned - i);
                                        goto out;
+                               }
+
                                map_pfn = this_pfn;
                                iova += csize;
+                               nchunks = 0;
                        }
 
                        last_pfn = this_pfn;
                }
 
-               cur_base += ret << PAGE_SHIFT;
-               npages -= ret;
+               cur_base += pinned << PAGE_SHIFT;
+               npages -= pinned;
        }
 
        /* Pin the rest chunk */
@@ -681,10 +708,27 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
                             map_pfn << PAGE_SHIFT, msg->perm);
 out:
        if (ret) {
+               if (nchunks) {
+                       unsigned long pfn;
+
+                       /*
+                        * Unpin the outstanding pages which are yet to be
+                        * mapped but haven't due to vdpa_map() or
+                        * pin_user_pages() failure.
+                        *
+                        * Mapped pages are accounted in vdpa_map(), hence
+                        * the corresponding unpinning will be handled by
+                        * vdpa_unmap().
+                        */
+                       WARN_ON(!last_pfn);
+                       for (pfn = map_pfn; pfn <= last_pfn; pfn++)
+                               unpin_user_page(pfn_to_page(pfn));
+               }
                vhost_vdpa_unmap(v, msg->iova, msg->size);
-               atomic64_sub(npages, &dev->mm->pinned_vm);
        }
+unlock:
        mmap_read_unlock(dev->mm);
+free:
        free_page((unsigned long)page_list);
        return ret;
 }