Merge tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost
[linux-2.6-microblaze.git] / drivers / vhost / vdpa.c
index 9479f7f..f41d081 100644 (file)
@@ -116,12 +116,13 @@ static void vhost_vdpa_unsetup_vq_irq(struct vhost_vdpa *v, u16 qid)
        irq_bypass_unregister_producer(&vq->call_ctx.producer);
 }
 
-static void vhost_vdpa_reset(struct vhost_vdpa *v)
+static int vhost_vdpa_reset(struct vhost_vdpa *v)
 {
        struct vdpa_device *vdpa = v->vdpa;
 
-       vdpa_reset(vdpa);
        v->in_batch = 0;
+
+       return vdpa_reset(vdpa);
 }
 
 static long vhost_vdpa_get_device_id(struct vhost_vdpa *v, u8 __user *argp)
@@ -157,7 +158,7 @@ static long vhost_vdpa_set_status(struct vhost_vdpa *v, u8 __user *statusp)
        struct vdpa_device *vdpa = v->vdpa;
        const struct vdpa_config_ops *ops = vdpa->config;
        u8 status, status_old;
-       int nvqs = v->nvqs;
+       int ret, nvqs = v->nvqs;
        u16 i;
 
        if (copy_from_user(&status, statusp, sizeof(status)))
@@ -172,7 +173,12 @@ static long vhost_vdpa_set_status(struct vhost_vdpa *v, u8 __user *statusp)
        if (status != 0 && (ops->get_status(vdpa) & ~status) != 0)
                return -EINVAL;
 
-       ops->set_status(vdpa, status);
+       if (status == 0) {
+               ret = ops->reset(vdpa);
+               if (ret)
+                       return ret;
+       } else
+               ops->set_status(vdpa, status);
 
        if ((status & VIRTIO_CONFIG_S_DRIVER_OK) && !(status_old & VIRTIO_CONFIG_S_DRIVER_OK))
                for (i = 0; i < nvqs; i++)
@@ -498,7 +504,7 @@ static long vhost_vdpa_unlocked_ioctl(struct file *filep,
        return r;
 }
 
-static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v, u64 start, u64 last)
+static void vhost_vdpa_pa_unmap(struct vhost_vdpa *v, u64 start, u64 last)
 {
        struct vhost_dev *dev = &v->vdev;
        struct vhost_iotlb *iotlb = dev->iotlb;
@@ -507,19 +513,44 @@ static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v, u64 start, u64 last)
        unsigned long pfn, pinned;
 
        while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
-               pinned = map->size >> PAGE_SHIFT;
-               for (pfn = map->addr >> PAGE_SHIFT;
+               pinned = PFN_DOWN(map->size);
+               for (pfn = PFN_DOWN(map->addr);
                     pinned > 0; pfn++, pinned--) {
                        page = pfn_to_page(pfn);
                        if (map->perm & VHOST_ACCESS_WO)
                                set_page_dirty_lock(page);
                        unpin_user_page(page);
                }
-               atomic64_sub(map->size >> PAGE_SHIFT, &dev->mm->pinned_vm);
+               atomic64_sub(PFN_DOWN(map->size), &dev->mm->pinned_vm);
                vhost_iotlb_map_free(iotlb, map);
        }
 }
 
+static void vhost_vdpa_va_unmap(struct vhost_vdpa *v, u64 start, u64 last)
+{
+       struct vhost_dev *dev = &v->vdev;
+       struct vhost_iotlb *iotlb = dev->iotlb;
+       struct vhost_iotlb_map *map;
+       struct vdpa_map_file *map_file;
+
+       while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
+               map_file = (struct vdpa_map_file *)map->opaque;
+               fput(map_file->file);
+               kfree(map_file);
+               vhost_iotlb_map_free(iotlb, map);
+       }
+}
+
+static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v, u64 start, u64 last)
+{
+       struct vdpa_device *vdpa = v->vdpa;
+
+       if (vdpa->use_va)
+               return vhost_vdpa_va_unmap(v, start, last);
+
+       return vhost_vdpa_pa_unmap(v, start, last);
+}
+
 static void vhost_vdpa_iotlb_free(struct vhost_vdpa *v)
 {
        struct vhost_dev *dev = &v->vdev;
@@ -551,21 +582,21 @@ static int perm_to_iommu_flags(u32 perm)
        return flags | IOMMU_CACHE;
 }
 
