Merge branch 'kvm-older-features' into HEAD
[linux-2.6-microblaze.git] / arch / x86 / kvm / mmu / mmu.c
index 5628d0b..c623019 100644 (file)
@@ -104,15 +104,6 @@ static int max_huge_page_level __read_mostly;
 static int tdp_root_level __read_mostly;
 static int max_tdp_level __read_mostly;
 
-enum {
-       AUDIT_PRE_PAGE_FAULT,
-       AUDIT_POST_PAGE_FAULT,
-       AUDIT_PRE_PTE_WRITE,
-       AUDIT_POST_PTE_WRITE,
-       AUDIT_PRE_SYNC,
-       AUDIT_POST_SYNC
-};
-
 #ifdef MMU_DEBUG
 bool dbg = 0;
 module_param(dbg, bool, 0644);
@@ -190,8 +181,6 @@ struct kmem_cache *mmu_page_header_cache;
 static struct percpu_counter kvm_total_used_mmu_pages;
 
 static void mmu_spte_set(u64 *sptep, u64 spte);
-static union kvm_mmu_page_role
-kvm_mmu_calc_root_page_role(struct kvm_vcpu *vcpu);
 
 struct kvm_mmu_role_regs {
        const unsigned long cr0;
@@ -529,6 +518,7 @@ static u64 mmu_spte_update_no_track(u64 *sptep, u64 new_spte)
        u64 old_spte = *sptep;
 
        WARN_ON(!is_shadow_present_pte(new_spte));
+       check_spte_writable_invariants(new_spte);
 
        if (!is_shadow_present_pte(old_spte)) {
                mmu_spte_set(sptep, new_spte);
@@ -548,11 +538,9 @@ static u64 mmu_spte_update_no_track(u64 *sptep, u64 new_spte)
 /* Rules for using mmu_spte_update:
  * Update the state bits, it means the mapped pfn is not changed.
  *
- * Whenever we overwrite a writable spte with a read-only one we
- * should flush remote TLBs. Otherwise rmap_write_protect
- * will find a read-only spte, even though the writable spte
- * might be cached on a CPU's TLB, the return value indicates this
- * case.
+ * Whenever an MMU-writable SPTE is overwritten with a read-only SPTE, remote
+ * TLBs must be flushed. Otherwise rmap_write_protect will find a read-only
+ * spte, even though the writable spte might be cached on a CPU's TLB.
  *
  * Returns true if the TLB needs to be flushed
  */
@@ -646,24 +634,6 @@ static u64 mmu_spte_get_lockless(u64 *sptep)
        return __get_spte_lockless(sptep);
 }
 
-/* Restore an acc-track PTE back to a regular PTE */
-static u64 restore_acc_track_spte(u64 spte)
-{
-       u64 new_spte = spte;
-       u64 saved_bits = (spte >> SHADOW_ACC_TRACK_SAVED_BITS_SHIFT)
-                        & SHADOW_ACC_TRACK_SAVED_BITS_MASK;
-
-       WARN_ON_ONCE(spte_ad_enabled(spte));
-       WARN_ON_ONCE(!is_access_track_spte(spte));
-
-       new_spte &= ~shadow_acc_track_mask;
-       new_spte &= ~(SHADOW_ACC_TRACK_SAVED_BITS_MASK <<
-                     SHADOW_ACC_TRACK_SAVED_BITS_SHIFT);
-       new_spte |= saved_bits;
-
-       return new_spte;
-}
-
 /* Returns the Accessed status of the PTE and resets it at the same time. */
 static bool mmu_spte_age(u64 *sptep)
 {
@@ -1229,9 +1199,8 @@ static bool spte_write_protect(u64 *sptep, bool pt_protect)
        return mmu_spte_update(sptep, spte);
 }
 
-static bool __rmap_write_protect(struct kvm *kvm,
-                                struct kvm_rmap_head *rmap_head,
-                                bool pt_protect)
+static bool rmap_write_protect(struct kvm_rmap_head *rmap_head,
+                              bool pt_protect)
 {
        u64 *sptep;
        struct rmap_iterator iter;
@@ -1311,7 +1280,7 @@ static void kvm_mmu_write_protect_pt_masked(struct kvm *kvm,
        while (mask) {
                rmap_head = gfn_to_rmap(slot->base_gfn + gfn_offset + __ffs(mask),
                                        PG_LEVEL_4K, slot);
-               __rmap_write_protect(kvm, rmap_head, false);
+               rmap_write_protect(rmap_head, false);
 
                /* clear the first set bit */
                mask &= mask - 1;
@@ -1378,6 +1347,9 @@ void kvm_arch_mmu_enable_log_dirty_pt_masked(struct kvm *kvm,
                gfn_t start = slot->base_gfn + gfn_offset + __ffs(mask);
                gfn_t end = slot->base_gfn + gfn_offset + __fls(mask);
 
+               if (READ_ONCE(eager_page_split))
+                       kvm_mmu_try_split_huge_pages(kvm, slot, start, end, PG_LEVEL_4K);
+
                kvm_mmu_slot_gfn_write_protect(kvm, slot, start, PG_LEVEL_2M);
 
                /* Cross two large pages? */
@@ -1410,7 +1382,7 @@ bool kvm_mmu_slot_gfn_write_protect(struct kvm *kvm,
        if (kvm_memslots_have_rmaps(kvm)) {
                for (i = min_level; i <= KVM_MAX_HUGEPAGE_LEVEL; ++i) {
                        rmap_head = gfn_to_rmap(gfn, i, slot);
-                       write_protected |= __rmap_write_protect(kvm, rmap_head, true);
+                       write_protected |= rmap_write_protect(rmap_head, true);
                }
        }
 
@@ -1421,7 +1393,7 @@ bool kvm_mmu_slot_gfn_write_protect(struct kvm *kvm,
        return write_protected;
 }
 
-static bool rmap_write_protect(struct kvm_vcpu *vcpu, u64 gfn)
+static bool kvm_vcpu_write_protect_gfn(struct kvm_vcpu *vcpu, u64 gfn)
 {
        struct kvm_memory_slot *slot;
 
@@ -1894,17 +1866,14 @@ static void kvm_mmu_commit_zap_page(struct kvm *kvm,
          &(_kvm)->arch.mmu_page_hash[kvm_page_table_hashfn(_gfn)])     \
                if ((_sp)->gfn != (_gfn) || (_sp)->role.direct) {} else
 
-static bool kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
+static int kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
                         struct list_head *invalid_list)
 {
        int ret = vcpu->arch.mmu->sync_page(vcpu, sp);
 
-       if (ret < 0) {
+       if (ret < 0)
                kvm_mmu_prepare_zap_page(vcpu->kvm, sp, invalid_list);
-               return false;
-       }
-
-       return !!ret;
+       return ret;
 }
 
 static bool kvm_mmu_remote_flush_or_zap(struct kvm *kvm,
@@ -1921,13 +1890,6 @@ static bool kvm_mmu_remote_flush_or_zap(struct kvm *kvm,
        return true;
 }
 
-#ifdef CONFIG_KVM_MMU_AUDIT
-#include "mmu_audit.c"
-#else
-static void kvm_mmu_audit(struct kvm_vcpu *vcpu, int point) { }
-static void mmu_audit_disable(void) { }
-#endif
-
 static bool is_obsolete_sp(struct kvm *kvm, struct kvm_mmu_page *sp)
 {
        if (sp->role.invalid)
@@ -2024,7 +1986,7 @@ static int mmu_sync_children(struct kvm_vcpu *vcpu,
                bool protected = false;
 
                for_each_sp(pages, sp, parents, i)
-                       protected |= rmap_write_protect(vcpu, sp->gfn);
+                       protected |= kvm_vcpu_write_protect_gfn(vcpu, sp->gfn);
 
                if (protected) {
                        kvm_mmu_remote_flush_or_zap(vcpu->kvm, &invalid_list, true);
@@ -2033,7 +1995,7 @@ static int mmu_sync_children(struct kvm_vcpu *vcpu,
 
                for_each_sp(pages, sp, parents, i) {
                        kvm_unlink_unsync_page(vcpu->kvm, sp);
-                       flush |= kvm_sync_page(vcpu, sp, &invalid_list);
+                       flush |= kvm_sync_page(vcpu, sp, &invalid_list) > 0;
                        mmu_pages_clear_parents(&parents);
                }
                if (need_resched() || rwlock_needbreak(&vcpu->kvm->mmu_lock)) {
@@ -2074,6 +2036,7 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
        struct hlist_head *sp_list;
        unsigned quadrant;
        struct kvm_mmu_page *sp;
+       int ret;
        int collisions = 0;
        LIST_HEAD(invalid_list);
 
@@ -2126,11 +2089,13 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
                         * If the sync fails, the page is zapped.  If so, break
                         * in order to rebuild it.
                         */
-                       if (!kvm_sync_page(vcpu, sp, &invalid_list))
+                       ret = kvm_sync_page(vcpu, sp, &invalid_list);
+                       if (ret < 0)
                                break;
 
                        WARN_ON(!list_empty(&invalid_list));
-                       kvm_flush_remote_tlbs(vcpu->kvm);
+                       if (ret > 0)
+                               kvm_flush_remote_tlbs(vcpu->kvm);
                }
 
                __clear_sp_write_flooding_count(sp);
@@ -2149,7 +2114,7 @@ trace_get_page:
        hlist_add_head(&sp->hash_link, sp_list);
        if (!direct) {
                account_shadowed(vcpu->kvm, sp);
-               if (level == PG_LEVEL_4K && rmap_write_protect(vcpu, gfn))
+               if (level == PG_LEVEL_4K && kvm_vcpu_write_protect_gfn(vcpu, gfn))
                        kvm_flush_remote_tlbs_with_address(vcpu->kvm, gfn, 1);
        }
        trace_kvm_mmu_get_page(sp, true);
@@ -2179,7 +2144,7 @@ static void shadow_walk_init_using_root(struct kvm_shadow_walk_iterator *iterato
                 * prev_root is currently only used for 64-bit hosts. So only
                 * the active root_hpa is valid here.
                 */
-               BUG_ON(root != vcpu->arch.mmu->root_hpa);
+               BUG_ON(root != vcpu->arch.mmu->root.hpa);
 
                iterator->shadow_addr
                        = vcpu->arch.mmu->pae_root[(addr >> 30) & 3];
@@ -2193,7 +2158,7 @@ static void shadow_walk_init_using_root(struct kvm_shadow_walk_iterator *iterato
 static void shadow_walk_init(struct kvm_shadow_walk_iterator *iterator,
                             struct kvm_vcpu *vcpu, u64 addr)
 {
-       shadow_walk_init_using_root(iterator, vcpu, vcpu->arch.mmu->root_hpa,
+       shadow_walk_init_using_root(iterator, vcpu, vcpu->arch.mmu->root.hpa,
                                    addr);
 }
 
@@ -2307,7 +2272,7 @@ static int kvm_mmu_page_unlink_children(struct kvm *kvm,
        return zapped;
 }
 
-static void kvm_mmu_unlink_parents(struct kvm *kvm, struct kvm_mmu_page *sp)
+static void kvm_mmu_unlink_parents(struct kvm_mmu_page *sp)
 {
        u64 *sptep;
        struct rmap_iterator iter;
@@ -2345,13 +2310,13 @@ static bool __kvm_mmu_prepare_zap_page(struct kvm *kvm,
                                       struct list_head *invalid_list,
                                       int *nr_zapped)
 {
-       bool list_unstable;
+       bool list_unstable, zapped_root = false;
 
        trace_kvm_mmu_prepare_zap_page(sp);
        ++kvm->stat.mmu_shadow_zapped;
        *nr_zapped = mmu_zap_unsync_children(kvm, sp, invalid_list);
        *nr_zapped += kvm_mmu_page_unlink_children(kvm, sp, invalid_list);
-       kvm_mmu_unlink_parents(kvm, sp);
+       kvm_mmu_unlink_parents(sp);
 
        /* Zapping children means active_mmu_pages has become unstable. */
        list_unstable = *nr_zapped;
@@ -2387,14 +2352,20 @@ static bool __kvm_mmu_prepare_zap_page(struct kvm *kvm,
                 * in kvm_mmu_zap_all_fast().  Note, is_obsolete_sp() also
                 * treats invalid shadow pages as being obsolete.
                 */
-               if (!is_obsolete_sp(kvm, sp))
-                       kvm_reload_remote_mmus(kvm);
+               zapped_root = !is_obsolete_sp(kvm, sp);
        }
 
        if (sp->lpage_disallowed)
                unaccount_huge_nx_page(kvm, sp);
 
        sp->role.invalid = 1;
+
+       /*
+        * Make the request to free obsolete roots after marking the root
+        * invalid, otherwise other vCPUs may not see it as invalid.
+        */
+       if (zapped_root)
+               kvm_make_all_cpus_request(kvm, KVM_REQ_MMU_FREE_OBSOLETE_ROOTS);
        return list_unstable;
 }
 
@@ -2725,8 +2696,8 @@ static int mmu_set_spte(struct kvm_vcpu *vcpu, struct kvm_memory_slot *slot,
        if (*sptep == spte) {
                ret = RET_PF_SPURIOUS;
        } else {
-               trace_kvm_mmu_set_spte(level, gfn, sptep);
                flush |= mmu_spte_update(sptep, spte);
+               trace_kvm_mmu_set_spte(level, gfn, sptep);
        }
 
        if (wrprot) {
@@ -3239,6 +3210,8 @@ static void mmu_free_root_page(struct kvm *kvm, hpa_t *root_hpa,
                return;
 
        sp = to_shadow_page(*root_hpa & PT64_BASE_ADDR_MASK);
+       if (WARN_ON(!sp))
+               return;
 
        if (is_tdp_mmu_page(sp))
                kvm_tdp_mmu_put_root(kvm, sp, false);
@@ -3249,18 +3222,20 @@ static void mmu_free_root_page(struct kvm *kvm, hpa_t *root_hpa,
 }
 
 /* roots_to_free must be some combination of the KVM_MMU_ROOT_* flags */
-void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
+void kvm_mmu_free_roots(struct kvm *kvm, struct kvm_mmu *mmu,
                        ulong roots_to_free)
 {
-       struct kvm *kvm = vcpu->kvm;
        int i;
        LIST_HEAD(invalid_list);
-       bool free_active_root = roots_to_free & KVM_MMU_ROOT_CURRENT;
+       bool free_active_root;
 
        BUILD_BUG_ON(KVM_MMU_NUM_PREV_ROOTS >= BITS_PER_LONG);
 
        /* Before acquiring the MMU lock, see if we need to do any real work. */
-       if (!(free_active_root && VALID_PAGE(mmu->root_hpa))) {
+       free_active_root = (roots_to_free & KVM_MMU_ROOT_CURRENT)
+               && VALID_PAGE(mmu->root.hpa);
+
+       if (!free_active_root) {
                for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++)
                        if ((roots_to_free & KVM_MMU_ROOT_PREVIOUS(i)) &&
                            VALID_PAGE(mmu->prev_roots[i].hpa))
@@ -3278,9 +3253,8 @@ void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
                                           &invalid_list);
 
        if (free_active_root) {
-               if (mmu->shadow_root_level >= PT64_ROOT_4LEVEL &&
-                   (mmu->root_level >= PT64_ROOT_4LEVEL || mmu->direct_map)) {
-                       mmu_free_root_page(kvm, &mmu->root_hpa, &invalid_list);
+               if (to_shadow_page(mmu->root.hpa)) {
+                       mmu_free_root_page(kvm, &mmu->root.hpa, &invalid_list);
                } else if (mmu->pae_root) {
                        for (i = 0; i < 4; ++i) {
                                if (!IS_VALID_PAE_ROOT(mmu->pae_root[i]))
@@ -3291,8 +3265,8 @@ void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
                                mmu->pae_root[i] = INVALID_PAE_ROOT;
                        }
                }
-               mmu->root_hpa = INVALID_PAGE;
-               mmu->root_pgd = 0;
+               mmu->root.hpa = INVALID_PAGE;
+               mmu->root.pgd = 0;
        }
 
        kvm_mmu_commit_zap_page(kvm, &invalid_list);
@@ -3300,7 +3274,7 @@ void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_free_roots);
 
-void kvm_mmu_free_guest_mode_roots(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu)
+void kvm_mmu_free_guest_mode_roots(struct kvm *kvm, struct kvm_mmu *mmu)
 {
        unsigned long roots_to_free = 0;
        hpa_t root_hpa;
@@ -3322,7 +3296,7 @@ void kvm_mmu_free_guest_mode_roots(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu)
                        roots_to_free |= KVM_MMU_ROOT_PREVIOUS(i);
        }
 
-       kvm_mmu_free_roots(vcpu, mmu, roots_to_free);
+       kvm_mmu_free_roots(kvm, mmu, roots_to_free);
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_free_guest_mode_roots);
 
@@ -3365,10 +3339,10 @@ static int mmu_alloc_direct_roots(struct kvm_vcpu *vcpu)
 
        if (is_tdp_mmu_enabled(vcpu->kvm)) {
                root = kvm_tdp_mmu_get_vcpu_root_hpa(vcpu);
-               mmu->root_hpa = root;
+               mmu->root.hpa = root;
        } else if (shadow_root_level >= PT64_ROOT_4LEVEL) {
                root = mmu_alloc_root(vcpu, 0, 0, shadow_root_level, true);
-               mmu->root_hpa = root;
+               mmu->root.hpa = root;
        } else if (shadow_root_level == PT32E_ROOT_LEVEL) {
                if (WARN_ON_ONCE(!mmu->pae_root)) {
                        r = -EIO;
@@ -3383,15 +3357,15 @@ static int mmu_alloc_direct_roots(struct kvm_vcpu *vcpu)
                        mmu->pae_root[i] = root | PT_PRESENT_MASK |
                                           shadow_me_mask;
                }
-               mmu->root_hpa = __pa(mmu->pae_root);
+               mmu->root.hpa = __pa(mmu->pae_root);
        } else {
                WARN_ONCE(1, "Bad TDP root level = %d\n", shadow_root_level);
                r = -EIO;
                goto out_unlock;
        }
 
-       /* root_pgd is ignored for direct MMUs. */
-       mmu->root_pgd = 0;
+       /* root.pgd is ignored for direct MMUs. */
+       mmu->root.pgd = 0;
 out_unlock:
        write_unlock(&vcpu->kvm->mmu_lock);
        return r;
@@ -3504,7 +3478,7 @@ static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
        if (mmu->root_level >= PT64_ROOT_4LEVEL) {
                root = mmu_alloc_root(vcpu, root_gfn, 0,
                                      mmu->shadow_root_level, false);
-               mmu->root_hpa = root;
+               mmu->root.hpa = root;
                goto set_root_pgd;
        }
 
@@ -3554,14 +3528,14 @@ static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
        }
 
        if (mmu->shadow_root_level == PT64_ROOT_5LEVEL)
-               mmu->root_hpa = __pa(mmu->pml5_root);
+               mmu->root.hpa = __pa(mmu->pml5_root);
        else if (mmu->shadow_root_level == PT64_ROOT_4LEVEL)
-               mmu->root_hpa = __pa(mmu->pml4_root);
+               mmu->root.hpa = __pa(mmu->pml4_root);
        else
-               mmu->root_hpa = __pa(mmu->pae_root);
+               mmu->root.hpa = __pa(mmu->pae_root);
 
 set_root_pgd:
-       mmu->root_pgd = root_pgd;
+       mmu->root.pgd = root_pgd;
 out_unlock:
        write_unlock(&vcpu->kvm->mmu_lock);
 
@@ -3660,6 +3634,14 @@ static bool is_unsync_root(hpa_t root)
         */
        smp_rmb();
        sp = to_shadow_page(root);
+
+       /*
+        * PAE roots (somewhat arbitrarily) aren't backed by shadow pages, the
+        * PDPTEs for a given PAE root need to be synchronized individually.
+        */
+       if (WARN_ON_ONCE(!sp))
+               return false;
+
        if (sp->unsync || sp->unsync_children)
                return true;
 
@@ -3674,30 +3656,25 @@ void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu)
        if (vcpu->arch.mmu->direct_map)
                return;
 
-       if (!VALID_PAGE(vcpu->arch.mmu->root_hpa))
+       if (!VALID_PAGE(vcpu->arch.mmu->root.hpa))
                return;
 
        vcpu_clear_mmio_info(vcpu, MMIO_GVA_ANY);
 
        if (vcpu->arch.mmu->root_level >= PT64_ROOT_4LEVEL) {
-               hpa_t root = vcpu->arch.mmu->root_hpa;
+               hpa_t root = vcpu->arch.mmu->root.hpa;
                sp = to_shadow_page(root);
 
                if (!is_unsync_root(root))
                        return;
 
                write_lock(&vcpu->kvm->mmu_lock);
-               kvm_mmu_audit(vcpu, AUDIT_PRE_SYNC);
-
                mmu_sync_children(vcpu, sp, true);
-
-               kvm_mmu_audit(vcpu, AUDIT_POST_SYNC);
                write_unlock(&vcpu->kvm->mmu_lock);
                return;
        }
 
        write_lock(&vcpu->kvm->mmu_lock);
-       kvm_mmu_audit(vcpu, AUDIT_PRE_SYNC);
 
        for (i = 0; i < 4; ++i) {
                hpa_t root = vcpu->arch.mmu->pae_root[i];
@@ -3709,7 +3686,6 @@ void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu)
                }
        }
 
-       kvm_mmu_audit(vcpu, AUDIT_POST_SYNC);
        write_unlock(&vcpu->kvm->mmu_lock);
 }
 
@@ -3723,11 +3699,11 @@ void kvm_mmu_sync_prev_roots(struct kvm_vcpu *vcpu)
                        roots_to_free |= KVM_MMU_ROOT_PREVIOUS(i);
 
        /* sync prev_roots by simply freeing them */
-       kvm_mmu_free_roots(vcpu, vcpu->arch.mmu, roots_to_free);
+       kvm_mmu_free_roots(vcpu->kvm, vcpu->arch.mmu, roots_to_free);
 }
 
 static gpa_t nonpaging_gva_to_gpa(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
-                                 gpa_t vaddr, u32 access,
+                                 gpa_t vaddr, u64 access,
                                  struct x86_exception *exception)
 {
        if (exception)
@@ -3982,7 +3958,7 @@ out_retry:
 static bool is_page_fault_stale(struct kvm_vcpu *vcpu,
                                struct kvm_page_fault *fault, int mmu_seq)
 {
-       struct kvm_mmu_page *sp = to_shadow_page(vcpu->arch.mmu->root_hpa);
+       struct kvm_mmu_page *sp = to_shadow_page(vcpu->arch.mmu->root.hpa);
 
        /* Special roots, e.g. pae_root, are not backed by shadow pages. */
        if (sp && is_obsolete_sp(vcpu->kvm, sp))
@@ -3996,7 +3972,7 @@ static bool is_page_fault_stale(struct kvm_vcpu *vcpu,
         * previous root, then __kvm_mmu_prepare_zap_page() signals all vCPUs
         * to reload even if no vCPU is actively using the root.
         */
-       if (!sp && kvm_test_request(KVM_REQ_MMU_RELOAD, vcpu))
+       if (!sp && kvm_test_request(KVM_REQ_MMU_FREE_OBSOLETE_ROOTS, vcpu))
                return true;
 
        return fault->slot &&
@@ -4132,74 +4108,105 @@ static inline bool is_root_usable(struct kvm_mmu_root_info *root, gpa_t pgd,
                                  union kvm_mmu_page_role role)
 {
        return (role.direct || pgd == root->pgd) &&
-              VALID_PAGE(root->hpa) && to_shadow_page(root->hpa) &&
+              VALID_PAGE(root->hpa) &&
               role.word == to_shadow_page(root->hpa)->role.word;
 }
 
 /*
- * Find out if a previously cached root matching the new pgd/role is available.
- * The current root is also inserted into the cache.
- * If a matching root was found, it is assigned to kvm_mmu->root_hpa and true is
- * returned.
- * Otherwise, the LRU root from the cache is assigned to kvm_mmu->root_hpa and
- * false is returned. This root should now be freed by the caller.
+ * Find out if a previously cached root matching the new pgd/role is available,
+ * and insert the current root as the MRU in the cache.
+ * If a matching root is found, it is assigned to kvm_mmu->root and
+ * true is returned.
+ * If no match is found, kvm_mmu->root is left invalid, the LRU root is
+ * evicted to make room for the current root, and false is returned.
  */
-static bool cached_root_available(struct kvm_vcpu *vcpu, gpa_t new_pgd,
-                                 union kvm_mmu_page_role new_role)
+static bool cached_root_find_and_keep_current(struct kvm *kvm, struct kvm_mmu *mmu,
+                                             gpa_t new_pgd,
+                                             union kvm_mmu_page_role new_role)
 {
        uint i;
-       struct kvm_mmu_root_info root;
-       struct kvm_mmu *mmu = vcpu->arch.mmu;
-
-       root.pgd = mmu->root_pgd;
-       root.hpa = mmu->root_hpa;
 
-       if (is_root_usable(&root, new_pgd, new_role))
+       if (is_root_usable(&mmu->root, new_pgd, new_role))
                return true;
 
        for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++) {
-               swap(root, mmu->prev_roots[i]);
-
-               if (is_root_usable(&root, new_pgd, new_role))
-                       break;
+               /*
+                * The swaps end up rotating the cache like this:
+                *   C   0 1 2 3   (on entry to the function)
+                *   0   C 1 2 3
+                *   1   C 0 2 3
+                *   2   C 0 1 3
+                *   3   C 0 1 2   (on exit from the loop)
+                */
+               swap(mmu->root, mmu->prev_roots[i]);
+               if (is_root_usable(&mmu->root, new_pgd, new_role))
+                       return true;
        }
 
-       mmu->root_hpa = root.hpa;
-       mmu->root_pgd = root.pgd;
-
-       return i < KVM_MMU_NUM_PREV_ROOTS;
+       kvm_mmu_free_roots(kvm, mmu, KVM_MMU_ROOT_CURRENT);
+       return false;
 }
 
-static bool fast_pgd_switch(struct kvm_vcpu *vcpu, gpa_t new_pgd,
-                           union kvm_mmu_page_role new_role)
+/*
+ * Find out if a previously cached root matching the new pgd/role is available.
+ * On entry, mmu->root is invalid.
+ * If a matching root is found, it is assigned to kvm_mmu->root, the LRU entry
+ * of the cache becomes invalid, and true is returned.
+ * If no match is found, kvm_mmu->root is left invalid and false is returned.
+ */
+static bool cached_root_find_without_current(struct kvm *kvm, struct kvm_mmu *mmu,
+                                            gpa_t new_pgd,
+                                            union kvm_mmu_page_role new_role)
 {
-       struct kvm_mmu *mmu = vcpu->arch.mmu;
+       uint i;
+
+       for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++)
+               if (is_root_usable(&mmu->prev_roots[i], new_pgd, new_role))
+                       goto hit;
+
+       return false;
 
+hit:
+       swap(mmu->root, mmu->prev_roots[i]);
+       /* Bubble up the remaining roots.  */
+       for (; i < KVM_MMU_NUM_PREV_ROOTS - 1; i++)
+               mmu->prev_roots[i] = mmu->prev_roots[i + 1];
+       mmu->prev_roots[i].hpa = INVALID_PAGE;
+       return true;
+}
+
+static bool fast_pgd_switch(struct kvm *kvm, struct kvm_mmu *mmu,
+                           gpa_t new_pgd, union kvm_mmu_page_role new_role)
+{
        /*
-        * For now, limit the fast switch to 64-bit hosts+VMs in order to avoid
+        * For now, limit the caching to 64-bit hosts+VMs in order to avoid
         * having to deal with PDPTEs. We may add support for 32-bit hosts/VMs
         * later if necessary.
         */
-       if (mmu->shadow_root_level >= PT64_ROOT_4LEVEL &&
-           mmu->root_level >= PT64_ROOT_4LEVEL)
-               return cached_root_available(vcpu, new_pgd, new_role);
+       if (VALID_PAGE(mmu->root.hpa) && !to_shadow_page(mmu->root.hpa))
+               kvm_mmu_free_roots(kvm, mmu, KVM_MMU_ROOT_CURRENT);
 
-       return false;
+       if (VALID_PAGE(mmu->root.hpa))
+               return cached_root_find_and_keep_current(kvm, mmu, new_pgd, new_role);
+       else
+               return cached_root_find_without_current(kvm, mmu, new_pgd, new_role);
 }
 
-static void __kvm_mmu_new_pgd(struct kvm_vcpu *vcpu, gpa_t new_pgd,
-                             union kvm_mmu_page_role new_role)
+void kvm_mmu_new_pgd(struct kvm_vcpu *vcpu, gpa_t new_pgd)
 {
-       if (!fast_pgd_switch(vcpu, new_pgd, new_role)) {
-               kvm_mmu_free_roots(vcpu, vcpu->arch.mmu, KVM_MMU_ROOT_CURRENT);
+       struct kvm_mmu *mmu = vcpu->arch.mmu;
+       union kvm_mmu_page_role new_role = mmu->mmu_role.base;
+
+       if (!fast_pgd_switch(vcpu->kvm, mmu, new_pgd, new_role)) {
+               /* kvm_mmu_ensure_valid_pgd will set up a new root.  */
                return;
        }
 
        /*
         * It's possible that the cached previous root page is obsolete because
         * of a change in the MMU generation number. However, changing the
-        * generation number is accompanied by KVM_REQ_MMU_RELOAD, which will
-        * free the root set here and allocate a new one.
+        * generation number is accompanied by KVM_REQ_MMU_FREE_OBSOLETE_ROOTS,
+        * which will free the root set here and allocate a new one.
         */
        kvm_make_request(KVM_REQ_LOAD_MMU_PGD, vcpu);
 
@@ -4222,12 +4229,7 @@ static void __kvm_mmu_new_pgd(struct kvm_vcpu *vcpu, gpa_t new_pgd,
         */
        if (!new_role.direct)
                __clear_sp_write_flooding_count(
-                               to_shadow_page(vcpu->arch.mmu->root_hpa));
-}
-
-void kvm_mmu_new_pgd(struct kvm_vcpu *vcpu, gpa_t new_pgd)
-{
-       __kvm_mmu_new_pgd(vcpu, new_pgd, kvm_mmu_calc_root_page_role(vcpu));
+                               to_shadow_page(vcpu->arch.mmu->root.hpa));
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_new_pgd);
 
@@ -4485,8 +4487,7 @@ static inline bool boot_cpu_is_amd(void)
  * possible, however, kvm currently does not do execution-protection.
  */
 static void
-reset_tdp_shadow_zero_bits_mask(struct kvm_vcpu *vcpu,
-                               struct kvm_mmu *context)
+reset_tdp_shadow_zero_bits_mask(struct kvm_mmu *context)
 {
        struct rsvd_bits_validate *shadow_zero_check;
        int i;
@@ -4517,8 +4518,7 @@ reset_tdp_shadow_zero_bits_mask(struct kvm_vcpu *vcpu,
  * is the shadow page table for intel nested guest.
  */
 static void
-reset_ept_shadow_zero_bits_mask(struct kvm_vcpu *vcpu,
-                               struct kvm_mmu *context, bool execonly)
+reset_ept_shadow_zero_bits_mask(struct kvm_mmu *context, bool execonly)
 {
        __reset_rsvds_bits_mask_ept(&context->shadow_zero_check,
                                    reserved_hpa_bits(), execonly,
@@ -4591,11 +4591,11 @@ static void update_permission_bitmask(struct kvm_mmu *mmu, bool ept)
                         *   - X86_CR4_SMAP is set in CR4
                         *   - A user page is accessed
                         *   - The access is not a fetch
-                        *   - Page fault in kernel mode
-                        *   - if CPL = 3 or X86_EFLAGS_AC is clear
+                        *   - The access is supervisor mode
+                        *   - If implicit supervisor access or X86_EFLAGS_AC is clear
                         *
-                        * Here, we cover the first three conditions.
-                        * The fourth is computed dynamically in permission_fault();
+                        * Here, we cover the first four conditions.
+                        * The fifth is computed dynamically in permission_fault();
                         * PFERR_RSVD_MASK bit will be set in PFEC if the access is
                         * *not* subject to SMAP restrictions.
                         */
@@ -4805,7 +4805,7 @@ static void init_kvm_tdp_mmu(struct kvm_vcpu *vcpu)
                context->gva_to_gpa = paging32_gva_to_gpa;
 
        reset_guest_paging_metadata(vcpu, context);
-       reset_tdp_shadow_zero_bits_mask(vcpu, context);
+       reset_tdp_shadow_zero_bits_mask(context);
 }
 
 static union kvm_mmu_role
@@ -4899,9 +4899,8 @@ void kvm_init_shadow_npt_mmu(struct kvm_vcpu *vcpu, unsigned long cr0,
 
        new_role = kvm_calc_shadow_npt_root_page_role(vcpu, &regs);
 
-       __kvm_mmu_new_pgd(vcpu, nested_cr3, new_role.base);
-
        shadow_mmu_init_context(vcpu, context, &regs, new_role);
+       kvm_mmu_new_pgd(vcpu, nested_cr3);
 }
 EXPORT_SYMBOL_GPL(kvm_init_shadow_npt_mmu);
 
@@ -4939,27 +4938,25 @@ void kvm_init_shadow_ept_mmu(struct kvm_vcpu *vcpu, bool execonly,
                kvm_calc_shadow_ept_root_page_role(vcpu, accessed_dirty,
                                                   execonly, level);
 
-       __kvm_mmu_new_pgd(vcpu, new_eptp, new_role.base);
-
-       if (new_role.as_u64 == context->mmu_role.as_u64)
-               return;
-
-       context->mmu_role.as_u64 = new_role.as_u64;
-
-       context->shadow_root_level = level;
-
-       context->ept_ad = accessed_dirty;
-       context->page_fault = ept_page_fault;
-       context->gva_to_gpa = ept_gva_to_gpa;
-       context->sync_page = ept_sync_page;
-       context->invlpg = ept_invlpg;
-       context->root_level = level;
-       context->direct_map = false;
+       if (new_role.as_u64 != context->mmu_role.as_u64) {
+               context->mmu_role.as_u64 = new_role.as_u64;
+
+               context->shadow_root_level = level;
+
+               context->ept_ad = accessed_dirty;
+               context->page_fault = ept_page_fault;
+               context->gva_to_gpa = ept_gva_to_gpa;
+               context->sync_page = ept_sync_page;
+               context->invlpg = ept_invlpg;
+               context->root_level = level;
+               context->direct_map = false;
+               update_permission_bitmask(context, true);
+               context->pkru_mask = 0;
+               reset_rsvds_bits_mask_ept(vcpu, context, execonly, huge_page_level);
+               reset_ept_shadow_zero_bits_mask(context, execonly);
+       }
 
-       update_permission_bitmask(context, true);
-       context->pkru_mask = 0;
-       reset_rsvds_bits_mask_ept(vcpu, context, execonly, huge_page_level);
-       reset_ept_shadow_zero_bits_mask(vcpu, context, execonly);
+       kvm_mmu_new_pgd(vcpu, new_eptp);
 }
 EXPORT_SYMBOL_GPL(kvm_init_shadow_ept_mmu);
 
@@ -5044,20 +5041,6 @@ void kvm_init_mmu(struct kvm_vcpu *vcpu)
 }
 EXPORT_SYMBOL_GPL(kvm_init_mmu);
 
-static union kvm_mmu_page_role
-kvm_mmu_calc_root_page_role(struct kvm_vcpu *vcpu)
-{
-       struct kvm_mmu_role_regs regs = vcpu_to_role_regs(vcpu);
-       union kvm_mmu_role role;
-
-       if (tdp_enabled)
-               role = kvm_calc_tdp_mmu_root_page_role(vcpu, &regs, true);
-       else
-               role = kvm_calc_shadow_mmu_root_page_role(vcpu, &regs, true);
-
-       return role.base;
-}
-
 void kvm_mmu_after_set_cpuid(struct kvm_vcpu *vcpu)
 {
        /*
@@ -5111,17 +5094,73 @@ int kvm_mmu_load(struct kvm_vcpu *vcpu)
        kvm_mmu_sync_roots(vcpu);
 
        kvm_mmu_load_pgd(vcpu);
-       static_call(kvm_x86_tlb_flush_current)(vcpu);
+
+       /*
+        * Flush any TLB entries for the new root, the provenance of the root
+        * is unknown.  Even if KVM ensures there are no stale TLB entries
+        * for a freed root, in theory another hypervisor could have left
+        * stale entries.  Flushing on alloc also allows KVM to skip the TLB
+        * flush when freeing a root (see kvm_tdp_mmu_put_root()).
+        */
+       static_call(kvm_x86_flush_tlb_current)(vcpu);
 out:
        return r;
 }
 
 void kvm_mmu_unload(struct kvm_vcpu *vcpu)
 {
-       kvm_mmu_free_roots(vcpu, &vcpu->arch.root_mmu, KVM_MMU_ROOTS_ALL);
-       WARN_ON(VALID_PAGE(vcpu->arch.root_mmu.root_hpa));
-       kvm_mmu_free_roots(vcpu, &vcpu->arch.guest_mmu, KVM_MMU_ROOTS_ALL);
-       WARN_ON(VALID_PAGE(vcpu->arch.guest_mmu.root_hpa));
+       struct kvm *kvm = vcpu->kvm;
+
+       kvm_mmu_free_roots(kvm, &vcpu->arch.root_mmu, KVM_MMU_ROOTS_ALL);
+       WARN_ON(VALID_PAGE(vcpu->arch.root_mmu.root.hpa));
+       kvm_mmu_free_roots(kvm, &vcpu->arch.guest_mmu, KVM_MMU_ROOTS_ALL);
+       WARN_ON(VALID_PAGE(vcpu->arch.guest_mmu.root.hpa));
+       vcpu_clear_mmio_info(vcpu, MMIO_GVA_ANY);
+}
+
+static bool is_obsolete_root(struct kvm *kvm, hpa_t root_hpa)
+{
+       struct kvm_mmu_page *sp;
+
+       if (!VALID_PAGE(root_hpa))
+               return false;
+
+       /*
+        * When freeing obsolete roots, treat roots as obsolete if they don't
+        * have an associated shadow page.  This does mean KVM will get false
+        * positives and free roots that don't strictly need to be freed, but
+        * such false positives are relatively rare:
+        *
+        *  (a) only PAE paging and nested NPT has roots without shadow pages
+        *  (b) remote reloads due to a memslot update obsoletes _all_ roots
+        *  (c) KVM doesn't track previous roots for PAE paging, and the guest
+        *      is unlikely to zap an in-use PGD.
+        */
+       sp = to_shadow_page(root_hpa);
+       return !sp || is_obsolete_sp(kvm, sp);
+}
+
+static void __kvm_mmu_free_obsolete_roots(struct kvm *kvm, struct kvm_mmu *mmu)
+{
+       unsigned long roots_to_free = 0;
+       int i;
+
+       if (is_obsolete_root(kvm, mmu->root.hpa))
+               roots_to_free |= KVM_MMU_ROOT_CURRENT;
+
+       for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++) {
+               if (is_obsolete_root(kvm, mmu->root.hpa))
+                       roots_to_free |= KVM_MMU_ROOT_PREVIOUS(i);
+       }
+
+       if (roots_to_free)
+               kvm_mmu_free_roots(kvm, mmu, roots_to_free);
+}
+
+void kvm_mmu_free_obsolete_roots(struct kvm_vcpu *vcpu)
+{
+       __kvm_mmu_free_obsolete_roots(vcpu->kvm, &vcpu->arch.root_mmu);
+       __kvm_mmu_free_obsolete_roots(vcpu->kvm, &vcpu->arch.guest_mmu);
 }
 
 static bool need_remote_flush(u64 old, u64 new)
@@ -5271,7 +5310,6 @@ static void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
        gentry = mmu_pte_write_fetch_gpte(vcpu, &gpa, &bytes);
 
        ++vcpu->kvm->stat.mmu_pte_write;
-       kvm_mmu_audit(vcpu, AUDIT_PRE_PTE_WRITE);
 
        for_each_gfn_indirect_valid_sp(vcpu->kvm, sp, gfn) {
                if (detect_write_misaligned(sp, gpa, bytes) ||
@@ -5296,7 +5334,6 @@ static void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
                }
        }
        kvm_mmu_remote_flush_or_zap(vcpu->kvm, &invalid_list, flush);
-       kvm_mmu_audit(vcpu, AUDIT_POST_PTE_WRITE);
        write_unlock(&vcpu->kvm->mmu_lock);
 }
 
@@ -5306,7 +5343,7 @@ int kvm_mmu_page_fault(struct kvm_vcpu *vcpu, gpa_t cr2_or_gpa, u64 error_code,
        int r, emulation_type = EMULTYPE_PF;
        bool direct = vcpu->arch.mmu->direct_map;
 
-       if (WARN_ON(!VALID_PAGE(vcpu->arch.mmu->root_hpa)))
+       if (WARN_ON(!VALID_PAGE(vcpu->arch.mmu->root.hpa)))
                return RET_PF_RETRY;
 
        r = RET_PF_INVALID;
@@ -5371,14 +5408,14 @@ void kvm_mmu_invalidate_gva(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
                if (is_noncanonical_address(gva, vcpu))
                        return;
 
-               static_call(kvm_x86_tlb_flush_gva)(vcpu, gva);
+               static_call(kvm_x86_flush_tlb_gva)(vcpu, gva);
        }
 
        if (!mmu->invlpg)
                return;
 
        if (root_hpa == INVALID_PAGE) {
-               mmu->invlpg(vcpu, gva, mmu->root_hpa);
+               mmu->invlpg(vcpu, gva, mmu->root.hpa);
 
                /*
                 * INVLPG is required to invalidate any global mappings for the VA,
@@ -5414,7 +5451,7 @@ void kvm_mmu_invpcid_gva(struct kvm_vcpu *vcpu, gva_t gva, unsigned long pcid)
        uint i;
 
        if (pcid == kvm_get_active_pcid(vcpu)) {
-               mmu->invlpg(vcpu, gva, mmu->root_hpa);
+               mmu->invlpg(vcpu, gva, mmu->root.hpa);
                tlb_flush = true;
        }
 
@@ -5427,7 +5464,7 @@ void kvm_mmu_invpcid_gva(struct kvm_vcpu *vcpu, gva_t gva, unsigned long pcid)
        }
 
        if (tlb_flush)
-               static_call(kvm_x86_tlb_flush_gva)(vcpu, gva);
+               static_call(kvm_x86_flush_tlb_gva)(vcpu, gva);
 
        ++vcpu->stat.invlpg;
 
@@ -5527,8 +5564,8 @@ static int __kvm_mmu_create(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu)
        struct page *page;
        int i;
 
-       mmu->root_hpa = INVALID_PAGE;
-       mmu->root_pgd = 0;
+       mmu->root.hpa = INVALID_PAGE;
+       mmu->root.pgd = 0;
        for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++)
                mmu->prev_roots[i] = KVM_MMU_ROOT_INFO_INVALID;
 
@@ -5648,9 +5685,13 @@ restart:
        }
 
        /*
-        * Trigger a remote TLB flush before freeing the page tables to ensure
-        * KVM is not in the middle of a lockless shadow page table walk, which
-        * may reference the pages.
+        * Kick all vCPUs (via remote TLB flush) before freeing the page tables
+        * to ensure KVM is not in the middle of a lockless shadow page table
+        * walk, which may reference the pages.  The remote TLB flush itself is
+        * not required and is simply a convenient way to kick vCPUs as needed.
+        * KVM performs a local TLB flush when allocating a new root (see
+        * kvm_mmu_load()), and the reload in the caller ensure no vCPUs are
+        * running with an obsolete MMU.
         */
        kvm_mmu_commit_zap_page(kvm, &kvm->arch.zapped_obsolete_pages);
 }
@@ -5680,11 +5721,11 @@ static void kvm_mmu_zap_all_fast(struct kvm *kvm)
         */
        kvm->arch.mmu_valid_gen = kvm->arch.mmu_valid_gen ? 0 : 1;
 
-       /* In order to ensure all threads see this change when
-        * handling the MMU reload signal, this must happen in the
-        * same critical section as kvm_reload_remote_mmus, and
-        * before kvm_zap_obsolete_pages as kvm_zap_obsolete_pages
-        * could drop the MMU lock and yield.
+       /*
+        * In order to ensure all vCPUs drop their soon-to-be invalid roots,
+        * invalidating TDP MMU roots must be done while holding mmu_lock for
+        * write and in the same critical section as making the reload request,
+        * e.g. before kvm_zap_obsolete_pages() could drop mmu_lock and yield.
         */
        if (is_tdp_mmu_enabled(kvm))
                kvm_tdp_mmu_invalidate_all_roots(kvm);
@@ -5697,17 +5738,22 @@ static void kvm_mmu_zap_all_fast(struct kvm *kvm)
         * Note: we need to do this under the protection of mmu_lock,
         * otherwise, vcpu would purge shadow page but miss tlb flush.
         */
-       kvm_reload_remote_mmus(kvm);
+       kvm_make_all_cpus_request(kvm, KVM_REQ_MMU_FREE_OBSOLETE_ROOTS);
 
        kvm_zap_obsolete_pages(kvm);
 
        write_unlock(&kvm->mmu_lock);
 
-       if (is_tdp_mmu_enabled(kvm)) {
-               read_lock(&kvm->mmu_lock);
+       /*
+        * Zap the invalidated TDP MMU roots, all SPTEs must be dropped before
+        * returning to the caller, e.g. if the zap is in response to a memslot
+        * deletion, mmu_notifier callbacks will be unable to reach the SPTEs
+        * associated with the deleted memslot once the update completes, and
+        * Deferring the zap until the final reference to the root is put would
+        * lead to use-after-free.
+        */
+       if (is_tdp_mmu_enabled(kvm))
                kvm_tdp_mmu_zap_invalidated_roots(kvm);
-               read_unlock(&kvm->mmu_lock);
-       }
 }
 
 static bool kvm_has_zapped_obsolete_pages(struct kvm *kvm)
@@ -5722,17 +5768,24 @@ static void kvm_mmu_invalidate_zap_pages_in_memslot(struct kvm *kvm,
        kvm_mmu_zap_all_fast(kvm);
 }
 
-void kvm_mmu_init_vm(struct kvm *kvm)
+int kvm_mmu_init_vm(struct kvm *kvm)
 {
        struct kvm_page_track_notifier_node *node = &kvm->arch.mmu_sp_tracker;
+       int r;
 
+       INIT_LIST_HEAD(&kvm->arch.active_mmu_pages);
+       INIT_LIST_HEAD(&kvm->arch.zapped_obsolete_pages);
+       INIT_LIST_HEAD(&kvm->arch.lpage_disallowed_mmu_pages);
        spin_lock_init(&kvm->arch.mmu_unsync_pages_lock);
 
-       kvm_mmu_init_tdp_mmu(kvm);
+       r = kvm_mmu_init_tdp_mmu(kvm);
+       if (r < 0)
+               return r;
 
        node->track_write = kvm_mmu_pte_write;
        node->track_flush_slot = kvm_mmu_invalidate_zap_pages_in_memslot;
        kvm_page_track_register_notifier(kvm, node);
+       return 0;
 }
 
 void kvm_mmu_uninit_vm(struct kvm *kvm)
@@ -5796,8 +5849,8 @@ void kvm_zap_gfn_range(struct kvm *kvm, gfn_t gfn_start, gfn_t gfn_end)
 
        if (is_tdp_mmu_enabled(kvm)) {
                for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++)
-                       flush = kvm_tdp_mmu_zap_gfn_range(kvm, i, gfn_start,
-                                                         gfn_end, flush);
+                       flush = kvm_tdp_mmu_zap_leafs(kvm, i, gfn_start,
+                                                     gfn_end, true, flush);
        }
 
        if (flush)
@@ -5813,7 +5866,7 @@ static bool slot_rmap_write_protect(struct kvm *kvm,
                                    struct kvm_rmap_head *rmap_head,
                                    const struct kvm_memory_slot *slot)
 {
-       return __rmap_write_protect(kvm, rmap_head, false);
+       return rmap_write_protect(rmap_head, false);
 }
 
 void kvm_mmu_slot_remove_write_access(struct kvm *kvm,
@@ -5857,12 +5910,52 @@ void kvm_mmu_slot_remove_write_access(struct kvm *kvm,
         * will clear a separate software-only bit (MMU-writable) and skip the
         * flush if-and-only-if this bit was already clear.
         *
-        * See DEFAULT_SPTE_MMU_WRITEABLE for more details.
+        * See is_writable_pte() for more details.
         */
        if (flush)
                kvm_arch_flush_remote_tlbs_memslot(kvm, memslot);
 }
 
+/* Must be called with the mmu_lock held in write-mode. */
+void kvm_mmu_try_split_huge_pages(struct kvm *kvm,
+                                  const struct kvm_memory_slot *memslot,
+                                  u64 start, u64 end,
+                                  int target_level)
+{
+       if (is_tdp_mmu_enabled(kvm))
+               kvm_tdp_mmu_try_split_huge_pages(kvm, memslot, start, end,
+                                                target_level, false);
+
+       /*
+        * A TLB flush is unnecessary at this point for the same resons as in
+        * kvm_mmu_slot_try_split_huge_pages().
+        */
+}
+
+void kvm_mmu_slot_try_split_huge_pages(struct kvm *kvm,
+                                       const struct kvm_memory_slot *memslot,
+                                       int target_level)
+{
+       u64 start = memslot->base_gfn;
+       u64 end = start + memslot->npages;
+
+       if (is_tdp_mmu_enabled(kvm)) {
+               read_lock(&kvm->mmu_lock);
+               kvm_tdp_mmu_try_split_huge_pages(kvm, memslot, start, end, target_level, true);
+               read_unlock(&kvm->mmu_lock);
+       }
+
+       /*
+        * No TLB flush is necessary here. KVM will flush TLBs after
+        * write-protecting and/or clearing dirty on the newly split SPTEs to
+        * ensure that guest writes are reflected in the dirty log before the
+        * ioctl to enable dirty logging on this memslot completes. Since the
+        * split SPTEs retain the write and dirty bits of the huge SPTE, it is
+        * safe for KVM to decide if a TLB flush is necessary based on the split
+        * SPTEs.
+        */
+}
+
 static bool kvm_mmu_zap_collapsible_spte(struct kvm *kvm,
                                         struct kvm_rmap_head *rmap_head,
                                         const struct kvm_memory_slot *slot)
@@ -6144,12 +6237,24 @@ static int set_nx_huge_pages(const char *val, const struct kernel_param *kp)
        return 0;
 }
 
-int kvm_mmu_module_init(void)
+/*
+ * nx_huge_pages needs to be resolved to true/false when kvm.ko is loaded, as
+ * its default value of -1 is technically undefined behavior for a boolean.
+ */
+void kvm_mmu_x86_module_init(void)
 {
-       int ret = -ENOMEM;
-
        if (nx_huge_pages == -1)
                __set_nx_huge_pages(get_nx_auto_mode());
+}
+
+/*
+ * The bulk of the MMU initialization is deferred until the vendor module is
+ * loaded as many of the masks/values may be modified by VMX or SVM, i.e. need
+ * to be reset when a potentially different vendor module is loaded.
+ */
+int kvm_mmu_vendor_module_init(void)
+{
+       int ret = -ENOMEM;
 
        /*
         * MMU roles use union aliasing which is, generally speaking, an
@@ -6197,12 +6302,11 @@ void kvm_mmu_destroy(struct kvm_vcpu *vcpu)
        mmu_free_memory_caches(vcpu);
 }
 
-void kvm_mmu_module_exit(void)
+void kvm_mmu_vendor_module_exit(void)
 {
        mmu_destroy_caches();
        percpu_counter_destroy(&kvm_total_used_mmu_pages);
        unregister_shrinker(&mmu_shrinker);
-       mmu_audit_disable();
 }
 
 /*
@@ -6272,6 +6376,13 @@ static void kvm_recover_nx_lpages(struct kvm *kvm)
        rcu_idx = srcu_read_lock(&kvm->srcu);
        write_lock(&kvm->mmu_lock);
 
+       /*
+        * Zapping TDP MMU shadow pages, including the remote TLB flush, must
+        * be done under RCU protection, because the pages are freed via RCU
+        * callback.
+        */
+       rcu_read_lock();
+
        ratio = READ_ONCE(nx_huge_pages_recovery_ratio);
        to_zap = ratio ? DIV_ROUND_UP(nx_lpage_splits, ratio) : 0;
        for ( ; to_zap; --to_zap) {
@@ -6296,12 +6407,18 @@ static void kvm_recover_nx_lpages(struct kvm *kvm)
 
                if (need_resched() || rwlock_needbreak(&kvm->mmu_lock)) {
                        kvm_mmu_remote_flush_or_zap(kvm, &invalid_list, flush);
+                       rcu_read_unlock();
+
                        cond_resched_rwlock_write(&kvm->mmu_lock);
                        flush = false;
+
+                       rcu_read_lock();
                }
        }
        kvm_mmu_remote_flush_or_zap(kvm, &invalid_list, flush);
 
+       rcu_read_unlock();
+
        write_unlock(&kvm->mmu_lock);
        srcu_read_unlock(&kvm->srcu, rcu_idx);
 }