Merge tag 'ata-5.17-rc1-part2' of git://git.kernel.org/pub/scm/linux/kernel/git/dlemo...
[linux-2.6-microblaze.git] / mm / gup.c
index 886d614..f0af462 100644 (file)
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -642,12 +642,17 @@ static struct page *follow_pmd_mask(struct vm_area_struct *vma,
        }
 retry:
        if (!pmd_present(pmdval)) {
+               /*
+                * Should never reach here, if thp migration is not supported;
+                * Otherwise, it must be a thp migration entry.
+                */
+               VM_BUG_ON(!thp_migration_supported() ||
+                                 !is_pmd_migration_entry(pmdval));
+
                if (likely(!(flags & FOLL_MIGRATION)))
                        return no_page_table(vma, flags);
-               VM_BUG_ON(thp_migration_supported() &&
-                                 !is_pmd_migration_entry(pmdval));
-               if (is_pmd_migration_entry(pmdval))
-                       pmd_migration_entry_wait(mm, pmd);
+
+               pmd_migration_entry_wait(mm, pmd);
                pmdval = READ_ONCE(*pmd);
                /*
                 * MADV_DONTNEED may convert the pmd to null because
@@ -918,6 +923,8 @@ static int faultin_page(struct vm_area_struct *vma,
        /* mlock all present pages, but do not fault in new pages */
        if ((*flags & (FOLL_POPULATE | FOLL_MLOCK)) == FOLL_MLOCK)
                return -ENOENT;
+       if (*flags & FOLL_NOFAULT)
+               return -EFAULT;
        if (*flags & FOLL_WRITE)
                fault_flags |= FAULT_FLAG_WRITE;
        if (*flags & FOLL_REMOTE)
@@ -1656,6 +1663,143 @@ finish_or_fault:
 }
 #endif /* !CONFIG_MMU */
 