-static int vhost_vdpa_map(struct vhost_vdpa *v,
-                         u64 iova, u64 size, u64 pa, u32 perm)
+static int vhost_vdpa_map(struct vhost_vdpa *v, u64 iova,
+                         u64 size, u64 pa, u32 perm, void *opaque)
 {
        struct vhost_dev *dev = &v->vdev;
        struct vdpa_device *vdpa = v->vdpa;
        const struct vdpa_config_ops *ops = vdpa->config;
        int r = 0;
 
-       r = vhost_iotlb_add_range(dev->iotlb, iova, iova + size - 1,
-                                 pa, perm);
+       r = vhost_iotlb_add_range_ctx(dev->iotlb, iova, iova + size - 1,
+                                     pa, perm, opaque);
        if (r)
                return r;
 
        if (ops->dma_map) {
-               r = ops->dma_map(vdpa, iova, size, pa, perm);
+               r = ops->dma_map(vdpa, iova, size, pa, perm, opaque);
        } else if (ops->set_map) {
                if (!v->in_batch)
                        r = ops->set_map(vdpa, dev->iotlb);
@@ -573,13 +604,15 @@ static int vhost_vdpa_map(struct vhost_vdpa *v,
                r = iommu_map(v->domain, iova, pa, size,
                              perm_to_iommu_flags(perm));
        }
-
-       if (r)
+       if (r) {
                vhost_iotlb_del_range(dev->iotlb, iova, iova + size - 1);
-       else
-               atomic64_add(size >> PAGE_SHIFT, &dev->mm->pinned_vm);
+               return r;
+       }
 
-       return r;
+       if (!vdpa->use_va)
+               atomic64_add(PFN_DOWN(size), &dev->mm->pinned_vm);
+
+       return 0;
 }
 
 static void vhost_vdpa_unmap(struct vhost_vdpa *v, u64 iova, u64 size)
@@ -600,38 +633,78 @@ static void vhost_vdpa_unmap(struct vhost_vdpa *v, u64 iova, u64 size)
        }
 }
 
