Merge tag 'drm-intel-fixes-2020-04-23' of git://anongit.freedesktop.org/drm/drm-intel...
[linux-2.6-microblaze.git] / mm / hmm.c
index 72e5a6d..2805858 100644 (file)
--- a/mm/hmm.c
+++ b/mm/hmm.c
 
 struct hmm_vma_walk {
        struct hmm_range        *range;
-       struct dev_pagemap      *pgmap;
        unsigned long           last;
-       unsigned int            flags;
 };
 
-static int hmm_vma_do_fault(struct mm_walk *walk, unsigned long addr,
-                           bool write_fault, uint64_t *pfn)
-{
-       unsigned int flags = FAULT_FLAG_REMOTE;
-       struct hmm_vma_walk *hmm_vma_walk = walk->private;
-       struct hmm_range *range = hmm_vma_walk->range;
-       struct vm_area_struct *vma = walk->vma;
-       vm_fault_t ret;
-
-       if (!vma)
-               goto err;
-
-       if (hmm_vma_walk->flags & HMM_FAULT_ALLOW_RETRY)
-               flags |= FAULT_FLAG_ALLOW_RETRY;
-       if (write_fault)
-               flags |= FAULT_FLAG_WRITE;
-
-       ret = handle_mm_fault(vma, addr, flags);
-       if (ret & VM_FAULT_RETRY) {
-               /* Note, handle_mm_fault did up_read(&mm->mmap_sem)) */
-               return -EAGAIN;
-       }
-       if (ret & VM_FAULT_ERROR)
-               goto err;
-
-       return -EBUSY;
+enum {
+       HMM_NEED_FAULT = 1 << 0,
+       HMM_NEED_WRITE_FAULT = 1 << 1,
+       HMM_NEED_ALL_BITS = HMM_NEED_FAULT | HMM_NEED_WRITE_FAULT,
+};
 
-err:
-       *pfn = range->values[HMM_PFN_ERROR];
-       return -EFAULT;
+/*
+ * hmm_device_entry_from_pfn() - create a valid device entry value from pfn
+ * @range: range use to encode HMM pfn value
+ * @pfn: pfn value for which to create the device entry
+ * Return: valid device entry for the pfn
+ */
+static uint64_t hmm_device_entry_from_pfn(const struct hmm_range *range,
+                                         unsigned long pfn)
+{
+       return (pfn << range->pfn_shift) | range->flags[HMM_PFN_VALID];
 }
 
 static int hmm_pfns_fill(unsigned long addr, unsigned long end,
@@ -79,56 +63,43 @@ static int hmm_pfns_fill(unsigned long addr, unsigned long end,
 }
 
 /*
- * hmm_vma_walk_hole_() - handle a range lacking valid pmd or pte(s)
+ * hmm_vma_fault() - fault in a range lacking valid pmd or pte(s)
  * @addr: range virtual start address (inclusive)
  * @end: range virtual end address (exclusive)
- * @fault: should we fault or not ?
- * @write_fault: write fault ?
+ * @required_fault: HMM_NEED_* flags
  * @walk: mm_walk structure
- * Return: 0 on success, -EBUSY after page fault, or page fault error
+ * Return: -EBUSY after page fault, or page fault error
  *
  * This function will be called whenever pmd_none() or pte_none() returns true,
  * or whenever there is no page directory covering the virtual address range.
  */
-static int hmm_vma_walk_hole_(unsigned long addr, unsigned long end,
-                             bool fault, bool write_fault,
-                             struct mm_walk *walk)
+static int hmm_vma_fault(unsigned long addr, unsigned long end,
+                        unsigned int required_fault, struct mm_walk *walk)
 {
        struct hmm_vma_walk *hmm_vma_walk = walk->private;
-       struct hmm_range *range = hmm_vma_walk->range;
-       uint64_t *pfns = range->pfns;
-       unsigned long i;
+       struct vm_area_struct *vma = walk->vma;
+       unsigned int fault_flags = FAULT_FLAG_REMOTE;
 
+       WARN_ON_ONCE(!required_fault);
        hmm_vma_walk->last = addr;
-       i = (addr - range->start) >> PAGE_SHIFT;
-
-       if (write_fault && walk->vma && !(walk->vma->vm_flags & VM_WRITE))
-               return -EPERM;
-
-       for (; addr < end; addr += PAGE_SIZE, i++) {
-               pfns[i] = range->values[HMM_PFN_NONE];
-               if (fault || write_fault) {
-                       int ret;
 
-                       ret = hmm_vma_do_fault(walk, addr, write_fault,
-                                              &pfns[i]);
-                       if (ret != -EBUSY)
-                               return ret;
-               }
+       if (required_fault & HMM_NEED_WRITE_FAULT) {
+               if (!(vma->vm_flags & VM_WRITE))
+                       return -EPERM;
+               fault_flags |= FAULT_FLAG_WRITE;
        }
 
-       return (fault || write_fault) ? -EBUSY : 0;
+       for (; addr < end; addr += PAGE_SIZE)
+               if (handle_mm_fault(vma, addr, fault_flags) & VM_FAULT_ERROR)
+                       return -EFAULT;
+       return -EBUSY;
 }
 
-static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
-                                     uint64_t pfns, uint64_t cpu_flags,
-                                     bool *fault, bool *write_fault)
+static unsigned int hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
+                                      uint64_t pfns, uint64_t cpu_flags)
 {
        struct hmm_range *range = hmm_vma_walk->range;
 
-       if (hmm_vma_walk->flags & HMM_FAULT_SNAPSHOT)
-               return;
-
        /*
         * So we not only consider the individual per page request we also
         * consider the default flags requested for the range. The API can
@@ -143,46 +114,44 @@ static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
 
        /* We aren't ask to do anything ... */
        if (!(pfns & range->flags[HMM_PFN_VALID]))
-               return;
-       /* If this is device memory then only fault if explicitly requested */
-       if ((cpu_flags & range->flags[HMM_PFN_DEVICE_PRIVATE])) {
-               /* Do we fault on device memory ? */
-               if (pfns & range->flags[HMM_PFN_DEVICE_PRIVATE]) {
-                       *write_fault = pfns & range->flags[HMM_PFN_WRITE];
-                       *fault = true;
-               }
-               return;
-       }
+               return 0;
 
-       /* If CPU page table is not valid then we need to fault */
-       *fault = !(cpu_flags & range->flags[HMM_PFN_VALID]);
        /* Need to write fault ? */
        if ((pfns & range->flags[HMM_PFN_WRITE]) &&
-           !(cpu_flags & range->flags[HMM_PFN_WRITE])) {
-               *write_fault = true;
-               *fault = true;
-       }
+           !(cpu_flags & range->flags[HMM_PFN_WRITE]))
+               return HMM_NEED_FAULT | HMM_NEED_WRITE_FAULT;
+
+       /* If CPU page table is not valid then we need to fault */
+       if (!(cpu_flags & range->flags[HMM_PFN_VALID]))
+               return HMM_NEED_FAULT;
+       return 0;
 }
 
