Merge branch 'expand-stack'
[linux-2.6-microblaze.git] / mm / gup.c
index 0814576..ef29641 100644 (file)
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -18,6 +18,7 @@
 #include <linux/migrate.h>
 #include <linux/mm_inline.h>
 #include <linux/sched/mm.h>
+#include <linux/shmem_fs.h>
 
 #include <asm/mmu_context.h>
 #include <asm/tlbflush.h>
@@ -124,65 +125,65 @@ retry:
  */
 struct folio *try_grab_folio(struct page *page, int refs, unsigned int flags)
 {
+       struct folio *folio;
+
+       if (WARN_ON_ONCE((flags & (FOLL_GET | FOLL_PIN)) == 0))
+               return NULL;
+
        if (unlikely(!(flags & FOLL_PCI_P2PDMA) && is_pci_p2pdma_page(page)))
                return NULL;
 
        if (flags & FOLL_GET)
                return try_get_folio(page, refs);
-       else if (flags & FOLL_PIN) {
-               struct folio *folio;
 
-               /*
-                * Don't take a pin on the zero page - it's not going anywhere
-                * and it is used in a *lot* of places.
-                */
-               if (is_zero_page(page))
-                       return page_folio(page);
+       /* FOLL_PIN is set */
 
-               /*
-                * Can't do FOLL_LONGTERM + FOLL_PIN gup fast path if not in a
-                * right zone, so fail and let the caller fall back to the slow
-                * path.
-                */
-               if (unlikely((flags & FOLL_LONGTERM) &&
-                            !is_longterm_pinnable_page(page)))
-                       return NULL;
+       /*
+        * Don't take a pin on the zero page - it's not going anywhere
+        * and it is used in a *lot* of places.
+        */
+       if (is_zero_page(page))
+               return page_folio(page);
 
-               /*
-                * CAUTION: Don't use compound_head() on the page before this
-                * point, the result won't be stable.
-                */
-               folio = try_get_folio(page, refs);
-               if (!folio)
-                       return NULL;
+       folio = try_get_folio(page, refs);
+       if (!folio)
+               return NULL;
 
-               /*
-                * When pinning a large folio, use an exact count to track it.
-                *
-                * However, be sure to *also* increment the normal folio
-                * refcount field at least once, so that the folio really
-                * is pinned.  That's why the refcount from the earlier
-                * try_get_folio() is left intact.
-                */
-               if (folio_test_large(folio))
-                       atomic_add(refs, &folio->_pincount);
-               else
-                       folio_ref_add(folio,
-                                       refs * (GUP_PIN_COUNTING_BIAS - 1));
-               /*
-                * Adjust the pincount before re-checking the PTE for changes.
-                * This is essentially a smp_mb() and is paired with a memory
-                * barrier in page_try_share_anon_rmap().
-                */
-               smp_mb__after_atomic();
+       /*
+        * Can't do FOLL_LONGTERM + FOLL_PIN gup fast path if not in a
+        * right zone, so fail and let the caller fall back to the slow
+        * path.
+        */
+       if (unlikely((flags & FOLL_LONGTERM) &&
+                    !folio_is_longterm_pinnable(folio))) {
+               if (!put_devmap_managed_page_refs(&folio->page, refs))
+                       folio_put_refs(folio, refs);
+               return NULL;
+       }
 
-               node_stat_mod_folio(folio, NR_FOLL_PIN_ACQUIRED, refs);
+       /*
+        * When pinning a large folio, use an exact count to track it.
+        *
+        * However, be sure to *also* increment the normal folio
+        * refcount field at least once, so that the folio really
+        * is pinned.  That's why the refcount from the earlier
+        * try_get_folio() is left intact.
+        */
+       if (folio_test_large(folio))
+               atomic_add(refs, &folio->_pincount);
+       else
+               folio_ref_add(folio,
+                               refs * (GUP_PIN_COUNTING_BIAS - 1));
+       /*
+        * Adjust the pincount before re-checking the PTE for changes.
+        * This is essentially a smp_mb() and is paired with a memory
+        * barrier in page_try_share_anon_rmap().
+        */
+       smp_mb__after_atomic();
 
-               return folio;
-       }
+       node_stat_mod_folio(folio, NR_FOLL_PIN_ACQUIRED, refs);
 
-       WARN_ON_ONCE(1);
-       return NULL;
+       return folio;
 }
 
 static void gup_put_folio(struct folio *folio, int refs, unsigned int flags)
