Merge tag 'ide-5.11-2021-02-28' of git://git.kernel.dk/linux-block
[linux-2.6-microblaze.git] / drivers / vfio / vfio_iommu_type1.c
index 0b4deda..4bb162c 100644 (file)
@@ -24,6 +24,7 @@
 #include <linux/compat.h>
 #include <linux/device.h>
 #include <linux/fs.h>
+#include <linux/highmem.h>
 #include <linux/iommu.h>
 #include <linux/module.h>
 #include <linux/mm.h>
@@ -69,11 +70,15 @@ struct vfio_iommu {
        struct rb_root          dma_list;
        struct blocking_notifier_head notifier;
        unsigned int            dma_avail;
+       unsigned int            vaddr_invalid_count;
        uint64_t                pgsize_bitmap;
+       uint64_t                num_non_pinned_groups;
+       wait_queue_head_t       vaddr_wait;
        bool                    v2;
        bool                    nesting;
        bool                    dirty_page_tracking;
        bool                    pinned_page_dirty_scope;
+       bool                    container_open;
 };
 
 struct vfio_domain {
@@ -92,11 +97,20 @@ struct vfio_dma {
        int                     prot;           /* IOMMU_READ/WRITE */
        bool                    iommu_mapped;
        bool                    lock_cap;       /* capable(CAP_IPC_LOCK) */
+       bool                    vaddr_invalid;
        struct task_struct      *task;
        struct rb_root          pfn_list;       /* Ex-user pinned pfn list */
        unsigned long           *bitmap;
 };
 
+struct vfio_batch {
+       struct page             **pages;        /* for pin_user_pages_remote */
+       struct page             *fallback_page; /* if pages alloc fails */
+       int                     capacity;       /* length of pages array */
+       int                     size;           /* of batch currently */
+       int                     offset;         /* of next entry in pages */
+};
+
 struct vfio_group {
        struct iommu_group      *iommu_group;
        struct list_head        next;
@@ -143,12 +157,13 @@ struct vfio_regions {
 #define DIRTY_BITMAP_PAGES_MAX  ((u64)INT_MAX)
 #define DIRTY_BITMAP_SIZE_MAX   DIRTY_BITMAP_BYTES(DIRTY_BITMAP_PAGES_MAX)
 
+#define WAITED 1
+
 static int put_pfn(unsigned long pfn, int prot);
 
 static struct vfio_group *vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
                                               struct iommu_group *iommu_group);
 
-static void update_pinned_page_dirty_scope(struct vfio_iommu *iommu);
 /*
  * This code handles mapping and unmapping of user data buffers
  * into DMA'ble space using the IOMMU
@@ -173,6 +188,31 @@ static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
        return NULL;
 }
 
+static struct rb_node *vfio_find_dma_first_node(struct vfio_iommu *iommu,
+                                               dma_addr_t start, size_t size)
+{
+       struct rb_node *res = NULL;
+       struct rb_node *node = iommu->dma_list.rb_node;
+       struct vfio_dma *dma_res = NULL;
+
+       while (node) {
+               struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
+
+               if (start < dma->iova + dma->size) {
+                       res = node;
+                       dma_res = dma;
+                       if (start >= dma->iova)
+                               break;
+                       node = node->rb_left;
+               } else {
+                       node = node->rb_right;
+               }
+       }
+       if (res && size && dma_res->iova >= start + size)
+               res = NULL;
+       return res;
+}
+
 static void vfio_link_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
 {
        struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
@@ -236,6 +276,18 @@ static void vfio_dma_populate_bitmap(struct vfio_dma *dma, size_t pgsize)
        }
 }
 
+static void vfio_iommu_populate_bitmap_full(struct vfio_iommu *iommu)
+{
+       struct rb_node *n;
+       unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
+
+       for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
+               struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
+
+               bitmap_set(dma->bitmap, 0, dma->size >> pgshift);
+       }
+}
+
 static int vfio_dma_bitmap_alloc_all(struct vfio_iommu *iommu, size_t pgsize)
 {
        struct rb_node *n;
@@ -415,13 +467,54 @@ static int put_pfn(unsigned long pfn, int prot)
        return 0;
 }
 
+#define VFIO_BATCH_MAX_CAPACITY (PAGE_SIZE / sizeof(struct page *))
+
+static void vfio_batch_init(struct vfio_batch *batch)
+{
+       batch->size = 0;
+       batch->offset = 0;
+
+       if (unlikely(disable_hugepages))
+               goto fallback;
+
+       batch->pages = (struct page **) __get_free_page(GFP_KERNEL);
+       if (!batch->pages)
+               goto fallback;
+
+       batch->capacity = VFIO_BATCH_MAX_CAPACITY;
+       return;
+
+fallback:
+       batch->pages = &batch->fallback_page;
+       batch->capacity = 1;
+}
+
+static void vfio_batch_unpin(struct vfio_batch *batch, struct vfio_dma *dma)
+{
+       while (batch->size) {
+               unsigned long pfn = page_to_pfn(batch->pages[batch->offset]);
+
+               put_pfn(pfn, dma->prot);
+               batch->offset++;
+               batch->size--;
+       }
+}
+
+static void vfio_batch_fini(struct vfio_batch *batch)
+{
+       if (batch->capacity == VFIO_BATCH_MAX_CAPACITY)
+               free_page((unsigned long)batch->pages);
+}
+
 static int follow_fault_pfn(struct vm_area_struct *vma, struct mm_struct *mm,
                            unsigned long vaddr, unsigned long *pfn,
                            bool write_fault)
 {
+       pte_t *ptep;
+       spinlock_t *ptl;
        int ret;
 
-       ret = follow_pfn(vma, vaddr, pfn);
+       ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
        if (ret) {
                bool unlocked = false;
 
@@ -435,16 +528,28 @@ static int follow_fault_pfn(struct vm_area_struct *vma, struct mm_struct *mm,
                if (ret)
                        return ret;
 
-               ret = follow_pfn(vma, vaddr, pfn);
+               ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
+               if (ret)
+                       return ret;
        }
 
+       if (write_fault && !pte_write(*ptep))
+               ret = -EFAULT;
+       else
+               *pfn = pte_pfn(*ptep);
+
+       pte_unmap_unlock(ptep, ptl);
        return ret;
 }
 
-static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
-                        int prot, unsigned long *pfn)
+/*
+ * Returns the positive number of pfns successfully obtained or a negative
+ * error code.
+ */
+static int vaddr_get_pfns(struct mm_struct *mm, unsigned long vaddr,
+                         long npages, int prot, unsigned long *pfn,
+                         struct page **pages)
 {
-       struct page *page[1];
        struct vm_area_struct *vma;
        unsigned int flags = 0;
        int ret;
@@ -453,11 +558,10 @@ static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
                flags |= FOLL_WRITE;
 
        mmap_read_lock(mm);
-       ret = pin_user_pages_remote(mm, vaddr, 1, flags | FOLL_LONGTERM,
-                                   page, NULL, NULL);
-       if (ret == 1) {
-               *pfn = page_to_pfn(page[0]);
-               ret = 0;
+       ret = pin_user_pages_remote(mm, vaddr, npages, flags | FOLL_LONGTERM,
+                                   pages, NULL, NULL);
+       if (ret > 0) {
+               *pfn = page_to_pfn(pages[0]);
                goto done;
        }
 
@@ -471,14 +575,73 @@ retry:
                if (ret == -EAGAIN)
                        goto retry;
 
-               if (!ret && !is_invalid_reserved_pfn(*pfn))
-                       ret = -EFAULT;
+               if (!ret) {
+                       if (is_invalid_reserved_pfn(*pfn))
+                               ret = 1;
+                       else
+                               ret = -EFAULT;
+               }
        }
 done:
        mmap_read_unlock(mm);
        return ret;
 }
 
+static int vfio_wait(struct vfio_iommu *iommu)
+{
+       DEFINE_WAIT(wait);
+
+       prepare_to_wait(&iommu->vaddr_wait, &wait, TASK_KILLABLE);
+       mutex_unlock(&iommu->lock);
+       schedule();
+       mutex_lock(&iommu->lock);
+       finish_wait(&iommu->vaddr_wait, &wait);
+       if (kthread_should_stop() || !iommu->container_open ||
+           fatal_signal_pending(current)) {
+               return -EFAULT;
+       }
+       return WAITED;
+}
+
+/*
+ * Find dma struct and wait for its vaddr to be valid.  iommu lock is dropped
+ * if the task waits, but is re-locked on return.  Return result in *dma_p.
+ * Return 0 on success with no waiting, WAITED on success if waited, and -errno
+ * on error.
+ */
+static int vfio_find_dma_valid(struct vfio_iommu *iommu, dma_addr_t start,
+                              size_t size, struct vfio_dma **dma_p)
+{
+       int ret;
+
+       do {
+               *dma_p = vfio_find_dma(iommu, start, size);
+               if (!*dma_p)
+                       ret = -EINVAL;
+               else if (!(*dma_p)->vaddr_invalid)
+                       ret = 0;
+               else
+                       ret = vfio_wait(iommu);
+       } while (ret > 0);
+
+       return ret;
+}
+
+/*
+ * Wait for all vaddr in the dma_list to become valid.  iommu lock is dropped
+ * if the task waits, but is re-locked on return.  Return 0 on success with no
+ * waiting, WAITED on success if waited, and -errno on error.
+ */
+static int vfio_wait_all_valid(struct vfio_iommu *iommu)
+{
+       int ret = 0;
+
+       while (iommu->vaddr_invalid_count && ret >= 0)
+               ret = vfio_wait(iommu);
+
+       return ret;
+}
+
 /*
  * Attempt to pin pages.  We really don't want to track all the pfns and
  * the iommu can only map chunks of consecutive pfns anyway, so get the
@@ -486,76 +649,102 @@ done:
  */
 static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
                                  long npage, unsigned long *pfn_base,
-                                 unsigned long limit)
+                                 unsigned long limit, struct vfio_batch *batch)
 {
-       unsigned long pfn = 0;
+       unsigned long pfn;
+       struct mm_struct *mm = current->mm;
        long ret, pinned = 0, lock_acct = 0;
        bool rsvd;
        dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
 
        /* This code path is only user initiated */
-       if (!current->mm)
+       if (!mm)
                return -ENODEV;
 
-       ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, pfn_base);
-       if (ret)
-               return ret;
-
-       pinned++;
-       rsvd = is_invalid_reserved_pfn(*pfn_base);
-
-       /*
-        * Reserved pages aren't counted against the user, externally pinned
-        * pages are already counted against the user.
-        */
-       if (!rsvd && !vfio_find_vpfn(dma, iova)) {
-               if (!dma->lock_cap && current->mm->locked_vm + 1 > limit) {
-                       put_pfn(*pfn_base, dma->prot);
-                       pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__,
-                                       limit << PAGE_SHIFT);
-                       return -ENOMEM;
-               }
-               lock_acct++;
+       if (batch->size) {
+               /* Leftover pages in batch from an earlier call. */
+               *pfn_base = page_to_pfn(batch->pages[batch->offset]);
+               pfn = *pfn_base;
+               rsvd = is_invalid_reserved_pfn(*pfn_base);
+       } else {
+               *pfn_base = 0;
        }
 