-static void hmm_range_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
-                                const uint64_t *pfns, unsigned long npages,
-                                uint64_t cpu_flags, bool *fault,
-                                bool *write_fault)
+static unsigned int
+hmm_range_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
+                    const uint64_t *pfns, unsigned long npages,
+                    uint64_t cpu_flags)
 {
+       struct hmm_range *range = hmm_vma_walk->range;
+       unsigned int required_fault = 0;
        unsigned long i;
 
-       if (hmm_vma_walk->flags & HMM_FAULT_SNAPSHOT) {
-               *fault = *write_fault = false;
-               return;
-       }
+       /*
+        * If the default flags do not request to fault pages, and the mask does
+        * not allow for individual pages to be faulted, then
+        * hmm_pte_need_fault() will always return 0.
+        */
+       if (!((range->default_flags | range->pfn_flags_mask) &
+             range->flags[HMM_PFN_VALID]))
+               return 0;
 
-       *fault = *write_fault = false;
        for (i = 0; i < npages; ++i) {
-               hmm_pte_need_fault(hmm_vma_walk, pfns[i], cpu_flags,
-                                  fault, write_fault);
-               if ((*write_fault))
-                       return;
+               required_fault |=
+                       hmm_pte_need_fault(hmm_vma_walk, pfns[i], cpu_flags);
+               if (required_fault == HMM_NEED_ALL_BITS)
+                       return required_fault;
        }
+       return required_fault;
 }
 
 static int hmm_vma_walk_hole(unsigned long addr, unsigned long end,
@@ -190,16 +159,23 @@ static int hmm_vma_walk_hole(unsigned long addr, unsigned long end,
 {
        struct hmm_vma_walk *hmm_vma_walk = walk->private;
        struct hmm_range *range = hmm_vma_walk->range;
-       bool fault, write_fault;
+       unsigned int required_fault;
        unsigned long i, npages;
        uint64_t *pfns;
 
        i = (addr - range->start) >> PAGE_SHIFT;
        npages = (end - addr) >> PAGE_SHIFT;
        pfns = &range->pfns[i];
-       hmm_range_need_fault(hmm_vma_walk, pfns, npages,
-                            0, &fault, &write_fault);
-       return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
+       required_fault = hmm_range_need_fault(hmm_vma_walk, pfns, npages, 0);
+       if (!walk->vma) {
+               if (required_fault)
+                       return -EFAULT;
+               return hmm_pfns_fill(addr, end, range, HMM_PFN_ERROR);
+       }
+       if (required_fault)
+               return hmm_vma_fault(addr, end, required_fault, walk);
+       hmm_vma_walk->last = addr;
+       return hmm_pfns_fill(addr, end, range, HMM_PFN_NONE);
 }
 
 static inline uint64_t pmd_to_hmm_pfn_flags(struct hmm_range *range, pmd_t pmd)
@@ -218,31 +194,19 @@ static int hmm_vma_handle_pmd(struct mm_walk *walk, unsigned long addr,
        struct hmm_vma_walk *hmm_vma_walk = walk->private;
        struct hmm_range *range = hmm_vma_walk->range;
        unsigned long pfn, npages, i;
-       bool fault, write_fault;
+       unsigned int required_fault;
        uint64_t cpu_flags;
 
        npages = (end - addr) >> PAGE_SHIFT;
        cpu_flags = pmd_to_hmm_pfn_flags(range, pmd);
-       hmm_range_need_fault(hmm_vma_walk, pfns, npages, cpu_flags,
-                            &fault, &write_fault);
-
-       if (pmd_protnone(pmd) || fault || write_fault)
-               return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
+       required_fault =
+               hmm_range_need_fault(hmm_vma_walk, pfns, npages, cpu_flags);
+       if (required_fault)
+               return hmm_vma_fault(addr, end, required_fault, walk);
 
        pfn = pmd_pfn(pmd) + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
-       for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++) {
-               if (pmd_devmap(pmd)) {
-                       hmm_vma_walk->pgmap = get_dev_pagemap(pfn,
-                                             hmm_vma_walk->pgmap);
-                       if (unlikely(!hmm_vma_walk->pgmap))
-                               return -EBUSY;
-               }
+       for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++)
                pfns[i] = hmm_device_entry_from_pfn(range, pfn) | cpu_flags;
-       }
-       if (hmm_vma_walk->pgmap) {
-               put_dev_pagemap(hmm_vma_walk->pgmap);
-               hmm_vma_walk->pgmap = NULL;
-       }
        hmm_vma_walk->last = end;
        return 0;
 }