+/**
+ * fault_in_writeable - fault in userspace address range for writing
+ * @uaddr: start of address range
+ * @size: size of address range
+ *
+ * Returns the number of bytes not faulted in (like copy_to_user() and
+ * copy_from_user()).
+ */
+size_t fault_in_writeable(char __user *uaddr, size_t size)
+{
+       char __user *start = uaddr, *end;
+
+       if (unlikely(size == 0))
+               return 0;
+       if (!user_write_access_begin(uaddr, size))
+               return size;
+       if (!PAGE_ALIGNED(uaddr)) {
+               unsafe_put_user(0, uaddr, out);
+               uaddr = (char __user *)PAGE_ALIGN((unsigned long)uaddr);
+       }
+       end = (char __user *)PAGE_ALIGN((unsigned long)start + size);
+       if (unlikely(end < start))
+               end = NULL;
+       while (uaddr != end) {
+               unsafe_put_user(0, uaddr, out);
+               uaddr += PAGE_SIZE;
+       }
+
+out:
+       user_write_access_end();
+       if (size > uaddr - start)
+               return size - (uaddr - start);
+       return 0;
+}
+EXPORT_SYMBOL(fault_in_writeable);
+
+/*
+ * fault_in_safe_writeable - fault in an address range for writing
+ * @uaddr: start of address range
+ * @size: length of address range
+ *
+ * Faults in an address range using get_user_pages, i.e., without triggering
+ * hardware page faults.  This is primarily useful when we already know that
+ * some or all of the pages in the address range aren't in memory.
+ *
+ * Other than fault_in_writeable(), this function is non-destructive.
+ *
+ * Note that we don't pin or otherwise hold the pages referenced that we fault
+ * in.  There's no guarantee that they'll stay in memory for any duration of
+ * time.
+ *
+ * Returns the number of bytes not faulted in, like copy_to_user() and
+ * copy_from_user().
+ */
+size_t fault_in_safe_writeable(const char __user *uaddr, size_t size)
+{
+       unsigned long start = (unsigned long)untagged_addr(uaddr);
+       unsigned long end, nstart, nend;
+       struct mm_struct *mm = current->mm;
+       struct vm_area_struct *vma = NULL;
+       int locked = 0;
+
+       nstart = start & PAGE_MASK;
+       end = PAGE_ALIGN(start + size);
+       if (end < nstart)
+               end = 0;
+       for (; nstart != end; nstart = nend) {
+               unsigned long nr_pages;
+               long ret;
+
+               if (!locked) {
+                       locked = 1;
+                       mmap_read_lock(mm);
+                       vma = find_vma(mm, nstart);
+               } else if (nstart >= vma->vm_end)
+                       vma = vma->vm_next;
+               if (!vma || vma->vm_start >= end)
+                       break;
+               nend = end ? min(end, vma->vm_end) : vma->vm_end;
+               if (vma->vm_flags & (VM_IO | VM_PFNMAP))
+                       continue;
+               if (nstart < vma->vm_start)
+                       nstart = vma->vm_start;
+               nr_pages = (nend - nstart) / PAGE_SIZE;
+               ret = __get_user_pages_locked(mm, nstart, nr_pages,
+                                             NULL, NULL, &locked,
+                                             FOLL_TOUCH | FOLL_WRITE);
+               if (ret <= 0)
+                       break;
+               nend = nstart + ret * PAGE_SIZE;
+       }
+       if (locked)
+               mmap_read_unlock(mm);
+       if (nstart == end)
+               return 0;
+       return size - min_t(size_t, nstart - start, size);
+}
+EXPORT_SYMBOL(fault_in_safe_writeable);
+
+/**
+ * fault_in_readable - fault in userspace address range for reading
+ * @uaddr: start of user address range
+ * @size: size of user address range
+ *
+ * Returns the number of bytes not faulted in (like copy_to_user() and
+ * copy_from_user()).
+ */
+size_t fault_in_readable(const char __user *uaddr, size_t size)
+{
+       const char __user *start = uaddr, *end;
+       volatile char c;
+
+       if (unlikely(size == 0))
+               return 0;
+       if (!user_read_access_begin(uaddr, size))
+               return size;
+       if (!PAGE_ALIGNED(uaddr)) {
+               unsafe_get_user(c, uaddr, out);
+               uaddr = (const char __user *)PAGE_ALIGN((unsigned long)uaddr);
+       }
+       end = (const char __user *)PAGE_ALIGN((unsigned long)start + size);
+       if (unlikely(end < start))
+               end = NULL;
+       while (uaddr != end) {
+               unsafe_get_user(c, uaddr, out);
+               uaddr += PAGE_SIZE;
+       }
+
+out:
+       user_read_access_end();
+       (void)c;
+       if (size > uaddr - start)
+               return size - (uaddr - start);
+       return 0;
+}
+EXPORT_SYMBOL(fault_in_readable);
+
 /**
  * get_dump_page() - pin user page in memory while writing it to core dump
  * @addr: user address
@@ -2228,7 +2372,6 @@ static int __gup_device_huge(unsigned long pfn, unsigned long addr,
 {
        int nr_start = *nr;
        struct dev_pagemap *pgmap = NULL;
-       int ret = 1;
 
        do {
                struct page *page = pfn_to_page(pfn);
@@ -2236,14 +2379,12 @@ static int __gup_device_huge(unsigned long pfn, unsigned long addr,
                pgmap = get_dev_pagemap(pfn, pgmap);
                if (unlikely(!pgmap)) {
                        undo_dev_pagemap(nr, nr_start, flags, pages);
-                       ret = 0;
                        break;
                }
                SetPageReferenced(page);
                pages[*nr] = page;
                if (unlikely(!try_grab_page(page, flags))) {
                        undo_dev_pagemap(nr, nr_start, flags, pages);
-                       ret = 0;
                        break;
                }
                (*nr)++;
@@ -2251,7 +2392,7 @@ static int __gup_device_huge(unsigned long pfn, unsigned long addr,
        } while (addr += PAGE_SIZE, addr != end);
 
        put_dev_pagemap(pgmap);
-       return ret;
+       return addr == end;
 }
 
 static int __gup_device_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr,
@@ -2708,7 +2849,7 @@ static int internal_get_user_pages_fast(unsigned long start,
 
        if (WARN_ON_ONCE(gup_flags & ~(FOLL_WRITE | FOLL_LONGTERM |
                                       FOLL_FORCE | FOLL_PIN | FOLL_GET |
-                                      FOLL_FAST_ONLY)))
+                                      FOLL_FAST_ONLY | FOLL_NOFAULT)))
                return -EINVAL;
 
        if (gup_flags & FOLL_PIN)