mm/hmm: don't handle the non-fault case in hmm_vma_walk_hole_()
authorChristoph Hellwig <hch@lst.de>
Mon, 16 Mar 2020 13:53:09 +0000 (14:53 +0100)
committerJason Gunthorpe <jgg@mellanox.com>
Thu, 26 Mar 2020 17:33:37 +0000 (14:33 -0300)
Setting a pfns entry to NONE before returning -EBUSY is a bug that will
cause corruption of the input flags on the next loop.

There is just a single caller using hmm_vma_walk_hole_() for the non-fault
case.  Use hmm_pfns_fill() to fill the whole pfn array with zeroes in the
only caller for the non-fault case and remove the non-fault path from
hmm_vma_walk_hole_(). This avoids setting NONE before returning -EBUSY.

Also rename the function to hmm_vma_fault() to better describe what it
does.

Fixes: 2aee09d8c116 ("mm/hmm: change hmm_vma_fault() to allow write fault on page basis")
Link: https://lore.kernel.org/r/20200316135310.899364-5-hch@lst.de
Signed-off-by: Christoph Hellwig <hch@lst.de>
Reviewed-by: Jason Gunthorpe <jgg@mellanox.com>
Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
mm/hmm.c

index d13dedf..b15bf40 100644 (file)
--- a/mm/hmm.c
+++ b/mm/hmm.c
@@ -73,45 +73,41 @@ static int hmm_pfns_fill(unsigned long addr, unsigned long end,
 }
 
 /*
- * hmm_vma_walk_hole_() - handle a range lacking valid pmd or pte(s)
+ * hmm_vma_fault() - fault in a range lacking valid pmd or pte(s)
  * @addr: range virtual start address (inclusive)
  * @end: range virtual end address (exclusive)
  * @fault: should we fault or not ?
  * @write_fault: write fault ?
  * @walk: mm_walk structure
- * Return: 0 on success, -EBUSY after page fault, or page fault error
+ * Return: -EBUSY after page fault, or page fault error
  *
  * This function will be called whenever pmd_none() or pte_none() returns true,
  * or whenever there is no page directory covering the virtual address range.
  */
-static int hmm_vma_walk_hole_(unsigned long addr, unsigned long end,
+static int hmm_vma_fault(unsigned long addr, unsigned long end,
                              bool fault, bool write_fault,
                              struct mm_walk *walk)
 {
        struct hmm_vma_walk *hmm_vma_walk = walk->private;
        struct hmm_range *range = hmm_vma_walk->range;
        uint64_t *pfns = range->pfns;
-       unsigned long i;
+       unsigned long i = (addr - range->start) >> PAGE_SHIFT;
 
+       WARN_ON_ONCE(!fault && !write_fault);
        hmm_vma_walk->last = addr;
-       i = (addr - range->start) >> PAGE_SHIFT;
 
        if (write_fault && walk->vma && !(walk->vma->vm_flags & VM_WRITE))
                return -EPERM;
 
        for (; addr < end; addr += PAGE_SIZE, i++) {
-               pfns[i] = range->values[HMM_PFN_NONE];
-               if (fault || write_fault) {
-                       int ret;
+               int ret;
 
-                       ret = hmm_vma_do_fault(walk, addr, write_fault,
-                                              &pfns[i]);
-                       if (ret != -EBUSY)
-                               return ret;
-               }
+               ret = hmm_vma_do_fault(walk, addr, write_fault, &pfns[i]);
+               if (ret != -EBUSY)
+                       return ret;
        }
 
-       return (fault || write_fault) ? -EBUSY : 0;
+       return -EBUSY;
 }
 
 static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
@@ -193,7 +189,10 @@ static int hmm_vma_walk_hole(unsigned long addr, unsigned long end,
        pfns = &range->pfns[i];
        hmm_range_need_fault(hmm_vma_walk, pfns, npages,
                             0, &fault, &write_fault);
-       return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
+       if (fault || write_fault)
+               return hmm_vma_fault(addr, end, fault, write_fault, walk);
+       hmm_vma_walk->last = addr;
+       return hmm_pfns_fill(addr, end, range, HMM_PFN_NONE);
 }
 
 static inline uint64_t pmd_to_hmm_pfn_flags(struct hmm_range *range, pmd_t pmd)
@@ -221,7 +220,7 @@ static int hmm_vma_handle_pmd(struct mm_walk *walk, unsigned long addr,
                             &fault, &write_fault);
 
        if (fault || write_fault)
-               return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
+               return hmm_vma_fault(addr, end, fault, write_fault, walk);
 
        pfn = pmd_pfn(pmd) + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
        for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++) {
@@ -360,7 +359,7 @@ fault:
        }
        pte_unmap(ptep);
        /* Fault any virtual address we were asked to fault */
-       return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
+       return hmm_vma_fault(addr, end, fault, write_fault, walk);
 }
 
 static int hmm_vma_walk_pmd(pmd_t *pmdp,
@@ -512,7 +511,7 @@ static int hmm_vma_walk_pud(pud_t *pudp, unsigned long start, unsigned long end,
                                     cpu_flags, &fault, &write_fault);
                if (fault || write_fault) {
                        spin_unlock(ptl);
-                       return hmm_vma_walk_hole_(addr, end, fault, write_fault,
+                       return hmm_vma_fault(addr, end, fault, write_fault,
                                                  walk);
                }
 
@@ -572,7 +571,7 @@ static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask,
                           &fault, &write_fault);
        if (fault || write_fault) {
                spin_unlock(ptl);
-               return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
+               return hmm_vma_fault(addr, end, fault, write_fault, walk);
        }
 
        pfn = pte_pfn(entry) + ((start & ~hmask) >> PAGE_SHIFT);