@@ -252,6 +216,14 @@ int hmm_vma_handle_pmd(struct mm_walk *walk, unsigned long addr,
                unsigned long end, uint64_t *pfns, pmd_t pmd);
 #endif /* CONFIG_TRANSPARENT_HUGEPAGE */
 
+static inline bool hmm_is_device_private_entry(struct hmm_range *range,
+               swp_entry_t entry)
+{
+       return is_device_private_entry(entry) &&
+               device_private_entry_to_page(entry)->pgmap->owner ==
+               range->dev_private_owner;
+}
+
 static inline uint64_t pte_to_hmm_pfn_flags(struct hmm_range *range, pte_t pte)
 {
        if (pte_none(pte) || !pte_present(pte) || pte_protnone(pte))
@@ -267,102 +239,81 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
 {
        struct hmm_vma_walk *hmm_vma_walk = walk->private;
        struct hmm_range *range = hmm_vma_walk->range;
-       bool fault, write_fault;
+       unsigned int required_fault;
        uint64_t cpu_flags;
        pte_t pte = *ptep;
        uint64_t orig_pfn = *pfn;
 
-       *pfn = range->values[HMM_PFN_NONE];
-       fault = write_fault = false;
-
        if (pte_none(pte)) {
-               hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0,
-                                  &fault, &write_fault);
-               if (fault || write_fault)
+               required_fault = hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0);
+               if (required_fault)
                        goto fault;
+               *pfn = range->values[HMM_PFN_NONE];
                return 0;
        }
 
        if (!pte_present(pte)) {
                swp_entry_t entry = pte_to_swp_entry(pte);
 
-               if (!non_swap_entry(entry)) {
-                       cpu_flags = pte_to_hmm_pfn_flags(range, pte);
-                       hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
-                                          &fault, &write_fault);
-                       if (fault || write_fault)
-                               goto fault;
-                       return 0;
-               }
-
                /*
-                * This is a special swap entry, ignore migration, use
-                * device and report anything else as error.
+                * Never fault in device private pages pages, but just report
+                * the PFN even if not present.
                 */