-static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
-                                          struct vhost_iotlb_msg *msg)
+static int vhost_vdpa_va_map(struct vhost_vdpa *v,
+                            u64 iova, u64 size, u64 uaddr, u32 perm)
+{
+       struct vhost_dev *dev = &v->vdev;
+       u64 offset, map_size, map_iova = iova;
+       struct vdpa_map_file *map_file;
+       struct vm_area_struct *vma;
+       int ret;
+
+       mmap_read_lock(dev->mm);
+
+       while (size) {
+               vma = find_vma(dev->mm, uaddr);
+               if (!vma) {
+                       ret = -EINVAL;
+                       break;
+               }
+               map_size = min(size, vma->vm_end - uaddr);
+               if (!(vma->vm_file && (vma->vm_flags & VM_SHARED) &&
+                       !(vma->vm_flags & (VM_IO | VM_PFNMAP))))
+                       goto next;
+
+               map_file = kzalloc(sizeof(*map_file), GFP_KERNEL);
+               if (!map_file) {
+                       ret = -ENOMEM;
+                       break;
+               }
+               offset = (vma->vm_pgoff << PAGE_SHIFT) + uaddr - vma->vm_start;
+               map_file->offset = offset;
+               map_file->file = get_file(vma->vm_file);
+               ret = vhost_vdpa_map(v, map_iova, map_size, uaddr,
+                                    perm, map_file);
+               if (ret) {
+                       fput(map_file->file);
+                       kfree(map_file);
+                       break;
+               }
+next:
+               size -= map_size;
+               uaddr += map_size;
+               map_iova += map_size;
+       }
+       if (ret)
+               vhost_vdpa_unmap(v, iova, map_iova - iova);
+
+       mmap_read_unlock(dev->mm);
+
+       return ret;
+}
+
+static int vhost_vdpa_pa_map(struct vhost_vdpa *v,
+                            u64 iova, u64 size, u64 uaddr, u32 perm)
 {
        struct vhost_dev *dev = &v->vdev;
-       struct vhost_iotlb *iotlb = dev->iotlb;
        struct page **page_list;
        unsigned long list_size = PAGE_SIZE / sizeof(struct page *);
        unsigned int gup_flags = FOLL_LONGTERM;
        unsigned long npages, cur_base, map_pfn, last_pfn = 0;
        unsigned long lock_limit, sz2pin, nchunks, i;
-       u64 iova = msg->iova;
+       u64 start = iova;
        long pinned;
        int ret = 0;
 
-       if (msg->iova < v->range.first || !msg->size ||
-           msg->iova > U64_MAX - msg->size + 1 ||
-           msg->iova + msg->size - 1 > v->range.last)
-               return -EINVAL;
-
-       if (vhost_iotlb_itree_first(iotlb, msg->iova,
-                                   msg->iova + msg->size - 1))
-               return -EEXIST;
-
        /* Limit the use of memory for bookkeeping */
        page_list = (struct page **) __get_free_page(GFP_KERNEL);
        if (!page_list)
                return -ENOMEM;
 
-       if (msg->perm & VHOST_ACCESS_WO)
+       if (perm & VHOST_ACCESS_WO)
                gup_flags |= FOLL_WRITE;
 
-       npages = PAGE_ALIGN(msg->size + (iova & ~PAGE_MASK)) >> PAGE_SHIFT;
+       npages = PFN_UP(size + (iova & ~PAGE_MASK));
        if (!npages) {
                ret = -EINVAL;
                goto free;
@@ -639,13 +712,13 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
 
        mmap_read_lock(dev->mm);
 
-       lock_limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
+       lock_limit = PFN_DOWN(rlimit(RLIMIT_MEMLOCK));
        if (npages + atomic64_read(&dev->mm->pinned_vm) > lock_limit) {
                ret = -ENOMEM;
                goto unlock;
        }
 
-       cur_base = msg->uaddr & PAGE_MASK;
+       cur_base = uaddr & PAGE_MASK;
        iova &= PAGE_MASK;
        nchunks = 0;
 
@@ -673,10 +746,10 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
 
                        if (last_pfn && (this_pfn != last_pfn + 1)) {
                                /* Pin a contiguous chunk of memory */
-                               csize = (last_pfn - map_pfn + 1) << PAGE_SHIFT;
+                               csize = PFN_PHYS(last_pfn - map_pfn + 1);
                                ret = vhost_vdpa_map(v, iova, csize,
-                                                    map_pfn << PAGE_SHIFT,
-                                                    msg->perm);
+                                                    PFN_PHYS(map_pfn),
+                                                    perm, NULL);
                                if (ret) {
                                        /*
                                         * Unpin the pages that are left unmapped
@@ -699,13 +772,13 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
                        last_pfn = this_pfn;
                }
 
-               cur_base += pinned << PAGE_SHIFT;
+               cur_base += PFN_PHYS(pinned);
                npages -= pinned;
        }
 
        /* Pin the rest chunk */
-       ret = vhost_vdpa_map(v, iova, (last_pfn - map_pfn + 1) << PAGE_SHIFT,
-                            map_pfn << PAGE_SHIFT, msg->perm);
+       ret = vhost_vdpa_map(v, iova, PFN_PHYS(last_pfn - map_pfn + 1),
+                            PFN_PHYS(map_pfn), perm, NULL);
 out:
        if (ret) {
                if (nchunks) {
@@ -724,13 +797,38 @@ out:
                        for (pfn = map_pfn; pfn <= last_pfn; pfn++)
                                unpin_user_page(pfn_to_page(pfn));
                }
-               vhost_vdpa_unmap(v, msg->iova, msg->size);
+               vhost_vdpa_unmap(v, start, size);
        }
 unlock:
        mmap_read_unlock(dev->mm);
 free:
        free_page((unsigned long)page_list);
        return ret;
+
+}
+
+static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
+                                          struct vhost_iotlb_msg *msg)
+{
+       struct vhost_dev *dev = &v->vdev;
+       struct vdpa_device *vdpa = v->vdpa;
+       struct vhost_iotlb *iotlb = dev->iotlb;
+
+       if (msg->iova < v->range.first || !msg->size ||
+           msg->iova > U64_MAX - msg->size + 1 ||
+           msg->iova + msg->size - 1 > v->range.last)
+               return -EINVAL;
+
+       if (vhost_iotlb_itree_first(iotlb, msg->iova,
+                                   msg->iova + msg->size - 1))
+               return -EEXIST;
+
+       if (vdpa->use_va)
+               return vhost_vdpa_va_map(v, msg->iova, msg->size,
+                                        msg->uaddr, msg->perm);
+
+       return vhost_vdpa_pa_map(v, msg->iova, msg->size, msg->uaddr,
+                                msg->perm);
 }
 
 static int vhost_vdpa_process_iotlb_msg(struct vhost_dev *dev,
@@ -860,7 +958,9 @@ static int vhost_vdpa_open(struct inode *inode, struct file *filep)
                return -EBUSY;
 
        nvqs = v->nvqs;
-       vhost_vdpa_reset(v);
+       r = vhost_vdpa_reset(v);
+       if (r)
+               goto err;
 
        vqs = kmalloc_array(nvqs, sizeof(*vqs), GFP_KERNEL);
        if (!vqs) {
@@ -945,7 +1045,7 @@ static vm_fault_t vhost_vdpa_fault(struct vm_fault *vmf)
 
        vma->vm_page_prot = pgprot_noncached(vma->vm_page_prot);
        if (remap_pfn_range(vma, vmf->address & PAGE_MASK,
-                           notify.addr >> PAGE_SHIFT, PAGE_SIZE,
+                           PFN_DOWN(notify.addr), PAGE_SIZE,
                            vma->vm_page_prot))
                return VM_FAULT_SIGBUS;