Merge tag 'for-5.15/drivers-2021-08-30' of git://git.kernel.dk/linux-block
[linux-2.6-microblaze.git] / drivers / block / nbd.c
index 19f5d5a..5170a63 100644 (file)
@@ -49,6 +49,7 @@
 
 static DEFINE_IDR(nbd_index_idr);
 static DEFINE_MUTEX(nbd_index_mutex);
+static struct workqueue_struct *nbd_del_wq;
 static int nbd_total_devices = 0;
 
 struct nbd_sock {
@@ -113,12 +114,12 @@ struct nbd_device {
        struct mutex config_lock;
        struct gendisk *disk;
        struct workqueue_struct *recv_workq;
+       struct work_struct remove_work;
 
        struct list_head list;
        struct task_struct *task_recv;
        struct task_struct *task_setup;
 
-       struct completion *destroy_complete;
        unsigned long flags;
 
        char *backend;
@@ -237,32 +238,36 @@ static void nbd_dev_remove(struct nbd_device *nbd)
 {
        struct gendisk *disk = nbd->disk;
 
-       if (disk) {
-               del_gendisk(disk);
-               blk_cleanup_disk(disk);
-               blk_mq_free_tag_set(&nbd->tag_set);
-       }
+       del_gendisk(disk);
+       blk_cleanup_disk(disk);
+       blk_mq_free_tag_set(&nbd->tag_set);
 
        /*
-        * Place this in the last just before the nbd is freed to
-        * make sure that the disk and the related kobject are also
-        * totally removed to avoid duplicate creation of the same
-        * one.
+        * Remove from idr after del_gendisk() completes, so if the same ID is
+        * reused, the following add_disk() will succeed.
         */
-       if (test_bit(NBD_DESTROY_ON_DISCONNECT, &nbd->flags) && nbd->destroy_complete)
-               complete(nbd->destroy_complete);
+       mutex_lock(&nbd_index_mutex);
+       idr_remove(&nbd_index_idr, nbd->index);
+       mutex_unlock(&nbd_index_mutex);
 
        kfree(nbd);
 }
 
+static void nbd_dev_remove_work(struct work_struct *work)
+{
+       nbd_dev_remove(container_of(work, struct nbd_device, remove_work));
+}
+
 static void nbd_put(struct nbd_device *nbd)
 {
-       if (refcount_dec_and_mutex_lock(&nbd->refs,
-                                       &nbd_index_mutex)) {
-               idr_remove(&nbd_index_idr, nbd->index);
+       if (!refcount_dec_and_test(&nbd->refs))
+               return;
+
+       /* Call del_gendisk() asynchrounously to prevent deadlock */
+       if (test_bit(NBD_DESTROY_ON_DISCONNECT, &nbd->flags))
+               queue_work(nbd_del_wq, &nbd->remove_work);
+       else
                nbd_dev_remove(nbd);
-               mutex_unlock(&nbd_index_mutex);
-       }
 }
 
 static int nbd_disconnected(struct nbd_config *config)
@@ -1388,6 +1393,7 @@ static int __nbd_ioctl(struct block_device *bdev, struct nbd_device *nbd,
                       unsigned int cmd, unsigned long arg)
 {
        struct nbd_config *config = nbd->config;
+       loff_t bytesize;
 
        switch (cmd) {
        case NBD_DISCONNECT:
@@ -1402,8 +1408,9 @@ static int __nbd_ioctl(struct block_device *bdev, struct nbd_device *nbd,
        case NBD_SET_SIZE:
                return nbd_set_size(nbd, arg, config->blksize);
        case NBD_SET_SIZE_BLOCKS:
-               return nbd_set_size(nbd, arg * config->blksize,
-                                   config->blksize);
+               if (check_mul_overflow((loff_t)arg, config->blksize, &bytesize))
+                       return -EINVAL;
+               return nbd_set_size(nbd, bytesize, config->blksize);
        case NBD_SET_TIMEOUT:
                nbd_set_cmd_timeout(nbd, arg);
                return 0;
@@ -1665,7 +1672,7 @@ static const struct blk_mq_ops nbd_mq_ops = {
        .timeout        = nbd_xmit_timeout,
 };
 
-static int nbd_dev_add(int index)
+static struct nbd_device *nbd_dev_add(int index, unsigned int refs)
 {
        struct nbd_device *nbd;
        struct gendisk *disk;
@@ -1683,13 +1690,14 @@ static int nbd_dev_add(int index)
        nbd->tag_set.flags = BLK_MQ_F_SHOULD_MERGE |
                BLK_MQ_F_BLOCKING;
        nbd->tag_set.driver_data = nbd;
-       nbd->destroy_complete = NULL;
+       INIT_WORK(&nbd->remove_work, nbd_dev_remove_work);
        nbd->backend = NULL;
 
        err = blk_mq_alloc_tag_set(&nbd->tag_set);
        if (err)
                goto out_free_nbd;
 
+       mutex_lock(&nbd_index_mutex);
        if (index >= 0) {
                err = idr_alloc(&nbd_index_idr, nbd, index, index + 1,
                                GFP_KERNEL);
@@ -1700,9 +1708,10 @@ static int nbd_dev_add(int index)
                if (err >= 0)
                        index = err;
        }
+       nbd->index = index;
+       mutex_unlock(&nbd_index_mutex);
        if (err < 0)
                goto out_free_tags;
-       nbd->index = index;
 
        disk = blk_mq_alloc_disk(&nbd->tag_set, NULL);
        if (IS_ERR(disk)) {
@@ -1726,38 +1735,65 @@ static int nbd_dev_add(int index)
 
        mutex_init(&nbd->config_lock);
        refcount_set(&nbd->config_refs, 0);
-       refcount_set(&nbd->refs, 1);
+       /*
+        * Start out with a zero references to keep other threads from using
+        * this device until it is fully initialized.
+        */
+       refcount_set(&nbd->refs, 0);
        INIT_LIST_HEAD(&nbd->list);
        disk->major = NBD_MAJOR;
+
+       /* Too big first_minor can cause duplicate creation of
+        * sysfs files/links, since first_minor will be truncated to
+        * byte in __device_add_disk().
+        */
        disk->first_minor = index << part_shift;
+       if (disk->first_minor > 0xff) {
+               err = -EINVAL;
+               goto out_free_idr;
+       }
+
        disk->minors = 1 << part_shift;
        disk->fops = &nbd_fops;
        disk->private_data = nbd;
        sprintf(disk->disk_name, "nbd%d", index);
        add_disk(disk);
+
+       /*
+        * Now publish the device.
+        */
+       refcount_set(&nbd->refs, refs);
        nbd_total_devices++;
-       return index;
+       return nbd;
 
 out_free_idr:
+       mutex_lock(&nbd_index_mutex);
        idr_remove(&nbd_index_idr, index);
+       mutex_unlock(&nbd_index_mutex);
 out_free_tags:
        blk_mq_free_tag_set(&nbd->tag_set);
 out_free_nbd:
        kfree(nbd);
 out:
-       return err;
+       return ERR_PTR(err);
 }
 
-static int find_free_cb(int id, void *ptr, void *data)
+static struct nbd_device *nbd_find_get_unused(void)
 {
-       struct nbd_device *nbd = ptr;
-       struct nbd_device **found = data;
+       struct nbd_device *nbd;
+       int id;
 
-       if (!refcount_read(&nbd->config_refs)) {
-               *found = nbd;
-               return 1;
+       lockdep_assert_held(&nbd_index_mutex);
+
+       idr_for_each_entry(&nbd_index_idr, nbd, id) {
+               if (refcount_read(&nbd->config_refs) ||
+                   test_bit(NBD_DESTROY_ON_DISCONNECT, &nbd->flags))
+                       continue;
+               if (refcount_inc_not_zero(&nbd->refs))
+                       return nbd;
        }
-       return 0;
+
+       return NULL;
 }
 
 /* Netlink interface. */
@@ -1806,8 +1842,7 @@ static int nbd_genl_size_set(struct genl_info *info, struct nbd_device *nbd)
 
 static int nbd_genl_connect(struct sk_buff *skb, struct genl_info *info)
 {
-       DECLARE_COMPLETION_ONSTACK(destroy_complete);
-       struct nbd_device *nbd = NULL;
+       struct nbd_device *nbd;
        struct nbd_config *config;
        int index = -1;
        int ret;
@@ -1829,55 +1864,29 @@ static int nbd_genl_connect(struct sk_buff *skb, struct genl_info *info)
 again:
        mutex_lock(&nbd_index_mutex);
        if (index == -1) {
-               ret = idr_for_each(&nbd_index_idr, &find_free_cb, &nbd);
-               if (ret == 0) {
-                       int new_index;
-                       new_index = nbd_dev_add(-1);
-                       if (new_index < 0) {
-                               mutex_unlock(&nbd_index_mutex);
-                               printk(KERN_ERR "nbd: failed to add new device\n");
-                               return new_index;
-                       }
-                       nbd = idr_find(&nbd_index_idr, new_index);
-               }
+               nbd = nbd_find_get_unused();
        } else {
                nbd = idr_find(&nbd_index_idr, index);
-               if (!nbd) {
-                       ret = nbd_dev_add(index);
-                       if (ret < 0) {
+               if (nbd) {
+                       if ((test_bit(NBD_DESTROY_ON_DISCONNECT, &nbd->flags) &&
+                            test_bit(NBD_DISCONNECT_REQUESTED, &nbd->flags)) ||
+                           !refcount_inc_not_zero(&nbd->refs)) {
                                mutex_unlock(&nbd_index_mutex);
-                               printk(KERN_ERR "nbd: failed to add new device\n");
-                               return ret;
+                               pr_err("nbd: device at index %d is going down\n",
+                                       index);
+                               return -EINVAL;
                        }
-                       nbd = idr_find(&nbd_index_idr, index);
                }
        }
-       if (!nbd) {
-               printk(KERN_ERR "nbd: couldn't find device at index %d\n",
-                      index);
-               mutex_unlock(&nbd_index_mutex);
-               return -EINVAL;
-       }
-
-       if (test_bit(NBD_DESTROY_ON_DISCONNECT, &nbd->flags) &&
-           test_bit(NBD_DISCONNECT_REQUESTED, &nbd->flags)) {
-               nbd->destroy_complete = &destroy_complete;
-               mutex_unlock(&nbd_index_mutex);
-
-               /* Wait untill the the nbd stuff is totally destroyed */
-               wait_for_completion(&destroy_complete);
-               goto again;
-       }
+       mutex_unlock(&nbd_index_mutex);
 
-       if (!refcount_inc_not_zero(&nbd->refs)) {
-               mutex_unlock(&nbd_index_mutex);
-               if (index == -1)
-                       goto again;
-               printk(KERN_ERR "nbd: device at index %d is going down\n",
-                      index);
-               return -EINVAL;
+       if (!nbd) {
+               nbd = nbd_dev_add(index, 2);
+               if (IS_ERR(nbd)) {
+                       pr_err("nbd: failed to add new device\n");
+                       return PTR_ERR(nbd);
+               }
        }
-       mutex_unlock(&nbd_index_mutex);
 
        mutex_lock(&nbd->config_lock);
        if (refcount_read(&nbd->config_refs)) {
@@ -2424,16 +2433,21 @@ static int __init nbd_init(void)
        if (register_blkdev(NBD_MAJOR, "nbd"))
                return -EIO;
 
+       nbd_del_wq = alloc_workqueue("nbd-del", WQ_UNBOUND, 0);
+       if (!nbd_del_wq) {
+               unregister_blkdev(NBD_MAJOR, "nbd");
+               return -ENOMEM;
+       }
+
        if (genl_register_family(&nbd_genl_family)) {
+               destroy_workqueue(nbd_del_wq);
                unregister_blkdev(NBD_MAJOR, "nbd");
                return -EINVAL;
        }
        nbd_dbg_init();
 
-       mutex_lock(&nbd_index_mutex);
        for (i = 0; i < nbds_max; i++)
-               nbd_dev_add(i);
-       mutex_unlock(&nbd_index_mutex);
+               nbd_dev_add(i, 1);
        return 0;
 }
 
@@ -2442,7 +2456,10 @@ static int nbd_exit_cb(int id, void *ptr, void *data)
        struct list_head *list = (struct list_head *)data;
        struct nbd_device *nbd = ptr;
 
-       list_add_tail(&nbd->list, list);
+       /* Skip nbd that is being removed asynchronously */
+       if (refcount_read(&nbd->refs))
+               list_add_tail(&nbd->list, list);
+
        return 0;
 }
 
@@ -2465,6 +2482,9 @@ static void __exit nbd_cleanup(void)
                nbd_put(nbd);
        }
 
+       /* Also wait for nbd_dev_remove_work() completes */
+       destroy_workqueue(nbd_del_wq);
+
        idr_destroy(&nbd_index_idr);
        genl_unregister_family(&nbd_genl_family);
        unregister_blkdev(NBD_MAJOR, "nbd");