-               if (is_device_private_entry(entry)) {
-                       cpu_flags = range->flags[HMM_PFN_VALID] |
-                               range->flags[HMM_PFN_DEVICE_PRIVATE];
-                       cpu_flags |= is_write_device_private_entry(entry) ?
-                               range->flags[HMM_PFN_WRITE] : 0;
-                       hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
-                                          &fault, &write_fault);
-                       if (fault || write_fault)
-                               goto fault;
+               if (hmm_is_device_private_entry(range, entry)) {
                        *pfn = hmm_device_entry_from_pfn(range,
-                                           swp_offset(entry));
-                       *pfn |= cpu_flags;
+                               device_private_entry_to_pfn(entry));
+                       *pfn |= range->flags[HMM_PFN_VALID];
+                       if (is_write_device_private_entry(entry))
+                               *pfn |= range->flags[HMM_PFN_WRITE];
                        return 0;
                }
 
-               if (is_migration_entry(entry)) {
-                       if (fault || write_fault) {
-                               pte_unmap(ptep);
-                               hmm_vma_walk->last = addr;
-                               migration_entry_wait(walk->mm, pmdp, addr);
-                               return -EBUSY;
-                       }
+               required_fault = hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0);
+               if (!required_fault) {
+                       *pfn = range->values[HMM_PFN_NONE];
                        return 0;
                }
 
+               if (!non_swap_entry(entry))
+                       goto fault;
+
+               if (is_migration_entry(entry)) {
+                       pte_unmap(ptep);
+                       hmm_vma_walk->last = addr;
+                       migration_entry_wait(walk->mm, pmdp, addr);
+                       return -EBUSY;
+               }
+
                /* Report error for everything else */
-               *pfn = range->values[HMM_PFN_ERROR];
+               pte_unmap(ptep);
                return -EFAULT;
-       } else {
-               cpu_flags = pte_to_hmm_pfn_flags(range, pte);
-               hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
-                                  &fault, &write_fault);
        }
 
-       if (fault || write_fault)
+       cpu_flags = pte_to_hmm_pfn_flags(range, pte);
+       required_fault = hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags);
+       if (required_fault)
                goto fault;
 