-       if (unlikely(disable_hugepages))
-               goto out;
+       while (npage) {
+               if (!batch->size) {
+                       /* Empty batch, so refill it. */
+                       long req_pages = min_t(long, npage, batch->capacity);
 
-       /* Lock all the consecutive pages from pfn_base */
-       for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; pinned < npage;
-            pinned++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) {
-               ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, &pfn);
-               if (ret)
-                       break;
+                       ret = vaddr_get_pfns(mm, vaddr, req_pages, dma->prot,
+                                            &pfn, batch->pages);
+                       if (ret < 0)
+                               goto unpin_out;
 
-               if (pfn != *pfn_base + pinned ||
-                   rsvd != is_invalid_reserved_pfn(pfn)) {
-                       put_pfn(pfn, dma->prot);
-                       break;
+                       batch->size = ret;
+                       batch->offset = 0;
+
+                       if (!*pfn_base) {
+                               *pfn_base = pfn;
+                               rsvd = is_invalid_reserved_pfn(*pfn_base);
+                       }
                }
 
-               if (!rsvd && !vfio_find_vpfn(dma, iova)) {
-                       if (!dma->lock_cap &&
-                           current->mm->locked_vm + lock_acct + 1 > limit) {
-                               put_pfn(pfn, dma->prot);
-                               pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
-                                       __func__, limit << PAGE_SHIFT);
-                               ret = -ENOMEM;
-                               goto unpin_out;
+               /*
+                * pfn is preset for the first iteration of this inner loop and
+                * updated at the end to handle a VM_PFNMAP pfn.  In that case,
+                * batch->pages isn't valid (there's no struct page), so allow
+                * batch->pages to be touched only when there's more than one
+                * pfn to check, which guarantees the pfns are from a
+                * !VM_PFNMAP vma.
+                */
+               while (true) {
+                       if (pfn != *pfn_base + pinned ||
+                           rsvd != is_invalid_reserved_pfn(pfn))
+                               goto out;
+
+                       /*
+                        * Reserved pages aren't counted against the user,
+                        * externally pinned pages are already counted against
+                        * the user.
+                        */
+                       if (!rsvd && !vfio_find_vpfn(dma, iova)) {
+                               if (!dma->lock_cap &&
+                                   mm->locked_vm + lock_acct + 1 > limit) {
+                                       pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
+                                               __func__, limit << PAGE_SHIFT);
+                                       ret = -ENOMEM;
+                                       goto unpin_out;
+                               }
+                               lock_acct++;
                        }
-                       lock_acct++;
+
+                       pinned++;
+                       npage--;
+                       vaddr += PAGE_SIZE;
+                       iova += PAGE_SIZE;
+                       batch->offset++;
+                       batch->size--;
+
+                       if (!batch->size)
+                               break;
+
+                       pfn = page_to_pfn(batch->pages[batch->offset]);
                }
+
+               if (unlikely(disable_hugepages))
+                       break;
        }
 
 out:
        ret = vfio_lock_acct(dma, lock_acct, false);
 
 unpin_out:
