vdpa/mlx5: Introduce reference counting to mrs
[linux-2.6-microblaze.git] / drivers / vdpa / mlx5 / core / mr.c
index 2197c46..c7dc891 100644 (file)
@@ -498,32 +498,52 @@ static void destroy_user_mr(struct mlx5_vdpa_dev *mvdev, struct mlx5_vdpa_mr *mr
 
 static void _mlx5_vdpa_destroy_mr(struct mlx5_vdpa_dev *mvdev, struct mlx5_vdpa_mr *mr)
 {
+       if (WARN_ON(!mr))
+               return;
+
        if (mr->user_mr)
                destroy_user_mr(mvdev, mr);
        else
                destroy_dma_mr(mvdev, mr);
 
        vhost_iotlb_free(mr->iotlb);
+
+       kfree(mr);
 }
 
-void mlx5_vdpa_destroy_mr(struct mlx5_vdpa_dev *mvdev,
-                         struct mlx5_vdpa_mr *mr)
+static void _mlx5_vdpa_put_mr(struct mlx5_vdpa_dev *mvdev,
+                             struct mlx5_vdpa_mr *mr)
 {
        if (!mr)
                return;
 
+       if (refcount_dec_and_test(&mr->refcount))
+               _mlx5_vdpa_destroy_mr(mvdev, mr);
+}
+
+void mlx5_vdpa_put_mr(struct mlx5_vdpa_dev *mvdev,
+                     struct mlx5_vdpa_mr *mr)
+{
        mutex_lock(&mvdev->mr_mtx);
+       _mlx5_vdpa_put_mr(mvdev, mr);
+       mutex_unlock(&mvdev->mr_mtx);
+}
 
-       _mlx5_vdpa_destroy_mr(mvdev, mr);
+static void _mlx5_vdpa_get_mr(struct mlx5_vdpa_dev *mvdev,
+                             struct mlx5_vdpa_mr *mr)
+{
+       if (!mr)
+               return;
 
-       for (int i = 0; i < MLX5_VDPA_NUM_AS; i++) {
-               if (mvdev->mr[i] == mr)
-                       mvdev->mr[i] = NULL;
-       }
+       refcount_inc(&mr->refcount);
+}
 
+void mlx5_vdpa_get_mr(struct mlx5_vdpa_dev *mvdev,
+                     struct mlx5_vdpa_mr *mr)
+{
+       mutex_lock(&mvdev->mr_mtx);
+       _mlx5_vdpa_get_mr(mvdev, mr);
        mutex_unlock(&mvdev->mr_mtx);
-
-       kfree(mr);
 }
 
 void mlx5_vdpa_update_mr(struct mlx5_vdpa_dev *mvdev,
@@ -534,20 +554,16 @@ void mlx5_vdpa_update_mr(struct mlx5_vdpa_dev *mvdev,
 
        mutex_lock(&mvdev->mr_mtx);
 
+       _mlx5_vdpa_put_mr(mvdev, old_mr);
        mvdev->mr[asid] = new_mr;
-       if (old_mr) {
-               _mlx5_vdpa_destroy_mr(mvdev, old_mr);
-               kfree(old_mr);
-       }
 
        mutex_unlock(&mvdev->mr_mtx);
-
 }
 
 void mlx5_vdpa_destroy_mr_resources(struct mlx5_vdpa_dev *mvdev)
 {
        for (int i = 0; i < MLX5_VDPA_NUM_AS; i++)
-               mlx5_vdpa_destroy_mr(mvdev, mvdev->mr[i]);
+               mlx5_vdpa_update_mr(mvdev, NULL, i);
 
        prune_iotlb(mvdev->cvq.iotlb);
 }
@@ -607,6 +623,8 @@ struct mlx5_vdpa_mr *mlx5_vdpa_create_mr(struct mlx5_vdpa_dev *mvdev,
        if (err)
                goto out_err;
 
+       refcount_set(&mr->refcount, 1);
+
        return mr;
 
 out_err:
@@ -651,7 +669,7 @@ int mlx5_vdpa_reset_mr(struct mlx5_vdpa_dev *mvdev, unsigned int asid)
        if (asid >= MLX5_VDPA_NUM_AS)
                return -EINVAL;
 
-       mlx5_vdpa_destroy_mr(mvdev, mvdev->mr[asid]);
+       mlx5_vdpa_update_mr(mvdev, NULL, asid);
 
        if (asid == 0 && MLX5_CAP_GEN(mvdev->mdev, umem_uid_0)) {
                if (mlx5_vdpa_create_dma_mr(mvdev))