KVM: x86/mmu: Add accessors to query mmu_role bits
[linux-2.6-microblaze.git] / arch / x86 / kvm / mmu / mmu.c
index 0144c40..1e5beac 100644 (file)
@@ -55,7 +55,7 @@
 
 extern bool itlb_multihit_kvm_mitigation;
 
-static int __read_mostly nx_huge_pages = -1;
+int __read_mostly nx_huge_pages = -1;
 #ifdef CONFIG_PREEMPT_RT
 /* Recovery can cause latency spikes, disable it for PREEMPT_RT.  */
 static uint __read_mostly nx_huge_pages_recovery_ratio = 0;
@@ -176,9 +176,67 @@ static void mmu_spte_set(u64 *sptep, u64 spte);
 static union kvm_mmu_page_role
 kvm_mmu_calc_root_page_role(struct kvm_vcpu *vcpu);
 
+struct kvm_mmu_role_regs {
+       const unsigned long cr0;
+       const unsigned long cr4;
+       const u64 efer;
+};
+
 #define CREATE_TRACE_POINTS
 #include "mmutrace.h"
 
+/*
+ * Yes, lot's of underscores.  They're a hint that you probably shouldn't be
+ * reading from the role_regs.  Once the mmu_role is constructed, it becomes
+ * the single source of truth for the MMU's state.
+ */
+#define BUILD_MMU_ROLE_REGS_ACCESSOR(reg, name, flag)                  \
+static inline bool ____is_##reg##_##name(struct kvm_mmu_role_regs *regs)\
+{                                                                      \
+       return !!(regs->reg & flag);                                    \
+}
+BUILD_MMU_ROLE_REGS_ACCESSOR(cr0, pg, X86_CR0_PG);
+BUILD_MMU_ROLE_REGS_ACCESSOR(cr0, wp, X86_CR0_WP);
+BUILD_MMU_ROLE_REGS_ACCESSOR(cr4, pse, X86_CR4_PSE);
+BUILD_MMU_ROLE_REGS_ACCESSOR(cr4, pae, X86_CR4_PAE);
+BUILD_MMU_ROLE_REGS_ACCESSOR(cr4, smep, X86_CR4_SMEP);
+BUILD_MMU_ROLE_REGS_ACCESSOR(cr4, smap, X86_CR4_SMAP);
+BUILD_MMU_ROLE_REGS_ACCESSOR(cr4, pke, X86_CR4_PKE);
+BUILD_MMU_ROLE_REGS_ACCESSOR(cr4, la57, X86_CR4_LA57);
+BUILD_MMU_ROLE_REGS_ACCESSOR(efer, nx, EFER_NX);
+BUILD_MMU_ROLE_REGS_ACCESSOR(efer, lma, EFER_LMA);
+
+/*
+ * The MMU itself (with a valid role) is the single source of truth for the
+ * MMU.  Do not use the regs used to build the MMU/role, nor the vCPU.  The
+ * regs don't account for dependencies, e.g. clearing CR4 bits if CR0.PG=1,
+ * and the vCPU may be incorrect/irrelevant.
+ */
+#define BUILD_MMU_ROLE_ACCESSOR(base_or_ext, reg, name)                \
+static inline bool is_##reg##_##name(struct kvm_mmu *mmu)      \
+{                                                              \
+       return !!(mmu->mmu_role. base_or_ext . reg##_##name);   \
+}
+BUILD_MMU_ROLE_ACCESSOR(ext,  cr0, pg);
+BUILD_MMU_ROLE_ACCESSOR(base, cr0, wp);
+BUILD_MMU_ROLE_ACCESSOR(ext,  cr4, pse);
+BUILD_MMU_ROLE_ACCESSOR(ext,  cr4, pae);
+BUILD_MMU_ROLE_ACCESSOR(ext,  cr4, smep);
+BUILD_MMU_ROLE_ACCESSOR(ext,  cr4, smap);
+BUILD_MMU_ROLE_ACCESSOR(ext,  cr4, pke);
+BUILD_MMU_ROLE_ACCESSOR(ext,  cr4, la57);
+BUILD_MMU_ROLE_ACCESSOR(base, efer, nx);
+
+static struct kvm_mmu_role_regs vcpu_to_role_regs(struct kvm_vcpu *vcpu)
+{
+       struct kvm_mmu_role_regs regs = {
+               .cr0 = kvm_read_cr0_bits(vcpu, KVM_MMU_CR0_ROLE_BITS),
+               .cr4 = kvm_read_cr4_bits(vcpu, KVM_MMU_CR4_ROLE_BITS),
+               .efer = vcpu->arch.efer,
+       };
+
+       return regs;
+}
 
 static inline bool kvm_available_flush_tlb_with_range(void)
 {
@@ -208,11 +266,6 @@ void kvm_flush_remote_tlbs_with_address(struct kvm *kvm,
        kvm_flush_remote_tlbs_with_range(kvm, &range);
 }
 
-bool is_nx_huge_page_enabled(void)
-{
-       return READ_ONCE(nx_huge_pages);
-}
-
 static void mark_mmio_spte(struct kvm_vcpu *vcpu, u64 *sptep, u64 gfn,
                           unsigned int access)
 {
@@ -1177,8 +1230,7 @@ static bool __rmap_clear_dirty(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
  * @gfn_offset: start of the BITS_PER_LONG pages we care about
  * @mask: indicates which pages we should protect
  *
- * Used when we do not need to care about huge page mappings: e.g. during dirty
- * logging we do not have any such mappings.
+ * Used when we do not need to care about huge page mappings.
  */
 static void kvm_mmu_write_protect_pt_masked(struct kvm *kvm,
                                     struct kvm_memory_slot *slot,
@@ -1189,6 +1241,10 @@ static void kvm_mmu_write_protect_pt_masked(struct kvm *kvm,
        if (is_tdp_mmu_enabled(kvm))
                kvm_tdp_mmu_clear_dirty_pt_masked(kvm, slot,
                                slot->base_gfn + gfn_offset, mask, true);
+
+       if (!kvm_memslots_have_rmaps(kvm))
+               return;
+
        while (mask) {
                rmap_head = __gfn_to_rmap(slot->base_gfn + gfn_offset + __ffs(mask),
                                          PG_LEVEL_4K, slot);
@@ -1218,6 +1274,10 @@ static void kvm_mmu_clear_dirty_pt_masked(struct kvm *kvm,
        if (is_tdp_mmu_enabled(kvm))
                kvm_tdp_mmu_clear_dirty_pt_masked(kvm, slot,
                                slot->base_gfn + gfn_offset, mask, false);
+
+       if (!kvm_memslots_have_rmaps(kvm))
+               return;
+
        while (mask) {
                rmap_head = __gfn_to_rmap(slot->base_gfn + gfn_offset + __ffs(mask),
                                          PG_LEVEL_4K, slot);
@@ -1235,13 +1295,36 @@ static void kvm_mmu_clear_dirty_pt_masked(struct kvm *kvm,
  * It calls kvm_mmu_write_protect_pt_masked to write protect selected pages to
  * enable dirty logging for them.
  *
- * Used when we do not need to care about huge page mappings: e.g. during dirty
- * logging we do not have any such mappings.
+ * We need to care about huge page mappings: e.g. during dirty logging we may
+ * have such mappings.
  */
 void kvm_arch_mmu_enable_log_dirty_pt_masked(struct kvm *kvm,
                                struct kvm_memory_slot *slot,
                                gfn_t gfn_offset, unsigned long mask)
 {
+       /*
+        * Huge pages are NOT write protected when we start dirty logging in
+        * initially-all-set mode; must write protect them here so that they
+        * are split to 4K on the first write.
+        *
+        * The gfn_offset is guaranteed to be aligned to 64, but the base_gfn
+        * of memslot has no such restriction, so the range can cross two large
+        * pages.
+        */
+       if (kvm_dirty_log_manual_protect_and_init_set(kvm)) {
+               gfn_t start = slot->base_gfn + gfn_offset + __ffs(mask);
+               gfn_t end = slot->base_gfn + gfn_offset + __fls(mask);
+
+               kvm_mmu_slot_gfn_write_protect(kvm, slot, start, PG_LEVEL_2M);
+
+               /* Cross two large pages? */
+               if (ALIGN(start << PAGE_SHIFT, PMD_SIZE) !=
+                   ALIGN(end << PAGE_SHIFT, PMD_SIZE))
+                       kvm_mmu_slot_gfn_write_protect(kvm, slot, end,
+                                                      PG_LEVEL_2M);
+       }
+
+       /* Now handle 4K PTEs.  */
        if (kvm_x86_ops.cpu_dirty_log_size)
                kvm_mmu_clear_dirty_pt_masked(kvm, slot, gfn_offset, mask);
        else
@@ -1254,20 +1337,23 @@ int kvm_cpu_dirty_log_size(void)
 }
 
 bool kvm_mmu_slot_gfn_write_protect(struct kvm *kvm,
-                                   struct kvm_memory_slot *slot, u64 gfn)
+                                   struct kvm_memory_slot *slot, u64 gfn,
+                                   int min_level)
 {
        struct kvm_rmap_head *rmap_head;
        int i;
        bool write_protected = false;
 
-       for (i = PG_LEVEL_4K; i <= KVM_MAX_HUGEPAGE_LEVEL; ++i) {
-               rmap_head = __gfn_to_rmap(gfn, i, slot);
-               write_protected |= __rmap_write_protect(kvm, rmap_head, true);
+       if (kvm_memslots_have_rmaps(kvm)) {
+               for (i = min_level; i <= KVM_MAX_HUGEPAGE_LEVEL; ++i) {
+                       rmap_head = __gfn_to_rmap(gfn, i, slot);
+                       write_protected |= __rmap_write_protect(kvm, rmap_head, true);
+               }
        }
 
        if (is_tdp_mmu_enabled(kvm))
                write_protected |=
-                       kvm_tdp_mmu_write_protect_gfn(kvm, slot, gfn);
+                       kvm_tdp_mmu_write_protect_gfn(kvm, slot, gfn, min_level);
 
        return write_protected;
 }
@@ -1277,7 +1363,7 @@ static bool rmap_write_protect(struct kvm_vcpu *vcpu, u64 gfn)
        struct kvm_memory_slot *slot;
 
        slot = kvm_vcpu_gfn_to_memslot(vcpu, gfn);
-       return kvm_mmu_slot_gfn_write_protect(vcpu->kvm, slot, gfn);
+       return kvm_mmu_slot_gfn_write_protect(vcpu->kvm, slot, gfn, PG_LEVEL_4K);
 }
 
 static bool kvm_zap_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
@@ -1433,9 +1519,10 @@ static __always_inline bool kvm_handle_gfn_range(struct kvm *kvm,
 
 bool kvm_unmap_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-       bool flush;
+       bool flush = false;
 
-       flush = kvm_handle_gfn_range(kvm, range, kvm_unmap_rmapp);
+       if (kvm_memslots_have_rmaps(kvm))
+               flush = kvm_handle_gfn_range(kvm, range, kvm_unmap_rmapp);
 
        if (is_tdp_mmu_enabled(kvm))
                flush |= kvm_tdp_mmu_unmap_gfn_range(kvm, range, flush);
@@ -1445,9 +1532,10 @@ bool kvm_unmap_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range)
 
 bool kvm_set_spte_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-       bool flush;
+       bool flush = false;
 
-       flush = kvm_handle_gfn_range(kvm, range, kvm_set_pte_rmapp);
+       if (kvm_memslots_have_rmaps(kvm))
+               flush = kvm_handle_gfn_range(kvm, range, kvm_set_pte_rmapp);
 
        if (is_tdp_mmu_enabled(kvm))
                flush |= kvm_tdp_mmu_set_spte_gfn(kvm, range);
@@ -1500,9 +1588,10 @@ static void rmap_recycle(struct kvm_vcpu *vcpu, u64 *spte, gfn_t gfn)
 
 bool kvm_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-       bool young;
+       bool young = false;
 
-       young = kvm_handle_gfn_range(kvm, range, kvm_age_rmapp);
+       if (kvm_memslots_have_rmaps(kvm))
+               young = kvm_handle_gfn_range(kvm, range, kvm_age_rmapp);
 
        if (is_tdp_mmu_enabled(kvm))
                young |= kvm_tdp_mmu_age_gfn_range(kvm, range);
@@ -1512,9 +1601,10 @@ bool kvm_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 
 bool kvm_test_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-       bool young;
+       bool young = false;
 
-       young = kvm_handle_gfn_range(kvm, range, kvm_test_age_rmapp);
+       if (kvm_memslots_have_rmaps(kvm))
+               young = kvm_handle_gfn_range(kvm, range, kvm_test_age_rmapp);
 
        if (is_tdp_mmu_enabled(kvm))
                young |= kvm_tdp_mmu_test_age_gfn(kvm, range);
@@ -1748,17 +1838,10 @@ static void kvm_mmu_commit_zap_page(struct kvm *kvm,
          &(_kvm)->arch.mmu_page_hash[kvm_page_table_hashfn(_gfn)])     \
                if ((_sp)->gfn != (_gfn) || (_sp)->role.direct) {} else
 
-static inline bool is_ept_sp(struct kvm_mmu_page *sp)
-{
-       return sp->role.cr0_wp && sp->role.smap_andnot_wp;
-}
-
-/* @sp->gfn should be write-protected at the call site */
-static bool __kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
-                           struct list_head *invalid_list)
+static bool kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
+                        struct list_head *invalid_list)
 {
-       if ((!is_ept_sp(sp) && sp->role.gpte_is_8_bytes != !!is_pae(vcpu)) ||
-           vcpu->arch.mmu->sync_page(vcpu, sp) == 0) {
+       if (vcpu->arch.mmu->sync_page(vcpu, sp) == 0) {
                kvm_mmu_prepare_zap_page(vcpu->kvm, sp, invalid_list);
                return false;
        }
@@ -1804,31 +1887,6 @@ static bool is_obsolete_sp(struct kvm *kvm, struct kvm_mmu_page *sp)
               unlikely(sp->mmu_valid_gen != kvm->arch.mmu_valid_gen);
 }
 
-static bool kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
-                        struct list_head *invalid_list)
-{
-       kvm_unlink_unsync_page(vcpu->kvm, sp);
-       return __kvm_sync_page(vcpu, sp, invalid_list);
-}
-
-/* @gfn should be write-protected at the call site */
-static bool kvm_sync_pages(struct kvm_vcpu *vcpu, gfn_t gfn,
-                          struct list_head *invalid_list)
-{
-       struct kvm_mmu_page *s;
-       bool ret = false;
-
-       for_each_gfn_indirect_valid_sp(vcpu->kvm, s, gfn) {
-               if (!s->unsync)
-                       continue;
-
-               WARN_ON(s->role.level != PG_LEVEL_4K);
-               ret |= kvm_sync_page(vcpu, s, invalid_list);
-       }
-
-       return ret;
-}
-
 struct mmu_page_path {
        struct kvm_mmu_page *parent[PT64_ROOT_MAX_LEVEL];
        unsigned int idx[PT64_ROOT_MAX_LEVEL];
@@ -1923,6 +1981,7 @@ static void mmu_sync_children(struct kvm_vcpu *vcpu,
                }
 
                for_each_sp(pages, sp, parents, i) {
+                       kvm_unlink_unsync_page(vcpu->kvm, sp);
                        flush |= kvm_sync_page(vcpu, sp, &invalid_list);
                        mmu_pages_clear_parents(&parents);
                }
@@ -1958,8 +2017,6 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
        struct hlist_head *sp_list;
        unsigned quadrant;
        struct kvm_mmu_page *sp;
-       bool need_sync = false;
-       bool flush = false;
        int collisions = 0;
        LIST_HEAD(invalid_list);
 
@@ -1982,20 +2039,39 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
                        continue;
                }
 
-               if (!need_sync && sp->unsync)
-                       need_sync = true;
-
-               if (sp->role.word != role.word)
+               if (sp->role.word != role.word) {
+                       /*
+                        * If the guest is creating an upper-level page, zap
+                        * unsync pages for the same gfn.  While it's possible
+                        * the guest is using recursive page tables, in all
+                        * likelihood the guest has stopped using the unsync
+                        * page and is installing a completely unrelated page.
+                        * Unsync pages must not be left as is, because the new
+                        * upper-level page will be write-protected.
+                        */
+                       if (level > PG_LEVEL_4K && sp->unsync)
+                               kvm_mmu_prepare_zap_page(vcpu->kvm, sp,
+                                                        &invalid_list);
                        continue;
+               }
 
                if (direct_mmu)
                        goto trace_get_page;
 
                if (sp->unsync) {
-                       /* The page is good, but __kvm_sync_page might still end
-                        * up zapping it.  If so, break in order to rebuild it.
+                       /*
+                        * The page is good, but is stale.  kvm_sync_page does
+                        * get the latest guest state, but (unlike mmu_unsync_children)
+                        * it doesn't write-protect the page or mark it synchronized!
+                        * This way the validity of the mapping is ensured, but the
+                        * overhead of write protection is not incurred until the
+                        * guest invalidates the TLB mapping.  This allows multiple
+                        * SPs for a single gfn to be unsync.
+                        *
+                        * If the sync fails, the page is zapped.  If so, break
+                        * in order to rebuild it.
                         */
-                       if (!__kvm_sync_page(vcpu, sp, &invalid_list))
+                       if (!kvm_sync_page(vcpu, sp, &invalid_list))
                                break;
 
                        WARN_ON(!list_empty(&invalid_list));
@@ -2020,22 +2096,14 @@ trace_get_page:
        sp->role = role;
        hlist_add_head(&sp->hash_link, sp_list);
        if (!direct) {
-               /*
-                * we should do write protection before syncing pages
-                * otherwise the content of the synced shadow page may
-                * be inconsistent with guest page table.
-                */
                account_shadowed(vcpu->kvm, sp);
                if (level == PG_LEVEL_4K && rmap_write_protect(vcpu, gfn))
                        kvm_flush_remote_tlbs_with_address(vcpu->kvm, gfn, 1);
-
-               if (level > PG_LEVEL_4K && need_sync)
-                       flush |= kvm_sync_pages(vcpu, gfn, &invalid_list);
        }
        trace_kvm_mmu_get_page(sp, true);
-
-       kvm_mmu_flush_or_zap(vcpu, &invalid_list, false, flush);
 out:
+       kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
+
        if (collisions > vcpu->kvm->stat.max_mmu_page_hash_collisions)
                vcpu->kvm->stat.max_mmu_page_hash_collisions = collisions;
        return sp;
@@ -2448,17 +2516,33 @@ static void kvm_unsync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
        kvm_mmu_mark_parents_unsync(sp);
 }
 
-bool mmu_need_write_protect(struct kvm_vcpu *vcpu, gfn_t gfn,
-                           bool can_unsync)
+/*
+ * Attempt to unsync any shadow pages that can be reached by the specified gfn,
+ * KVM is creating a writable mapping for said gfn.  Returns 0 if all pages
+ * were marked unsync (or if there is no shadow page), -EPERM if the SPTE must
+ * be write-protected.
+ */
+int mmu_try_to_unsync_pages(struct kvm_vcpu *vcpu, gfn_t gfn, bool can_unsync)
 {
        struct kvm_mmu_page *sp;
 
+       /*
+        * Force write-protection if the page is being tracked.  Note, the page
+        * track machinery is used to write-protect upper-level shadow pages,
+        * i.e. this guards the role.level == 4K assertion below!
+        */
        if (kvm_page_track_is_active(vcpu, gfn, KVM_PAGE_TRACK_WRITE))
-               return true;
+               return -EPERM;
 
+       /*
+        * The page is not write-tracked, mark existing shadow pages unsync
+        * unless KVM is synchronizing an unsync SP (can_unsync = false).  In
+        * that case, KVM must complete emulation of the guest TLB flush before
+        * allowing shadow pages to become unsync (writable by the guest).
+        */
        for_each_gfn_indirect_valid_sp(vcpu->kvm, sp, gfn) {
                if (!can_unsync)
-                       return true;
+                       return -EPERM;
 
                if (sp->unsync)
                        continue;
@@ -2489,8 +2573,8 @@ bool mmu_need_write_protect(struct kvm_vcpu *vcpu, gfn_t gfn,
         *                      2.2 Guest issues TLB flush.
         *                          That causes a VM Exit.
         *
-        *                      2.3 kvm_mmu_sync_pages() reads sp->unsync.
-        *                          Since it is false, so it just returns.
+        *                      2.3 Walking of unsync pages sees sp->unsync is
+        *                          false and skips the page.
         *
         *                      2.4 Guest accesses GVA X.
         *                          Since the mapping in the SP was not updated,
@@ -2506,7 +2590,7 @@ bool mmu_need_write_protect(struct kvm_vcpu *vcpu, gfn_t gfn,
         */
        smp_wmb();
 
-       return false;
+       return 0;
 }
 
 static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
@@ -2827,9 +2911,6 @@ static int __direct_map(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
        gfn_t gfn = gpa >> PAGE_SHIFT;
        gfn_t base_gfn = gfn;
 
-       if (WARN_ON(!VALID_PAGE(vcpu->arch.mmu->root_hpa)))
-               return RET_PF_RETRY;
-
        level = kvm_mmu_hugepage_adjust(vcpu, gfn, max_level, &pfn,
                                        huge_page_disallowed, &req_level);
 
@@ -3180,6 +3261,33 @@ void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_free_roots);
 
+void kvm_mmu_free_guest_mode_roots(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu)
+{
+       unsigned long roots_to_free = 0;
+       hpa_t root_hpa;
+       int i;
+
+       /*
+        * This should not be called while L2 is active, L2 can't invalidate
+        * _only_ its own roots, e.g. INVVPID unconditionally exits.
+        */
+       WARN_ON_ONCE(mmu->mmu_role.base.guest_mode);
+
+       for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++) {
+               root_hpa = mmu->prev_roots[i].hpa;
+               if (!VALID_PAGE(root_hpa))
+                       continue;
+
+               if (!to_shadow_page(root_hpa) ||
+                       to_shadow_page(root_hpa)->role.guest_mode)
+                       roots_to_free |= KVM_MMU_ROOT_PREVIOUS(i);
+       }
+
+       kvm_mmu_free_roots(vcpu, mmu, roots_to_free);
+}
+EXPORT_SYMBOL_GPL(kvm_mmu_free_guest_mode_roots);
+
+
 static int mmu_check_root(struct kvm_vcpu *vcpu, gfn_t root_gfn)
 {
        int ret = 0;
@@ -3280,6 +3388,10 @@ static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
                }
        }
 
+       r = alloc_all_memslots_rmaps(vcpu->kvm);
+       if (r)
+               return r;
+
        write_lock(&vcpu->kvm->mmu_lock);
        r = make_mmu_pages_available(vcpu);
        if (r < 0)
@@ -3423,8 +3535,8 @@ void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu)
                 * flush strictly after those changes are made. We only need to
                 * ensure that the other CPU sets these flags before any actual
                 * changes to the page tables are made. The comments in
-                * mmu_need_write_protect() describe what could go wrong if this
-                * requirement isn't satisfied.
+                * mmu_try_to_unsync_pages() describe what could go wrong if
+                * this requirement isn't satisfied.
                 */
                if (!smp_load_acquire(&sp->unsync) &&
                    !smp_load_acquire(&sp->unsync_children))
@@ -3540,12 +3652,7 @@ static bool get_mmio_spte(struct kvm_vcpu *vcpu, u64 addr, u64 *sptep)
        int root, leaf, level;
        bool reserved = false;
 
-       if (!VALID_PAGE(vcpu->arch.mmu->root_hpa)) {
-               *sptep = 0ull;
-               return reserved;
-       }
-
-       if (is_tdp_mmu_root(vcpu->kvm, vcpu->arch.mmu->root_hpa))
+       if (is_tdp_mmu(vcpu->arch.mmu))
                leaf = kvm_tdp_mmu_get_walk(vcpu, addr, sptes, &root);
        else
                leaf = get_walk(vcpu, addr, sptes, &root);
@@ -3717,6 +3824,7 @@ static bool try_async_pf(struct kvm_vcpu *vcpu, bool prefault, gfn_t gfn,
 static int direct_page_fault(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
                             bool prefault, int max_level, bool is_tdp)
 {
+       bool is_tdp_mmu_fault = is_tdp_mmu(vcpu->arch.mmu);
        bool write = error_code & PFERR_WRITE_MASK;
        bool map_writable;
 
@@ -3729,7 +3837,7 @@ static int direct_page_fault(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
        if (page_fault_handle_page_track(vcpu, error_code, gfn))
                return RET_PF_EMULATE;
 
-       if (!is_tdp_mmu_root(vcpu->kvm, vcpu->arch.mmu->root_hpa)) {
+       if (!is_tdp_mmu_fault) {
                r = fast_page_fault(vcpu, gpa, error_code);
                if (r != RET_PF_INVALID)
                        return r;
@@ -3751,7 +3859,7 @@ static int direct_page_fault(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
 
        r = RET_PF_RETRY;
 
-       if (is_tdp_mmu_root(vcpu->kvm, vcpu->arch.mmu->root_hpa))
+       if (is_tdp_mmu_fault)
                read_lock(&vcpu->kvm->mmu_lock);
        else
                write_lock(&vcpu->kvm->mmu_lock);
@@ -3762,7 +3870,7 @@ static int direct_page_fault(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
        if (r)
                goto out_unlock;
 
-       if (is_tdp_mmu_root(vcpu->kvm, vcpu->arch.mmu->root_hpa))
+       if (is_tdp_mmu_fault)
                r = kvm_tdp_mmu_map(vcpu, gpa, error_code, map_writable, max_level,
                                    pfn, prefault);
        else
@@ -3770,7 +3878,7 @@ static int direct_page_fault(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
                                 prefault, is_tdp);
 
 out_unlock:
-       if (is_tdp_mmu_root(vcpu->kvm, vcpu->arch.mmu->root_hpa))
+       if (is_tdp_mmu_fault)
                read_unlock(&vcpu->kvm->mmu_lock);
        else
                write_unlock(&vcpu->kvm->mmu_lock);
@@ -3848,7 +3956,6 @@ static void nonpaging_init_context(struct kvm_vcpu *vcpu,
        context->sync_page = nonpaging_sync_page;
        context->invlpg = NULL;
        context->root_level = 0;
-       context->shadow_root_level = PT32E_ROOT_LEVEL;
        context->direct_map = true;
        context->nx = false;
 }
@@ -3913,8 +4020,7 @@ static bool fast_pgd_switch(struct kvm_vcpu *vcpu, gpa_t new_pgd,
 }
 
 static void __kvm_mmu_new_pgd(struct kvm_vcpu *vcpu, gpa_t new_pgd,
-                             union kvm_mmu_page_role new_role,
-                             bool skip_tlb_flush, bool skip_mmu_sync)
+                             union kvm_mmu_page_role new_role)
 {
        if (!fast_pgd_switch(vcpu, new_pgd, new_role)) {
                kvm_mmu_free_roots(vcpu, vcpu->arch.mmu, KVM_MMU_ROOT_CURRENT);
@@ -3929,10 +4035,10 @@ static void __kvm_mmu_new_pgd(struct kvm_vcpu *vcpu, gpa_t new_pgd,
         */
        kvm_make_request(KVM_REQ_LOAD_MMU_PGD, vcpu);
 
-       if (!skip_mmu_sync || force_flush_and_sync_on_reuse)
+       if (force_flush_and_sync_on_reuse) {
                kvm_make_request(KVM_REQ_MMU_SYNC, vcpu);
-       if (!skip_tlb_flush || force_flush_and_sync_on_reuse)
                kvm_make_request(KVM_REQ_TLB_FLUSH_CURRENT, vcpu);
+       }
 
        /*
         * The last MMIO access's GVA and GPA are cached in the VCPU. When
@@ -3951,11 +4057,9 @@ static void __kvm_mmu_new_pgd(struct kvm_vcpu *vcpu, gpa_t new_pgd,
                                to_shadow_page(vcpu->arch.mmu->root_hpa));
 }
 
-void kvm_mmu_new_pgd(struct kvm_vcpu *vcpu, gpa_t new_pgd, bool skip_tlb_flush,
-                    bool skip_mmu_sync)
+void kvm_mmu_new_pgd(struct kvm_vcpu *vcpu, gpa_t new_pgd)
 {
-       __kvm_mmu_new_pgd(vcpu, new_pgd, kvm_mmu_calc_root_page_role(vcpu),
-                         skip_tlb_flush, skip_mmu_sync);
+       __kvm_mmu_new_pgd(vcpu, new_pgd, kvm_mmu_calc_root_page_role(vcpu));
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_new_pgd);
 
@@ -4165,11 +4269,18 @@ static inline u64 reserved_hpa_bits(void)
  * table in guest or amd nested guest, its mmu features completely
  * follow the features in guest.
  */
-void
-reset_shadow_zero_bits_mask(struct kvm_vcpu *vcpu, struct kvm_mmu *context)
+static void reset_shadow_zero_bits_mask(struct kvm_vcpu *vcpu,
+                                       struct kvm_mmu *context)
 {
-       bool uses_nx = context->nx ||
-               context->mmu_role.base.smep_andnot_wp;
+       /*
+        * KVM uses NX when TDP is disabled to handle a variety of scenarios,
+        * notably for huge SPTEs if iTLB multi-hit mitigation is enabled and
+        * to generate correct permissions for CR0.WP=0/CR4.SMEP=1/EFER.NX=0.
+        * The iTLB multi-hit workaround can be toggled at any time, so assume
+        * NX can be used by any non-nested shadow MMU to avoid having to reset
+        * MMU contexts.  Note, KVM forces EFER.NX=1 when TDP is disabled.
+        */
+       bool uses_nx = context->nx || !tdp_enabled;
        struct rsvd_bits_validate *shadow_zero_check;
        int i;
 
@@ -4193,7 +4304,6 @@ reset_shadow_zero_bits_mask(struct kvm_vcpu *vcpu, struct kvm_mmu *context)
        }
 
 }
-EXPORT_SYMBOL_GPL(reset_shadow_zero_bits_mask);
 
 static inline bool boot_cpu_is_amd(void)
 {
@@ -4413,22 +4523,16 @@ static void update_last_nonleaf_level(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu
 
 static void paging64_init_context_common(struct kvm_vcpu *vcpu,
                                         struct kvm_mmu *context,
-                                        int level)
+                                        int root_level)
 {
        context->nx = is_nx(vcpu);
-       context->root_level = level;
-
-       reset_rsvds_bits_mask(vcpu, context);
-       update_permission_bitmask(vcpu, context, false);
-       update_pkru_bitmask(vcpu, context, false);
-       update_last_nonleaf_level(vcpu, context);
+       context->root_level = root_level;
 
        MMU_WARN_ON(!is_pae(vcpu));
        context->page_fault = paging64_page_fault;
        context->gva_to_gpa = paging64_gva_to_gpa;
        context->sync_page = paging64_sync_page;
        context->invlpg = paging64_invlpg;
-       context->shadow_root_level = level;
        context->direct_map = false;
 }
 
@@ -4446,17 +4550,10 @@ static void paging32_init_context(struct kvm_vcpu *vcpu,
 {
        context->nx = false;
        context->root_level = PT32_ROOT_LEVEL;
-
-       reset_rsvds_bits_mask(vcpu, context);
-       update_permission_bitmask(vcpu, context, false);
-       update_pkru_bitmask(vcpu, context, false);
-       update_last_nonleaf_level(vcpu, context);
-
        context->page_fault = paging32_page_fault;
        context->gva_to_gpa = paging32_gva_to_gpa;
        context->sync_page = paging32_sync_page;
        context->invlpg = paging32_invlpg;
-       context->shadow_root_level = PT32E_ROOT_LEVEL;
        context->direct_map = false;
 }
 
@@ -4466,17 +4563,18 @@ static void paging32E_init_context(struct kvm_vcpu *vcpu,
        paging64_init_context_common(vcpu, context, PT32E_ROOT_LEVEL);
 }
 
-static union kvm_mmu_extended_role kvm_calc_mmu_role_ext(struct kvm_vcpu *vcpu)
+static union kvm_mmu_extended_role kvm_calc_mmu_role_ext(struct kvm_vcpu *vcpu,
+                                                        struct kvm_mmu_role_regs *regs)
 {
        union kvm_mmu_extended_role ext = {0};
 
-       ext.cr0_pg = !!is_paging(vcpu);
-       ext.cr4_pae = !!is_pae(vcpu);
-       ext.cr4_smep = !!kvm_read_cr4_bits(vcpu, X86_CR4_SMEP);
-       ext.cr4_smap = !!kvm_read_cr4_bits(vcpu, X86_CR4_SMAP);
-       ext.cr4_pse = !!is_pse(vcpu);
-       ext.cr4_pke = !!kvm_read_cr4_bits(vcpu, X86_CR4_PKE);
-       ext.maxphyaddr = cpuid_maxphyaddr(vcpu);
+       ext.cr0_pg = ____is_cr0_pg(regs);
+       ext.cr4_pae = ____is_cr4_pae(regs);
+       ext.cr4_smep = ____is_cr4_smep(regs);
+       ext.cr4_smap = ____is_cr4_smap(regs);
+       ext.cr4_pse = ____is_cr4_pse(regs);
+       ext.cr4_pke = ____is_cr4_pke(regs);
+       ext.cr4_la57 = ____is_cr4_la57(regs);
 
        ext.valid = 1;
 
@@ -4484,20 +4582,21 @@ static union kvm_mmu_extended_role kvm_calc_mmu_role_ext(struct kvm_vcpu *vcpu)
 }
 
 static union kvm_mmu_role kvm_calc_mmu_role_common(struct kvm_vcpu *vcpu,
+                                                  struct kvm_mmu_role_regs *regs,
                                                   bool base_only)
 {
        union kvm_mmu_role role = {0};
 
        role.base.access = ACC_ALL;
-       role.base.nxe = !!is_nx(vcpu);
-       role.base.cr0_wp = is_write_protection(vcpu);
+       role.base.efer_nx = ____is_efer_nx(regs);
+       role.base.cr0_wp = ____is_cr0_wp(regs);
        role.base.smm = is_smm(vcpu);
        role.base.guest_mode = is_guest_mode(vcpu);
 
        if (base_only)
                return role;
 
-       role.ext = kvm_calc_mmu_role_ext(vcpu);
+       role.ext = kvm_calc_mmu_role_ext(vcpu, regs);
 
        return role;
 }
@@ -4512,9 +4611,10 @@ static inline int kvm_mmu_get_tdp_level(struct kvm_vcpu *vcpu)
 }
 
 static union kvm_mmu_role
-kvm_calc_tdp_mmu_root_page_role(struct kvm_vcpu *vcpu, bool base_only)
+kvm_calc_tdp_mmu_root_page_role(struct kvm_vcpu *vcpu,
+                               struct kvm_mmu_role_regs *regs, bool base_only)
 {
-       union kvm_mmu_role role = kvm_calc_mmu_role_common(vcpu, base_only);
+       union kvm_mmu_role role = kvm_calc_mmu_role_common(vcpu, regs, base_only);
 
        role.base.ad_disabled = (shadow_accessed_mask == 0);
        role.base.level = kvm_mmu_get_tdp_level(vcpu);
@@ -4527,8 +4627,9 @@ kvm_calc_tdp_mmu_root_page_role(struct kvm_vcpu *vcpu, bool base_only)
 static void init_kvm_tdp_mmu(struct kvm_vcpu *vcpu)
 {
        struct kvm_mmu *context = &vcpu->arch.root_mmu;
+       struct kvm_mmu_role_regs regs = vcpu_to_role_regs(vcpu);
        union kvm_mmu_role new_role =
-               kvm_calc_tdp_mmu_root_page_role(vcpu, false);
+               kvm_calc_tdp_mmu_root_page_role(vcpu, &regs, false);
 
        if (new_role.as_u64 == context->mmu_role.as_u64)
                return;
@@ -4572,30 +4673,30 @@ static void init_kvm_tdp_mmu(struct kvm_vcpu *vcpu)
 }
 
 static union kvm_mmu_role
-kvm_calc_shadow_root_page_role_common(struct kvm_vcpu *vcpu, bool base_only)
+kvm_calc_shadow_root_page_role_common(struct kvm_vcpu *vcpu,
+                                     struct kvm_mmu_role_regs *regs, bool base_only)
 {
-       union kvm_mmu_role role = kvm_calc_mmu_role_common(vcpu, base_only);
+       union kvm_mmu_role role = kvm_calc_mmu_role_common(vcpu, regs, base_only);
 
-       role.base.smep_andnot_wp = role.ext.cr4_smep &&
-               !is_write_protection(vcpu);
-       role.base.smap_andnot_wp = role.ext.cr4_smap &&
-               !is_write_protection(vcpu);
-       role.base.gpte_is_8_bytes = !!is_pae(vcpu);
+       role.base.smep_andnot_wp = role.ext.cr4_smep && !____is_cr0_wp(regs);
+       role.base.smap_andnot_wp = role.ext.cr4_smap && !____is_cr0_wp(regs);
+       role.base.gpte_is_8_bytes = ____is_cr4_pae(regs);
 
        return role;
 }
 
 static union kvm_mmu_role
-kvm_calc_shadow_mmu_root_page_role(struct kvm_vcpu *vcpu, bool base_only)
+kvm_calc_shadow_mmu_root_page_role(struct kvm_vcpu *vcpu,
+                                  struct kvm_mmu_role_regs *regs, bool base_only)
 {
        union kvm_mmu_role role =
-               kvm_calc_shadow_root_page_role_common(vcpu, base_only);
+               kvm_calc_shadow_root_page_role_common(vcpu, regs, base_only);
 
-       role.base.direct = !is_paging(vcpu);
+       role.base.direct = !____is_cr0_pg(regs);
 
-       if (!is_long_mode(vcpu))
+       if (!____is_efer_lma(regs))
                role.base.level = PT32E_ROOT_LEVEL;
-       else if (is_la57_mode(vcpu))
+       else if (____is_cr4_la57(regs))
                role.base.level = PT64_ROOT_5LEVEL;
        else
                role.base.level = PT64_ROOT_4LEVEL;
@@ -4604,37 +4705,47 @@ kvm_calc_shadow_mmu_root_page_role(struct kvm_vcpu *vcpu, bool base_only)
 }
 
 static void shadow_mmu_init_context(struct kvm_vcpu *vcpu, struct kvm_mmu *context,
-                                   u32 cr0, u32 cr4, u32 efer,
+                                   struct kvm_mmu_role_regs *regs,
                                    union kvm_mmu_role new_role)
 {
-       if (!(cr0 & X86_CR0_PG))
+       if (!____is_cr0_pg(regs))
                nonpaging_init_context(vcpu, context);
-       else if (efer & EFER_LMA)
+       else if (____is_efer_lma(regs))
                paging64_init_context(vcpu, context);
-       else if (cr4 & X86_CR4_PAE)
+       else if (____is_cr4_pae(regs))
                paging32E_init_context(vcpu, context);
        else
                paging32_init_context(vcpu, context);
 
+       if (____is_cr0_pg(regs)) {
+               reset_rsvds_bits_mask(vcpu, context);
+               update_permission_bitmask(vcpu, context, false);
+               update_pkru_bitmask(vcpu, context, false);
+               update_last_nonleaf_level(vcpu, context);
+       }
+       context->shadow_root_level = new_role.base.level;
+
        context->mmu_role.as_u64 = new_role.as_u64;
        reset_shadow_zero_bits_mask(vcpu, context);
 }
 
-static void kvm_init_shadow_mmu(struct kvm_vcpu *vcpu, u32 cr0, u32 cr4, u32 efer)
+static void kvm_init_shadow_mmu(struct kvm_vcpu *vcpu,
+                               struct kvm_mmu_role_regs *regs)
 {
        struct kvm_mmu *context = &vcpu->arch.root_mmu;
        union kvm_mmu_role new_role =
-               kvm_calc_shadow_mmu_root_page_role(vcpu, false);
+               kvm_calc_shadow_mmu_root_page_role(vcpu, regs, false);
 
        if (new_role.as_u64 != context->mmu_role.as_u64)
-               shadow_mmu_init_context(vcpu, context, cr0, cr4, efer, new_role);
+               shadow_mmu_init_context(vcpu, context, regs, new_role);
 }
 
 static union kvm_mmu_role
-kvm_calc_shadow_npt_root_page_role(struct kvm_vcpu *vcpu)
+kvm_calc_shadow_npt_root_page_role(struct kvm_vcpu *vcpu,
+                                  struct kvm_mmu_role_regs *regs)
 {
        union kvm_mmu_role role =
-               kvm_calc_shadow_root_page_role_common(vcpu, false);
+               kvm_calc_shadow_root_page_role_common(vcpu, regs, false);
 
        role.base.direct = false;
        role.base.level = kvm_mmu_get_tdp_level(vcpu);
@@ -4642,23 +4753,29 @@ kvm_calc_shadow_npt_root_page_role(struct kvm_vcpu *vcpu)
        return role;
 }
 
-void kvm_init_shadow_npt_mmu(struct kvm_vcpu *vcpu, u32 cr0, u32 cr4, u32 efer,
-                            gpa_t nested_cr3)
+void kvm_init_shadow_npt_mmu(struct kvm_vcpu *vcpu, unsigned long cr0,
+                            unsigned long cr4, u64 efer, gpa_t nested_cr3)
 {
        struct kvm_mmu *context = &vcpu->arch.guest_mmu;
-       union kvm_mmu_role new_role = kvm_calc_shadow_npt_root_page_role(vcpu);
+       struct kvm_mmu_role_regs regs = {
+               .cr0 = cr0,
+               .cr4 = cr4,
+               .efer = efer,
+       };
+       union kvm_mmu_role new_role;
 
-       __kvm_mmu_new_pgd(vcpu, nested_cr3, new_role.base, false, false);
+       new_role = kvm_calc_shadow_npt_root_page_role(vcpu, &regs);
 
-       if (new_role.as_u64 != context->mmu_role.as_u64) {
-               shadow_mmu_init_context(vcpu, context, cr0, cr4, efer, new_role);
+       __kvm_mmu_new_pgd(vcpu, nested_cr3, new_role.base);
 
-               /*
-                * Override the level set by the common init helper, nested TDP
-                * always uses the host's TDP configuration.
-                */
-               context->shadow_root_level = new_role.base.level;
-       }
+       if (new_role.as_u64 != context->mmu_role.as_u64)
+               shadow_mmu_init_context(vcpu, context, &regs, new_role);
+
+       /*
+        * Redo the shadow bits, the reset done by shadow_mmu_init_context()
+        * (above) may use the wrong shadow_root_level.
+        */
+       reset_shadow_zero_bits_mask(vcpu, context);
 }
 EXPORT_SYMBOL_GPL(kvm_init_shadow_npt_mmu);
 
@@ -4678,15 +4795,10 @@ kvm_calc_shadow_ept_root_page_role(struct kvm_vcpu *vcpu, bool accessed_dirty,
        role.base.guest_mode = true;
        role.base.access = ACC_ALL;
 
-       /*
-        * WP=1 and NOT_WP=1 is an impossible combination, use WP and the
-        * SMAP variation to denote shadow EPT entries.
-        */
-       role.base.cr0_wp = true;
-       role.base.smap_andnot_wp = true;
-
-       role.ext = kvm_calc_mmu_role_ext(vcpu);
+       /* EPT, and thus nested EPT, does not consume CR0, CR4, nor EFER. */
+       role.ext.word = 0;
        role.ext.execonly = execonly;
+       role.ext.valid = 1;
 
        return role;
 }
@@ -4700,7 +4812,7 @@ void kvm_init_shadow_ept_mmu(struct kvm_vcpu *vcpu, bool execonly,
                kvm_calc_shadow_ept_root_page_role(vcpu, accessed_dirty,
                                                   execonly, level);
 
-       __kvm_mmu_new_pgd(vcpu, new_eptp, new_role.base, true, true);
+       __kvm_mmu_new_pgd(vcpu, new_eptp, new_role.base);
 
        if (new_role.as_u64 == context->mmu_role.as_u64)
                return;
@@ -4728,20 +4840,46 @@ EXPORT_SYMBOL_GPL(kvm_init_shadow_ept_mmu);
 static void init_kvm_softmmu(struct kvm_vcpu *vcpu)
 {
        struct kvm_mmu *context = &vcpu->arch.root_mmu;
+       struct kvm_mmu_role_regs regs = vcpu_to_role_regs(vcpu);
 
-       kvm_init_shadow_mmu(vcpu,
-                           kvm_read_cr0_bits(vcpu, X86_CR0_PG),
-                           kvm_read_cr4_bits(vcpu, X86_CR4_PAE),
-                           vcpu->arch.efer);
+       kvm_init_shadow_mmu(vcpu, &regs);
 
        context->get_guest_pgd     = get_cr3;
        context->get_pdptr         = kvm_pdptr_read;
        context->inject_page_fault = kvm_inject_page_fault;
 }
 
+static union kvm_mmu_role
+kvm_calc_nested_mmu_role(struct kvm_vcpu *vcpu, struct kvm_mmu_role_regs *regs)
+{
+       union kvm_mmu_role role;
+
+       role = kvm_calc_shadow_root_page_role_common(vcpu, regs, false);
+
+       /*
+        * Nested MMUs are used only for walking L2's gva->gpa, they never have
+        * shadow pages of their own and so "direct" has no meaning.   Set it
+        * to "true" to try to detect bogus usage of the nested MMU.
+        */
+       role.base.direct = true;
+
+       if (!____is_cr0_pg(regs))
+               role.base.level = 0;
+       else if (____is_efer_lma(regs))
+               role.base.level = ____is_cr4_la57(regs) ? PT64_ROOT_5LEVEL :
+                                                         PT64_ROOT_4LEVEL;
+       else if (____is_cr4_pae(regs))
+               role.base.level = PT32E_ROOT_LEVEL;
+       else
+               role.base.level = PT32_ROOT_LEVEL;
+
+       return role;
+}
+
 static void init_kvm_nested_mmu(struct kvm_vcpu *vcpu)
 {
-       union kvm_mmu_role new_role = kvm_calc_mmu_role_common(vcpu, false);
+       struct kvm_mmu_role_regs regs = vcpu_to_role_regs(vcpu);
+       union kvm_mmu_role new_role = kvm_calc_nested_mmu_role(vcpu, &regs);
        struct kvm_mmu *g_context = &vcpu->arch.nested_mmu;
 
        if (new_role.as_u64 == g_context->mmu_role.as_u64)
@@ -4793,17 +4931,8 @@ static void init_kvm_nested_mmu(struct kvm_vcpu *vcpu)
        update_last_nonleaf_level(vcpu, g_context);
 }
 
-void kvm_init_mmu(struct kvm_vcpu *vcpu, bool reset_roots)
+void kvm_init_mmu(struct kvm_vcpu *vcpu)
 {
-       if (reset_roots) {
-               uint i;
-
-               vcpu->arch.mmu->root_hpa = INVALID_PAGE;
-
-               for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++)
-                       vcpu->arch.mmu->prev_roots[i] = KVM_MMU_ROOT_INFO_INVALID;
-       }
-
        if (mmu_is_nested(vcpu))
                init_kvm_nested_mmu(vcpu);
        else if (tdp_enabled)
@@ -4816,20 +4945,53 @@ EXPORT_SYMBOL_GPL(kvm_init_mmu);
 static union kvm_mmu_page_role
 kvm_mmu_calc_root_page_role(struct kvm_vcpu *vcpu)
 {
+       struct kvm_mmu_role_regs regs = vcpu_to_role_regs(vcpu);
        union kvm_mmu_role role;
 
        if (tdp_enabled)
-               role = kvm_calc_tdp_mmu_root_page_role(vcpu, true);
+               role = kvm_calc_tdp_mmu_root_page_role(vcpu, &regs, true);
        else
-               role = kvm_calc_shadow_mmu_root_page_role(vcpu, true);
+               role = kvm_calc_shadow_mmu_root_page_role(vcpu, &regs, true);
 
        return role.base;
 }
 
+void kvm_mmu_after_set_cpuid(struct kvm_vcpu *vcpu)
+{
+       /*
+        * Invalidate all MMU roles to force them to reinitialize as CPUID
+        * information is factored into reserved bit calculations.
+        */
+       vcpu->arch.root_mmu.mmu_role.ext.valid = 0;
+       vcpu->arch.guest_mmu.mmu_role.ext.valid = 0;
+       vcpu->arch.nested_mmu.mmu_role.ext.valid = 0;
+       kvm_mmu_reset_context(vcpu);
+
+       /*
+        * KVM does not correctly handle changing guest CPUID after KVM_RUN, as
+        * MAXPHYADDR, GBPAGES support, AMD reserved bit behavior, etc.. aren't
+        * tracked in kvm_mmu_page_role.  As a result, KVM may miss guest page
+        * faults due to reusing SPs/SPTEs.  Alert userspace, but otherwise
+        * sweep the problem under the rug.
+        *
+        * KVM's horrific CPUID ABI makes the problem all but impossible to
+        * solve, as correctly handling multiple vCPU models (with respect to
+        * paging and physical address properties) in a single VM would require
+        * tracking all relevant CPUID information in kvm_mmu_page_role.  That
+        * is very undesirable as it would double the memory requirements for
+        * gfn_track (see struct kvm_mmu_page_role comments), and in practice
+        * no sane VMM mucks with the core vCPU model on the fly.
+        */
+       if (vcpu->arch.last_vmentry_cpu != -1) {
+               pr_warn_ratelimited("KVM: KVM_SET_CPUID{,2} after KVM_RUN may cause guest instability\n");
+               pr_warn_ratelimited("KVM: KVM_SET_CPUID{,2} will fail after KVM_RUN starting with Linux 5.16\n");
+       }
+}
+
 void kvm_mmu_reset_context(struct kvm_vcpu *vcpu)
 {
        kvm_mmu_unload(vcpu);
-       kvm_init_mmu(vcpu, true);
+       kvm_init_mmu(vcpu);
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_reset_context);
 
@@ -5467,7 +5629,13 @@ void kvm_mmu_init_vm(struct kvm *kvm)
 {
        struct kvm_page_track_notifier_node *node = &kvm->arch.mmu_sp_tracker;
 
-       kvm_mmu_init_tdp_mmu(kvm);
+       if (!kvm_mmu_init_tdp_mmu(kvm))
+               /*
+                * No smp_load/store wrappers needed here as we are in
+                * VM init and there cannot be any memslots / other threads
+                * accessing this struct kvm yet.
+                */
+               kvm->arch.memslots_have_rmaps = true;
 
        node->track_write = kvm_mmu_pte_write;
        node->track_flush_slot = kvm_mmu_invalidate_zap_pages_in_memslot;
@@ -5490,29 +5658,29 @@ void kvm_zap_gfn_range(struct kvm *kvm, gfn_t gfn_start, gfn_t gfn_end)
        int i;
        bool flush = false;
 
-       write_lock(&kvm->mmu_lock);
-       for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++) {
-               slots = __kvm_memslots(kvm, i);
-               kvm_for_each_memslot(memslot, slots) {
-                       gfn_t start, end;
-
-                       start = max(gfn_start, memslot->base_gfn);
-                       end = min(gfn_end, memslot->base_gfn + memslot->npages);
-                       if (start >= end)
-                               continue;
+       if (kvm_memslots_have_rmaps(kvm)) {
+               write_lock(&kvm->mmu_lock);
+               for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++) {
+                       slots = __kvm_memslots(kvm, i);
+                       kvm_for_each_memslot(memslot, slots) {
+                               gfn_t start, end;
+
+                               start = max(gfn_start, memslot->base_gfn);
+                               end = min(gfn_end, memslot->base_gfn + memslot->npages);
+                               if (start >= end)
+                                       continue;
 
-                       flush = slot_handle_level_range(kvm, memslot, kvm_zap_rmapp,
-                                                       PG_LEVEL_4K,
-                                                       KVM_MAX_HUGEPAGE_LEVEL,
-                                                       start, end - 1, true, flush);
+                               flush = slot_handle_level_range(kvm, memslot,
+                                               kvm_zap_rmapp, PG_LEVEL_4K,
+                                               KVM_MAX_HUGEPAGE_LEVEL, start,
+                                               end - 1, true, flush);
+                       }
                }
+               if (flush)
+                       kvm_flush_remote_tlbs_with_address(kvm, gfn_start, gfn_end);
+               write_unlock(&kvm->mmu_lock);
        }
 
-       if (flush)
-               kvm_flush_remote_tlbs_with_address(kvm, gfn_start, gfn_end);
-
-       write_unlock(&kvm->mmu_lock);
-
        if (is_tdp_mmu_enabled(kvm)) {
                flush = false;
 
@@ -5539,12 +5707,15 @@ void kvm_mmu_slot_remove_write_access(struct kvm *kvm,
                                      struct kvm_memory_slot *memslot,
                                      int start_level)
 {
-       bool flush;
+       bool flush = false;
 
-       write_lock(&kvm->mmu_lock);
-       flush = slot_handle_level(kvm, memslot, slot_rmap_write_protect,
-                               start_level, KVM_MAX_HUGEPAGE_LEVEL, false);
-       write_unlock(&kvm->mmu_lock);
+       if (kvm_memslots_have_rmaps(kvm)) {
+               write_lock(&kvm->mmu_lock);
+               flush = slot_handle_level(kvm, memslot, slot_rmap_write_protect,
+                                         start_level, KVM_MAX_HUGEPAGE_LEVEL,
+                                         false);
+               write_unlock(&kvm->mmu_lock);
+       }
 
        if (is_tdp_mmu_enabled(kvm)) {
                read_lock(&kvm->mmu_lock);
@@ -5612,18 +5783,17 @@ void kvm_mmu_zap_collapsible_sptes(struct kvm *kvm,
 {
        /* FIXME: const-ify all uses of struct kvm_memory_slot.  */
        struct kvm_memory_slot *slot = (struct kvm_memory_slot *)memslot;
-       bool flush;
-
-       write_lock(&kvm->mmu_lock);
-       flush = slot_handle_leaf(kvm, slot, kvm_mmu_zap_collapsible_spte, true);
+       bool flush = false;
 
-       if (flush)
-               kvm_arch_flush_remote_tlbs_memslot(kvm, slot);
-       write_unlock(&kvm->mmu_lock);
+       if (kvm_memslots_have_rmaps(kvm)) {
+               write_lock(&kvm->mmu_lock);
+               flush = slot_handle_leaf(kvm, slot, kvm_mmu_zap_collapsible_spte, true);
+               if (flush)
+                       kvm_arch_flush_remote_tlbs_memslot(kvm, slot);
+               write_unlock(&kvm->mmu_lock);
+       }
 
        if (is_tdp_mmu_enabled(kvm)) {
-               flush = false;
-
                read_lock(&kvm->mmu_lock);
                flush = kvm_tdp_mmu_zap_collapsible_sptes(kvm, slot, flush);
                if (flush)
@@ -5650,11 +5820,14 @@ void kvm_arch_flush_remote_tlbs_memslot(struct kvm *kvm,
 void kvm_mmu_slot_leaf_clear_dirty(struct kvm *kvm,
                                   struct kvm_memory_slot *memslot)
 {
-       bool flush;
+       bool flush = false;
 
-       write_lock(&kvm->mmu_lock);
-       flush = slot_handle_leaf(kvm, memslot, __rmap_clear_dirty, false);
-       write_unlock(&kvm->mmu_lock);
+       if (kvm_memslots_have_rmaps(kvm)) {
+               write_lock(&kvm->mmu_lock);
+               flush = slot_handle_leaf(kvm, memslot, __rmap_clear_dirty,
+                                        false);
+               write_unlock(&kvm->mmu_lock);
+       }
 
        if (is_tdp_mmu_enabled(kvm)) {
                read_lock(&kvm->mmu_lock);
@@ -5957,6 +6130,7 @@ static int set_nx_huge_pages_recovery_ratio(const char *val, const struct kernel
 
 static void kvm_recover_nx_lpages(struct kvm *kvm)
 {
+       unsigned long nx_lpage_splits = kvm->stat.nx_lpage_splits;
        int rcu_idx;
        struct kvm_mmu_page *sp;
        unsigned int ratio;
@@ -5968,7 +6142,7 @@ static void kvm_recover_nx_lpages(struct kvm *kvm)
        write_lock(&kvm->mmu_lock);
 
        ratio = READ_ONCE(nx_huge_pages_recovery_ratio);
-       to_zap = ratio ? DIV_ROUND_UP(kvm->stat.nx_lpage_splits, ratio) : 0;
+       to_zap = ratio ? DIV_ROUND_UP(nx_lpage_splits, ratio) : 0;
        for ( ; to_zap; --to_zap) {
                if (list_empty(&kvm->arch.lpage_disallowed_mmu_pages))
                        break;