local_irq_restore(flags);
 }
 
+static void sync_current_stack_to_mm(struct mm_struct *mm)
+{
+       unsigned long sp = current_stack_pointer;
+       pgd_t *pgd = pgd_offset(mm, sp);
+
+       if (CONFIG_PGTABLE_LEVELS > 4) {
+               if (unlikely(pgd_none(*pgd))) {
+                       pgd_t *pgd_ref = pgd_offset_k(sp);
+
+                       set_pgd(pgd, *pgd_ref);
+               }
+       } else {
+               /*
+                * "pgd" is faked.  The top level entries are "p4d"s, so sync
+                * the p4d.  This compiles to approximately the same code as
+                * the 5-level case.
+                */
+               p4d_t *p4d = p4d_offset(pgd, sp);
+
+               if (unlikely(p4d_none(*p4d))) {
+                       pgd_t *pgd_ref = pgd_offset_k(sp);
+                       p4d_t *p4d_ref = p4d_offset(pgd_ref, sp);
+
+                       set_p4d(p4d, *p4d_ref);
+               }
+       }
+}
+
 void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
                        struct task_struct *tsk)
 {
                         * mapped in the new pgd, we'll double-fault.  Forcibly
                         * map it.
                         */
-                       unsigned int index = pgd_index(current_stack_pointer);
-                       pgd_t *pgd = next->pgd + index;
-
-                       if (unlikely(pgd_none(*pgd)))
-                               set_pgd(pgd, init_mm.pgd[index]);
+                       sync_current_stack_to_mm(next);
                }
 
                /* Stop remote flushes for the previous mm */