@@ -520,13 +521,14 @@ static int follow_pfn_pte(struct vm_area_struct *vma, unsigned long address,
                pte_t *pte, unsigned int flags)
 {
        if (flags & FOLL_TOUCH) {
-               pte_t entry = *pte;
+               pte_t orig_entry = ptep_get(pte);
+               pte_t entry = orig_entry;
 
                if (flags & FOLL_WRITE)
                        entry = pte_mkdirty(entry);
                entry = pte_mkyoung(entry);
 
-               if (!pte_same(*pte, entry)) {
+               if (!pte_same(orig_entry, entry)) {
                        set_pte_at(vma->vm_mm, address, pte, entry);
                        update_mmu_cache(vma, address, pte);
                }
@@ -588,11 +590,11 @@ static struct page *follow_page_pte(struct vm_area_struct *vma,
        if (WARN_ON_ONCE((flags & (FOLL_PIN | FOLL_GET)) ==
                         (FOLL_PIN | FOLL_GET)))
                return ERR_PTR(-EINVAL);
-       if (unlikely(pmd_bad(*pmd)))
-               return no_page_table(vma, flags);
 
        ptep = pte_offset_map_lock(mm, pmd, address, &ptl);
-       pte = *ptep;
+       if (!ptep)
+               return no_page_table(vma, flags);
+       pte = ptep_get(ptep);
        if (!pte_present(pte))
                goto no_page;
        if (pte_protnone(pte) && !gup_can_follow_protnone(flags))
@@ -697,11 +699,7 @@ static struct page *follow_pmd_mask(struct vm_area_struct *vma,
        struct mm_struct *mm = vma->vm_mm;
 
        pmd = pmd_offset(pudp, address);
-       /*
-        * The READ_ONCE() will stabilize the pmdval in a register or
-        * on the stack so that it will stop changing under the code.
-        */
-       pmdval = READ_ONCE(*pmd);
+       pmdval = pmdp_get_lockless(pmd);
        if (pmd_none(pmdval))
                return no_page_table(vma, flags);
        if (!pmd_present(pmdval))
@@ -729,21 +727,10 @@ static struct page *follow_pmd_mask(struct vm_area_struct *vma,
                return follow_page_pte(vma, address, pmd, flags, &ctx->pgmap);
        }
        if (flags & FOLL_SPLIT_PMD) {
-               int ret;
-               page = pmd_page(*pmd);
-               if (is_huge_zero_page(page)) {
-                       spin_unlock(ptl);
-                       ret = 0;
-                       split_huge_pmd(vma, pmd, address);
-                       if (pmd_trans_unstable(pmd))
-                               ret = -EBUSY;
-               } else {
-                       spin_unlock(ptl);
-                       split_huge_pmd(vma, pmd, address);
-                       ret = pte_alloc(mm, pmd) ? -ENOMEM : 0;
-               }
-
-               return ret ? ERR_PTR(ret) :
+               spin_unlock(ptl);
+               split_huge_pmd(vma, pmd, address);
+               /* If pmd was left empty, stuff a page table in there quickly */
+               return pte_alloc(mm, pmd) ? ERR_PTR(-ENOMEM) :
                        follow_page_pte(vma, address, pmd, flags, &ctx->pgmap);
        }
        page = follow_trans_huge_pmd(vma, address, pmd, flags);
@@ -879,6 +866,7 @@ static int get_gate_page(struct mm_struct *mm, unsigned long address,
        pud_t *pud;
        pmd_t *pmd;
        pte_t *pte;
+       pte_t entry;
        int ret = -EFAULT;
 
        /* user gate pages are read-only */
@@ -899,18 +887,20 @@ static int get_gate_page(struct mm_struct *mm, unsigned long address,
        pmd = pmd_offset(pud, address);
        if (!pmd_present(*pmd))
                return -EFAULT;
-       VM_BUG_ON(pmd_trans_huge(*pmd));
        pte = pte_offset_map(pmd, address);
-       if (pte_none(*pte))
+       if (!pte)
+               return -EFAULT;
+       entry = ptep_get(pte);
+       if (pte_none(entry))
                goto unmap;
        *vma = get_gate_vma(mm);
        if (!page)
                goto out;
-       *page = vm_normal_page(*vma, address, *pte);
+       *page = vm_normal_page(*vma, address, entry);
        if (!*page) {
-               if ((gup_flags & FOLL_DUMP) || !is_zero_pfn(pte_pfn(*pte)))
+               if ((gup_flags & FOLL_DUMP) || !is_zero_pfn(pte_pfn(entry)))
                        goto unmap;
-               *page = pte_page(*pte);
+               *page = pte_page(entry);
        }
        ret = try_grab_page(*page, gup_flags);
        if (unlikely(ret))
@@ -1003,16 +993,54 @@ static int faultin_page(struct vm_area_struct *vma,
        return 0;
 }
 
+/*
+ * Writing to file-backed mappings which require folio dirty tracking using GUP
+ * is a fundamentally broken operation, as kernel write access to GUP mappings
+ * do not adhere to the semantics expected by a file system.
+ *
+ * Consider the following scenario:-
+ *
+ * 1. A folio is written to via GUP which write-faults the memory, notifying
+ *    the file system and dirtying the folio.
+ * 2. Later, writeback is triggered, resulting in the folio being cleaned and
+ *    the PTE being marked read-only.
+ * 3. The GUP caller writes to the folio, as it is mapped read/write via the
+ *    direct mapping.
+ * 4. The GUP caller, now done with the page, unpins it and sets it dirty
+ *    (though it does not have to).
+ *
+ * This results in both data being written to a folio without writenotify, and
+ * the folio being dirtied unexpectedly (if the caller decides to do so).
+ */
+static bool writable_file_mapping_allowed(struct vm_area_struct *vma,
+                                         unsigned long gup_flags)
+{
+       /*
+        * If we aren't pinning then no problematic write can occur. A long term
+        * pin is the most egregious case so this is the case we disallow.
+        */
+       if ((gup_flags & (FOLL_PIN | FOLL_LONGTERM)) !=
+           (FOLL_PIN | FOLL_LONGTERM))
+               return true;
+
+       /*
+        * If the VMA does not require dirty tracking then no problematic write
+        * can occur either.
+        */
+       return !vma_needs_dirty_tracking(vma);
+}
+
 static int check_vma_flags(struct vm_area_struct *vma, unsigned long gup_flags)
 {
        vm_flags_t vm_flags = vma->vm_flags;
        int write = (gup_flags & FOLL_WRITE);
        int foreign = (gup_flags & FOLL_REMOTE);
+       bool vma_anon = vma_is_anonymous(vma);
 
        if (vm_flags & (VM_IO | VM_PFNMAP))
                return -EFAULT;
 
-       if (gup_flags & FOLL_ANON && !vma_is_anonymous(vma))
+       if ((gup_flags & FOLL_ANON) && !vma_anon)
                return -EFAULT;
 
        if ((gup_flags & FOLL_LONGTERM) && vma_is_fsdax(vma))
@@ -1022,6 +1050,10 @@ static int check_vma_flags(struct vm_area_struct *vma, unsigned long gup_flags)
                return -EFAULT;
 
        if (write) {
+               if (!vma_anon &&
+                   !writable_file_mapping_allowed(vma, gup_flags))
+                       return -EFAULT;
+
                if (!(vm_flags & VM_WRITE)) {
                        if (!(gup_flags & FOLL_FORCE))
                                return -EFAULT;
@@ -1068,8 +1100,6 @@ static int check_vma_flags(struct vm_area_struct *vma, unsigned long gup_flags)
  * @pages:     array that receives pointers to the pages pinned.
  *             Should be at least nr_pages long. Or NULL, if caller
  *             only intends to ensure the pages are faulted in.
- * @vmas:      array of pointers to vmas corresponding to each page.
- *             Or NULL if the caller does not require them.
  * @locked:     whether we're still with the mmap_lock held
  *
  * Returns either number of pages pinned (which may be less than the
@@ -1083,8 +1113,6 @@ static int check_vma_flags(struct vm_area_struct *vma, unsigned long gup_flags)
  *
  * The caller is responsible for releasing returned @pages, via put_page().
  *
- * @vmas are valid only as long as mmap_lock is held.
- *
  * Must be called with mmap_lock held.  It may be released.  See below.
  *
  * __get_user_pages walks a process's page tables and takes a reference to
@@ -1120,7 +1148,7 @@ static int check_vma_flags(struct vm_area_struct *vma, unsigned long gup_flags)
 static long __get_user_pages(struct mm_struct *mm,
                unsigned long start, unsigned long nr_pages,
                unsigned int gup_flags, struct page **pages,
-               struct vm_area_struct **vmas, int *locked)
+               int *locked)
 {
        long ret = 0, i = 0;
        struct vm_area_struct *vma = NULL;
@@ -1140,7 +1168,11 @@ static long __get_user_pages(struct mm_struct *mm,
 
                /* first iteration or cross vma bound */
                if (!vma || start >= vma->vm_end) {
-                       vma = find_extend_vma(mm, start);
+                       vma = find_vma(mm, start);
+                       if (vma && (start < vma->vm_start)) {
+                               WARN_ON_ONCE(vma->vm_flags & VM_GROWSDOWN);
+                               vma = NULL;
+                       }
                        if (!vma && in_gate_area(mm, start)) {
                                ret = get_gate_page(mm, start & PAGE_MASK,
                                                gup_flags, &vma,
@@ -1160,9 +1192,9 @@ static long __get_user_pages(struct mm_struct *mm,
                                goto out;
 
                        if (is_vm_hugetlb_page(vma)) {
-                               i = follow_hugetlb_page(mm, vma, pages, vmas,
-                                               &start, &nr_pages, i,
-                                               gup_flags, locked);
+                               i = follow_hugetlb_page(mm, vma, pages,
+                                                       &start, &nr_pages, i,
+                                                       gup_flags, locked);
                                if (!*locked) {
                                        /*
                                         * We've got a VM_FAULT_RETRY
@@ -1227,10 +1259,6 @@ retry:
                        ctx.page_mask = 0;
                }
 next_page:
-               if (vmas) {
-                       vmas[i] = vma;
-                       ctx.page_mask = 0;
-               }
                page_increm = 1 + (~(start >> PAGE_SHIFT) & ctx.page_mask);
                if (page_increm > nr_pages)
                        page_increm = nr_pages;
@@ -1309,9 +1337,13 @@ int fixup_user_fault(struct mm_struct *mm,
                fault_flags |= FAULT_FLAG_ALLOW_RETRY | FAULT_FLAG_KILLABLE;
 
 retry:
-       vma = find_extend_vma(mm, address);
-       if (!vma || address < vma->vm_start)
+       vma = find_vma(mm, address);
+       if (!vma)
+               return -EFAULT;
+       if (address < vma->vm_start ) {
+               WARN_ON_ONCE(vma->vm_flags & VM_GROWSDOWN);
                return -EFAULT;
+       }
 
        if (!vma_permits_fault(vma, fault_flags))
                return -EFAULT;
@@ -1385,7 +1417,6 @@ static __always_inline long __get_user_pages_locked(struct mm_struct *mm,
                                                unsigned long start,
                                                unsigned long nr_pages,
                                                struct page **pages,
-                                               struct vm_area_struct **vmas,
                                                int *locked,
                                                unsigned int flags)
 {
@@ -1423,7 +1454,7 @@ static __always_inline long __get_user_pages_locked(struct mm_struct *mm,
        pages_done = 0;
        for (;;) {
                ret = __get_user_pages(mm, start, nr_pages, flags, pages,
-                                      vmas, locked);
+                                      locked);
                if (!(flags & FOLL_UNLOCKABLE)) {
                        /* VM_FAULT_RETRY couldn't trigger, bypass */
                        pages_done = ret;
@@ -1487,7 +1518,7 @@ retry:
 
                *locked = 1;
                ret = __get_user_pages(mm, start, 1, flags | FOLL_TRIED,
-                                      pages, NULL, locked);
+                                      pages, locked);
                if (!*locked) {
                        /* Continue to retry until we succeeded */
                        BUG_ON(ret != 0);
@@ -1585,7 +1616,7 @@ long populate_vma_page_range(struct vm_area_struct *vma,
         * not result in a stack expansion that recurses back here.
         */
        ret = __get_user_pages(mm, start, nr_pages, gup_flags,
-                               NULL, NULL, locked ? locked : &local_locked);
+                              NULL, locked ? locked : &local_locked);
        lru_add_drain();
        return ret;
 }
@@ -1643,7 +1674,7 @@ long faultin_vma_page_range(struct vm_area_struct *vma, unsigned long start,
                return -EINVAL;
 
        ret = __get_user_pages(mm, start, nr_pages, gup_flags,
-                               NULL, NULL, locked);
+                              NULL, locked);
        lru_add_drain();
        return ret;
 }
@@ -1711,8 +1742,7 @@ int __mm_populate(unsigned long start, unsigned long len, int ignore_errors)
 #else /* CONFIG_MMU */
 static long __get_user_pages_locked(struct mm_struct *mm, unsigned long start,
                unsigned long nr_pages, struct page **pages,
-               struct vm_area_struct **vmas, int *locked,
-               unsigned int foll_flags)
+               int *locked, unsigned int foll_flags)
 {
        struct vm_area_struct *vma;
        bool must_unlock = false;
@@ -1756,8 +1786,7 @@ static long __get_user_pages_locked(struct mm_struct *mm, unsigned long start,
                        if (pages[i])
                                get_page(pages[i]);
                }
-               if (vmas)
-                       vmas[i] = vma;
+
                start = (start + PAGE_SIZE) & PAGE_MASK;
        }
 
@@ -1938,8 +1967,7 @@ struct page *get_dump_page(unsigned long addr)
        int locked = 0;
        int ret;
 
-       ret = __get_user_pages_locked(current->mm, addr, 1, &page, NULL,
-                                     &locked,
+       ret = __get_user_pages_locked(current->mm, addr, 1, &page, &locked,
                                      FOLL_FORCE | FOLL_DUMP | FOLL_GET);
        return (ret == 1) ? page : NULL;
 }
@@ -2112,7 +2140,6 @@ static long __gup_longterm_locked(struct mm_struct *mm,
                                  unsigned long start,
                                  unsigned long nr_pages,
                                  struct page **pages,
-                                 struct vm_area_struct **vmas,
                                  int *locked,
                                  unsigned int gup_flags)
 {
@@ -2120,13 +2147,13 @@ static long __gup_longterm_locked(struct mm_struct *mm,
        long rc, nr_pinned_pages;
 
        if (!(gup_flags & FOLL_LONGTERM))
-               return __get_user_pages_locked(mm, start, nr_pages, pages, vmas,
+               return __get_user_pages_locked(mm, start, nr_pages, pages,
                                               locked, gup_flags);
 
        flags = memalloc_pin_save();
        do {
                nr_pinned_pages = __get_user_pages_locked(mm, start, nr_pages,
-                                                         pages, vmas, locked,
+                                                         pages, locked,
                                                          gup_flags);
                if (nr_pinned_pages <= 0) {
                        rc = nr_pinned_pages;
@@ -2144,9 +2171,8 @@ static long __gup_longterm_locked(struct mm_struct *mm,
  * Check that the given flags are valid for the exported gup/pup interface, and
  * update them with the required flags that the caller must have set.
  */
-static bool is_valid_gup_args(struct page **pages, struct vm_area_struct **vmas,
-                             int *locked, unsigned int *gup_flags_p,
-                             unsigned int to_set)
+static bool is_valid_gup_args(struct page **pages, int *locked,
+                             unsigned int *gup_flags_p, unsigned int to_set)
 {
        unsigned int gup_flags = *gup_flags_p;
 
@@ -2188,13 +2214,6 @@ static bool is_valid_gup_args(struct page **pages, struct vm_area_struct **vmas,
                         (gup_flags & FOLL_PCI_P2PDMA)))
                return false;
 
-       /*
-        * Can't use VMAs with locked, as locked allows GUP to unlock
-        * which invalidates the vmas array
-        */
-       if (WARN_ON_ONCE(vmas && (gup_flags & FOLL_UNLOCKABLE)))
-               return false;
-
        *gup_flags_p = gup_flags;
        return true;
 }
@@ -2209,8 +2228,6 @@ static bool is_valid_gup_args(struct page **pages, struct vm_area_struct **vmas,
  * @pages:     array that receives pointers to the pages pinned.
  *             Should be at least nr_pages long. Or NULL, if caller
  *             only intends to ensure the pages are faulted in.
- * @vmas:      array of pointers to vmas corresponding to each page.
- *             Or NULL if the caller does not require them.
  * @locked:    pointer to lock flag indicating whether lock is held and
  *             subsequently whether VM_FAULT_RETRY functionality can be
  *             utilised. Lock must initially be held.
@@ -2225,8 +2242,6 @@ static bool is_valid_gup_args(struct page **pages, struct vm_area_struct **vmas,
  *
  * The caller is responsible for releasing returned @pages, via put_page().
  *
- * @vmas are valid only as long as mmap_lock is held.
- *
  * Must be called with mmap_lock held for read or write.
  *
  * get_user_pages_remote walks a process's page tables and takes a reference
@@ -2263,15 +2278,15 @@ static bool is_valid_gup_args(struct page **pages, struct vm_area_struct **vmas,
 long get_user_pages_remote(struct mm_struct *mm,
                unsigned long start, unsigned long nr_pages,
                unsigned int gup_flags, struct page **pages,
-               struct vm_area_struct **vmas, int *locked)
+               int *locked)
 {
        int local_locked = 1;
 
-       if (!is_valid_gup_args(pages, vmas, locked, &gup_flags,
+       if (!is_valid_gup_args(pages, locked, &gup_flags,
                               FOLL_TOUCH | FOLL_REMOTE))
                return -EINVAL;
 
-       return __get_user_pages_locked(mm, start, nr_pages, pages, vmas,
+       return __get_user_pages_locked(mm, start, nr_pages, pages,
                                       locked ? locked : &local_locked,
                                       gup_flags);
 }
@@ -2281,7 +2296,7 @@ EXPORT_SYMBOL(get_user_pages_remote);
 long get_user_pages_remote(struct mm_struct *mm,
                           unsigned long start, unsigned long nr_pages,
                           unsigned int gup_flags, struct page **pages,
-                          struct vm_area_struct **vmas, int *locked)
+                          int *locked)
 {
        return 0;
 }
@@ -2295,8 +2310,6 @@ long get_user_pages_remote(struct mm_struct *mm,
  * @pages:      array that receives pointers to the pages pinned.
  *              Should be at least nr_pages long. Or NULL, if caller
  *              only intends to ensure the pages are faulted in.
- * @vmas:       array of pointers to vmas corresponding to each page.
- *              Or NULL if the caller does not require them.
  *
  * This is the same as get_user_pages_remote(), just with a less-flexible
  * calling convention where we assume that the mm being operated on belongs to
@@ -2304,16 +2317,15 @@ long get_user_pages_remote(struct mm_struct *mm,
  * obviously don't pass FOLL_REMOTE in here.
  */
 long get_user_pages(unsigned long start, unsigned long nr_pages,
-               unsigned int gup_flags, struct page **pages,
-               struct vm_area_struct **vmas)
+                   unsigned int gup_flags, struct page **pages)
 {
        int locked = 1;
 
-       if (!is_valid_gup_args(pages, vmas, NULL, &gup_flags, FOLL_TOUCH))
+       if (!is_valid_gup_args(pages, NULL, &gup_flags, FOLL_TOUCH))
                return -EINVAL;
 
        return __get_user_pages_locked(current->mm, start, nr_pages, pages,
-                                      vmas, &locked, gup_flags);
+                                      &locked, gup_flags);
 }
 EXPORT_SYMBOL(get_user_pages);
 
@@ -2337,12 +2349,12 @@ long get_user_pages_unlocked(unsigned long start, unsigned long nr_pages,
 {
        int locked = 0;
 
-       if (!is_valid_gup_args(pages, NULL, NULL, &gup_flags,
+       if (!is_valid_gup_args(pages, NULL, &gup_flags,
                               FOLL_TOUCH | FOLL_UNLOCKABLE))
                return -EINVAL;
 
        return __get_user_pages_locked(current->mm, start, nr_pages, pages,
-                                      NULL, &locked, gup_flags);
+                                      &locked, gup_flags);
 }
 EXPORT_SYMBOL(get_user_pages_unlocked);
 
@@ -2381,6 +2393,82 @@ EXPORT_SYMBOL(get_user_pages_unlocked);
  */
 #ifdef CONFIG_HAVE_FAST_GUP
 
+/*
+ * Used in the GUP-fast path to determine whether a pin is permitted for a
+ * specific folio.
+ *
+ * This call assumes the caller has pinned the folio, that the lowest page table
+ * level still points to this folio, and that interrupts have been disabled.
+ *
+ * Writing to pinned file-backed dirty tracked folios is inherently problematic
+ * (see comment describing the writable_file_mapping_allowed() function). We
+ * therefore try to avoid the most egregious case of a long-term mapping doing
+ * so.
+ *
+ * This function cannot be as thorough as that one as the VMA is not available
+ * in the fast path, so instead we whitelist known good cases and if in doubt,
+ * fall back to the slow path.
+ */
+static bool folio_fast_pin_allowed(struct folio *folio, unsigned int flags)
+{
+       struct address_space *mapping;
+       unsigned long mapping_flags;
+
+       /*
+        * If we aren't pinning then no problematic write can occur. A long term
+        * pin is the most egregious case so this is the one we disallow.
+        */
+       if ((flags & (FOLL_PIN | FOLL_LONGTERM | FOLL_WRITE)) !=
+           (FOLL_PIN | FOLL_LONGTERM | FOLL_WRITE))
+               return true;
+
+       /* The folio is pinned, so we can safely access folio fields. */
+
+       if (WARN_ON_ONCE(folio_test_slab(folio)))
+               return false;
+
+       /* hugetlb mappings do not require dirty-tracking. */
+       if (folio_test_hugetlb(folio))
+               return true;
+
+       /*
+        * GUP-fast disables IRQs. When IRQS are disabled, RCU grace periods
+        * cannot proceed, which means no actions performed under RCU can
+        * proceed either.
+        *
+        * inodes and thus their mappings are freed under RCU, which means the
+        * mapping cannot be freed beneath us and thus we can safely dereference
+        * it.
+        */
+       lockdep_assert_irqs_disabled();
+
+       /*
+        * However, there may be operations which _alter_ the mapping, so ensure
+        * we read it once and only once.
+        */
+       mapping = READ_ONCE(folio->mapping);
+
+       /*
+        * The mapping may have been truncated, in any case we cannot determine
+        * if this mapping is safe - fall back to slow path to determine how to
+        * proceed.
+        */
+       if (!mapping)
+               return false;
+
+       /* Anonymous folios pose no problem. */
+       mapping_flags = (unsigned long)mapping & PAGE_MAPPING_FLAGS;
+       if (mapping_flags)
+               return mapping_flags & PAGE_MAPPING_ANON;
+
+       /*
+        * At this point, we know the mapping is non-null and points to an
+        * address_space object. The only remaining whitelisted file system is
+        * shmem.
+        */
+       return shmem_mapping(mapping);
+}
+
 static void __maybe_unused undo_dev_pagemap(int *nr, int nr_start,
                                            unsigned int flags,
                                            struct page **pages)
@@ -2425,6 +2513,8 @@ static int gup_pte_range(pmd_t pmd, pmd_t *pmdp, unsigned long addr,
        pte_t *ptep, *ptem;
 
        ptem = ptep = pte_offset_map(&pmd, addr);
+       if (!ptep)
+               return 0;
        do {
                pte_t pte = ptep_get_lockless(ptep);
                struct page *page;
@@ -2461,7 +2551,12 @@ static int gup_pte_range(pmd_t pmd, pmd_t *pmdp, unsigned long addr,
                }
 
                if (unlikely(pmd_val(pmd) != pmd_val(*pmdp)) ||
-                   unlikely(pte_val(pte) != pte_val(*ptep))) {
+                   unlikely(pte_val(pte) != pte_val(ptep_get(ptep)))) {
+                       gup_put_folio(folio, 1, flags);
+                       goto pte_unmap;
+               }
+
+               if (!folio_fast_pin_allowed(folio, flags)) {
                        gup_put_folio(folio, 1, flags);
                        goto pte_unmap;
                }
@@ -2653,7 +2748,12 @@ static int gup_hugepte(pte_t *ptep, unsigned long sz, unsigned long addr,
        if (!folio)
                return 0;
 
-       if (unlikely(pte_val(pte) != pte_val(*ptep))) {
+       if (unlikely(pte_val(pte) != pte_val(ptep_get(ptep)))) {
+               gup_put_folio(folio, refs, flags);
+               return 0;
+       }
+
+       if (!folio_fast_pin_allowed(folio, flags)) {
                gup_put_folio(folio, refs, flags);
                return 0;
        }
@@ -2724,6 +2824,10 @@ static int gup_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr,
                return 0;
        }
 
+       if (!folio_fast_pin_allowed(folio, flags)) {
+               gup_put_folio(folio, refs, flags);
+               return 0;
+       }
        if (!pmd_write(orig) && gup_must_unshare(NULL, flags, &folio->page)) {
                gup_put_folio(folio, refs, flags);
                return 0;
@@ -2764,6 +2868,11 @@ static int gup_huge_pud(pud_t orig, pud_t *pudp, unsigned long addr,
                return 0;
        }
 
+       if (!folio_fast_pin_allowed(folio, flags)) {
+               gup_put_folio(folio, refs, flags);
+               return 0;
+       }
+
        if (!pud_write(orig) && gup_must_unshare(NULL, flags, &folio->page)) {
                gup_put_folio(folio, refs, flags);
                return 0;
@@ -2799,6 +2908,16 @@ static int gup_huge_pgd(pgd_t orig, pgd_t *pgdp, unsigned long addr,
                return 0;
        }
 
+       if (!pgd_write(orig) && gup_must_unshare(NULL, flags, &folio->page)) {
+               gup_put_folio(folio, refs, flags);
+               return 0;
+       }
+
+       if (!folio_fast_pin_allowed(folio, flags)) {
+               gup_put_folio(folio, refs, flags);
+               return 0;
+       }
+
        *nr += refs;
        folio_set_referenced(folio);
        return 1;
@@ -3013,7 +3132,7 @@ static int internal_get_user_pages_fast(unsigned long start,
        start = untagged_addr(start) & PAGE_MASK;
        len = nr_pages << PAGE_SHIFT;
        if (check_add_overflow(start, len, &end))
-               return 0;
+               return -EOVERFLOW;
        if (end > TASK_SIZE_MAX)
                return -EFAULT;
        if (unlikely(!access_ok((void __user *)start, len)))
@@ -3027,7 +3146,7 @@ static int internal_get_user_pages_fast(unsigned long start,
        start += nr_pinned << PAGE_SHIFT;
        pages += nr_pinned;
        ret = __gup_longterm_locked(current->mm, start, nr_pages - nr_pinned,
-                                   pages, NULL, &locked,
+                                   pages, &locked,
                                    gup_flags | FOLL_TOUCH | FOLL_UNLOCKABLE);
        if (ret < 0) {
                /*
@@ -3069,7 +3188,7 @@ int get_user_pages_fast_only(unsigned long start, int nr_pages,
         * FOLL_FAST_ONLY is required in order to match the API description of
         * this routine: no fall back to regular ("slow") GUP.
         */
-       if (!is_valid_gup_args(pages, NULL, NULL, &gup_flags,
+       if (!is_valid_gup_args(pages, NULL, &gup_flags,
                               FOLL_GET | FOLL_FAST_ONLY))
                return -EINVAL;
 
@@ -3102,7 +3221,7 @@ int get_user_pages_fast(unsigned long start, int nr_pages,
         * FOLL_GET, because gup fast is always a "pin with a +1 page refcount"
         * request.
         */
-       if (!is_valid_gup_args(pages, NULL, NULL, &gup_flags, FOLL_GET))
+       if (!is_valid_gup_args(pages, NULL, &gup_flags, FOLL_GET))
                return -EINVAL;
        return internal_get_user_pages_fast(start, nr_pages, gup_flags, pages);
 }
@@ -3130,7 +3249,7 @@ EXPORT_SYMBOL_GPL(get_user_pages_fast);
 int pin_user_pages_fast(unsigned long start, int nr_pages,
                        unsigned int gup_flags, struct page **pages)
 {
-       if (!is_valid_gup_args(pages, NULL, NULL, &gup_flags, FOLL_PIN))
+       if (!is_valid_gup_args(pages, NULL, &gup_flags, FOLL_PIN))
                return -EINVAL;
        return internal_get_user_pages_fast(start, nr_pages, gup_flags, pages);
 }
@@ -3145,8 +3264,6 @@ EXPORT_SYMBOL_GPL(pin_user_pages_fast);
  * @gup_flags: flags modifying lookup behaviour
  * @pages:     array that receives pointers to the pages pinned.
  *             Should be at least nr_pages long.
- * @vmas:      array of pointers to vmas corresponding to each page.
- *             Or NULL if the caller does not require them.
  * @locked:    pointer to lock flag indicating whether lock is held and
  *             subsequently whether VM_FAULT_RETRY functionality can be
  *             utilised. Lock must initially be held.
@@ -3164,14 +3281,14 @@ EXPORT_SYMBOL_GPL(pin_user_pages_fast);
 long pin_user_pages_remote(struct mm_struct *mm,
                           unsigned long start, unsigned long nr_pages,
                           unsigned int gup_flags, struct page **pages,
-                          struct vm_area_struct **vmas, int *locked)
+                          int *locked)
 {
        int local_locked = 1;
 
-       if (!is_valid_gup_args(pages, vmas, locked, &gup_flags,
+       if (!is_valid_gup_args(pages, locked, &gup_flags,
                               FOLL_PIN | FOLL_TOUCH | FOLL_REMOTE))
                return 0;
-       return __gup_longterm_locked(mm, start, nr_pages, pages, vmas,
+       return __gup_longterm_locked(mm, start, nr_pages, pages,
                                     locked ? locked : &local_locked,
                                     gup_flags);
 }
@@ -3185,8 +3302,6 @@ EXPORT_SYMBOL(pin_user_pages_remote);
  * @gup_flags: flags modifying lookup behaviour
  * @pages:     array that receives pointers to the pages pinned.
  *             Should be at least nr_pages long.
- * @vmas:      array of pointers to vmas corresponding to each page.
- *             Or NULL if the caller does not require them.
  *
  * Nearly the same as get_user_pages(), except that FOLL_TOUCH is not set, and
  * FOLL_PIN is set.
@@ -3198,15 +3313,14 @@ EXPORT_SYMBOL(pin_user_pages_remote);
  * pins in it and unpin_user_page*() will not remove pins from it.
  */
 long pin_user_pages(unsigned long start, unsigned long nr_pages,
-                   unsigned int gup_flags, struct page **pages,
-                   struct vm_area_struct **vmas)
+                   unsigned int gup_flags, struct page **pages)
 {
        int locked = 1;
 
-       if (!is_valid_gup_args(pages, vmas, NULL, &gup_flags, FOLL_PIN))
+       if (!is_valid_gup_args(pages, NULL, &gup_flags, FOLL_PIN))
                return 0;
        return __gup_longterm_locked(current->mm, start, nr_pages,
-                                    pages, vmas, &locked, gup_flags);
+                                    pages, &locked, gup_flags);
 }
 EXPORT_SYMBOL(pin_user_pages);
 
@@ -3223,11 +3337,11 @@ long pin_user_pages_unlocked(unsigned long start, unsigned long nr_pages,
 {
        int locked = 0;
 
-       if (!is_valid_gup_args(pages, NULL, NULL, &gup_flags,
+       if (!is_valid_gup_args(pages, NULL, &gup_flags,
                               FOLL_PIN | FOLL_TOUCH | FOLL_UNLOCKABLE))
                return 0;
 
-       return __gup_longterm_locked(current->mm, start, nr_pages, pages, NULL,
+       return __gup_longterm_locked(current->mm, start, nr_pages, pages,
                                     &locked, gup_flags);
 }
 EXPORT_SYMBOL(pin_user_pages_unlocked);