Merge tag 'for-linus-5.15-rc1' of git://git.kernel.org/pub/scm/linux/kernel/git/rw/uml
[linux-2.6-microblaze.git] / samples / vfio-mdev / mbochs.c
index 6c0f229..c313ab4 100644 (file)
@@ -129,7 +129,7 @@ static dev_t                mbochs_devt;
 static struct class    *mbochs_class;
 static struct cdev     mbochs_cdev;
 static struct device   mbochs_dev;
-static int             mbochs_used_mbytes;
+static atomic_t mbochs_avail_mbytes;
 static const struct vfio_device_ops mbochs_dev_ops;
 
 struct vfio_region_info_ext {
@@ -507,18 +507,22 @@ static int mbochs_reset(struct mdev_state *mdev_state)
 
 static int mbochs_probe(struct mdev_device *mdev)
 {
+       int avail_mbytes = atomic_read(&mbochs_avail_mbytes);
        const struct mbochs_type *type =
                &mbochs_types[mdev_get_type_group_id(mdev)];
        struct device *dev = mdev_dev(mdev);
        struct mdev_state *mdev_state;
        int ret = -ENOMEM;
 
-       if (type->mbytes + mbochs_used_mbytes > max_mbytes)
-               return -ENOMEM;
+       do {
+               if (avail_mbytes < type->mbytes)
+                       return -ENOSPC;
+       } while (!atomic_try_cmpxchg(&mbochs_avail_mbytes, &avail_mbytes,
+                                    avail_mbytes - type->mbytes));
 
        mdev_state = kzalloc(sizeof(struct mdev_state), GFP_KERNEL);
        if (mdev_state == NULL)
-               return -ENOMEM;
+               goto err_avail;
        vfio_init_group_dev(&mdev_state->vdev, &mdev->dev, &mbochs_dev_ops);
 
        mdev_state->vconfig = kzalloc(MBOCHS_CONFIG_SPACE_SIZE, GFP_KERNEL);
@@ -549,17 +553,18 @@ static int mbochs_probe(struct mdev_device *mdev)
        mbochs_create_config_space(mdev_state);
        mbochs_reset(mdev_state);
 
-       mbochs_used_mbytes += type->mbytes;
-
        ret = vfio_register_group_dev(&mdev_state->vdev);
        if (ret)
                goto err_mem;
        dev_set_drvdata(&mdev->dev, mdev_state);
        return 0;
-
 err_mem:
+       vfio_uninit_group_dev(&mdev_state->vdev);
+       kfree(mdev_state->pages);
        kfree(mdev_state->vconfig);
        kfree(mdev_state);
+err_avail:
+       atomic_add(type->mbytes, &mbochs_avail_mbytes);
        return ret;
 }
 
@@ -567,8 +572,9 @@ static void mbochs_remove(struct mdev_device *mdev)
 {
        struct mdev_state *mdev_state = dev_get_drvdata(&mdev->dev);
 
-       mbochs_used_mbytes -= mdev_state->type->mbytes;
        vfio_unregister_group_dev(&mdev_state->vdev);
+       vfio_uninit_group_dev(&mdev_state->vdev);
+       atomic_add(mdev_state->type->mbytes, &mbochs_avail_mbytes);
        kfree(mdev_state->pages);
        kfree(mdev_state->vconfig);
        kfree(mdev_state);
@@ -1272,15 +1278,7 @@ static long mbochs_ioctl(struct vfio_device *vdev, unsigned int cmd,
        return -ENOTTY;
 }
 
-static int mbochs_open(struct vfio_device *vdev)
-{
-       if (!try_module_get(THIS_MODULE))
-               return -ENODEV;
-
-       return 0;
-}
-
-static void mbochs_close(struct vfio_device *vdev)
+static void mbochs_close_device(struct vfio_device *vdev)
 {
        struct mdev_state *mdev_state =
                container_of(vdev, struct mdev_state, vdev);
@@ -1300,7 +1298,6 @@ static void mbochs_close(struct vfio_device *vdev)
        mbochs_put_pages(mdev_state);
 
        mutex_unlock(&mdev_state->ops_lock);
-       module_put(THIS_MODULE);
 }
 
 static ssize_t
@@ -1355,7 +1352,7 @@ static ssize_t available_instances_show(struct mdev_type *mtype,
 {
        const struct mbochs_type *type =
                &mbochs_types[mtype_get_type_group_id(mtype)];
-       int count = (max_mbytes - mbochs_used_mbytes) / type->mbytes;
+       int count = atomic_read(&mbochs_avail_mbytes) / type->mbytes;
 
        return sprintf(buf, "%d\n", count);
 }
@@ -1399,8 +1396,7 @@ static struct attribute_group *mdev_type_groups[] = {
 };
 
 static const struct vfio_device_ops mbochs_dev_ops = {
-       .open = mbochs_open,
-       .release = mbochs_close,
+       .close_device = mbochs_close_device,
        .read = mbochs_read,
        .write = mbochs_write,
        .ioctl = mbochs_ioctl,
@@ -1437,6 +1433,8 @@ static int __init mbochs_dev_init(void)
 {
        int ret = 0;
 
+       atomic_set(&mbochs_avail_mbytes, max_mbytes);
+
        ret = alloc_chrdev_region(&mbochs_devt, 0, MINORMASK + 1, MBOCHS_NAME);
        if (ret < 0) {
                pr_err("Error: failed to register mbochs_dev, err: %d\n", ret);