Merge branch 'for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/dtor/input
[linux-2.6-microblaze.git] / mm / gup.c
index 1b521e0..afce0bc 100644 (file)
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -29,6 +29,22 @@ struct follow_page_context {
        unsigned int page_mask;
 };
 
+static void hpage_pincount_add(struct page *page, int refs)
+{
+       VM_BUG_ON_PAGE(!hpage_pincount_available(page), page);
+       VM_BUG_ON_PAGE(page != compound_head(page), page);
+
+       atomic_add(refs, compound_pincount_ptr(page));
+}
+
+static void hpage_pincount_sub(struct page *page, int refs)
+{
+       VM_BUG_ON_PAGE(!hpage_pincount_available(page), page);
+       VM_BUG_ON_PAGE(page != compound_head(page), page);
+
+       atomic_sub(refs, compound_pincount_ptr(page));
+}
+
 /*
  * Return the compound head page with ref appropriately incremented,
  * or NULL if that failed.
@@ -44,6 +60,195 @@ static inline struct page *try_get_compound_head(struct page *page, int refs)
        return head;
 }
 
+/*
+ * try_grab_compound_head() - attempt to elevate a page's refcount, by a
+ * flags-dependent amount.
+ *
+ * "grab" names in this file mean, "look at flags to decide whether to use
+ * FOLL_PIN or FOLL_GET behavior, when incrementing the page's refcount.
+ *
+ * Either FOLL_PIN or FOLL_GET (or neither) must be set, but not both at the
+ * same time. (That's true throughout the get_user_pages*() and
+ * pin_user_pages*() APIs.) Cases:
+ *
+ *    FOLL_GET: page's refcount will be incremented by 1.
+ *    FOLL_PIN: page's refcount will be incremented by GUP_PIN_COUNTING_BIAS.
+ *
+ * Return: head page (with refcount appropriately incremented) for success, or
+ * NULL upon failure. If neither FOLL_GET nor FOLL_PIN was set, that's
+ * considered failure, and furthermore, a likely bug in the caller, so a warning
+ * is also emitted.
+ */
+static __maybe_unused struct page *try_grab_compound_head(struct page *page,
+                                                         int refs,
+                                                         unsigned int flags)
+{
+       if (flags & FOLL_GET)
+               return try_get_compound_head(page, refs);
+       else if (flags & FOLL_PIN) {
+               int orig_refs = refs;
+
+               /*
+                * Can't do FOLL_LONGTERM + FOLL_PIN with CMA in the gup fast
+                * path, so fail and let the caller fall back to the slow path.
+                */
+               if (unlikely(flags & FOLL_LONGTERM) &&
+                               is_migrate_cma_page(page))
+                       return NULL;
+
+               /*
+                * When pinning a compound page of order > 1 (which is what
+                * hpage_pincount_available() checks for), use an exact count to
+                * track it, via hpage_pincount_add/_sub().
+                *
+                * However, be sure to *also* increment the normal page refcount
+                * field at least once, so that the page really is pinned.
+                */
+               if (!hpage_pincount_available(page))
+                       refs *= GUP_PIN_COUNTING_BIAS;
+
+               page = try_get_compound_head(page, refs);
+               if (!page)
+                       return NULL;
+
+               if (hpage_pincount_available(page))
+                       hpage_pincount_add(page, refs);
+
+               mod_node_page_state(page_pgdat(page), NR_FOLL_PIN_ACQUIRED,
+                                   orig_refs);
+
+               return page;
+       }
+
+       WARN_ON_ONCE(1);
+       return NULL;
+}
+
+/**
+ * try_grab_page() - elevate a page's refcount by a flag-dependent amount
+ *
+ * This might not do anything at all, depending on the flags argument.
+ *
+ * "grab" names in this file mean, "look at flags to decide whether to use
+ * FOLL_PIN or FOLL_GET behavior, when incrementing the page's refcount.
+ *
+ * @page:    pointer to page to be grabbed
+ * @flags:   gup flags: these are the FOLL_* flag values.
+ *
+ * Either FOLL_PIN or FOLL_GET (or neither) may be set, but not both at the same
+ * time. Cases:
+ *
+ *    FOLL_GET: page's refcount will be incremented by 1.
+ *    FOLL_PIN: page's refcount will be incremented by GUP_PIN_COUNTING_BIAS.
+ *
+ * Return: true for success, or if no action was required (if neither FOLL_PIN
+ * nor FOLL_GET was set, nothing is done). False for failure: FOLL_GET or
+ * FOLL_PIN was set, but the page could not be grabbed.
+ */
+bool __must_check try_grab_page(struct page *page, unsigned int flags)
+{
+       WARN_ON_ONCE((flags & (FOLL_GET | FOLL_PIN)) == (FOLL_GET | FOLL_PIN));
+
+       if (flags & FOLL_GET)
+               return try_get_page(page);
+       else if (flags & FOLL_PIN) {
+               int refs = 1;
+
+               page = compound_head(page);
+
+               if (WARN_ON_ONCE(page_ref_count(page) <= 0))
+                       return false;
+
+               if (hpage_pincount_available(page))
+                       hpage_pincount_add(page, 1);
+               else
+                       refs = GUP_PIN_COUNTING_BIAS;
+
+               /*
+                * Similar to try_grab_compound_head(): even if using the
+                * hpage_pincount_add/_sub() routines, be sure to
+                * *also* increment the normal page refcount field at least
+                * once, so that the page really is pinned.
+                */
+               page_ref_add(page, refs);
+
+               mod_node_page_state(page_pgdat(page), NR_FOLL_PIN_ACQUIRED, 1);
+       }
+
+       return true;
+}
+
+#ifdef CONFIG_DEV_PAGEMAP_OPS
+static bool __unpin_devmap_managed_user_page(struct page *page)
+{
+       int count, refs = 1;
+
+       if (!page_is_devmap_managed(page))
+               return false;
+
+       if (hpage_pincount_available(page))
+               hpage_pincount_sub(page, 1);
+       else
+               refs = GUP_PIN_COUNTING_BIAS;
+
+       count = page_ref_sub_return(page, refs);
+
+       mod_node_page_state(page_pgdat(page), NR_FOLL_PIN_RELEASED, 1);
+       /*
+        * devmap page refcounts are 1-based, rather than 0-based: if
+        * refcount is 1, then the page is free and the refcount is
+        * stable because nobody holds a reference on the page.
+        */
+       if (count == 1)
+               free_devmap_managed_page(page);
+       else if (!count)
+               __put_page(page);
+
+       return true;
+}
+#else
+static bool __unpin_devmap_managed_user_page(struct page *page)
+{
+       return false;
+}
+#endif /* CONFIG_DEV_PAGEMAP_OPS */
+
+/**
+ * unpin_user_page() - release a dma-pinned page
+ * @page:            pointer to page to be released
+ *
+ * Pages that were pinned via pin_user_pages*() must be released via either
+ * unpin_user_page(), or one of the unpin_user_pages*() routines. This is so
+ * that such pages can be separately tracked and uniquely handled. In
+ * particular, interactions with RDMA and filesystems need special handling.
+ */
+void unpin_user_page(struct page *page)
+{
+       int refs = 1;
+
+       page = compound_head(page);
+
+       /*
+        * For devmap managed pages we need to catch refcount transition from
+        * GUP_PIN_COUNTING_BIAS to 1, when refcount reach one it means the
+        * page is free and we need to inform the device driver through
+        * callback. See include/linux/memremap.h and HMM for details.
+        */
+       if (__unpin_devmap_managed_user_page(page))
+               return;
+
+       if (hpage_pincount_available(page))
+               hpage_pincount_sub(page, 1);
+       else
+               refs = GUP_PIN_COUNTING_BIAS;
+
+       if (page_ref_sub_and_test(page, refs))
+               __put_page(page);
+
+       mod_node_page_state(page_pgdat(page), NR_FOLL_PIN_RELEASED, 1);
+}
+EXPORT_SYMBOL(unpin_user_page);
+
 /**
  * unpin_user_pages_dirty_lock() - release and optionally dirty gup-pinned pages
  * @pages:  array of pages to be maybe marked dirty, and definitely released.
@@ -146,7 +351,8 @@ static struct page *no_page_table(struct vm_area_struct *vma,
         * But we can only make this optimization where a hole would surely
         * be zero-filled if handle_mm_fault() actually did handle it.
         */
