Merge tag 'nfsd-5.13-1' of git://git.kernel.org/pub/scm/linux/kernel/git/cel/linux
[linux-2.6-microblaze.git] / drivers / vfio / vfio.c
index 38779e6..5e631c3 100644 (file)
@@ -46,7 +46,6 @@ static struct vfio {
        struct mutex                    group_lock;
        struct cdev                     group_cdev;
        dev_t                           group_devt;
-       wait_queue_head_t               release_q;
 } vfio;
 
 struct vfio_iommu_driver {
@@ -90,15 +89,6 @@ struct vfio_group {
        struct blocking_notifier_head   notifier;
 };
 
-struct vfio_device {
-       struct kref                     kref;
-       struct device                   *dev;
-       const struct vfio_device_ops    *ops;
-       struct vfio_group               *group;
-       struct list_head                group_next;
-       void                            *device_data;
-};
-
 #ifdef CONFIG_VFIO_NOIOMMU
 static bool noiommu __read_mostly;
 module_param_named(enable_unsafe_noiommu_mode,
@@ -109,8 +99,8 @@ MODULE_PARM_DESC(enable_unsafe_noiommu_mode, "Enable UNSAFE, no-IOMMU mode.  Thi
 /*
  * vfio_iommu_group_{get,put} are only intended for VFIO bus driver probe
  * and remove functions, any use cases other than acquiring the first
- * reference for the purpose of calling vfio_add_group_dev() or removing
- * that symmetric reference after vfio_del_group_dev() should use the raw
+ * reference for the purpose of calling vfio_register_group_dev() or removing
+ * that symmetric reference after vfio_unregister_group_dev() should use the raw
  * iommu_group_{get,put} functions.  In particular, vfio_iommu_group_put()
  * removes the device from the dummy group and cannot be nested.
  */
@@ -532,67 +522,17 @@ static struct vfio_group *vfio_group_get_from_dev(struct device *dev)
 /**
  * Device objects - create, release, get, put, search
  */
-static
-struct vfio_device *vfio_group_create_device(struct vfio_group *group,
-                                            struct device *dev,
-                                            const struct vfio_device_ops *ops,
-                                            void *device_data)
-{
-       struct vfio_device *device;
-
-       device = kzalloc(sizeof(*device), GFP_KERNEL);
-       if (!device)
-               return ERR_PTR(-ENOMEM);
-
-       kref_init(&device->kref);
-       device->dev = dev;
-       device->group = group;
-       device->ops = ops;
-       device->device_data = device_data;
-       dev_set_drvdata(dev, device);
-
-       /* No need to get group_lock, caller has group reference */
-       vfio_group_get(group);
-
-       mutex_lock(&group->device_lock);
-       list_add(&device->group_next, &group->device_list);
-       group->dev_counter++;
-       mutex_unlock(&group->device_lock);
-
-       return device;
-}
-
-static void vfio_device_release(struct kref *kref)
-{
-       struct vfio_device *device = container_of(kref,
-                                                 struct vfio_device, kref);
-       struct vfio_group *group = device->group;
-
-       list_del(&device->group_next);
-       group->dev_counter--;
-       mutex_unlock(&group->device_lock);
-
-       dev_set_drvdata(device->dev, NULL);
-
-       kfree(device);
-
-       /* vfio_del_group_dev may be waiting for this device */
-       wake_up(&vfio.release_q);
-}
-
 /* Device reference always implies a group reference */
 void vfio_device_put(struct vfio_device *device)
 {
-       struct vfio_group *group = device->group;
-       kref_put_mutex(&device->kref, vfio_device_release, &group->device_lock);
-       vfio_group_put(group);
+       if (refcount_dec_and_test(&device->refcount))
+               complete(&device->comp);
 }
 EXPORT_SYMBOL_GPL(vfio_device_put);
 
-static void vfio_device_get(struct vfio_device *device)
+static bool vfio_device_try_get(struct vfio_device *device)
 {
-       vfio_group_get(device->group);
-       kref_get(&device->kref);
+       return refcount_inc_not_zero(&device->refcount);
 }
 
 static struct vfio_device *vfio_group_get_device(struct vfio_group *group,
@@ -602,8 +542,7 @@ static struct vfio_device *vfio_group_get_device(struct vfio_group *group,
 
        mutex_lock(&group->device_lock);
        list_for_each_entry(device, &group->device_list, group_next) {
-               if (device->dev == dev) {
-                       vfio_device_get(device);
+               if (device->dev == dev && vfio_device_try_get(device)) {
                        mutex_unlock(&group->device_lock);
                        return device;
                }
@@ -801,14 +740,22 @@ static int vfio_iommu_group_notifier(struct notifier_block *nb,
 /**
  * VFIO driver API
  */
-int vfio_add_group_dev(struct device *dev,
-                      const struct vfio_device_ops *ops, void *device_data)
+void vfio_init_group_dev(struct vfio_device *device, struct device *dev,
+                        const struct vfio_device_ops *ops)
+{
+       init_completion(&device->comp);
+       device->dev = dev;
+       device->ops = ops;
+}
+EXPORT_SYMBOL_GPL(vfio_init_group_dev);
+
+int vfio_register_group_dev(struct vfio_device *device)
 {
+       struct vfio_device *existing_device;
        struct iommu_group *iommu_group;
        struct vfio_group *group;
-       struct vfio_device *device;
 
-       iommu_group = iommu_group_get(dev);
+       iommu_group = iommu_group_get(device->dev);
        if (!iommu_group)
                return -EINVAL;
 
@@ -827,31 +774,29 @@ int vfio_add_group_dev(struct device *dev,
                iommu_group_put(iommu_group);
        }
 
-       device = vfio_group_get_device(group, dev);
-       if (device) {
-               dev_WARN(dev, "Device already exists on group %d\n",
+       existing_device = vfio_group_get_device(group, device->dev);
+       if (existing_device) {
+               dev_WARN(device->dev, "Device already exists on group %d\n",
                         iommu_group_id(iommu_group));
-               vfio_device_put(device);
+               vfio_device_put(existing_device);
                vfio_group_put(group);
                return -EBUSY;
        }
 
-       device = vfio_group_create_device(group, dev, ops, device_data);
-       if (IS_ERR(device)) {
-               vfio_group_put(group);
-               return PTR_ERR(device);
-       }
+       /* Our reference on group is moved to the device */
+       device->group = group;
 
-       /*
-        * Drop all but the vfio_device reference.  The vfio_device holds
-        * a reference to the vfio_group, which holds a reference to the
-        * iommu_group.
-        */
-       vfio_group_put(group);
+       /* Refcounting can't start until the driver calls register */
+       refcount_set(&device->refcount, 1);
+
+       mutex_lock(&group->device_lock);
+       list_add(&device->group_next, &group->device_list);
+       group->dev_counter++;
+       mutex_unlock(&group->device_lock);
 
        return 0;
 }
-EXPORT_SYMBOL_GPL(vfio_add_group_dev);
+EXPORT_SYMBOL_GPL(vfio_register_group_dev);
 
 /**
  * Get a reference to the vfio_device for a device.  Even if the
@@ -886,7 +831,7 @@ static struct vfio_device *vfio_device_get_from_name(struct vfio_group *group,
                int ret;
 
                if (it->ops->match) {
-                       ret = it->ops->match(it->device_data, buf);
+                       ret = it->ops->match(it, buf);
                        if (ret < 0) {
                                device = ERR_PTR(ret);
                                break;
@@ -895,9 +840,8 @@ static struct vfio_device *vfio_device_get_from_name(struct vfio_group *group,
                        ret = !strcmp(dev_name(it->dev), buf);
                }
 
-               if (ret) {
+               if (ret && vfio_device_try_get(it)) {
                        device = it;
-                       vfio_device_get(device);
                        break;
                }
        }
@@ -906,33 +850,16 @@ static struct vfio_device *vfio_device_get_from_name(struct vfio_group *group,
        return device;
 }
 
-/*
- * Caller must hold a reference to the vfio_device
- */
-void *vfio_device_data(struct vfio_device *device)
-{
-       return device->device_data;
-}
-EXPORT_SYMBOL_GPL(vfio_device_data);
-
 /*
  * Decrement the device reference count and wait for the device to be
  * removed.  Open file descriptors for the device... */
-void *vfio_del_group_dev(struct device *dev)
+void vfio_unregister_group_dev(struct vfio_device *device)
 {
-       DEFINE_WAIT_FUNC(wait, woken_wake_function);
-       struct vfio_device *device = dev_get_drvdata(dev);
        struct vfio_group *group = device->group;
-       void *device_data = device->device_data;
        struct vfio_unbound_dev *unbound;
        unsigned int i = 0;
        bool interrupted = false;
-
-       /*
-        * The group exists so long as we have a device reference.  Get
-        * a group reference and use it to scan for the device going away.
-        */
-       vfio_group_get(group);
+       long rc;
 
        /*
         * When the device is removed from the group, the group suddenly
@@ -945,7 +872,7 @@ void *vfio_del_group_dev(struct device *dev)
         */
        unbound = kzalloc(sizeof(*unbound), GFP_KERNEL);
        if (unbound) {
-               unbound->dev = dev;
+               unbound->dev = device->dev;
                mutex_lock(&group->unbound_lock);
                list_add(&unbound->unbound_next, &group->unbound_list);
                mutex_unlock(&group->unbound_lock);
@@ -953,44 +880,33 @@ void *vfio_del_group_dev(struct device *dev)
        WARN_ON(!unbound);
 
        vfio_device_put(device);
-
-       /*
-        * If the device is still present in the group after the above
-        * 'put', then it is in use and we need to request it from the
-        * bus driver.  The driver may in turn need to request the
-        * device from the user.  We send the request on an arbitrary
-        * interval with counter to allow the driver to take escalating
-        * measures to release the device if it has the ability to do so.
-        */
-       add_wait_queue(&vfio.release_q, &wait);
-
-       do {
-               device = vfio_group_get_device(group, dev);
-               if (!device)
-                       break;
-
+       rc = try_wait_for_completion(&device->comp);
+       while (rc <= 0) {
                if (device->ops->request)
-                       device->ops->request(device_data, i++);
-
-               vfio_device_put(device);
+                       device->ops->request(device, i++);
 
                if (interrupted) {
-                       wait_woken(&wait, TASK_UNINTERRUPTIBLE, HZ * 10);
+                       rc = wait_for_completion_timeout(&device->comp,
+                                                        HZ * 10);
                } else {
-                       wait_woken(&wait, TASK_INTERRUPTIBLE, HZ * 10);
-                       if (signal_pending(current)) {
+                       rc = wait_for_completion_interruptible_timeout(
+                               &device->comp, HZ * 10);
+                       if (rc < 0) {
                                interrupted = true;
-                               dev_warn(dev,
+                               dev_warn(device->dev,
                                         "Device is currently in use, task"
                                         " \"%s\" (%d) "
                                         "blocked until device is released",
                                         current->comm, task_pid_nr(current));
                        }
                }
+       }
 
-       } while (1);
+       mutex_lock(&group->device_lock);
+       list_del(&device->group_next);
+       group->dev_counter--;
+       mutex_unlock(&group->device_lock);
 
-       remove_wait_queue(&vfio.release_q, &wait);
        /*
         * In order to support multiple devices per group, devices can be
         * plucked from the group while other devices in the group are still
@@ -1008,11 +924,10 @@ void *vfio_del_group_dev(struct device *dev)
        if (list_empty(&group->device_list))
                wait_event(group->container_q, !group->container);
 
+       /* Matches the get in vfio_register_group_dev() */
        vfio_group_put(group);
-
-       return device_data;
 }
-EXPORT_SYMBOL_GPL(vfio_del_group_dev);
+EXPORT_SYMBOL_GPL(vfio_unregister_group_dev);
 
 /**
  * VFIO base fd, /dev/vfio/vfio
@@ -1454,7 +1369,7 @@ static int vfio_group_get_device_fd(struct vfio_group *group, char *buf)
        if (IS_ERR(device))
                return PTR_ERR(device);
 
-       ret = device->ops->open(device->device_data);
+       ret = device->ops->open(device);
        if (ret) {
                vfio_device_put(device);
                return ret;
@@ -1466,7 +1381,7 @@ static int vfio_group_get_device_fd(struct vfio_group *group, char *buf)
         */
        ret = get_unused_fd_flags(O_CLOEXEC);
        if (ret < 0) {
-               device->ops->release(device->device_data);
+               device->ops->release(device);
                vfio_device_put(device);
                return ret;
        }
@@ -1476,7 +1391,7 @@ static int vfio_group_get_device_fd(struct vfio_group *group, char *buf)
        if (IS_ERR(filep)) {
                put_unused_fd(ret);
                ret = PTR_ERR(filep);
-               device->ops->release(device->device_data);
+               device->ops->release(device);
                vfio_device_put(device);
                return ret;
        }
@@ -1633,7 +1548,7 @@ static int vfio_device_fops_release(struct inode *inode, struct file *filep)
 {
        struct vfio_device *device = filep->private_data;
 
-       device->ops->release(device->device_data);
+       device->ops->release(device);
 
        vfio_group_try_dissolve_container(device->group);
 
@@ -1650,7 +1565,7 @@ static long vfio_device_fops_unl_ioctl(struct file *filep,
        if (unlikely(!device->ops->ioctl))
                return -EINVAL;
 
-       return device->ops->ioctl(device->device_data, cmd, arg);
+       return device->ops->ioctl(device, cmd, arg);
 }
 
 static ssize_t vfio_device_fops_read(struct file *filep, char __user *buf,
@@ -1661,7 +1576,7 @@ static ssize_t vfio_device_fops_read(struct file *filep, char __user *buf,
        if (unlikely(!device->ops->read))
                return -EINVAL;
 
-       return device->ops->read(device->device_data, buf, count, ppos);
+       return device->ops->read(device, buf, count, ppos);
 }
 
 static ssize_t vfio_device_fops_write(struct file *filep,
@@ -1673,7 +1588,7 @@ static ssize_t vfio_device_fops_write(struct file *filep,
        if (unlikely(!device->ops->write))
                return -EINVAL;
 
-       return device->ops->write(device->device_data, buf, count, ppos);
+       return device->ops->write(device, buf, count, ppos);
 }
 
 static int vfio_device_fops_mmap(struct file *filep, struct vm_area_struct *vma)
@@ -1683,7 +1598,7 @@ static int vfio_device_fops_mmap(struct file *filep, struct vm_area_struct *vma)
        if (unlikely(!device->ops->mmap))
                return -EINVAL;
 
-       return device->ops->mmap(device->device_data, vma);
+       return device->ops->mmap(device, vma);
 }
 
 static const struct file_operations vfio_device_fops = {
@@ -2379,7 +2294,6 @@ static int __init vfio_init(void)
        mutex_init(&vfio.iommu_drivers_lock);
        INIT_LIST_HEAD(&vfio.group_list);
        INIT_LIST_HEAD(&vfio.iommu_drivers_list);
-       init_waitqueue_head(&vfio.release_q);
 
        ret = misc_register(&vfio_dev);
        if (ret) {