-       if (pte_devmap(pte)) {
-               hmm_vma_walk->pgmap = get_dev_pagemap(pte_pfn(pte),
-                                             hmm_vma_walk->pgmap);
-               if (unlikely(!hmm_vma_walk->pgmap))
-                       return -EBUSY;
-       } else if (IS_ENABLED(CONFIG_ARCH_HAS_PTE_SPECIAL) && pte_special(pte)) {
-               if (!is_zero_pfn(pte_pfn(pte))) {
-                       *pfn = range->values[HMM_PFN_SPECIAL];
+       /*
+        * Since each architecture defines a struct page for the zero page, just
+        * fall through and treat it like a normal page.
+        */
+       if (pte_special(pte) && !is_zero_pfn(pte_pfn(pte))) {
+               if (hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0)) {
+                       pte_unmap(ptep);
                        return -EFAULT;
                }
-               /*
-                * Since each architecture defines a struct page for the zero
-                * page, just fall through and treat it like a normal page.
-                */
+               *pfn = range->values[HMM_PFN_SPECIAL];
+               return 0;
        }
 
        *pfn = hmm_device_entry_from_pfn(range, pte_pfn(pte)) | cpu_flags;
        return 0;
 
 fault:
-       if (hmm_vma_walk->pgmap) {
-               put_dev_pagemap(hmm_vma_walk->pgmap);
-               hmm_vma_walk->pgmap = NULL;
-       }
        pte_unmap(ptep);
        /* Fault any virtual address we were asked to fault */
-       return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
+       return hmm_vma_fault(addr, end, required_fault, walk);
 }
 
 static int hmm_vma_walk_pmd(pmd_t *pmdp,
@@ -372,8 +323,9 @@ static int hmm_vma_walk_pmd(pmd_t *pmdp,
 {
        struct hmm_vma_walk *hmm_vma_walk = walk->private;
        struct hmm_range *range = hmm_vma_walk->range;
-       uint64_t *pfns = range->pfns;
-       unsigned long addr = start, i;
+       uint64_t *pfns = &range->pfns[(start - range->start) >> PAGE_SHIFT];
+       unsigned long npages = (end - start) >> PAGE_SHIFT;
+       unsigned long addr = start;
        pte_t *ptep;
        pmd_t pmd;
 
@@ -383,24 +335,19 @@ again:
                return hmm_vma_walk_hole(start, end, -1, walk);
 
        if (thp_migration_supported() && is_pmd_migration_entry(pmd)) {
-               bool fault, write_fault;
-               unsigned long npages;
-               uint64_t *pfns;
-
-               i = (addr - range->start) >> PAGE_SHIFT;
-               npages = (end - addr) >> PAGE_SHIFT;
-               pfns = &range->pfns[i];
-
-               hmm_range_need_fault(hmm_vma_walk, pfns, npages,
-                                    0, &fault, &write_fault);
-               if (fault || write_fault) {
+               if (hmm_range_need_fault(hmm_vma_walk, pfns, npages, 0)) {
                        hmm_vma_walk->last = addr;
                        pmd_migration_entry_wait(walk->mm, pmdp);
                        return -EBUSY;
                }
-               return 0;
-       } else if (!pmd_present(pmd))
+               return hmm_pfns_fill(start, end, range, HMM_PFN_NONE);
+       }
+
+       if (!pmd_present(pmd)) {
+               if (hmm_range_need_fault(hmm_vma_walk, pfns, npages, 0))
+                       return -EFAULT;
                return hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
+       }
 
        if (pmd_devmap(pmd) || pmd_trans_huge(pmd)) {
                /*
@@ -417,8 +364,7 @@ again:
                if (!pmd_devmap(pmd) && !pmd_trans_huge(pmd))
                        goto again;
 
-               i = (addr - range->start) >> PAGE_SHIFT;
-               return hmm_vma_handle_pmd(walk, addr, end, &pfns[i], pmd);
+               return hmm_vma_handle_pmd(walk, addr, end, pfns, pmd);
        }
 
        /*
@@ -427,31 +373,23 @@ again:
         * entry pointing to pte directory or it is a bad pmd that will not
         * recover.
         */
-       if (pmd_bad(pmd))
+       if (pmd_bad(pmd)) {
+               if (hmm_range_need_fault(hmm_vma_walk, pfns, npages, 0))
+                       return -EFAULT;
                return hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
+       }
 
        ptep = pte_offset_map(pmdp, addr);
-       i = (addr - range->start) >> PAGE_SHIFT;
-       for (; addr < end; addr += PAGE_SIZE, ptep++, i++) {
+       for (; addr < end; addr += PAGE_SIZE, ptep++, pfns++) {
                int r;
 
-               r = hmm_vma_handle_pte(walk, addr, end, pmdp, ptep, &pfns[i]);
+               r = hmm_vma_handle_pte(walk, addr, end, pmdp, ptep, pfns);
                if (r) {
-                       /* hmm_vma_handle_pte() did unmap pte directory */
+                       /* hmm_vma_handle_pte() did pte_unmap() */
                        hmm_vma_walk->last = addr;
                        return r;
                }
        }
-       if (hmm_vma_walk->pgmap) {
-               /*
-                * We do put_dev_pagemap() here and not in hmm_vma_handle_pte()
-                * so that we can leverage get_dev_pagemap() optimization which
-                * will not re-take a reference on a pgmap if we already have
-                * one.
-                */
-               put_dev_pagemap(hmm_vma_walk->pgmap);
-               hmm_vma_walk->pgmap = NULL;
-       }
        pte_unmap(ptep - 1);
 
        hmm_vma_walk->last = addr;
@@ -487,18 +425,18 @@ static int hmm_vma_walk_pud(pud_t *pudp, unsigned long start, unsigned long end,
 
        pud = READ_ONCE(*pudp);
        if (pud_none(pud)) {
-               ret = hmm_vma_walk_hole(start, end, -1, walk);
-               goto out_unlock;
+               spin_unlock(ptl);
+               return hmm_vma_walk_hole(start, end, -1, walk);
        }
 
        if (pud_huge(pud) && pud_devmap(pud)) {
                unsigned long i, npages, pfn;
+               unsigned int required_fault;
                uint64_t *pfns, cpu_flags;
-               bool fault, write_fault;
 
                if (!pud_present(pud)) {
-                       ret = hmm_vma_walk_hole(start, end, -1, walk);
-                       goto out_unlock;
+                       spin_unlock(ptl);
+                       return hmm_vma_walk_hole(start, end, -1, walk);
                }
 
                i = (addr - range->start) >> PAGE_SHIFT;
@@ -506,29 +444,17 @@ static int hmm_vma_walk_pud(pud_t *pudp, unsigned long start, unsigned long end,
                pfns = &range->pfns[i];
 
                cpu_flags = pud_to_hmm_pfn_flags(range, pud);
-               hmm_range_need_fault(hmm_vma_walk, pfns, npages,
-                                    cpu_flags, &fault, &write_fault);
-               if (fault || write_fault) {
-                       ret = hmm_vma_walk_hole_(addr, end, fault,
-                                                write_fault, walk);
-                       goto out_unlock;
+               required_fault = hmm_range_need_fault(hmm_vma_walk, pfns,
+                                                     npages, cpu_flags);
+               if (required_fault) {
+                       spin_unlock(ptl);
+                       return hmm_vma_fault(addr, end, required_fault, walk);
                }
 
                pfn = pud_pfn(pud) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
-               for (i = 0; i < npages; ++i, ++pfn) {
-                       hmm_vma_walk->pgmap = get_dev_pagemap(pfn,
-                                             hmm_vma_walk->pgmap);
-                       if (unlikely(!hmm_vma_walk->pgmap)) {
-                               ret = -EBUSY;
-                               goto out_unlock;
-                       }
+               for (i = 0; i < npages; ++i, ++pfn)
                        pfns[i] = hmm_device_entry_from_pfn(range, pfn) |
                                  cpu_flags;
-               }
-               if (hmm_vma_walk->pgmap) {
-                       put_dev_pagemap(hmm_vma_walk->pgmap);
-                       hmm_vma_walk->pgmap = NULL;
-               }
                hmm_vma_walk->last = end;
                goto out_unlock;
        }
@@ -554,24 +480,20 @@ static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask,
        struct hmm_range *range = hmm_vma_walk->range;
        struct vm_area_struct *vma = walk->vma;
        uint64_t orig_pfn, cpu_flags;
-       bool fault, write_fault;
+       unsigned int required_fault;
        spinlock_t *ptl;
        pte_t entry;
-       int ret = 0;
 
        ptl = huge_pte_lock(hstate_vma(vma), walk->mm, pte);
        entry = huge_ptep_get(pte);
 
        i = (start - range->start) >> PAGE_SHIFT;
        orig_pfn = range->pfns[i];
-       range->pfns[i] = range->values[HMM_PFN_NONE];
        cpu_flags = pte_to_hmm_pfn_flags(range, entry);
-       fault = write_fault = false;
-       hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
-                          &fault, &write_fault);
-       if (fault || write_fault) {
-               ret = -ENOENT;
-               goto unlock;
+       required_fault = hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags);
+       if (required_fault) {
+               spin_unlock(ptl);
+               return hmm_vma_fault(addr, end, required_fault, walk);
        }
 
        pfn = pte_pfn(entry) + ((start & ~hmask) >> PAGE_SHIFT);
@@ -579,14 +501,8 @@ static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask,
                range->pfns[i] = hmm_device_entry_from_pfn(range, pfn) |
                                 cpu_flags;
        hmm_vma_walk->last = end;
-
-unlock:
        spin_unlock(ptl);
-
-       if (ret == -ENOENT)
-               return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
-
-       return ret;
+       return 0;
 }
 #else
 #define hmm_vma_walk_hugetlb_entry NULL
@@ -599,40 +515,32 @@ static int hmm_vma_walk_test(unsigned long start, unsigned long end,
        struct hmm_range *range = hmm_vma_walk->range;
        struct vm_area_struct *vma = walk->vma;
 
-       /*
-        * Skip vma ranges that don't have struct page backing them or
-        * map I/O devices directly.
-        */
-       if (vma->vm_flags & (VM_IO | VM_PFNMAP | VM_MIXEDMAP))
-               return -EFAULT;
+       if (!(vma->vm_flags & (VM_IO | VM_PFNMAP | VM_MIXEDMAP)) &&
+           vma->vm_flags & VM_READ)
+               return 0;
 
        /*
+        * vma ranges that don't have struct page backing them or map I/O
+        * devices directly cannot be handled by hmm_range_fault().
+        *
         * If the vma does not allow read access, then assume that it does not
-        * allow write access either. HMM does not support architectures
-        * that allow write without read.
+        * allow write access either. HMM does not support architectures that
+        * allow write without read.
+        *
+        * If a fault is requested for an unsupported range then it is a hard
+        * failure.
         */
-       if (!(vma->vm_flags & VM_READ)) {
-               bool fault, write_fault;
-
-               /*
-                * Check to see if a fault is requested for any page in the
-                * range.
-                */
-               hmm_range_need_fault(hmm_vma_walk, range->pfns +
-                                       ((start - range->start) >> PAGE_SHIFT),
-                                       (end - start) >> PAGE_SHIFT,
-                                       0, &fault, &write_fault);
-               if (fault || write_fault)
-                       return -EFAULT;
-
-               hmm_pfns_fill(start, end, range, HMM_PFN_NONE);
-               hmm_vma_walk->last = end;
+       if (hmm_range_need_fault(hmm_vma_walk,
+                                range->pfns +
+                                        ((start - range->start) >> PAGE_SHIFT),
+                                (end - start) >> PAGE_SHIFT, 0))
+               return -EFAULT;
 
-               /* Skip this vma and continue processing the next vma. */
-               return 1;
-       }
+       hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
+       hmm_vma_walk->last = end;
 
-       return 0;
+       /* Skip this vma and continue processing the next vma. */
+       return 1;
 }
 
 static const struct mm_walk_ops hmm_walk_ops = {
@@ -645,8 +553,7 @@ static const struct mm_walk_ops hmm_walk_ops = {
 
 /**
  * hmm_range_fault - try to fault some address in a virtual address range
- * @range:     range being faulted
- * @flags:     HMM_FAULT_* flags
+ * @range:     argument structure
  *
  * Return: the number of valid pages in range->pfns[] (from range start
  * address), which may be zero.  On error one of the following status codes
@@ -657,26 +564,19 @@ static const struct mm_walk_ops hmm_walk_ops = {
  * -ENOMEM:    Out of memory.
  * -EPERM:     Invalid permission (e.g., asking for write and range is read
  *             only).
- * -EAGAIN:    A page fault needs to be retried and mmap_sem was dropped.
  * -EBUSY:     The range has been invalidated and the caller needs to wait for
  *             the invalidation to finish.
- * -EFAULT:    Invalid (i.e., either no valid vma or it is illegal to access
- *             that range) number of valid pages in range->pfns[] (from
- *              range start address).
- *
- * This is similar to a regular CPU page fault except that it will not trigger
- * any memory migration if the memory being faulted is not accessible by CPUs
- * and caller does not ask for migration.
+ * -EFAULT:     A page was requested to be valid and could not be made valid
+ *              ie it has no backing VMA or it is illegal to access
  *
- * On error, for one virtual address in the range, the function will mark the
- * corresponding HMM pfn entry with an error flag.
+ * This is similar to get_user_pages(), except that it can read the page tables
+ * without mutating them (ie causing faults).
  */
-long hmm_range_fault(struct hmm_range *range, unsigned int flags)
+long hmm_range_fault(struct hmm_range *range)
 {
        struct hmm_vma_walk hmm_vma_walk = {
                .range = range,
                .last = range->start,
-               .flags = flags,
        };
        struct mm_struct *mm = range->notifier->mm;
        int ret;