-       if ((flags & FOLL_DUMP) && (!vma->vm_ops || !vma->vm_ops->fault))
+       if ((flags & FOLL_DUMP) &&
+                       (vma_is_anonymous(vma) || !vma->vm_ops->fault))
                return ERR_PTR(-EFAULT);
        return NULL;
 }
@@ -193,6 +399,7 @@ static struct page *follow_page_pte(struct vm_area_struct *vma,
        struct page *page;
        spinlock_t *ptl;
        pte_t *ptep, pte;
+       int ret;
 
        /* FOLL_GET and FOLL_PIN are mutually exclusive. */
        if (WARN_ON_ONCE((flags & (FOLL_PIN | FOLL_GET)) ==
@@ -230,10 +437,11 @@ retry:
        }
 
        page = vm_normal_page(vma, address, pte);
-       if (!page && pte_devmap(pte) && (flags & FOLL_GET)) {
+       if (!page && pte_devmap(pte) && (flags & (FOLL_GET | FOLL_PIN))) {
                /*
-                * Only return device mapping pages in the FOLL_GET case since
-                * they are only valid while holding the pgmap reference.
+                * Only return device mapping pages in the FOLL_GET or FOLL_PIN
+                * case since they are only valid while holding the pgmap
+                * reference.
                 */
                *pgmap = get_dev_pagemap(pte_pfn(pte), *pgmap);
                if (*pgmap)
@@ -250,8 +458,6 @@ retry:
                if (is_zero_pfn(pte_pfn(pte))) {
                        page = pte_page(pte);
                } else {
-                       int ret;
-
                        ret = follow_pfn_pte(vma, address, ptep, flags);
                        page = ERR_PTR(ret);
                        goto out;
@@ -259,7 +465,6 @@ retry:
        }
 
        if (flags & FOLL_SPLIT && PageTransCompound(page)) {
-               int ret;
                get_page(page);
                pte_unmap_unlock(ptep, ptl);
                lock_page(page);
@@ -271,9 +476,21 @@ retry:
                goto retry;
        }
 
-       if (flags & FOLL_GET) {
-               if (unlikely(!try_get_page(page))) {
-                       page = ERR_PTR(-ENOMEM);
+       /* try_grab_page() does nothing unless FOLL_GET or FOLL_PIN is set. */
+       if (unlikely(!try_grab_page(page, flags))) {
+               page = ERR_PTR(-ENOMEM);
+               goto out;
+       }
+       /*
+        * We need to make the page accessible if and only if we are going
+        * to access its content (the FOLL_PIN case).  Please see
+        * Documentation/core-api/pin_user_pages.rst for details.
+        */
+       if (flags & FOLL_PIN) {
+               ret = arch_make_page_accessible(page);
+               if (ret) {
+                       unpin_user_page(page);
+                       page = ERR_PTR(ret);
                        goto out;
                }
        }
@@ -537,7 +754,7 @@ static struct page *follow_page_mask(struct vm_area_struct *vma,
        /* make this handle hugepd */
        page = follow_huge_addr(mm, address, flags & FOLL_WRITE);
        if (!IS_ERR(page)) {
-               BUG_ON(flags & FOLL_GET);
+               WARN_ON_ONCE(flags & (FOLL_GET | FOLL_PIN));
                return page;
        }
 
@@ -630,12 +847,12 @@ unmap:
 }
 
 /*
- * mmap_sem must be held on entry.  If @nonblocking != NULL and
- * *@flags does not include FOLL_NOWAIT, the mmap_sem may be released.
- * If it is, *@nonblocking will be set to 0 and -EBUSY returned.
+ * mmap_sem must be held on entry.  If @locked != NULL and *@flags
+ * does not include FOLL_NOWAIT, the mmap_sem may be released.  If it
+ * is, *@locked will be set to 0 and -EBUSY returned.
  */
 static int faultin_page(struct task_struct *tsk, struct vm_area_struct *vma,
-               unsigned long address, unsigned int *flags, int *nonblocking)
+               unsigned long address, unsigned int *flags, int *locked)
 {
        unsigned int fault_flags = 0;
        vm_fault_t ret;
@@ -647,12 +864,15 @@ static int faultin_page(struct task_struct *tsk, struct vm_area_struct *vma,
                fault_flags |= FAULT_FLAG_WRITE;
        if (*flags & FOLL_REMOTE)
                fault_flags |= FAULT_FLAG_REMOTE;
-       if (nonblocking)
-               fault_flags |= FAULT_FLAG_ALLOW_RETRY;
+       if (locked)
+               fault_flags |= FAULT_FLAG_ALLOW_RETRY | FAULT_FLAG_KILLABLE;
        if (*flags & FOLL_NOWAIT)
                fault_flags |= FAULT_FLAG_ALLOW_RETRY | FAULT_FLAG_RETRY_NOWAIT;
        if (*flags & FOLL_TRIED) {
-               VM_WARN_ON_ONCE(fault_flags & FAULT_FLAG_ALLOW_RETRY);
+               /*
+                * Note: FAULT_FLAG_ALLOW_RETRY and FAULT_FLAG_TRIED
+                * can co-exist
+                */
                fault_flags |= FAULT_FLAG_TRIED;
        }
 
@@ -673,8 +893,8 @@ static int faultin_page(struct task_struct *tsk, struct vm_area_struct *vma,
        }
 
        if (ret & VM_FAULT_RETRY) {
-               if (nonblocking && !(fault_flags & FAULT_FLAG_RETRY_NOWAIT))
-                       *nonblocking = 0;
+               if (locked && !(fault_flags & FAULT_FLAG_RETRY_NOWAIT))
+                       *locked = 0;
                return -EBUSY;
        }
 
@@ -751,7 +971,7 @@ static int check_vma_flags(struct vm_area_struct *vma, unsigned long gup_flags)
  *             only intends to ensure the pages are faulted in.
  * @vmas:      array of pointers to vmas corresponding to each page.
  *             Or NULL if the caller does not require them.
- * @nonblocking: whether waiting for disk IO or mmap_sem contention
+ * @locked:     whether we're still with the mmap_sem held
  *
  * Returns either number of pages pinned (which may be less than the
  * number requested), or an error. Details about the return value:
@@ -786,13 +1006,11 @@ static int check_vma_flags(struct vm_area_struct *vma, unsigned long gup_flags)
  * appropriate) must be called after the page is finished with, and
  * before put_page is called.
  *
- * If @nonblocking != NULL, __get_user_pages will not wait for disk IO
- * or mmap_sem contention, and if waiting is needed to pin all pages,
- * *@nonblocking will be set to 0.  Further, if @gup_flags does not
- * include FOLL_NOWAIT, the mmap_sem will be released via up_read() in
- * this case.
+ * If @locked != NULL, *@locked will be set to 0 when mmap_sem is
+ * released by an up_read().  That can happen if @gup_flags does not
+ * have FOLL_NOWAIT.
  *
- * A caller using such a combination of @nonblocking and @gup_flags
+ * A caller using such a combination of @locked and @gup_flags
  * must therefore hold the mmap_sem for reading only, and recognize
  * when it's been released.  Otherwise, it must be held for either
  * reading or writing and will not be released.
@@ -804,7 +1022,7 @@ static int check_vma_flags(struct vm_area_struct *vma, unsigned long gup_flags)
 static long __get_user_pages(struct task_struct *tsk, struct mm_struct *mm,
                unsigned long start, unsigned long nr_pages,
                unsigned int gup_flags, struct page **pages,
-               struct vm_area_struct **vmas, int *nonblocking)
+               struct vm_area_struct **vmas, int *locked)
 {
        long ret = 0, i = 0;
        struct vm_area_struct *vma = NULL;
@@ -850,7 +1068,17 @@ static long __get_user_pages(struct task_struct *tsk, struct mm_struct *mm,
                        if (is_vm_hugetlb_page(vma)) {
                                i = follow_hugetlb_page(mm, vma, pages, vmas,
                                                &start, &nr_pages, i,
-                                               gup_flags, nonblocking);
+                                               gup_flags, locked);
+                               if (locked && *locked == 0) {
+                                       /*
+                                        * We've got a VM_FAULT_RETRY
+                                        * and we've lost mmap_sem.
+                                        * We must stop here.
+                                        */
+                                       BUG_ON(gup_flags & FOLL_NOWAIT);
+                                       BUG_ON(ret != 0);
+                                       goto out;
+                               }
                                continue;
                        }
                }
@@ -868,13 +1096,13 @@ retry:
                page = follow_page_mask(vma, start, foll_flags, &ctx);
                if (!page) {
                        ret = faultin_page(tsk, vma, start, &foll_flags,
-                                       nonblocking);
+                                          locked);
                        switch (ret) {
                        case 0:
                                goto retry;
                        case -EBUSY:
                                ret = 0;
-                               /* FALLTHRU */
+                               fallthrough;
                        case -EFAULT:
                        case -ENOMEM:
                        case -EHWPOISON:
@@ -980,7 +1208,7 @@ int fixup_user_fault(struct task_struct *tsk, struct mm_struct *mm,
        address = untagged_addr(address);
 
        if (unlocked)
-               fault_flags |= FAULT_FLAG_ALLOW_RETRY;
+               fault_flags |= FAULT_FLAG_ALLOW_RETRY | FAULT_FLAG_KILLABLE;
 
 retry:
        vma = find_extend_vma(mm, address);
@@ -1004,7 +1232,6 @@ retry:
                down_read(&mm->mmap_sem);
                if (!(fault_flags & FAULT_FLAG_TRIED)) {
                        *unlocked = true;
-                       fault_flags &= ~FAULT_FLAG_ALLOW_RETRY;
                        fault_flags |= FAULT_FLAG_TRIED;
                        goto retry;
                }
@@ -1088,17 +1315,36 @@ static __always_inline long __get_user_pages_locked(struct task_struct *tsk,
                if (likely(pages))
                        pages += ret;
                start += ret << PAGE_SHIFT;
+               lock_dropped = true;
 
+retry:
                /*
                 * Repeat on the address that fired VM_FAULT_RETRY
-                * without FAULT_FLAG_ALLOW_RETRY but with
-                * FAULT_FLAG_TRIED.
+                * with both FAULT_FLAG_ALLOW_RETRY and
+                * FAULT_FLAG_TRIED.  Note that GUP can be interrupted
+                * by fatal signals, so we need to check it before we
+                * start trying again otherwise it can loop forever.
                 */
+
+               if (fatal_signal_pending(current))
+                       break;
+
+               ret = down_read_killable(&mm->mmap_sem);
+               if (ret) {
+                       BUG_ON(ret > 0);
+                       if (!pages_done)
+                               pages_done = ret;
+                       break;
+               }
+
                *locked = 1;
-               lock_dropped = true;
-               down_read(&mm->mmap_sem);
                ret = __get_user_pages(tsk, mm, start, 1, flags | FOLL_TRIED,
-                                      pages, NULL, NULL);
+                                      pages, NULL, locked);
+               if (!*locked) {
+                       /* Continue to retry until we succeeded */
+                       BUG_ON(ret != 0);
+                       goto retry;
+               }
                if (ret != 1) {
                        BUG_ON(ret > 1);
                        if (!pages_done)
@@ -1129,7 +1375,7 @@ static __always_inline long __get_user_pages_locked(struct task_struct *tsk,
  * @vma:   target vma
  * @start: start address
  * @end:   end address
- * @nonblocking:
+ * @locked: whether the mmap_sem is still held
  *
  * This takes care of mlocking the pages too if VM_LOCKED is set.
  *
@@ -1137,14 +1383,14 @@ static __always_inline long __get_user_pages_locked(struct task_struct *tsk,
  *
  * vma->vm_mm->mmap_sem must be held.
  *
- * If @nonblocking is NULL, it may be held for read or write and will
+ * If @locked is NULL, it may be held for read or write and will
  * be unperturbed.
  *
- * If @nonblocking is non-NULL, it must held for read only and may be
- * released.  If it's released, *@nonblocking will be set to 0.
+ * If @locked is non-NULL, it must held for read only and may be
+ * released.  If it's released, *@locked will be set to 0.
  */
 long populate_vma_page_range(struct vm_area_struct *vma,
-               unsigned long start, unsigned long end, int *nonblocking)
+               unsigned long start, unsigned long end, int *locked)
 {
        struct mm_struct *mm = vma->vm_mm;
        unsigned long nr_pages = (end - start) / PAGE_SIZE;
@@ -1171,7 +1417,7 @@ long populate_vma_page_range(struct vm_area_struct *vma,
         * We want mlock to succeed for regions that have any permissions
         * other than PROT_NONE.
         */
-       if (vma->vm_flags & (VM_READ | VM_WRITE | VM_EXEC))
+       if (vma_is_accessible(vma))
                gup_flags |= FOLL_FORCE;
 
        /*
@@ -1179,7 +1425,7 @@ long populate_vma_page_range(struct vm_area_struct *vma,
         * not result in a stack expansion that recurses back here.
         */
        return __get_user_pages(current, mm, start, nr_pages, gup_flags,
-                               NULL, NULL, nonblocking);
+                               NULL, NULL, locked);
 }
 
 /*
@@ -1431,7 +1677,7 @@ check_again:
                                        list_add_tail(&head->lru, &cma_page_list);
                                        mod_node_page_state(page_pgdat(head),
                                                            NR_ISOLATED_ANON +
-                                                           page_is_file_cache(head),
+                                                           page_is_file_lru(head),
                                                            hpage_nr_pages(head));
                                }
                        }
@@ -1557,6 +1803,37 @@ static __always_inline long __gup_longterm_locked(struct task_struct *tsk,
 }
 #endif /* CONFIG_FS_DAX || CONFIG_CMA */
 
+#ifdef CONFIG_MMU
+static long __get_user_pages_remote(struct task_struct *tsk,
+                                   struct mm_struct *mm,
+                                   unsigned long start, unsigned long nr_pages,
+                                   unsigned int gup_flags, struct page **pages,
+                                   struct vm_area_struct **vmas, int *locked)
+{
+       /*
+        * Parts of FOLL_LONGTERM behavior are incompatible with
+        * FAULT_FLAG_ALLOW_RETRY because of the FS DAX check requirement on
+        * vmas. However, this only comes up if locked is set, and there are
+        * callers that do request FOLL_LONGTERM, but do not set locked. So,
+        * allow what we can.
+        */
+       if (gup_flags & FOLL_LONGTERM) {
+               if (WARN_ON_ONCE(locked))
+                       return -EINVAL;
+               /*
+                * This will check the vmas (even if our vmas arg is NULL)
+                * and return -ENOTSUPP if DAX isn't allowed in this case:
+                */
+               return __gup_longterm_locked(tsk, mm, start, nr_pages, pages,
+                                            vmas, gup_flags | FOLL_TOUCH |
+                                            FOLL_REMOTE);
+       }
+
+       return __get_user_pages_locked(tsk, mm, start, nr_pages, pages, vmas,
+                                      locked,
+                                      gup_flags | FOLL_TOUCH | FOLL_REMOTE);
+}
+
 /*
  * get_user_pages_remote() - pin user pages in memory
  * @tsk:       the task_struct to use for page fault accounting, or
@@ -1619,7 +1896,6 @@ static __always_inline long __gup_longterm_locked(struct task_struct *tsk,
  * should use get_user_pages because it cannot pass
  * FAULT_FLAG_ALLOW_RETRY to handle_mm_fault.
  */
-#ifdef CONFIG_MMU
 long get_user_pages_remote(struct task_struct *tsk, struct mm_struct *mm,
                unsigned long start, unsigned long nr_pages,
                unsigned int gup_flags, struct page **pages,
@@ -1632,28 +1908,8 @@ long get_user_pages_remote(struct task_struct *tsk, struct mm_struct *mm,
        if (WARN_ON_ONCE(gup_flags & FOLL_PIN))
                return -EINVAL;
 
-       /*
-        * Parts of FOLL_LONGTERM behavior are incompatible with
-        * FAULT_FLAG_ALLOW_RETRY because of the FS DAX check requirement on
-        * vmas. However, this only comes up if locked is set, and there are
-        * callers that do request FOLL_LONGTERM, but do not set locked. So,
-        * allow what we can.
-        */
-       if (gup_flags & FOLL_LONGTERM) {
-               if (WARN_ON_ONCE(locked))
-                       return -EINVAL;
-               /*
-                * This will check the vmas (even if our vmas arg is NULL)
-                * and return -ENOTSUPP if DAX isn't allowed in this case:
-                */
-               return __gup_longterm_locked(tsk, mm, start, nr_pages, pages,
-                                            vmas, gup_flags | FOLL_TOUCH |
-                                            FOLL_REMOTE);
-       }
-
-       return __get_user_pages_locked(tsk, mm, start, nr_pages, pages, vmas,
-                                      locked,
-                                      gup_flags | FOLL_TOUCH | FOLL_REMOTE);
+       return __get_user_pages_remote(tsk, mm, start, nr_pages, gup_flags,
+                                      pages, vmas, locked);
 }
 EXPORT_SYMBOL(get_user_pages_remote);
 
@@ -1665,6 +1921,15 @@ long get_user_pages_remote(struct task_struct *tsk, struct mm_struct *mm,
 {
        return 0;
 }
+
+static long __get_user_pages_remote(struct task_struct *tsk,
+                                   struct mm_struct *mm,
+                                   unsigned long start, unsigned long nr_pages,
+                                   unsigned int gup_flags, struct page **pages,
+                                   struct vm_area_struct **vmas, int *locked)
+{
+       return 0;
+}
 #endif /* !CONFIG_MMU */
 
 /*
@@ -1804,7 +2069,31 @@ EXPORT_SYMBOL(get_user_pages_unlocked);
  * This code is based heavily on the PowerPC implementation by Nick Piggin.
  */
 #ifdef CONFIG_HAVE_FAST_GUP
+
+static void put_compound_head(struct page *page, int refs, unsigned int flags)
+{
+       if (flags & FOLL_PIN) {
+               mod_node_page_state(page_pgdat(page), NR_FOLL_PIN_RELEASED,
+                                   refs);
+
+               if (hpage_pincount_available(page))
+                       hpage_pincount_sub(page, refs);
+               else
+                       refs *= GUP_PIN_COUNTING_BIAS;
+       }
+
+       VM_BUG_ON_PAGE(page_ref_count(page) < refs, page);
+       /*
+        * Calling put_page() for each ref is unnecessarily slow. Only the last
+        * ref needs a put_page().
+        */
+       if (refs > 1)
+               page_ref_sub(page, refs - 1);
+       put_page(page);
+}
+
 #ifdef CONFIG_GUP_GET_PTE_LOW_HIGH
+
 /*
  * WARNING: only to be used in the get_user_pages_fast() implementation.
  *
@@ -1860,13 +2149,17 @@ static inline pte_t gup_get_pte(pte_t *ptep)
 #endif /* CONFIG_GUP_GET_PTE_LOW_HIGH */
 
 static void __maybe_unused undo_dev_pagemap(int *nr, int nr_start,
+                                           unsigned int flags,
                                            struct page **pages)
 {
        while ((*nr) - nr_start) {
                struct page *page = pages[--(*nr)];
 
                ClearPageReferenced(page);
-               put_page(page);
+               if (flags & FOLL_PIN)
+                       unpin_user_page(page);
+               else
+                       put_page(page);
        }
 }
 
@@ -1899,7 +2192,7 @@ static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
 
                        pgmap = get_dev_pagemap(pte_pfn(pte), pgmap);
                        if (unlikely(!pgmap)) {
-                               undo_dev_pagemap(nr, nr_start, pages);
+                               undo_dev_pagemap(nr, nr_start, flags, pages);
                                goto pte_unmap;
                        }
                } else if (pte_special(pte))
@@ -1908,17 +2201,30 @@ static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
                VM_BUG_ON(!pfn_valid(pte_pfn(pte)));
                page = pte_page(pte);
 
-               head = try_get_compound_head(page, 1);
+               head = try_grab_compound_head(page, 1, flags);
                if (!head)
                        goto pte_unmap;
 
                if (unlikely(pte_val(pte) != pte_val(*ptep))) {
-                       put_page(head);
+                       put_compound_head(head, 1, flags);
                        goto pte_unmap;
                }
 
                VM_BUG_ON_PAGE(compound_head(page) != head, page);
 
+               /*
+                * We need to make the page accessible if and only if we are
+                * going to access its content (the FOLL_PIN case).  Please
+                * see Documentation/core-api/pin_user_pages.rst for
+                * details.
+                */
+               if (flags & FOLL_PIN) {
+                       ret = arch_make_page_accessible(page);
+                       if (ret) {
+                               unpin_user_page(page);
+                               goto pte_unmap;
+                       }
+               }
                SetPageReferenced(page);
                pages[*nr] = page;
                (*nr)++;
@@ -1953,7 +2259,8 @@ static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
 
 #if defined(CONFIG_ARCH_HAS_PTE_DEVMAP) && defined(CONFIG_TRANSPARENT_HUGEPAGE)
 static int __gup_device_huge(unsigned long pfn, unsigned long addr,
-               unsigned long end, struct page **pages, int *nr)
+                            unsigned long end, unsigned int flags,
+                            struct page **pages, int *nr)
 {
        int nr_start = *nr;
        struct dev_pagemap *pgmap = NULL;
@@ -1963,12 +2270,15 @@ static int __gup_device_huge(unsigned long pfn, unsigned long addr,
 
                pgmap = get_dev_pagemap(pfn, pgmap);
                if (unlikely(!pgmap)) {
-                       undo_dev_pagemap(nr, nr_start, pages);
+                       undo_dev_pagemap(nr, nr_start, flags, pages);
                        return 0;
                }
                SetPageReferenced(page);
                pages[*nr] = page;
-               get_page(page);
+               if (unlikely(!try_grab_page(page, flags))) {
+                       undo_dev_pagemap(nr, nr_start, flags, pages);
+                       return 0;
+               }
                (*nr)++;
                pfn++;
        } while (addr += PAGE_SIZE, addr != end);
@@ -1979,48 +2289,52 @@ static int __gup_device_huge(unsigned long pfn, unsigned long addr,
 }
 
 static int __gup_device_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr,
-               unsigned long end, struct page **pages, int *nr)
+                                unsigned long end, unsigned int flags,
+                                struct page **pages, int *nr)
 {
        unsigned long fault_pfn;
        int nr_start = *nr;
 
        fault_pfn = pmd_pfn(orig) + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
-       if (!__gup_device_huge(fault_pfn, addr, end, pages, nr))
+       if (!__gup_device_huge(fault_pfn, addr, end, flags, pages, nr))
                return 0;
 
        if (unlikely(pmd_val(orig) != pmd_val(*pmdp))) {
-               undo_dev_pagemap(nr, nr_start, pages);
+               undo_dev_pagemap(nr, nr_start, flags, pages);
                return 0;
        }
        return 1;
 }
 
 static int __gup_device_huge_pud(pud_t orig, pud_t *pudp, unsigned long addr,
-               unsigned long end, struct page **pages, int *nr)
+                                unsigned long end, unsigned int flags,
+                                struct page **pages, int *nr)
 {
        unsigned long fault_pfn;
        int nr_start = *nr;
 
        fault_pfn = pud_pfn(orig) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
-       if (!__gup_device_huge(fault_pfn, addr, end, pages, nr))
+       if (!__gup_device_huge(fault_pfn, addr, end, flags, pages, nr))
                return 0;
 
        if (unlikely(pud_val(orig) != pud_val(*pudp))) {
-               undo_dev_pagemap(nr, nr_start, pages);
+               undo_dev_pagemap(nr, nr_start, flags, pages);
                return 0;
        }
        return 1;
 }
 #else
 static int __gup_device_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr,
-               unsigned long end, struct page **pages, int *nr)
+                                unsigned long end, unsigned int flags,
+                                struct page **pages, int *nr)
 {
        BUILD_BUG();
        return 0;
 }
 
 static int __gup_device_huge_pud(pud_t pud, pud_t *pudp, unsigned long addr,
-               unsigned long end, struct page **pages, int *nr)
+                                unsigned long end, unsigned int flags,
+                                struct page **pages, int *nr)
 {
        BUILD_BUG();
        return 0;
@@ -2038,18 +2352,6 @@ static int record_subpages(struct page *page, unsigned long addr,
        return nr;
 }
 
-static void put_compound_head(struct page *page, int refs)
-{
-       VM_BUG_ON_PAGE(page_ref_count(page) < refs, page);
-       /*
-        * Calling put_page() for each ref is unnecessarily slow. Only the last
-        * ref needs a put_page().
-        */
-       if (refs > 1)
-               page_ref_sub(page, refs - 1);
-       put_page(page);
-}
-
 #ifdef CONFIG_ARCH_HAS_HUGEPD
 static unsigned long hugepte_addr_end(unsigned long addr, unsigned long end,
                                      unsigned long sz)
@@ -2083,12 +2385,12 @@ static int gup_hugepte(pte_t *ptep, unsigned long sz, unsigned long addr,
        page = head + ((addr & (sz-1)) >> PAGE_SHIFT);
        refs = record_subpages(page, addr, end, pages + *nr);
 
-       head = try_get_compound_head(head, refs);
+       head = try_grab_compound_head(head, refs, flags);
        if (!head)
                return 0;
 
        if (unlikely(pte_val(pte) != pte_val(*ptep))) {
-               put_compound_head(head, refs);
+               put_compound_head(head, refs, flags);
                return 0;
        }
 
@@ -2136,18 +2438,19 @@ static int gup_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr,
        if (pmd_devmap(orig)) {
                if (unlikely(flags & FOLL_LONGTERM))
                        return 0;
-               return __gup_device_huge_pmd(orig, pmdp, addr, end, pages, nr);
+               return __gup_device_huge_pmd(orig, pmdp, addr, end, flags,
+                                            pages, nr);
        }
 
        page = pmd_page(orig) + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
        refs = record_subpages(page, addr, end, pages + *nr);
 
-       head = try_get_compound_head(pmd_page(orig), refs);
+       head = try_grab_compound_head(pmd_page(orig), refs, flags);
        if (!head)
                return 0;
 
        if (unlikely(pmd_val(orig) != pmd_val(*pmdp))) {
-               put_compound_head(head, refs);
+               put_compound_head(head, refs, flags);
                return 0;
        }
 
@@ -2157,7 +2460,8 @@ static int gup_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr,
 }
 
 static int gup_huge_pud(pud_t orig, pud_t *pudp, unsigned long addr,
-               unsigned long end, unsigned int flags, struct page **pages, int *nr)
+                       unsigned long end, unsigned int flags,
+                       struct page **pages, int *nr)
 {
        struct page *head, *page;
        int refs;
@@ -2168,18 +2472,19 @@ static int gup_huge_pud(pud_t orig, pud_t *pudp, unsigned long addr,
        if (pud_devmap(orig)) {
                if (unlikely(flags & FOLL_LONGTERM))
                        return 0;
-               return __gup_device_huge_pud(orig, pudp, addr, end, pages, nr);
+               return __gup_device_huge_pud(orig, pudp, addr, end, flags,
+                                            pages, nr);
        }
 
        page = pud_page(orig) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
        refs = record_subpages(page, addr, end, pages + *nr);
 
-       head = try_get_compound_head(pud_page(orig), refs);
+       head = try_grab_compound_head(pud_page(orig), refs, flags);
        if (!head)
                return 0;
 
        if (unlikely(pud_val(orig) != pud_val(*pudp))) {
-               put_compound_head(head, refs);
+               put_compound_head(head, refs, flags);
                return 0;
        }
 
@@ -2203,12 +2508,12 @@ static int gup_huge_pgd(pgd_t orig, pgd_t *pgdp, unsigned long addr,
        page = pgd_page(orig) + ((addr & ~PGDIR_MASK) >> PAGE_SHIFT);
        refs = record_subpages(page, addr, end, pages + *nr);
 
-       head = try_get_compound_head(pgd_page(orig), refs);
+       head = try_grab_compound_head(pgd_page(orig), refs, flags);
        if (!head)
                return 0;
 
        if (unlikely(pgd_val(orig) != pgd_val(*pgdp))) {
-               put_compound_head(head, refs);
+               put_compound_head(head, refs, flags);
                return 0;
        }
 
@@ -2370,7 +2675,15 @@ int __get_user_pages_fast(unsigned long start, int nr_pages, int write,
 {
        unsigned long len, end;
        unsigned long flags;
-       int nr = 0;
+       int nr_pinned = 0;
+       /*
+        * Internally (within mm/gup.c), gup fast variants must set FOLL_GET,
+        * because gup fast is always a "pin with a +1 page refcount" request.
+        */
+       unsigned int gup_flags = FOLL_GET;
+
+       if (write)
+               gup_flags |= FOLL_WRITE;
 
        start = untagged_addr(start) & PAGE_MASK;
        len = (unsigned long) nr_pages << PAGE_SHIFT;
@@ -2396,11 +2709,11 @@ int __get_user_pages_fast(unsigned long start, int nr_pages, int write,
        if (IS_ENABLED(CONFIG_HAVE_FAST_GUP) &&
            gup_fast_permitted(start, end)) {
                local_irq_save(flags);
-               gup_pgd_range(start, end, write ? FOLL_WRITE : 0, pages, &nr);
+               gup_pgd_range(start, end, gup_flags, pages, &nr_pinned);
                local_irq_restore(flags);
        }
 
-       return nr;
+       return nr_pinned;
 }
 EXPORT_SYMBOL_GPL(__get_user_pages_fast);
 
@@ -2432,10 +2745,10 @@ static int internal_get_user_pages_fast(unsigned long start, int nr_pages,
                                        struct page **pages)
 {
        unsigned long addr, len, end;
-       int nr = 0, ret = 0;
+       int nr_pinned = 0, ret = 0;
 
        if (WARN_ON_ONCE(gup_flags & ~(FOLL_WRITE | FOLL_LONGTERM |
-                                      FOLL_FORCE | FOLL_PIN)))
+                                      FOLL_FORCE | FOLL_PIN | FOLL_GET)))
                return -EINVAL;
 
        start = untagged_addr(start) & PAGE_MASK;
@@ -2451,25 +2764,25 @@ static int internal_get_user_pages_fast(unsigned long start, int nr_pages,
        if (IS_ENABLED(CONFIG_HAVE_FAST_GUP) &&
            gup_fast_permitted(start, end)) {
                local_irq_disable();
-               gup_pgd_range(addr, end, gup_flags, pages, &nr);
+               gup_pgd_range(addr, end, gup_flags, pages, &nr_pinned);
                local_irq_enable();
-               ret = nr;
+               ret = nr_pinned;
        }
 
-       if (nr < nr_pages) {
+       if (nr_pinned < nr_pages) {
                /* Try to get the remaining pages with get_user_pages */
-               start += nr << PAGE_SHIFT;
-               pages += nr;
+               start += nr_pinned << PAGE_SHIFT;
+               pages += nr_pinned;
 
-               ret = __gup_longterm_unlocked(start, nr_pages - nr,
+               ret = __gup_longterm_unlocked(start, nr_pages - nr_pinned,
                                              gup_flags, pages);
 
                /* Have to be a bit careful with return values */
-               if (nr > 0) {
+               if (nr_pinned > 0) {
                        if (ret < 0)
-                               ret = nr;
+                               ret = nr_pinned;
                        else
-                               ret += nr;
+                               ret += nr_pinned;
                }
        }
 
@@ -2478,11 +2791,11 @@ static int internal_get_user_pages_fast(unsigned long start, int nr_pages,
 
 /**
  * get_user_pages_fast() - pin user pages in memory
- * @start:     starting user address
- * @nr_pages:  number of pages from start to pin
- * @gup_flags: flags modifying pin behaviour
- * @pages:     array that receives pointers to the pages pinned.
- *             Should be at least nr_pages long.
+ * @start:      starting user address
+ * @nr_pages:   number of pages from start to pin
+ * @gup_flags:  flags modifying pin behaviour
+ * @pages:      array that receives pointers to the pages pinned.
+ *              Should be at least nr_pages long.
  *
  * Attempt to pin user pages in memory without taking mm->mmap_sem.
  * If not successful, it will fall back to taking the lock and
@@ -2502,6 +2815,13 @@ int get_user_pages_fast(unsigned long start, int nr_pages,
        if (WARN_ON_ONCE(gup_flags & FOLL_PIN))
                return -EINVAL;
 
+       /*
+        * The caller may or may not have explicitly set FOLL_GET; either way is
+        * OK. However, internally (within mm/gup.c), gup fast variants must set
+        * FOLL_GET, because gup fast is always a "pin with a +1 page refcount"
+        * request.
+        */
+       gup_flags |= FOLL_GET;
        return internal_get_user_pages_fast(start, nr_pages, gup_flags, pages);
 }
 EXPORT_SYMBOL_GPL(get_user_pages_fast);
@@ -2509,9 +2829,18 @@ EXPORT_SYMBOL_GPL(get_user_pages_fast);
 /**
  * pin_user_pages_fast() - pin user pages in memory without taking locks
  *
- * For now, this is a placeholder function, until various call sites are
- * converted to use the correct get_user_pages*() or pin_user_pages*() API. So,
- * this is identical to get_user_pages_fast().
+ * @start:      starting user address
+ * @nr_pages:   number of pages from start to pin
+ * @gup_flags:  flags modifying pin behaviour
+ * @pages:      array that receives pointers to the pages pinned.
+ *              Should be at least nr_pages long.
+ *
+ * Nearly the same as get_user_pages_fast(), except that FOLL_PIN is set. See
+ * get_user_pages_fast() for documentation on the function arguments, because
+ * the arguments here are identical.
+ *
+ * FOLL_PIN means that the pages must be released via unpin_user_page(). Please
+ * see Documentation/vm/pin_user_pages.rst for further details.
  *
  * This is intended for Case 1 (DIO) in Documentation/vm/pin_user_pages.rst. It
  * is NOT intended for Case 2 (RDMA: long-term pins).
@@ -2519,21 +2848,39 @@ EXPORT_SYMBOL_GPL(get_user_pages_fast);
 int pin_user_pages_fast(unsigned long start, int nr_pages,
                        unsigned int gup_flags, struct page **pages)
 {
-       /*
-        * This is a placeholder, until the pin functionality is activated.
-        * Until then, just behave like the corresponding get_user_pages*()
-        * routine.
-        */
-       return get_user_pages_fast(start, nr_pages, gup_flags, pages);
+       /* FOLL_GET and FOLL_PIN are mutually exclusive. */
+       if (WARN_ON_ONCE(gup_flags & FOLL_GET))
+               return -EINVAL;
+
+       gup_flags |= FOLL_PIN;
+       return internal_get_user_pages_fast(start, nr_pages, gup_flags, pages);
 }
 EXPORT_SYMBOL_GPL(pin_user_pages_fast);
 
 /**
  * pin_user_pages_remote() - pin pages of a remote process (task != current)
  *
- * For now, this is a placeholder function, until various call sites are
- * converted to use the correct get_user_pages*() or pin_user_pages*() API. So,
- * this is identical to get_user_pages_remote().
+ * @tsk:       the task_struct to use for page fault accounting, or
+ *             NULL if faults are not to be recorded.
+ * @mm:                mm_struct of target mm
+ * @start:     starting user address
+ * @nr_pages:  number of pages from start to pin
+ * @gup_flags: flags modifying lookup behaviour
+ * @pages:     array that receives pointers to the pages pinned.
+ *             Should be at least nr_pages long. Or NULL, if caller
+ *             only intends to ensure the pages are faulted in.
+ * @vmas:      array of pointers to vmas corresponding to each page.
+ *             Or NULL if the caller does not require them.
+ * @locked:    pointer to lock flag indicating whether lock is held and
+ *             subsequently whether VM_FAULT_RETRY functionality can be
+ *             utilised. Lock must initially be held.
+ *
+ * Nearly the same as get_user_pages_remote(), except that FOLL_PIN is set. See
+ * get_user_pages_remote() for documentation on the function arguments, because
+ * the arguments here are identical.
+ *
+ * FOLL_PIN means that the pages must be released via unpin_user_page(). Please
+ * see Documentation/vm/pin_user_pages.rst for details.
  *
  * This is intended for Case 1 (DIO) in Documentation/vm/pin_user_pages.rst. It
  * is NOT intended for Case 2 (RDMA: long-term pins).
@@ -2543,22 +2890,33 @@ long pin_user_pages_remote(struct task_struct *tsk, struct mm_struct *mm,
                           unsigned int gup_flags, struct page **pages,
                           struct vm_area_struct **vmas, int *locked)
 {
-       /*
-        * This is a placeholder, until the pin functionality is activated.
-        * Until then, just behave like the corresponding get_user_pages*()
-        * routine.
-        */
-       return get_user_pages_remote(tsk, mm, start, nr_pages, gup_flags, pages,
-                                    vmas, locked);
+       /* FOLL_GET and FOLL_PIN are mutually exclusive. */
+       if (WARN_ON_ONCE(gup_flags & FOLL_GET))
+               return -EINVAL;
+
+       gup_flags |= FOLL_PIN;
+       return __get_user_pages_remote(tsk, mm, start, nr_pages, gup_flags,
+                                      pages, vmas, locked);
 }
 EXPORT_SYMBOL(pin_user_pages_remote);
 
 /**
  * pin_user_pages() - pin user pages in memory for use by other devices
  *
- * For now, this is a placeholder function, until various call sites are
- * converted to use the correct get_user_pages*() or pin_user_pages*() API. So,
- * this is identical to get_user_pages().
+ * @start:     starting user address
+ * @nr_pages:  number of pages from start to pin
+ * @gup_flags: flags modifying lookup behaviour
+ * @pages:     array that receives pointers to the pages pinned.
+ *             Should be at least nr_pages long. Or NULL, if caller
+ *             only intends to ensure the pages are faulted in.
+ * @vmas:      array of pointers to vmas corresponding to each page.
+ *             Or NULL if the caller does not require them.
+ *
+ * Nearly the same as get_user_pages(), except that FOLL_TOUCH is not set, and
+ * FOLL_PIN is set.
+ *
+ * FOLL_PIN means that the pages must be released via unpin_user_page(). Please
+ * see Documentation/vm/pin_user_pages.rst for details.
  *
  * This is intended for Case 1 (DIO) in Documentation/vm/pin_user_pages.rst. It
  * is NOT intended for Case 2 (RDMA: long-term pins).
@@ -2567,11 +2925,12 @@ long pin_user_pages(unsigned long start, unsigned long nr_pages,
                    unsigned int gup_flags, struct page **pages,
                    struct vm_area_struct **vmas)
 {
-       /*
-        * This is a placeholder, until the pin functionality is activated.
-        * Until then, just behave like the corresponding get_user_pages*()
-        * routine.
-        */
-       return get_user_pages(start, nr_pages, gup_flags, pages, vmas);
+       /* FOLL_GET and FOLL_PIN are mutually exclusive. */
+       if (WARN_ON_ONCE(gup_flags & FOLL_GET))
+               return -EINVAL;
+
+       gup_flags |= FOLL_PIN;
+       return __gup_longterm_locked(current, current->mm, start, nr_pages,
+                                    pages, vmas, gup_flags);
 }
 EXPORT_SYMBOL(pin_user_pages);