Merge tag '6.6-rc-smb3-client-fixes-part1' of git://git.samba.org/sfrench/cifs-2.6
[linux-2.6-microblaze.git] / fs / userfaultfd.c
index 7cecd49..56eaae9 100644 (file)
@@ -277,17 +277,16 @@ static inline struct uffd_msg userfault_msg(unsigned long address,
  * hugepmd ranges.
  */
 static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx,
-                                        struct vm_area_struct *vma,
-                                        unsigned long address,
-                                        unsigned long flags,
-                                        unsigned long reason)
+                                             struct vm_fault *vmf,
+                                             unsigned long reason)
 {
+       struct vm_area_struct *vma = vmf->vma;
        pte_t *ptep, pte;
        bool ret = true;
 
-       mmap_assert_locked(ctx->mm);
+       assert_fault_locked(vmf);
 
-       ptep = hugetlb_walk(vma, address, vma_mmu_pagesize(vma));
+       ptep = hugetlb_walk(vma, vmf->address, vma_mmu_pagesize(vma));
        if (!ptep)
                goto out;
 
@@ -308,10 +307,8 @@ out:
 }
 #else
 static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx,
-                                        struct vm_area_struct *vma,
-                                        unsigned long address,
-                                        unsigned long flags,
-                                        unsigned long reason)
+                                             struct vm_fault *vmf,
+                                             unsigned long reason)
 {
        return false;   /* should never get here */
 }
