vfio/type1: block on invalid vaddr
authorSteve Sistare <steven.sistare@oracle.com>
Fri, 29 Jan 2021 16:54:12 +0000 (08:54 -0800)
committerAlex Williamson <alex.williamson@redhat.com>
Mon, 1 Feb 2021 20:20:07 +0000 (13:20 -0700)
Block translation of host virtual address while an iova range has an
invalid vaddr.

Signed-off-by: Steve Sistare <steven.sistare@oracle.com>
Reviewed-by: Cornelia Huck <cohuck@redhat.com>
Signed-off-by: Alex Williamson <alex.williamson@redhat.com>
drivers/vfio/vfio_iommu_type1.c

index 2109803..6cf1dad 100644 (file)
@@ -31,6 +31,7 @@
 #include <linux/rbtree.h>
 #include <linux/sched/signal.h>
 #include <linux/sched/mm.h>
+#include <linux/kthread.h>
 #include <linux/slab.h>
 #include <linux/uaccess.h>
 #include <linux/vfio.h>
@@ -71,6 +72,7 @@ struct vfio_iommu {
        unsigned int            dma_avail;
        unsigned int            vaddr_invalid_count;
        uint64_t                pgsize_bitmap;
+       wait_queue_head_t       vaddr_wait;
        bool                    v2;
        bool                    nesting;
        bool                    dirty_page_tracking;
@@ -146,6 +148,8 @@ struct vfio_regions {
 #define DIRTY_BITMAP_PAGES_MAX  ((u64)INT_MAX)
 #define DIRTY_BITMAP_SIZE_MAX   DIRTY_BITMAP_BYTES(DIRTY_BITMAP_PAGES_MAX)
 
+#define WAITED 1
+
 static int put_pfn(unsigned long pfn, int prot);
 
 static struct vfio_group *vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
@@ -507,6 +511,61 @@ done:
        return ret;
 }
 
+static int vfio_wait(struct vfio_iommu *iommu)
+{
+       DEFINE_WAIT(wait);
+
+       prepare_to_wait(&iommu->vaddr_wait, &wait, TASK_KILLABLE);
+       mutex_unlock(&iommu->lock);
+       schedule();
+       mutex_lock(&iommu->lock);
+       finish_wait(&iommu->vaddr_wait, &wait);
+       if (kthread_should_stop() || !iommu->container_open ||
+           fatal_signal_pending(current)) {
+               return -EFAULT;
+       }
+       return WAITED;
+}
+
+/*
+ * Find dma struct and wait for its vaddr to be valid.  iommu lock is dropped
+ * if the task waits, but is re-locked on return.  Return result in *dma_p.
+ * Return 0 on success with no waiting, WAITED on success if waited, and -errno
+ * on error.
+ */
+static int vfio_find_dma_valid(struct vfio_iommu *iommu, dma_addr_t start,
+                              size_t size, struct vfio_dma **dma_p)
+{
+       int ret;
+
+       do {
+               *dma_p = vfio_find_dma(iommu, start, size);
+               if (!*dma_p)
+                       ret = -EINVAL;
+               else if (!(*dma_p)->vaddr_invalid)
+                       ret = 0;
+               else
+                       ret = vfio_wait(iommu);
+       } while (ret > 0);
+
+       return ret;
+}
+
+/*
+ * Wait for all vaddr in the dma_list to become valid.  iommu lock is dropped
+ * if the task waits, but is re-locked on return.  Return 0 on success with no
+ * waiting, WAITED on success if waited, and -errno on error.
+ */
+static int vfio_wait_all_valid(struct vfio_iommu *iommu)
+{
+       int ret = 0;
+
+       while (iommu->vaddr_invalid_count && ret >= 0)
+               ret = vfio_wait(iommu);
+
+       return ret;
+}
+
 /*
  * Attempt to pin pages.  We really don't want to track all the pfns and
  * the iommu can only map chunks of consecutive pfns anyway, so get the
@@ -668,6 +727,7 @@ static int vfio_iommu_type1_pin_pages(void *iommu_data,
        unsigned long remote_vaddr;
        struct vfio_dma *dma;
        bool do_accounting;
+       dma_addr_t iova;
 
        if (!iommu || !user_pfn || !phys_pfn)
                return -EINVAL;
@@ -678,6 +738,22 @@ static int vfio_iommu_type1_pin_pages(void *iommu_data,
 
        mutex_lock(&iommu->lock);
 
+       /*
+        * Wait for all necessary vaddr's to be valid so they can be used in
+        * the main loop without dropping the lock, to avoid racing vs unmap.
+        */
+again:
+       if (iommu->vaddr_invalid_count) {
+               for (i = 0; i < npage; i++) {
+                       iova = user_pfn[i] << PAGE_SHIFT;
+                       ret = vfio_find_dma_valid(iommu, iova, PAGE_SIZE, &dma);
+                       if (ret < 0)
+                               goto pin_done;
+                       if (ret == WAITED)
+                               goto again;
+               }
+       }
+
        /* Fail if notifier list is empty */
        if (!iommu->notifier.head) {
                ret = -EINVAL;
@@ -692,7 +768,6 @@ static int vfio_iommu_type1_pin_pages(void *iommu_data,
        do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
 
        for (i = 0; i < npage; i++) {
-               dma_addr_t iova;
                struct vfio_pfn *vpfn;
 
                iova = user_pfn[i] << PAGE_SHIFT;
@@ -977,8 +1052,10 @@ static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
        vfio_unlink_dma(iommu, dma);
        put_task_struct(dma->task);
        vfio_dma_bitmap_free(dma);
-       if (dma->vaddr_invalid)
+       if (dma->vaddr_invalid) {
                iommu->vaddr_invalid_count--;
+               wake_up_all(&iommu->vaddr_wait);
+       }
        kfree(dma);
        iommu->dma_avail++;
 }
@@ -1406,6 +1483,7 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu,
                        dma->vaddr = vaddr;
                        dma->vaddr_invalid = false;
                        iommu->vaddr_invalid_count--;
+                       wake_up_all(&iommu->vaddr_wait);
                }
                goto out_unlock;
        } else if (dma) {
@@ -1505,6 +1583,10 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
        unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
        int ret;
 
+       ret = vfio_wait_all_valid(iommu);
+       if (ret < 0)
+               return ret;
+
        /* Arbitrarily pick the first domain in the list for lookups */
        if (!list_empty(&iommu->domain_list))
                d = list_first_entry(&iommu->domain_list,
@@ -2524,6 +2606,7 @@ static void *vfio_iommu_type1_open(unsigned long arg)
        iommu->container_open = true;
        mutex_init(&iommu->lock);
        BLOCKING_INIT_NOTIFIER_HEAD(&iommu->notifier);
+       init_waitqueue_head(&iommu->vaddr_wait);
 
        return iommu;
 }
@@ -2992,12 +3075,13 @@ static int vfio_iommu_type1_dma_rw_chunk(struct vfio_iommu *iommu,
        struct vfio_dma *dma;
        bool kthread = current->mm == NULL;
        size_t offset;
+       int ret;
 
        *copied = 0;
 
-       dma = vfio_find_dma(iommu, user_iova, 1);
-       if (!dma)
-               return -EINVAL;
+       ret = vfio_find_dma_valid(iommu, user_iova, 1, &dma);
+       if (ret < 0)
+               return ret;
 
        if ((write && !(dma->prot & IOMMU_WRITE)) ||
                        !(dma->prot & IOMMU_READ))
@@ -3099,6 +3183,7 @@ static void vfio_iommu_type1_notify(void *iommu_data,
        mutex_lock(&iommu->lock);
        iommu->container_open = false;
        mutex_unlock(&iommu->lock);
+       wake_up_all(&iommu->vaddr_wait);
 }
 
 static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {