x86/mm: Fix fault error path using unsafe vma pointer
authorLaurent Dufour <ldufour@linux.vnet.ibm.com>
Mon, 4 Sep 2017 08:32:15 +0000 (10:32 +0200)
committerThomas Gleixner <tglx@linutronix.de>
Mon, 25 Sep 2017 07:36:15 +0000 (09:36 +0200)
commit 7b2d0dbac489 ("x86/mm/pkeys: Pass VMA down in to fault signal
generation code") passes down a vma pointer to the error path, but that is
done once the mmap_sem is released when calling mm_fault_error() from
__do_page_fault().

This is dangerous as the vma structure is no more safe to be used once the
mmap_sem has been released. As only the protection key value is required in
the error processing, we could just pass down this value.

Fix it by passing a pointer to a protection key value down to the fault
signal generation code. The use of a pointer allows to keep the check
generating a warning message in fill_sig_info_pkey() when the vma was not
known. If the pointer is valid, the protection value can be accessed by
deferencing the pointer.

[ tglx: Made *pkey u32 as that's the type which is passed in siginfo ]

Fixes: 7b2d0dbac489 ("x86/mm/pkeys: Pass VMA down in to fault signal generation code")
Signed-off-by: Laurent Dufour <ldufour@linux.vnet.ibm.com>
Signed-off-by: Thomas Gleixner <tglx@linutronix.de>
Cc: linux-mm@kvack.org
Cc: Dave Hansen <dave.hansen@linux.intel.com>
Cc: stable@vger.kernel.org
Link: http://lkml.kernel.org/r/1504513935-12742-1-git-send-email-ldufour@linux.vnet.ibm.com
arch/x86/mm/fault.c

index 39567b5..e2baeaa 100644 (file)
@@ -192,8 +192,7 @@ is_prefetch(struct pt_regs *regs, unsigned long error_code, unsigned long addr)
  * 6. T1   : reaches here, sees vma_pkey(vma)=5, when we really
  *          faulted on a pte with its pkey=4.
  */
-static void fill_sig_info_pkey(int si_code, siginfo_t *info,
-               struct vm_area_struct *vma)
+static void fill_sig_info_pkey(int si_code, siginfo_t *info, u32 *pkey)
 {
        /* This is effectively an #ifdef */
        if (!boot_cpu_has(X86_FEATURE_OSPKE))
@@ -209,7 +208,7 @@ static void fill_sig_info_pkey(int si_code, siginfo_t *info,
         * valid VMA, so we should never reach this without a
         * valid VMA.
         */
-       if (!vma) {
+       if (!pkey) {
                WARN_ONCE(1, "PKU fault with no VMA passed in");
                info->si_pkey = 0;
                return;
@@ -219,13 +218,12 @@ static void fill_sig_info_pkey(int si_code, siginfo_t *info,
         * absolutely guranteed to be 100% accurate because of
         * the race explained above.
         */
-       info->si_pkey = vma_pkey(vma);
+       info->si_pkey = *pkey;
 }
 
 static void
 force_sig_info_fault(int si_signo, int si_code, unsigned long address,
-                    struct task_struct *tsk, struct vm_area_struct *vma,
-                    int fault)
+                    struct task_struct *tsk, u32 *pkey, int fault)
 {
        unsigned lsb = 0;
        siginfo_t info;
@@ -240,7 +238,7 @@ force_sig_info_fault(int si_signo, int si_code, unsigned long address,
                lsb = PAGE_SHIFT;
        info.si_addr_lsb = lsb;
 
-       fill_sig_info_pkey(si_code, &info, vma);
+       fill_sig_info_pkey(si_code, &info, pkey);
 
        force_sig_info(si_signo, &info, tsk);
 }
@@ -762,8 +760,6 @@ no_context(struct pt_regs *regs, unsigned long error_code,
        struct task_struct *tsk = current;
        unsigned long flags;
        int sig;
-       /* No context means no VMA to pass down */
-       struct vm_area_struct *vma = NULL;
 
        /* Are we prepared to handle this kernel fault? */
        if (fixup_exception(regs, X86_TRAP_PF)) {
@@ -788,7 +784,7 @@ no_context(struct pt_regs *regs, unsigned long error_code,
 
                        /* XXX: hwpoison faults will set the wrong code. */
                        force_sig_info_fault(signal, si_code, address,
-                                            tsk, vma, 0);
+                                            tsk, NULL, 0);
                }
 
                /*
@@ -896,8 +892,7 @@ show_signal_msg(struct pt_regs *regs, unsigned long error_code,
 
 static void
 __bad_area_nosemaphore(struct pt_regs *regs, unsigned long error_code,
-                      unsigned long address, struct vm_area_struct *vma,
-                      int si_code)
+                      unsigned long address, u32 *pkey, int si_code)
 {
        struct task_struct *tsk = current;
 
@@ -945,7 +940,7 @@ __bad_area_nosemaphore(struct pt_regs *regs, unsigned long error_code,
                tsk->thread.error_code  = error_code;
                tsk->thread.trap_nr     = X86_TRAP_PF;
 
-               force_sig_info_fault(SIGSEGV, si_code, address, tsk, vma, 0);
+               force_sig_info_fault(SIGSEGV, si_code, address, tsk, pkey, 0);
 
                return;
        }
@@ -958,9 +953,9 @@ __bad_area_nosemaphore(struct pt_regs *regs, unsigned long error_code,
 
 static noinline void
 bad_area_nosemaphore(struct pt_regs *regs, unsigned long error_code,
-                    unsigned long address, struct vm_area_struct *vma)
+                    unsigned long address, u32 *pkey)
 {
-       __bad_area_nosemaphore(regs, error_code, address, vma, SEGV_MAPERR);
+       __bad_area_nosemaphore(regs, error_code, address, pkey, SEGV_MAPERR);
 }
 
 static void
@@ -968,6 +963,10 @@ __bad_area(struct pt_regs *regs, unsigned long error_code,
           unsigned long address,  struct vm_area_struct *vma, int si_code)
 {
        struct mm_struct *mm = current->mm;
+       u32 pkey;
+
+       if (vma)
+               pkey = vma_pkey(vma);
 
        /*
         * Something tried to access memory that isn't in our memory map..
@@ -975,7 +974,8 @@ __bad_area(struct pt_regs *regs, unsigned long error_code,
         */
        up_read(&mm->mmap_sem);
 
-       __bad_area_nosemaphore(regs, error_code, address, vma, si_code);
+       __bad_area_nosemaphore(regs, error_code, address,
+                              (vma) ? &pkey : NULL, si_code);
 }
 
 static noinline void
@@ -1018,7 +1018,7 @@ bad_area_access_error(struct pt_regs *regs, unsigned long error_code,
 
 static void
 do_sigbus(struct pt_regs *regs, unsigned long error_code, unsigned long address,
-         struct vm_area_struct *vma, unsigned int fault)
+         u32 *pkey, unsigned int fault)
 {
        struct task_struct *tsk = current;
        int code = BUS_ADRERR;
@@ -1045,13 +1045,12 @@ do_sigbus(struct pt_regs *regs, unsigned long error_code, unsigned long address,
                code = BUS_MCEERR_AR;
        }
 #endif
-       force_sig_info_fault(SIGBUS, code, address, tsk, vma, fault);
+       force_sig_info_fault(SIGBUS, code, address, tsk, pkey, fault);
 }
 
 static noinline void
 mm_fault_error(struct pt_regs *regs, unsigned long error_code,
-              unsigned long address, struct vm_area_struct *vma,
-              unsigned int fault)
+              unsigned long address, u32 *pkey, unsigned int fault)
 {
        if (fatal_signal_pending(current) && !(error_code & PF_USER)) {
                no_context(regs, error_code, address, 0, 0);
@@ -1075,9 +1074,9 @@ mm_fault_error(struct pt_regs *regs, unsigned long error_code,
        } else {
                if (fault & (VM_FAULT_SIGBUS|VM_FAULT_HWPOISON|
                             VM_FAULT_HWPOISON_LARGE))
-                       do_sigbus(regs, error_code, address, vma, fault);
+                       do_sigbus(regs, error_code, address, pkey, fault);
                else if (fault & VM_FAULT_SIGSEGV)
-                       bad_area_nosemaphore(regs, error_code, address, vma);
+                       bad_area_nosemaphore(regs, error_code, address, pkey);
                else
                        BUG();
        }
@@ -1267,6 +1266,7 @@ __do_page_fault(struct pt_regs *regs, unsigned long error_code,
        struct mm_struct *mm;
        int fault, major = 0;
        unsigned int flags = FAULT_FLAG_ALLOW_RETRY | FAULT_FLAG_KILLABLE;
+       u32 pkey;
 
        tsk = current;
        mm = tsk->mm;
@@ -1467,9 +1467,10 @@ good_area:
                return;
        }
 
+       pkey = vma_pkey(vma);
        up_read(&mm->mmap_sem);
        if (unlikely(fault & VM_FAULT_ERROR)) {
-               mm_fault_error(regs, error_code, address, vma, fault);
+               mm_fault_error(regs, error_code, address, &pkey, fault);
                return;
        }