@@ -325,11 +322,11 @@ static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx,
  * threads.
  */
 static inline bool userfaultfd_must_wait(struct userfaultfd_ctx *ctx,
-                                        unsigned long address,
-                                        unsigned long flags,
+                                        struct vm_fault *vmf,
                                         unsigned long reason)
 {
        struct mm_struct *mm = ctx->mm;
+       unsigned long address = vmf->address;
        pgd_t *pgd;
        p4d_t *p4d;
        pud_t *pud;
@@ -338,7 +335,7 @@ static inline bool userfaultfd_must_wait(struct userfaultfd_ctx *ctx,
        pte_t ptent;
        bool ret = true;
 
-       mmap_assert_locked(mm);
+       assert_fault_locked(vmf);
 
        pgd = pgd_offset(mm, address);
        if (!pgd_present(*pgd))
@@ -427,20 +424,16 @@ vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
         *
         * We also don't do userfault handling during
         * coredumping. hugetlbfs has the special
-        * follow_hugetlb_page() to skip missing pages in the
+        * hugetlb_follow_page_mask() to skip missing pages in the
         * FOLL_DUMP case, anon memory also checks for FOLL_DUMP with
         * the no_page_table() helper in follow_page_mask(), but the
         * shmem_vm_ops->fault method is invoked even during
-        * coredumping without mmap_lock and it ends up here.
+        * coredumping and it ends up here.
         */
        if (current->flags & (PF_EXITING|PF_DUMPCORE))
                goto out;
 
-       /*
-        * Coredumping runs without mmap_lock so we can only check that
-        * the mmap_lock is held, if PF_DUMPCORE was not set.
-        */
-       mmap_assert_locked(mm);
+       assert_fault_locked(vmf);
 
        ctx = vma->vm_userfaultfd_ctx.ctx;
        if (!ctx)
@@ -556,15 +549,12 @@ vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
        spin_unlock_irq(&ctx->fault_pending_wqh.lock);
 
        if (!is_vm_hugetlb_page(vma))
-               must_wait = userfaultfd_must_wait(ctx, vmf->address, vmf->flags,
-                                                 reason);
+               must_wait = userfaultfd_must_wait(ctx, vmf, reason);
        else
-               must_wait = userfaultfd_huge_must_wait(ctx, vma,
-                                                      vmf->address,
-                                                      vmf->flags, reason);
+               must_wait = userfaultfd_huge_must_wait(ctx, vmf, reason);
        if (is_vm_hugetlb_page(vma))
                hugetlb_vma_unlock_read(vma);
-       mmap_read_unlock(mm);
+       release_fault_lock(vmf);
 
        if (likely(must_wait && !READ_ONCE(ctx->released))) {
                wake_up_poll(&ctx->fd_wqh, EPOLLIN);
@@ -667,6 +657,7 @@ static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx,
                mmap_write_lock(mm);
                for_each_vma(vmi, vma) {
                        if (vma->vm_userfaultfd_ctx.ctx == release_new_ctx) {
+                               vma_start_write(vma);
                                vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
                                userfaultfd_set_vm_flags(vma,
                                                         vma->vm_flags & ~__VM_UFFD_FLAGS);
@@ -702,6 +693,7 @@ int dup_userfaultfd(struct vm_area_struct *vma, struct list_head *fcs)
 
        octx = vma->vm_userfaultfd_ctx.ctx;
        if (!octx || !(octx->features & UFFD_FEATURE_EVENT_FORK)) {
+               vma_start_write(vma);
                vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
                userfaultfd_set_vm_flags(vma, vma->vm_flags & ~__VM_UFFD_FLAGS);
                return 0;
@@ -783,6 +775,7 @@ void mremap_userfaultfd_prep(struct vm_area_struct *vma,
                atomic_inc(&ctx->mmap_changing);
        } else {
                /* Drop uffd context if remap feature not enabled */
+               vma_start_write(vma);
                vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
                userfaultfd_set_vm_flags(vma, vma->vm_flags & ~__VM_UFFD_FLAGS);
        }
@@ -940,6 +933,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
                        prev = vma;
                }
 
+               vma_start_write(vma);
                userfaultfd_set_vm_flags(vma, new_flags);
                vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
        }
@@ -1289,13 +1283,11 @@ static __always_inline void wake_userfault(struct userfaultfd_ctx *ctx,
                __wake_userfault(ctx, range);
 }
 
-static __always_inline int validate_range(struct mm_struct *mm,
-                                         __u64 start, __u64 len)
+static __always_inline int validate_unaligned_range(
+       struct mm_struct *mm, __u64 start, __u64 len)
 {
        __u64 task_size = mm->task_size;
 
-       if (start & ~PAGE_MASK)
-               return -EINVAL;
        if (len & ~PAGE_MASK)
                return -EINVAL;
        if (!len)
@@ -1306,9 +1298,20 @@ static __always_inline int validate_range(struct mm_struct *mm,
                return -EINVAL;
        if (len > task_size - start)
                return -EINVAL;
+       if (start + len <= start)
+               return -EINVAL;
        return 0;
 }
 
+static __always_inline int validate_range(struct mm_struct *mm,
+                                         __u64 start, __u64 len)
+{
+       if (start & ~PAGE_MASK)
+               return -EINVAL;
+
+       return validate_unaligned_range(mm, start, len);
+}
+
 static int userfaultfd_register(struct userfaultfd_ctx *ctx,
                                unsigned long arg)
 {
@@ -1502,6 +1505,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
                 * the next vma was merged into the current one and
                 * the current one has not been updated yet.
                 */
+               vma_start_write(vma);
                userfaultfd_set_vm_flags(vma, new_flags);
                vma->vm_userfaultfd_ctx.ctx = ctx;
 
@@ -1685,6 +1689,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
                 * the next vma was merged into the current one and
                 * the current one has not been updated yet.
                 */
+               vma_start_write(vma);
                userfaultfd_set_vm_flags(vma, new_flags);
                vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
 
@@ -1757,17 +1762,15 @@ static int userfaultfd_copy(struct userfaultfd_ctx *ctx,
                           sizeof(uffdio_copy)-sizeof(__s64)))
                goto out;
 
+       ret = validate_unaligned_range(ctx->mm, uffdio_copy.src,
+                                      uffdio_copy.len);
+       if (ret)
+               goto out;
        ret = validate_range(ctx->mm, uffdio_copy.dst, uffdio_copy.len);
        if (ret)
                goto out;
-       /*
-        * double check for wraparound just in case. copy_from_user()
-        * will later check uffdio_copy.src + uffdio_copy.len to fit
-        * in the userland range.
-        */
+
        ret = -EINVAL;
-       if (uffdio_copy.src + uffdio_copy.len <= uffdio_copy.src)
-               goto out;
        if (uffdio_copy.mode & ~(UFFDIO_COPY_MODE_DONTWAKE|UFFDIO_COPY_MODE_WP))
                goto out;
        if (uffdio_copy.mode & UFFDIO_COPY_MODE_WP)
@@ -1927,11 +1930,6 @@ static int userfaultfd_continue(struct userfaultfd_ctx *ctx, unsigned long arg)
                goto out;
 
        ret = -EINVAL;
-       /* double check for wraparound just in case. */
-       if (uffdio_continue.range.start + uffdio_continue.range.len <=
-           uffdio_continue.range.start) {
-               goto out;
-       }
        if (uffdio_continue.mode & ~(UFFDIO_CONTINUE_MODE_DONTWAKE |
                                     UFFDIO_CONTINUE_MODE_WP))
                goto out;
@@ -1965,6 +1963,61 @@ out:
        return ret;
 }
 
+static inline int userfaultfd_poison(struct userfaultfd_ctx *ctx, unsigned long arg)
+{
+       __s64 ret;
+       struct uffdio_poison uffdio_poison;
+       struct uffdio_poison __user *user_uffdio_poison;
+       struct userfaultfd_wake_range range;
+
+       user_uffdio_poison = (struct uffdio_poison __user *)arg;
+
+       ret = -EAGAIN;
+       if (atomic_read(&ctx->mmap_changing))
+               goto out;
+
+       ret = -EFAULT;
+       if (copy_from_user(&uffdio_poison, user_uffdio_poison,
+                          /* don't copy the output fields */
+                          sizeof(uffdio_poison) - (sizeof(__s64))))
+               goto out;
+
+       ret = validate_range(ctx->mm, uffdio_poison.range.start,
+                            uffdio_poison.range.len);
+       if (ret)
+               goto out;
+
+       ret = -EINVAL;
+       if (uffdio_poison.mode & ~UFFDIO_POISON_MODE_DONTWAKE)
+               goto out;
+
+       if (mmget_not_zero(ctx->mm)) {
+               ret = mfill_atomic_poison(ctx->mm, uffdio_poison.range.start,
+                                         uffdio_poison.range.len,
+                                         &ctx->mmap_changing, 0);
+               mmput(ctx->mm);
+       } else {
+               return -ESRCH;
+       }
+
+       if (unlikely(put_user(ret, &user_uffdio_poison->updated)))
+               return -EFAULT;
+       if (ret < 0)
+               goto out;
+
+       /* len == 0 would wake all */
+       BUG_ON(!ret);
+       range.len = ret;
+       if (!(uffdio_poison.mode & UFFDIO_POISON_MODE_DONTWAKE)) {
+               range.start = uffdio_poison.range.start;
+               wake_userfault(ctx, &range);
+       }
+       ret = range.len == uffdio_poison.range.len ? 0 : -EAGAIN;
+
+out:
+       return ret;
+}
+
 static inline unsigned int uffd_ctx_features(__u64 user_features)
 {
        /*
@@ -2066,6 +2119,9 @@ static long userfaultfd_ioctl(struct file *file, unsigned cmd,
        case UFFDIO_CONTINUE:
                ret = userfaultfd_continue(ctx, arg);
                break;
+       case UFFDIO_POISON:
+               ret = userfaultfd_poison(ctx, arg);
+               break;
        }
        return ret;
 }