-       if (ret) {
-               if (!rsvd) {
+       if (ret < 0) {
+               if (pinned && !rsvd) {
                        for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
                                put_pfn(pfn, dma->prot);
                }
+               vfio_batch_unpin(batch, dma);
 
                return ret;
        }
@@ -587,6 +776,7 @@ static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
 static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
                                  unsigned long *pfn_base, bool do_accounting)
 {
+       struct page *pages[1];
        struct mm_struct *mm;
        int ret;
 
@@ -594,8 +784,8 @@ static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
        if (!mm)
                return -ENODEV;
 
-       ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base);
-       if (!ret && do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
+       ret = vaddr_get_pfns(mm, vaddr, 1, dma->prot, pfn_base, pages);
+       if (ret == 1 && do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
                ret = vfio_lock_acct(dma, 1, true);
                if (ret) {
                        put_pfn(*pfn_base, dma->prot);
@@ -640,6 +830,7 @@ static int vfio_iommu_type1_pin_pages(void *iommu_data,
        unsigned long remote_vaddr;
        struct vfio_dma *dma;
        bool do_accounting;
+       dma_addr_t iova;
 
        if (!iommu || !user_pfn || !phys_pfn)
                return -EINVAL;
@@ -650,6 +841,22 @@ static int vfio_iommu_type1_pin_pages(void *iommu_data,
 
        mutex_lock(&iommu->lock);
 
+       /*
+        * Wait for all necessary vaddr's to be valid so they can be used in
+        * the main loop without dropping the lock, to avoid racing vs unmap.
+        */
+again:
+       if (iommu->vaddr_invalid_count) {
+               for (i = 0; i < npage; i++) {
+                       iova = user_pfn[i] << PAGE_SHIFT;
+                       ret = vfio_find_dma_valid(iommu, iova, PAGE_SIZE, &dma);
+                       if (ret < 0)
+                               goto pin_done;
+                       if (ret == WAITED)
+                               goto again;
+               }
+       }
+
        /* Fail if notifier list is empty */
        if (!iommu->notifier.head) {
                ret = -EINVAL;
@@ -664,7 +871,6 @@ static int vfio_iommu_type1_pin_pages(void *iommu_data,
        do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
 
        for (i = 0; i < npage; i++) {
-               dma_addr_t iova;
                struct vfio_pfn *vpfn;
 
                iova = user_pfn[i] << PAGE_SHIFT;
@@ -714,7 +920,7 @@ static int vfio_iommu_type1_pin_pages(void *iommu_data,
        group = vfio_iommu_find_iommu_group(iommu, iommu_group);
        if (!group->pinned_page_dirty_scope) {
                group->pinned_page_dirty_scope = true;
-               update_pinned_page_dirty_scope(iommu);
+               iommu->num_non_pinned_groups--;
        }
 
        goto pin_done;
@@ -945,10 +1151,15 @@ static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
 
 static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
 {
+       WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list));
        vfio_unmap_unpin(iommu, dma, true);
        vfio_unlink_dma(iommu, dma);
        put_task_struct(dma->task);
        vfio_dma_bitmap_free(dma);
+       if (dma->vaddr_invalid) {
+               iommu->vaddr_invalid_count--;
+               wake_up_all(&iommu->vaddr_wait);
+       }
        kfree(dma);
        iommu->dma_avail++;
 }
@@ -991,7 +1202,7 @@ static int update_user_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
         * mark all pages dirty if any IOMMU capable device is not able
         * to report dirty pages and all pages are pinned and mapped.
         */
-       if (!iommu->pinned_page_dirty_scope && dma->iommu_mapped)
+       if (iommu->num_non_pinned_groups && dma->iommu_mapped)
                bitmap_set(dma->bitmap, 0, nbits);
 
        if (shift) {
@@ -1074,34 +1285,36 @@ static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
 {
        struct vfio_dma *dma, *dma_last = NULL;
        size_t unmapped = 0, pgsize;
-       int ret = 0, retries = 0;
+       int ret = -EINVAL, retries = 0;
        unsigned long pgshift;
+       dma_addr_t iova = unmap->iova;
+       unsigned long size = unmap->size;
+       bool unmap_all = unmap->flags & VFIO_DMA_UNMAP_FLAG_ALL;
+       bool invalidate_vaddr = unmap->flags & VFIO_DMA_UNMAP_FLAG_VADDR;
+       struct rb_node *n, *first_n;
 
        mutex_lock(&iommu->lock);
 
        pgshift = __ffs(iommu->pgsize_bitmap);
        pgsize = (size_t)1 << pgshift;
 
-       if (unmap->iova & (pgsize - 1)) {
-               ret = -EINVAL;
+       if (iova & (pgsize - 1))
                goto unlock;
-       }
 
-       if (!unmap->size || unmap->size & (pgsize - 1)) {
-               ret = -EINVAL;
+       if (unmap_all) {
+               if (iova || size)
+                       goto unlock;
+               size = SIZE_MAX;
+       } else if (!size || size & (pgsize - 1)) {
                goto unlock;
        }
 
-       if (unmap->iova + unmap->size - 1 < unmap->iova ||
-           unmap->size > SIZE_MAX) {
-               ret = -EINVAL;
+       if (iova + size - 1 < iova || size > SIZE_MAX)
                goto unlock;
-       }
 
        /* When dirty tracking is enabled, allow only min supported pgsize */
        if ((unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
            (!iommu->dirty_page_tracking || (bitmap->pgsize != pgsize))) {
-               ret = -EINVAL;
                goto unlock;
        }
 
@@ -1138,21 +1351,25 @@ again:
         * will only return success and a size of zero if there were no
         * mappings within the range.
         */
-       if (iommu->v2) {
-               dma = vfio_find_dma(iommu, unmap->iova, 1);
-               if (dma && dma->iova != unmap->iova) {
-                       ret = -EINVAL;
+       if (iommu->v2 && !unmap_all) {
+               dma = vfio_find_dma(iommu, iova, 1);
+               if (dma && dma->iova != iova)
                        goto unlock;
-               }
-               dma = vfio_find_dma(iommu, unmap->iova + unmap->size - 1, 0);
-               if (dma && dma->iova + dma->size != unmap->iova + unmap->size) {
-                       ret = -EINVAL;
+
+               dma = vfio_find_dma(iommu, iova + size - 1, 0);
+               if (dma && dma->iova + dma->size != iova + size)
                        goto unlock;
-               }
        }
 
-       while ((dma = vfio_find_dma(iommu, unmap->iova, unmap->size))) {
-               if (!iommu->v2 && unmap->iova > dma->iova)
+       ret = 0;
+       n = first_n = vfio_find_dma_first_node(iommu, iova, size);
+
+       while (n) {
+               dma = rb_entry(n, struct vfio_dma, node);
+               if (dma->iova >= iova + size)
+                       break;
+
+               if (!iommu->v2 && iova > dma->iova)
                        break;
                /*
                 * Task with same address space who mapped this iova range is
@@ -1161,6 +1378,27 @@ again:
                if (dma->task->mm != current->mm)
                        break;
 
+               if (invalidate_vaddr) {
+                       if (dma->vaddr_invalid) {
+                               struct rb_node *last_n = n;
+
+                               for (n = first_n; n != last_n; n = rb_next(n)) {
+                                       dma = rb_entry(n,
+                                                      struct vfio_dma, node);
+                                       dma->vaddr_invalid = false;
+                                       iommu->vaddr_invalid_count--;
+                               }
+                               ret = -EINVAL;
+                               unmapped = 0;
+                               break;
+                       }
+                       dma->vaddr_invalid = true;
+                       iommu->vaddr_invalid_count++;
+                       unmapped += dma->size;
+                       n = rb_next(n);
+                       continue;
+               }
+
                if (!RB_EMPTY_ROOT(&dma->pfn_list)) {
                        struct vfio_iommu_type1_dma_unmap nb_unmap;
 
@@ -1190,12 +1428,13 @@ again:
 
                if (unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
                        ret = update_user_bitmap(bitmap->data, iommu, dma,
-                                                unmap->iova, pgsize);
+                                                iova, pgsize);
                        if (ret)
                                break;
                }
 
                unmapped += dma->size;
+               n = rb_next(n);
                vfio_remove_dma(iommu, dma);
        }
 
@@ -1239,15 +1478,19 @@ static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
 {
        dma_addr_t iova = dma->iova;
        unsigned long vaddr = dma->vaddr;
+       struct vfio_batch batch;
        size_t size = map_size;
        long npage;
        unsigned long pfn, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
        int ret = 0;
 
+       vfio_batch_init(&batch);
+
        while (size) {
                /* Pin a contiguous chunk of memory */
                npage = vfio_pin_pages_remote(dma, vaddr + dma->size,
-                                             size >> PAGE_SHIFT, &pfn, limit);
+                                             size >> PAGE_SHIFT, &pfn, limit,
+                                             &batch);
                if (npage <= 0) {
                        WARN_ON(!npage);
                        ret = (int)npage;
@@ -1260,6 +1503,7 @@ static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
                if (ret) {
                        vfio_unpin_pages_remote(dma, iova + dma->size, pfn,
                                                npage, true);
+                       vfio_batch_unpin(&batch, dma);
                        break;
                }
 
@@ -1267,6 +1511,7 @@ static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
                dma->size += npage << PAGE_SHIFT;
        }
 
+       vfio_batch_fini(&batch);
        dma->iommu_mapped = true;
 
        if (ret)
@@ -1299,6 +1544,7 @@ static bool vfio_iommu_iova_dma_valid(struct vfio_iommu *iommu,
 static int vfio_dma_do_map(struct vfio_iommu *iommu,
                           struct vfio_iommu_type1_dma_map *map)
 {
+       bool set_vaddr = map->flags & VFIO_DMA_MAP_FLAG_VADDR;
        dma_addr_t iova = map->iova;
        unsigned long vaddr = map->vaddr;
        size_t size = map->size;
@@ -1316,13 +1562,16 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu,
        if (map->flags & VFIO_DMA_MAP_FLAG_READ)
                prot |= IOMMU_READ;
 
+       if ((prot && set_vaddr) || (!prot && !set_vaddr))
+               return -EINVAL;
+
        mutex_lock(&iommu->lock);
 
        pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
 
        WARN_ON((pgsize - 1) & PAGE_MASK);
 
-       if (!prot || !size || (size | iova | vaddr) & (pgsize - 1)) {
+       if (!size || (size | iova | vaddr) & (pgsize - 1)) {
                ret = -EINVAL;
                goto out_unlock;
        }
@@ -1333,7 +1582,21 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu,
                goto out_unlock;
        }
 
-       if (vfio_find_dma(iommu, iova, size)) {
+       dma = vfio_find_dma(iommu, iova, size);
+       if (set_vaddr) {
+               if (!dma) {
+                       ret = -ENOENT;
+               } else if (!dma->vaddr_invalid || dma->iova != iova ||
+                          dma->size != size) {
+                       ret = -EINVAL;
+               } else {
+                       dma->vaddr = vaddr;
+                       dma->vaddr_invalid = false;
+                       iommu->vaddr_invalid_count--;
+                       wake_up_all(&iommu->vaddr_wait);
+               }
+               goto out_unlock;
+       } else if (dma) {
                ret = -EEXIST;
                goto out_unlock;
        }
@@ -1425,16 +1688,23 @@ static int vfio_bus_type(struct device *dev, void *data)
 static int vfio_iommu_replay(struct vfio_iommu *iommu,
                             struct vfio_domain *domain)
 {
+       struct vfio_batch batch;
        struct vfio_domain *d = NULL;
        struct rb_node *n;
        unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
        int ret;
 
+       ret = vfio_wait_all_valid(iommu);
+       if (ret < 0)
+               return ret;
+
        /* Arbitrarily pick the first domain in the list for lookups */
        if (!list_empty(&iommu->domain_list))
                d = list_first_entry(&iommu->domain_list,
                                     struct vfio_domain, next);
 
+       vfio_batch_init(&batch);
+
        n = rb_first(&iommu->dma_list);
 
        for (; n; n = rb_next(n)) {
@@ -1482,7 +1752,8 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
 
                                npage = vfio_pin_pages_remote(dma, vaddr,
                                                              n >> PAGE_SHIFT,
-                                                             &pfn, limit);
+                                                             &pfn, limit,
+                                                             &batch);
                                if (npage <= 0) {
                                        WARN_ON(!npage);
                                        ret = (int)npage;
@@ -1496,11 +1767,13 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
                        ret = iommu_map(domain->domain, iova, phys,
                                        size, dma->prot | domain->prot);
                        if (ret) {
-                               if (!dma->iommu_mapped)
+                               if (!dma->iommu_mapped) {
                                        vfio_unpin_pages_remote(dma, iova,
                                                        phys >> PAGE_SHIFT,
                                                        size >> PAGE_SHIFT,
                                                        true);
+                                       vfio_batch_unpin(&batch, dma);
+                               }
                                goto unwind;
                        }
 
@@ -1515,6 +1788,7 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
                dma->iommu_mapped = true;
        }
 
+       vfio_batch_fini(&batch);
        return 0;
 
 unwind:
@@ -1555,6 +1829,7 @@ unwind:
                }
        }
 
+       vfio_batch_fini(&batch);
        return ret;
 }
 
@@ -1622,33 +1897,6 @@ static struct vfio_group *vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
        return group;
 }
 
-static void update_pinned_page_dirty_scope(struct vfio_iommu *iommu)
-{
-       struct vfio_domain *domain;
-       struct vfio_group *group;
-
-       list_for_each_entry(domain, &iommu->domain_list, next) {
-               list_for_each_entry(group, &domain->group_list, next) {
-                       if (!group->pinned_page_dirty_scope) {
-                               iommu->pinned_page_dirty_scope = false;
-                               return;
-                       }
-               }
-       }
-
-       if (iommu->external_domain) {
-               domain = iommu->external_domain;
-               list_for_each_entry(group, &domain->group_list, next) {
-                       if (!group->pinned_page_dirty_scope) {
-                               iommu->pinned_page_dirty_scope = false;
-                               return;
-                       }
-               }
-       }
-
-       iommu->pinned_page_dirty_scope = true;
-}
-
 static bool vfio_iommu_has_sw_msi(struct list_head *group_resv_regions,
                                  phys_addr_t *base)
 {
@@ -2057,8 +2305,6 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
                         * addition of a dirty tracking group.
                         */
                        group->pinned_page_dirty_scope = true;
-                       if (!iommu->pinned_page_dirty_scope)
-                               update_pinned_page_dirty_scope(iommu);
                        mutex_unlock(&iommu->lock);
 
                        return 0;
@@ -2188,7 +2434,7 @@ done:
         * demotes the iommu scope until it declares itself dirty tracking
         * capable via the page pinning interface.
         */
-       iommu->pinned_page_dirty_scope = false;
+       iommu->num_non_pinned_groups++;
        mutex_unlock(&iommu->lock);
        vfio_iommu_resv_free(&group_resv_regions);
 
@@ -2238,23 +2484,6 @@ static void vfio_iommu_unmap_unpin_reaccount(struct vfio_iommu *iommu)
        }
 }
 
-static void vfio_sanity_check_pfn_list(struct vfio_iommu *iommu)
-{
-       struct rb_node *n;
-
-       n = rb_first(&iommu->dma_list);
-       for (; n; n = rb_next(n)) {
-               struct vfio_dma *dma;
-
-               dma = rb_entry(n, struct vfio_dma, node);
-
-               if (WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list)))
-                       break;
-       }
-       /* mdev vendor driver must unregister notifier */
-       WARN_ON(iommu->notifier.head);
-}
-
 /*
  * Called when a domain is removed in detach. It is possible that
  * the removed domain decided the iova aperture window. Modify the
@@ -2354,10 +2583,10 @@ static void vfio_iommu_type1_detach_group(void *iommu_data,
                        kfree(group);
 
                        if (list_empty(&iommu->external_domain->group_list)) {
-                               vfio_sanity_check_pfn_list(iommu);
-
-                               if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
+                               if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu)) {
+                                       WARN_ON(iommu->notifier.head);
                                        vfio_iommu_unmap_unpin_all(iommu);
+                               }
 
                                kfree(iommu->external_domain);
                                iommu->external_domain = NULL;
@@ -2391,10 +2620,12 @@ static void vfio_iommu_type1_detach_group(void *iommu_data,
                 */
                if (list_empty(&domain->group_list)) {
                        if (list_is_singular(&iommu->domain_list)) {
-                               if (!iommu->external_domain)
+                               if (!iommu->external_domain) {
+                                       WARN_ON(iommu->notifier.head);
                                        vfio_iommu_unmap_unpin_all(iommu);
-                               else
+                               } else {
                                        vfio_iommu_unmap_unpin_reaccount(iommu);
+                               }
                        }
                        iommu_domain_free(domain->domain);
                        list_del(&domain->next);
@@ -2415,8 +2646,11 @@ detach_group_done:
         * Removal of a group without dirty tracking may allow the iommu scope
         * to be promoted.
         */
-       if (update_dirty_scope)
-               update_pinned_page_dirty_scope(iommu);
+       if (update_dirty_scope) {
+               iommu->num_non_pinned_groups--;
+               if (iommu->dirty_page_tracking)
+                       vfio_iommu_populate_bitmap_full(iommu);
+       }
        mutex_unlock(&iommu->lock);
 }
 
@@ -2446,8 +2680,10 @@ static void *vfio_iommu_type1_open(unsigned long arg)
        INIT_LIST_HEAD(&iommu->iova_list);
        iommu->dma_list = RB_ROOT;
        iommu->dma_avail = dma_entry_limit;
+       iommu->container_open = true;
        mutex_init(&iommu->lock);
        BLOCKING_INIT_NOTIFIER_HEAD(&iommu->notifier);
+       init_waitqueue_head(&iommu->vaddr_wait);
 
        return iommu;
 }
@@ -2475,7 +2711,6 @@ static void vfio_iommu_type1_release(void *iommu_data)
 
        if (iommu->external_domain) {
                vfio_release_domain(iommu->external_domain, true);
-               vfio_sanity_check_pfn_list(iommu);
                kfree(iommu->external_domain);
        }
 
@@ -2517,6 +2752,8 @@ static int vfio_iommu_type1_check_extension(struct vfio_iommu *iommu,
        case VFIO_TYPE1_IOMMU:
        case VFIO_TYPE1v2_IOMMU:
        case VFIO_TYPE1_NESTING_IOMMU:
+       case VFIO_UNMAP_ALL:
+       case VFIO_UPDATE_VADDR:
                return 1;
        case VFIO_DMA_CC_IOMMU:
                if (!iommu)
@@ -2688,7 +2925,8 @@ static int vfio_iommu_type1_map_dma(struct vfio_iommu *iommu,
 {
        struct vfio_iommu_type1_dma_map map;
        unsigned long minsz;
-       uint32_t mask = VFIO_DMA_MAP_FLAG_READ | VFIO_DMA_MAP_FLAG_WRITE;
+       uint32_t mask = VFIO_DMA_MAP_FLAG_READ | VFIO_DMA_MAP_FLAG_WRITE |
+                       VFIO_DMA_MAP_FLAG_VADDR;
 
        minsz = offsetofend(struct vfio_iommu_type1_dma_map, size);
 
@@ -2706,6 +2944,9 @@ static int vfio_iommu_type1_unmap_dma(struct vfio_iommu *iommu,
 {
        struct vfio_iommu_type1_dma_unmap unmap;
        struct vfio_bitmap bitmap = { 0 };
+       uint32_t mask = VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP |
+                       VFIO_DMA_UNMAP_FLAG_VADDR |
+                       VFIO_DMA_UNMAP_FLAG_ALL;
        unsigned long minsz;
        int ret;
 
@@ -2714,8 +2955,12 @@ static int vfio_iommu_type1_unmap_dma(struct vfio_iommu *iommu,
        if (copy_from_user(&unmap, (void __user *)arg, minsz))
                return -EFAULT;
 
-       if (unmap.argsz < minsz ||
-           unmap.flags & ~VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP)
+       if (unmap.argsz < minsz || unmap.flags & ~mask)
+               return -EINVAL;
+
+       if ((unmap.flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
+           (unmap.flags & (VFIO_DMA_UNMAP_FLAG_ALL |
+                           VFIO_DMA_UNMAP_FLAG_VADDR)))
                return -EINVAL;
 
        if (unmap.flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
@@ -2906,12 +3151,13 @@ static int vfio_iommu_type1_dma_rw_chunk(struct vfio_iommu *iommu,
        struct vfio_dma *dma;
        bool kthread = current->mm == NULL;
        size_t offset;
+       int ret;
 
        *copied = 0;
 
-       dma = vfio_find_dma(iommu, user_iova, 1);
-       if (!dma)
-               return -EINVAL;
+       ret = vfio_find_dma_valid(iommu, user_iova, 1, &dma);
+       if (ret < 0)
+               return ret;
 
        if ((write && !(dma->prot & IOMMU_WRITE)) ||
                        !(dma->prot & IOMMU_READ))
@@ -3003,6 +3249,19 @@ vfio_iommu_type1_group_iommu_domain(void *iommu_data,
        return domain;
 }
 
+static void vfio_iommu_type1_notify(void *iommu_data,
+                                   enum vfio_iommu_notify_type event)
+{
+       struct vfio_iommu *iommu = iommu_data;
+
+       if (event != VFIO_IOMMU_CONTAINER_CLOSE)
+               return;
+       mutex_lock(&iommu->lock);
+       iommu->container_open = false;
+       mutex_unlock(&iommu->lock);
+       wake_up_all(&iommu->vaddr_wait);
+}
+
 static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
        .name                   = "vfio-iommu-type1",
        .owner                  = THIS_MODULE,
@@ -3017,6 +3276,7 @@ static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
        .unregister_notifier    = vfio_iommu_type1_unregister_notifier,
        .dma_rw                 = vfio_iommu_type1_dma_rw,
        .group_iommu_domain     = vfio_iommu_type1_group_iommu_domain,
+       .notify                 = vfio_iommu_type1_notify,
 };
 
 static int __init vfio_iommu_type1_init(void)