Merge remote-tracking branch 'torvalds/master' into perf/core
[linux-2.6-microblaze.git] / arch / x86 / kvm / mmu / mmu.c
index 951dae4..8d5876d 100644 (file)
@@ -48,6 +48,7 @@
 #include <asm/memtype.h>
 #include <asm/cmpxchg.h>
 #include <asm/io.h>
+#include <asm/set_memory.h>
 #include <asm/vmx.h>
 #include <asm/kvm_page_track.h>
 #include "trace.h"
@@ -215,10 +216,10 @@ bool is_nx_huge_page_enabled(void)
 static void mark_mmio_spte(struct kvm_vcpu *vcpu, u64 *sptep, u64 gfn,
                           unsigned int access)
 {
-       u64 mask = make_mmio_spte(vcpu, gfn, access);
+       u64 spte = make_mmio_spte(vcpu, gfn, access);
 
-       trace_mark_mmio_spte(sptep, gfn, mask);
-       mmu_spte_set(sptep, mask);
+       trace_mark_mmio_spte(sptep, gfn, spte);
+       mmu_spte_set(sptep, spte);
 }
 
 static gfn_t get_mmio_spte_gfn(u64 spte)
@@ -236,17 +237,6 @@ static unsigned get_mmio_spte_access(u64 spte)
        return spte & shadow_mmio_access_mask;
 }
 
-static bool set_mmio_spte(struct kvm_vcpu *vcpu, u64 *sptep, gfn_t gfn,
-                         kvm_pfn_t pfn, unsigned int access)
-{
-       if (unlikely(is_noslot_pfn(pfn))) {
-               mark_mmio_spte(vcpu, sptep, gfn, access);
-               return true;
-       }
-
-       return false;
-}
-
 static bool check_mmio_spte(struct kvm_vcpu *vcpu, u64 spte)
 {
        u64 kvm_gen, spte_gen, gen;
@@ -725,8 +715,7 @@ static void kvm_mmu_page_set_gfn(struct kvm_mmu_page *sp, int index, gfn_t gfn)
  * handling slots that are not large page aligned.
  */
 static struct kvm_lpage_info *lpage_info_slot(gfn_t gfn,
-                                             struct kvm_memory_slot *slot,
-                                             int level)
+               const struct kvm_memory_slot *slot, int level)
 {
        unsigned long idx;
 
@@ -1118,7 +1107,7 @@ static bool spte_write_protect(u64 *sptep, bool pt_protect)
        rmap_printk("spte %p %llx\n", sptep, *sptep);
 
        if (pt_protect)
-               spte &= ~SPTE_MMU_WRITEABLE;
+               spte &= ~shadow_mmu_writable_mask;
        spte = spte & ~PT_WRITABLE_MASK;
 
        return mmu_spte_update(sptep, spte);
@@ -1308,26 +1297,25 @@ static bool kvm_zap_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
        return flush;
 }
 
-static int kvm_unmap_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
-                          struct kvm_memory_slot *slot, gfn_t gfn, int level,
-                          unsigned long data)
+static bool kvm_unmap_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
+                           struct kvm_memory_slot *slot, gfn_t gfn, int level,
+                           pte_t unused)
 {
        return kvm_zap_rmapp(kvm, rmap_head, slot);
 }
 
-static int kvm_set_pte_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
-                            struct kvm_memory_slot *slot, gfn_t gfn, int level,
-                            unsigned long data)
+static bool kvm_set_pte_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
+                             struct kvm_memory_slot *slot, gfn_t gfn, int level,
+                             pte_t pte)
 {
        u64 *sptep;
        struct rmap_iterator iter;
        int need_flush = 0;
        u64 new_spte;
-       pte_t *ptep = (pte_t *)data;
        kvm_pfn_t new_pfn;
 
-       WARN_ON(pte_huge(*ptep));
-       new_pfn = pte_pfn(*ptep);
+       WARN_ON(pte_huge(pte));
+       new_pfn = pte_pfn(pte);
 
 restart:
        for_each_rmap_spte(rmap_head, &iter, sptep) {
@@ -1336,7 +1324,7 @@ restart:
 
                need_flush = 1;
 
-               if (pte_write(*ptep)) {
+               if (pte_write(pte)) {
                        pte_list_remove(rmap_head, sptep);
                        goto restart;
                } else {
@@ -1424,93 +1412,52 @@ static void slot_rmap_walk_next(struct slot_rmap_walk_iterator *iterator)
             slot_rmap_walk_okay(_iter_);                               \
             slot_rmap_walk_next(_iter_))
 
-static __always_inline int
-kvm_handle_hva_range(struct kvm *kvm,
-                    unsigned long start,
-                    unsigned long end,
-                    unsigned long data,
-                    int (*handler)(struct kvm *kvm,
-                                   struct kvm_rmap_head *rmap_head,
-                                   struct kvm_memory_slot *slot,
-                                   gfn_t gfn,
-                                   int level,
-                                   unsigned long data))
+typedef bool (*rmap_handler_t)(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
+                              struct kvm_memory_slot *slot, gfn_t gfn,
+                              int level, pte_t pte);
+
+static __always_inline bool kvm_handle_gfn_range(struct kvm *kvm,
+                                                struct kvm_gfn_range *range,
+                                                rmap_handler_t handler)
 {
-       struct kvm_memslots *slots;
-       struct kvm_memory_slot *memslot;
        struct slot_rmap_walk_iterator iterator;
-       int ret = 0;
-       int i;
+       bool ret = false;
 
-       for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++) {
-               slots = __kvm_memslots(kvm, i);
-               kvm_for_each_memslot(memslot, slots) {
-                       unsigned long hva_start, hva_end;
-                       gfn_t gfn_start, gfn_end;
-
-                       hva_start = max(start, memslot->userspace_addr);
-                       hva_end = min(end, memslot->userspace_addr +
-                                     (memslot->npages << PAGE_SHIFT));
-                       if (hva_start >= hva_end)
-                               continue;
-                       /*
-                        * {gfn(page) | page intersects with [hva_start, hva_end)} =
-                        * {gfn_start, gfn_start+1, ..., gfn_end-1}.
-                        */
-                       gfn_start = hva_to_gfn_memslot(hva_start, memslot);
-                       gfn_end = hva_to_gfn_memslot(hva_end + PAGE_SIZE - 1, memslot);
-
-                       for_each_slot_rmap_range(memslot, PG_LEVEL_4K,
-                                                KVM_MAX_HUGEPAGE_LEVEL,
-                                                gfn_start, gfn_end - 1,
-                                                &iterator)
-                               ret |= handler(kvm, iterator.rmap, memslot,
-                                              iterator.gfn, iterator.level, data);
-               }
-       }
+       for_each_slot_rmap_range(range->slot, PG_LEVEL_4K, KVM_MAX_HUGEPAGE_LEVEL,
+                                range->start, range->end - 1, &iterator)
+               ret |= handler(kvm, iterator.rmap, range->slot, iterator.gfn,
+                              iterator.level, range->pte);
 
        return ret;
 }
 
-static int kvm_handle_hva(struct kvm *kvm, unsigned long hva,
-                         unsigned long data,
-                         int (*handler)(struct kvm *kvm,
-                                        struct kvm_rmap_head *rmap_head,
-                                        struct kvm_memory_slot *slot,
-                                        gfn_t gfn, int level,
-                                        unsigned long data))
-{
-       return kvm_handle_hva_range(kvm, hva, hva + 1, data, handler);
-}
-
-int kvm_unmap_hva_range(struct kvm *kvm, unsigned long start, unsigned long end,
-                       unsigned flags)
+bool kvm_unmap_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-       int r;
+       bool flush;
 
-       r = kvm_handle_hva_range(kvm, start, end, 0, kvm_unmap_rmapp);
+       flush = kvm_handle_gfn_range(kvm, range, kvm_unmap_rmapp);
 
        if (is_tdp_mmu_enabled(kvm))
-               r |= kvm_tdp_mmu_zap_hva_range(kvm, start, end);
+               flush |= kvm_tdp_mmu_unmap_gfn_range(kvm, range, flush);
 
-       return r;
+       return flush;
 }
 
-int kvm_set_spte_hva(struct kvm *kvm, unsigned long hva, pte_t pte)
+bool kvm_set_spte_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-       int r;
+       bool flush;
 
-       r = kvm_handle_hva(kvm, hva, (unsigned long)&pte, kvm_set_pte_rmapp);
+       flush = kvm_handle_gfn_range(kvm, range, kvm_set_pte_rmapp);
 
        if (is_tdp_mmu_enabled(kvm))
-               r |= kvm_tdp_mmu_set_spte_hva(kvm, hva, &pte);
+               flush |= kvm_tdp_mmu_set_spte_gfn(kvm, range);
 
-       return r;
+       return flush;
 }
 
-static int kvm_age_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
-                        struct kvm_memory_slot *slot, gfn_t gfn, int level,
-                        unsigned long data)
+static bool kvm_age_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
+                         struct kvm_memory_slot *slot, gfn_t gfn, int level,
+                         pte_t unused)
 {
        u64 *sptep;
        struct rmap_iterator iter;
@@ -1519,13 +1466,12 @@ static int kvm_age_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
        for_each_rmap_spte(rmap_head, &iter, sptep)
                young |= mmu_spte_age(sptep);
 
-       trace_kvm_age_page(gfn, level, slot, young);
        return young;
 }
 
-static int kvm_test_age_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
-                             struct kvm_memory_slot *slot, gfn_t gfn,
-                             int level, unsigned long data)
+static bool kvm_test_age_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
+                              struct kvm_memory_slot *slot, gfn_t gfn,
+                              int level, pte_t unused)
 {
        u64 *sptep;
        struct rmap_iterator iter;
@@ -1547,29 +1493,31 @@ static void rmap_recycle(struct kvm_vcpu *vcpu, u64 *spte, gfn_t gfn)
 
        rmap_head = gfn_to_rmap(vcpu->kvm, gfn, sp);
 
-       kvm_unmap_rmapp(vcpu->kvm, rmap_head, NULL, gfn, sp->role.level, 0);
+       kvm_unmap_rmapp(vcpu->kvm, rmap_head, NULL, gfn, sp->role.level, __pte(0));
        kvm_flush_remote_tlbs_with_address(vcpu->kvm, sp->gfn,
                        KVM_PAGES_PER_HPAGE(sp->role.level));
 }
 
-int kvm_age_hva(struct kvm *kvm, unsigned long start, unsigned long end)
+bool kvm_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-       int young = false;
+       bool young;
+
+       young = kvm_handle_gfn_range(kvm, range, kvm_age_rmapp);
 
-       young = kvm_handle_hva_range(kvm, start, end, 0, kvm_age_rmapp);
        if (is_tdp_mmu_enabled(kvm))
-               young |= kvm_tdp_mmu_age_hva_range(kvm, start, end);
+               young |= kvm_tdp_mmu_age_gfn_range(kvm, range);
 
        return young;
 }
 
-int kvm_test_age_hva(struct kvm *kvm, unsigned long hva)
+bool kvm_test_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-       int young = false;
+       bool young;
+
+       young = kvm_handle_gfn_range(kvm, range, kvm_test_age_rmapp);
 
-       young = kvm_handle_hva(kvm, hva, 0, kvm_test_age_rmapp);
        if (is_tdp_mmu_enabled(kvm))
-               young |= kvm_tdp_mmu_test_age_hva(kvm, hva);
+               young |= kvm_tdp_mmu_test_age_gfn(kvm, range);
 
        return young;
 }
@@ -2421,6 +2369,15 @@ static int make_mmu_pages_available(struct kvm_vcpu *vcpu)
 
        kvm_mmu_zap_oldest_mmu_pages(vcpu->kvm, KVM_REFILL_PAGES - avail);
 
+       /*
+        * Note, this check is intentionally soft, it only guarantees that one
+        * page is available, while the caller may end up allocating as many as
+        * four pages, e.g. for PAE roots or for 5-level paging.  Temporarily
+        * exceeding the (arbitrary by default) limit will not harm the host,
+        * being too agressive may unnecessarily kill the guest, and getting an
+        * exact count is far more trouble than it's worth, especially in the
+        * page fault paths.
+        */
        if (!kvm_mmu_available_pages(vcpu->kvm))
                return -ENOSPC;
        return 0;
@@ -2561,9 +2518,6 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
        struct kvm_mmu_page *sp;
        int ret;
 
-       if (set_mmio_spte(vcpu, sptep, gfn, pfn, pte_access))
-               return 0;
-
        sp = sptep_to_sp(sptep);
 
        ret = make_spte(vcpu, pte_access, level, gfn, pfn, *sptep, speculative,
@@ -2593,6 +2547,11 @@ static int mmu_set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
        pgprintk("%s: spte %llx write_fault %d gfn %llx\n", __func__,
                 *sptep, write_fault, gfn);
 
+       if (unlikely(is_noslot_pfn(pfn))) {
+               mark_mmio_spte(vcpu, sptep, gfn, pte_access);
+               return RET_PF_EMULATE;
+       }
+
        if (is_shadow_present_pte(*sptep)) {
                /*
                 * If we overwrite a PTE page pointer with a 2MB PMD, unlink
@@ -2626,9 +2585,6 @@ static int mmu_set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                kvm_flush_remote_tlbs_with_address(vcpu->kvm, gfn,
                                KVM_PAGES_PER_HPAGE(level));
 
-       if (unlikely(is_mmio_spte(*sptep)))
-               ret = RET_PF_EMULATE;
-
        /*
         * The fault is fully spurious if and only if the new SPTE and old SPTE
         * are identical, and emulation is not required.
@@ -2745,7 +2701,7 @@ static void direct_pte_prefetch(struct kvm_vcpu *vcpu, u64 *sptep)
 }
 
 static int host_pfn_mapping_level(struct kvm *kvm, gfn_t gfn, kvm_pfn_t pfn,
-                                 struct kvm_memory_slot *slot)
+                                 const struct kvm_memory_slot *slot)
 {
        unsigned long hva;
        pte_t *pte;
@@ -2771,8 +2727,9 @@ static int host_pfn_mapping_level(struct kvm *kvm, gfn_t gfn, kvm_pfn_t pfn,
        return level;
 }
 
-int kvm_mmu_max_mapping_level(struct kvm *kvm, struct kvm_memory_slot *slot,
-                             gfn_t gfn, kvm_pfn_t pfn, int max_level)
+int kvm_mmu_max_mapping_level(struct kvm *kvm,
+                             const struct kvm_memory_slot *slot, gfn_t gfn,
+                             kvm_pfn_t pfn, int max_level)
 {
        struct kvm_lpage_info *linfo;
 
@@ -2946,9 +2903,19 @@ static bool handle_abnormal_pfn(struct kvm_vcpu *vcpu, gva_t gva, gfn_t gfn,
                return true;
        }
 
-       if (unlikely(is_noslot_pfn(pfn)))
+       if (unlikely(is_noslot_pfn(pfn))) {
                vcpu_cache_mmio_info(vcpu, gva, gfn,
                                     access & shadow_mmio_access_mask);
+               /*
+                * If MMIO caching is disabled, emulate immediately without
+                * touching the shadow page tables as attempting to install an
+                * MMIO SPTE will just be an expensive nop.
+                */
+               if (unlikely(!shadow_mmio_value)) {
+                       *ret_val = RET_PF_EMULATE;
+                       return true;
+               }
+       }
 
        return false;
 }
@@ -3061,6 +3028,9 @@ static int fast_page_fault(struct kvm_vcpu *vcpu, gpa_t cr2_or_gpa,
                        if (!is_shadow_present_pte(spte))
                                break;
 
+               if (!is_shadow_present_pte(spte))
+                       break;
+
                sp = sptep_to_sp(iterator.sptep);
                if (!is_last_spte(spte, sp->role.level))
                        break;
@@ -3150,12 +3120,10 @@ static void mmu_free_root_page(struct kvm *kvm, hpa_t *root_hpa,
 
        sp = to_shadow_page(*root_hpa & PT64_BASE_ADDR_MASK);
 
-       if (kvm_mmu_put_root(kvm, sp)) {
-               if (is_tdp_mmu_page(sp))
-                       kvm_tdp_mmu_free_root(kvm, sp);
-               else if (sp->role.invalid)
-                       kvm_mmu_prepare_zap_page(kvm, sp, invalid_list);
-       }
+       if (is_tdp_mmu_page(sp))
+               kvm_tdp_mmu_put_root(kvm, sp, false);
+       else if (!--sp->root_count && sp->role.invalid)
+               kvm_mmu_prepare_zap_page(kvm, sp, invalid_list);
 
        *root_hpa = INVALID_PAGE;
 }
@@ -3193,14 +3161,17 @@ void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
                if (mmu->shadow_root_level >= PT64_ROOT_4LEVEL &&
                    (mmu->root_level >= PT64_ROOT_4LEVEL || mmu->direct_map)) {
                        mmu_free_root_page(kvm, &mmu->root_hpa, &invalid_list);
-               } else {
-                       for (i = 0; i < 4; ++i)
-                               if (mmu->pae_root[i] != 0)
-                                       mmu_free_root_page(kvm,
-                                                          &mmu->pae_root[i],
-                                                          &invalid_list);
-                       mmu->root_hpa = INVALID_PAGE;
+               } else if (mmu->pae_root) {
+                       for (i = 0; i < 4; ++i) {
+                               if (!IS_VALID_PAE_ROOT(mmu->pae_root[i]))
+                                       continue;
+
+                               mmu_free_root_page(kvm, &mmu->pae_root[i],
+                                                  &invalid_list);
+                               mmu->pae_root[i] = INVALID_PAE_ROOT;
+                       }
                }
+               mmu->root_hpa = INVALID_PAGE;
                mmu->root_pgd = 0;
        }
 
@@ -3226,155 +3197,208 @@ static hpa_t mmu_alloc_root(struct kvm_vcpu *vcpu, gfn_t gfn, gva_t gva,
 {
        struct kvm_mmu_page *sp;
 
-       write_lock(&vcpu->kvm->mmu_lock);
-
-       if (make_mmu_pages_available(vcpu)) {
-               write_unlock(&vcpu->kvm->mmu_lock);
-               return INVALID_PAGE;
-       }
        sp = kvm_mmu_get_page(vcpu, gfn, gva, level, direct, ACC_ALL);
        ++sp->root_count;
 
-       write_unlock(&vcpu->kvm->mmu_lock);
        return __pa(sp->spt);
 }
 
 static int mmu_alloc_direct_roots(struct kvm_vcpu *vcpu)
 {
-       u8 shadow_root_level = vcpu->arch.mmu->shadow_root_level;
+       struct kvm_mmu *mmu = vcpu->arch.mmu;
+       u8 shadow_root_level = mmu->shadow_root_level;
        hpa_t root;
        unsigned i;
+       int r;
+
+       write_lock(&vcpu->kvm->mmu_lock);
+       r = make_mmu_pages_available(vcpu);
+       if (r < 0)
+               goto out_unlock;
 
        if (is_tdp_mmu_enabled(vcpu->kvm)) {
                root = kvm_tdp_mmu_get_vcpu_root_hpa(vcpu);
-
-               if (!VALID_PAGE(root))
-                       return -ENOSPC;
-               vcpu->arch.mmu->root_hpa = root;
+               mmu->root_hpa = root;
        } else if (shadow_root_level >= PT64_ROOT_4LEVEL) {
-               root = mmu_alloc_root(vcpu, 0, 0, shadow_root_level,
-                                     true);
-
-               if (!VALID_PAGE(root))
-                       return -ENOSPC;
-               vcpu->arch.mmu->root_hpa = root;
+               root = mmu_alloc_root(vcpu, 0, 0, shadow_root_level, true);
+               mmu->root_hpa = root;
        } else if (shadow_root_level == PT32E_ROOT_LEVEL) {
+               if (WARN_ON_ONCE(!mmu->pae_root)) {
+                       r = -EIO;
+                       goto out_unlock;
+               }
+
                for (i = 0; i < 4; ++i) {
-                       MMU_WARN_ON(VALID_PAGE(vcpu->arch.mmu->pae_root[i]));
+                       WARN_ON_ONCE(IS_VALID_PAE_ROOT(mmu->pae_root[i]));
 
                        root = mmu_alloc_root(vcpu, i << (30 - PAGE_SHIFT),
                                              i << 30, PT32_ROOT_LEVEL, true);
-                       if (!VALID_PAGE(root))
-                               return -ENOSPC;
-                       vcpu->arch.mmu->pae_root[i] = root | PT_PRESENT_MASK;
+                       mmu->pae_root[i] = root | PT_PRESENT_MASK |
+                                          shadow_me_mask;
                }
-               vcpu->arch.mmu->root_hpa = __pa(vcpu->arch.mmu->pae_root);
-       } else
-               BUG();
+               mmu->root_hpa = __pa(mmu->pae_root);
+       } else {
+               WARN_ONCE(1, "Bad TDP root level = %d\n", shadow_root_level);
+               r = -EIO;
+               goto out_unlock;
+       }
 
        /* root_pgd is ignored for direct MMUs. */
-       vcpu->arch.mmu->root_pgd = 0;
-
-       return 0;
+       mmu->root_pgd = 0;
+out_unlock:
+       write_unlock(&vcpu->kvm->mmu_lock);
+       return r;
 }
 
 static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
 {
-       u64 pdptr, pm_mask;
+       struct kvm_mmu *mmu = vcpu->arch.mmu;
+       u64 pdptrs[4], pm_mask;
        gfn_t root_gfn, root_pgd;
        hpa_t root;
-       int i;
+       unsigned i;
+       int r;
 
-       root_pgd = vcpu->arch.mmu->get_guest_pgd(vcpu);
+       root_pgd = mmu->get_guest_pgd(vcpu);
        root_gfn = root_pgd >> PAGE_SHIFT;
 
        if (mmu_check_root(vcpu, root_gfn))
                return 1;
 
+       /*
+        * On SVM, reading PDPTRs might access guest memory, which might fault
+        * and thus might sleep.  Grab the PDPTRs before acquiring mmu_lock.
+        */
+       if (mmu->root_level == PT32E_ROOT_LEVEL) {
+               for (i = 0; i < 4; ++i) {
+                       pdptrs[i] = mmu->get_pdptr(vcpu, i);
+                       if (!(pdptrs[i] & PT_PRESENT_MASK))
+                               continue;
+
+                       if (mmu_check_root(vcpu, pdptrs[i] >> PAGE_SHIFT))
+                               return 1;
+               }
+       }
+
+       write_lock(&vcpu->kvm->mmu_lock);
+       r = make_mmu_pages_available(vcpu);
+       if (r < 0)
+               goto out_unlock;
+
        /*
         * Do we shadow a long mode page table? If so we need to
         * write-protect the guests page table root.
         */
-       if (vcpu->arch.mmu->root_level >= PT64_ROOT_4LEVEL) {
-               MMU_WARN_ON(VALID_PAGE(vcpu->arch.mmu->root_hpa));
-
+       if (mmu->root_level >= PT64_ROOT_4LEVEL) {
                root = mmu_alloc_root(vcpu, root_gfn, 0,
-                                     vcpu->arch.mmu->shadow_root_level, false);
-               if (!VALID_PAGE(root))
-                       return -ENOSPC;
-               vcpu->arch.mmu->root_hpa = root;
+                                     mmu->shadow_root_level, false);
+               mmu->root_hpa = root;
                goto set_root_pgd;
        }
 
+       if (WARN_ON_ONCE(!mmu->pae_root)) {
+               r = -EIO;
+               goto out_unlock;
+       }
+
        /*
         * We shadow a 32 bit page table. This may be a legacy 2-level
         * or a PAE 3-level page table. In either case we need to be aware that
         * the shadow page table may be a PAE or a long mode page table.
         */
-       pm_mask = PT_PRESENT_MASK;
-       if (vcpu->arch.mmu->shadow_root_level == PT64_ROOT_4LEVEL)
+       pm_mask = PT_PRESENT_MASK | shadow_me_mask;
+       if (mmu->shadow_root_level == PT64_ROOT_4LEVEL) {
                pm_mask |= PT_ACCESSED_MASK | PT_WRITABLE_MASK | PT_USER_MASK;
 
+               if (WARN_ON_ONCE(!mmu->pml4_root)) {
+                       r = -EIO;
+                       goto out_unlock;
+               }
+
+               mmu->pml4_root[0] = __pa(mmu->pae_root) | pm_mask;
+       }
+
        for (i = 0; i < 4; ++i) {
-               MMU_WARN_ON(VALID_PAGE(vcpu->arch.mmu->pae_root[i]));
-               if (vcpu->arch.mmu->root_level == PT32E_ROOT_LEVEL) {
-                       pdptr = vcpu->arch.mmu->get_pdptr(vcpu, i);
-                       if (!(pdptr & PT_PRESENT_MASK)) {
-                               vcpu->arch.mmu->pae_root[i] = 0;
+               WARN_ON_ONCE(IS_VALID_PAE_ROOT(mmu->pae_root[i]));
+
+               if (mmu->root_level == PT32E_ROOT_LEVEL) {
+                       if (!(pdptrs[i] & PT_PRESENT_MASK)) {
+                               mmu->pae_root[i] = INVALID_PAE_ROOT;
                                continue;
                        }
-                       root_gfn = pdptr >> PAGE_SHIFT;
-                       if (mmu_check_root(vcpu, root_gfn))
-                               return 1;
+                       root_gfn = pdptrs[i] >> PAGE_SHIFT;
                }
 
                root = mmu_alloc_root(vcpu, root_gfn, i << 30,
                                      PT32_ROOT_LEVEL, false);
-               if (!VALID_PAGE(root))
-                       return -ENOSPC;
-               vcpu->arch.mmu->pae_root[i] = root | pm_mask;
+               mmu->pae_root[i] = root | pm_mask;
        }
-       vcpu->arch.mmu->root_hpa = __pa(vcpu->arch.mmu->pae_root);
+
+       if (mmu->shadow_root_level == PT64_ROOT_4LEVEL)
+               mmu->root_hpa = __pa(mmu->pml4_root);
+       else
+               mmu->root_hpa = __pa(mmu->pae_root);
+
+set_root_pgd:
+       mmu->root_pgd = root_pgd;
+out_unlock:
+       write_unlock(&vcpu->kvm->mmu_lock);
+
+       return 0;
+}
+
+static int mmu_alloc_special_roots(struct kvm_vcpu *vcpu)
+{
+       struct kvm_mmu *mmu = vcpu->arch.mmu;
+       u64 *pml4_root, *pae_root;
 
        /*
-        * If we shadow a 32 bit page table with a long mode page
-        * table we enter this path.
+        * When shadowing 32-bit or PAE NPT with 64-bit NPT, the PML4 and PDP
+        * tables are allocated and initialized at root creation as there is no
+        * equivalent level in the guest's NPT to shadow.  Allocate the tables
+        * on demand, as running a 32-bit L1 VMM on 64-bit KVM is very rare.
         */
-       if (vcpu->arch.mmu->shadow_root_level == PT64_ROOT_4LEVEL) {
-               if (vcpu->arch.mmu->lm_root == NULL) {
-                       /*
-                        * The additional page necessary for this is only
-                        * allocated on demand.
-                        */
+       if (mmu->direct_map || mmu->root_level >= PT64_ROOT_4LEVEL ||
+           mmu->shadow_root_level < PT64_ROOT_4LEVEL)
+               return 0;
 
-                       u64 *lm_root;
+       /*
+        * This mess only works with 4-level paging and needs to be updated to
+        * work with 5-level paging.
+        */
+       if (WARN_ON_ONCE(mmu->shadow_root_level != PT64_ROOT_4LEVEL))
+               return -EIO;
 
-                       lm_root = (void*)get_zeroed_page(GFP_KERNEL_ACCOUNT);
-                       if (lm_root == NULL)
-                               return 1;
+       if (mmu->pae_root && mmu->pml4_root)
+               return 0;
 
-                       lm_root[0] = __pa(vcpu->arch.mmu->pae_root) | pm_mask;
+       /*
+        * The special roots should always be allocated in concert.  Yell and
+        * bail if KVM ends up in a state where only one of the roots is valid.
+        */
+       if (WARN_ON_ONCE(!tdp_enabled || mmu->pae_root || mmu->pml4_root))
+               return -EIO;
 
-                       vcpu->arch.mmu->lm_root = lm_root;
-               }
+       /*
+        * Unlike 32-bit NPT, the PDP table doesn't need to be in low mem, and
+        * doesn't need to be decrypted.
+        */
+       pae_root = (void *)get_zeroed_page(GFP_KERNEL_ACCOUNT);
+       if (!pae_root)
+               return -ENOMEM;
 
-               vcpu->arch.mmu->root_hpa = __pa(vcpu->arch.mmu->lm_root);
+       pml4_root = (void *)get_zeroed_page(GFP_KERNEL_ACCOUNT);
+       if (!pml4_root) {
+               free_page((unsigned long)pae_root);
+               return -ENOMEM;
        }
 
-set_root_pgd:
-       vcpu->arch.mmu->root_pgd = root_pgd;
+       mmu->pae_root = pae_root;
+       mmu->pml4_root = pml4_root;
 
        return 0;
 }
 
-static int mmu_alloc_roots(struct kvm_vcpu *vcpu)
-{
-       if (vcpu->arch.mmu->direct_map)
-               return mmu_alloc_direct_roots(vcpu);
-       else
-               return mmu_alloc_shadow_roots(vcpu);
-}
-
 void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu)
 {
        int i;
@@ -3422,7 +3446,7 @@ void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu)
        for (i = 0; i < 4; ++i) {
                hpa_t root = vcpu->arch.mmu->pae_root[i];
 
-               if (root && VALID_PAGE(root)) {
+               if (IS_VALID_PAE_ROOT(root)) {
                        root &= PT64_BASE_ADDR_MASK;
                        sp = to_shadow_page(root);
                        mmu_sync_children(vcpu, sp);
@@ -3554,11 +3578,12 @@ static bool get_mmio_spte(struct kvm_vcpu *vcpu, u64 addr, u64 *sptep)
                            __is_rsvd_bits_set(rsvd_check, sptes[level], level);
 
        if (reserved) {
-               pr_err("%s: detect reserved bits on spte, addr 0x%llx, dump hierarchy:\n",
+               pr_err("%s: reserved bits set on MMU-present spte, addr 0x%llx, hierarchy:\n",
                       __func__, addr);
                for (level = root; level >= leaf; level--)
-                       pr_err("------ spte 0x%llx level %d.\n",
-                              sptes[level], level);
+                       pr_err("------ spte = 0x%llx level = %d, rsvd bits = 0x%llx",
+                              sptes[level], level,
+                              rsvd_check->rsvd_bits_mask[(sptes[level] >> 7) & 1][level-1]);
        }
 
        return reserved;
@@ -3653,6 +3678,14 @@ static bool try_async_pf(struct kvm_vcpu *vcpu, bool prefault, gfn_t gfn,
        struct kvm_memory_slot *slot = kvm_vcpu_gfn_to_memslot(vcpu, gfn);
        bool async;
 
+       /*
+        * Retry the page fault if the gfn hit a memslot that is being deleted
+        * or moved.  This ensures any existing SPTEs for the old memslot will
+        * be zapped before KVM inserts a new MMIO SPTE for the gfn.
+        */
+       if (slot && (slot->flags & KVM_MEMSLOT_INVALID))
+               return true;
+
        /* Don't expose private memslots to L2. */
        if (is_guest_mode(vcpu) && !kvm_is_visible_memslot(slot)) {
                *pfn = KVM_PFN_NOSLOT;
@@ -4615,12 +4648,17 @@ void kvm_init_shadow_npt_mmu(struct kvm_vcpu *vcpu, u32 cr0, u32 cr4, u32 efer,
        struct kvm_mmu *context = &vcpu->arch.guest_mmu;
        union kvm_mmu_role new_role = kvm_calc_shadow_npt_root_page_role(vcpu);
 
-       context->shadow_root_level = new_role.base.level;
-
        __kvm_mmu_new_pgd(vcpu, nested_cr3, new_role.base, false, false);
 
-       if (new_role.as_u64 != context->mmu_role.as_u64)
+       if (new_role.as_u64 != context->mmu_role.as_u64) {
                shadow_mmu_init_context(vcpu, context, cr0, cr4, efer, new_role);
+
+               /*
+                * Override the level set by the common init helper, nested TDP
+                * always uses the host's TDP configuration.
+                */
+               context->shadow_root_level = new_role.base.level;
+       }
 }
 EXPORT_SYMBOL_GPL(kvm_init_shadow_npt_mmu);
 
@@ -4701,9 +4739,33 @@ static void init_kvm_softmmu(struct kvm_vcpu *vcpu)
        context->inject_page_fault = kvm_inject_page_fault;
 }
 
+static union kvm_mmu_role kvm_calc_nested_mmu_role(struct kvm_vcpu *vcpu)
+{
+       union kvm_mmu_role role = kvm_calc_shadow_root_page_role_common(vcpu, false);
+
+       /*
+        * Nested MMUs are used only for walking L2's gva->gpa, they never have
+        * shadow pages of their own and so "direct" has no meaning.   Set it
+        * to "true" to try to detect bogus usage of the nested MMU.
+        */
+       role.base.direct = true;
+
+       if (!is_paging(vcpu))
+               role.base.level = 0;
+       else if (is_long_mode(vcpu))
+               role.base.level = is_la57_mode(vcpu) ? PT64_ROOT_5LEVEL :
+                                                      PT64_ROOT_4LEVEL;
+       else if (is_pae(vcpu))
+               role.base.level = PT32E_ROOT_LEVEL;
+       else
+               role.base.level = PT32_ROOT_LEVEL;
+
+       return role;
+}
+
 static void init_kvm_nested_mmu(struct kvm_vcpu *vcpu)
 {
-       union kvm_mmu_role new_role = kvm_calc_mmu_role_common(vcpu, false);
+       union kvm_mmu_role new_role = kvm_calc_nested_mmu_role(vcpu);
        struct kvm_mmu *g_context = &vcpu->arch.nested_mmu;
 
        if (new_role.as_u64 == g_context->mmu_role.as_u64)
@@ -4802,16 +4864,23 @@ int kvm_mmu_load(struct kvm_vcpu *vcpu)
        r = mmu_topup_memory_caches(vcpu, !vcpu->arch.mmu->direct_map);
        if (r)
                goto out;
-       r = mmu_alloc_roots(vcpu);
-       kvm_mmu_sync_roots(vcpu);
+       r = mmu_alloc_special_roots(vcpu);
        if (r)
                goto out;
+       if (vcpu->arch.mmu->direct_map)
+               r = mmu_alloc_direct_roots(vcpu);
+       else
+               r = mmu_alloc_shadow_roots(vcpu);
+       if (r)
+               goto out;
+
+       kvm_mmu_sync_roots(vcpu);
+
        kvm_mmu_load_pgd(vcpu);
        static_call(kvm_x86_tlb_flush_current)(vcpu);
 out:
        return r;
 }
-EXPORT_SYMBOL_GPL(kvm_mmu_load);
 
 void kvm_mmu_unload(struct kvm_vcpu *vcpu)
 {
@@ -4820,7 +4889,6 @@ void kvm_mmu_unload(struct kvm_vcpu *vcpu)
        kvm_mmu_free_roots(vcpu, &vcpu->arch.guest_mmu, KVM_MMU_ROOTS_ALL);
        WARN_ON(VALID_PAGE(vcpu->arch.guest_mmu.root_hpa));
 }
-EXPORT_SYMBOL_GPL(kvm_mmu_unload);
 
 static bool need_remote_flush(u64 old, u64 new)
 {
@@ -4961,7 +5029,7 @@ static void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
 
        /*
         * No need to care whether allocation memory is successful
-        * or not since pte prefetch is skiped if it does not have
+        * or not since pte prefetch is skipped if it does not have
         * enough objects in the cache.
         */
        mmu_topup_memory_caches(vcpu, true);
@@ -5169,10 +5237,10 @@ typedef bool (*slot_level_handler) (struct kvm *kvm, struct kvm_rmap_head *rmap_
 static __always_inline bool
 slot_handle_level_range(struct kvm *kvm, struct kvm_memory_slot *memslot,
                        slot_level_handler fn, int start_level, int end_level,
-                       gfn_t start_gfn, gfn_t end_gfn, bool lock_flush_tlb)
+                       gfn_t start_gfn, gfn_t end_gfn, bool flush_on_yield,
+                       bool flush)
 {
        struct slot_rmap_walk_iterator iterator;
-       bool flush = false;
 
        for_each_slot_rmap_range(memslot, start_level, end_level, start_gfn,
                        end_gfn, &iterator) {
@@ -5180,7 +5248,7 @@ slot_handle_level_range(struct kvm *kvm, struct kvm_memory_slot *memslot,
                        flush |= fn(kvm, iterator.rmap, memslot);
 
                if (need_resched() || rwlock_needbreak(&kvm->mmu_lock)) {
-                       if (flush && lock_flush_tlb) {
+                       if (flush && flush_on_yield) {
                                kvm_flush_remote_tlbs_with_address(kvm,
                                                start_gfn,
                                                iterator.gfn - start_gfn + 1);
@@ -5190,38 +5258,34 @@ slot_handle_level_range(struct kvm *kvm, struct kvm_memory_slot *memslot,
                }
        }
 
-       if (flush && lock_flush_tlb) {
-               kvm_flush_remote_tlbs_with_address(kvm, start_gfn,
-                                                  end_gfn - start_gfn + 1);
-               flush = false;
-       }
-
        return flush;
 }
 
 static __always_inline bool
 slot_handle_level(struct kvm *kvm, struct kvm_memory_slot *memslot,
                  slot_level_handler fn, int start_level, int end_level,
-                 bool lock_flush_tlb)
+                 bool flush_on_yield)
 {
        return slot_handle_level_range(kvm, memslot, fn, start_level,
                        end_level, memslot->base_gfn,
                        memslot->base_gfn + memslot->npages - 1,
-                       lock_flush_tlb);
+                       flush_on_yield, false);
 }
 
 static __always_inline bool
 slot_handle_leaf(struct kvm *kvm, struct kvm_memory_slot *memslot,
-                slot_level_handler fn, bool lock_flush_tlb)
+                slot_level_handler fn, bool flush_on_yield)
 {
        return slot_handle_level(kvm, memslot, fn, PG_LEVEL_4K,
-                                PG_LEVEL_4K, lock_flush_tlb);
+                                PG_LEVEL_4K, flush_on_yield);
 }
 
 static void free_mmu_pages(struct kvm_mmu *mmu)
 {
+       if (!tdp_enabled && mmu->pae_root)
+               set_memory_encrypted((unsigned long)mmu->pae_root, 1);
        free_page((unsigned long)mmu->pae_root);
-       free_page((unsigned long)mmu->lm_root);
+       free_page((unsigned long)mmu->pml4_root);
 }
 
 static int __kvm_mmu_create(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu)
@@ -5240,9 +5304,11 @@ static int __kvm_mmu_create(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu)
         * while the PDP table is a per-vCPU construct that's allocated at MMU
         * creation.  When emulating 32-bit mode, cr3 is only 32 bits even on
         * x86_64.  Therefore we need to allocate the PDP table in the first
-        * 4GB of memory, which happens to fit the DMA32 zone.  Except for
-        * SVM's 32-bit NPT support, TDP paging doesn't use PAE paging and can
-        * skip allocating the PDP table.
+        * 4GB of memory, which happens to fit the DMA32 zone.  TDP paging
+        * generally doesn't use PAE paging and can skip allocating the PDP
+        * table.  The main exception, handled here, is SVM's 32-bit NPT.  The
+        * other exception is for shadowing L1's 32-bit or PAE NPT on 64-bit
+        * KVM; that horror is handled on-demand by mmu_alloc_shadow_roots().
         */
        if (tdp_enabled && kvm_mmu_get_tdp_level(vcpu) > PT32E_ROOT_LEVEL)
                return 0;
@@ -5252,8 +5318,22 @@ static int __kvm_mmu_create(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu)
                return -ENOMEM;
 
        mmu->pae_root = page_address(page);
+
+       /*
+        * CR3 is only 32 bits when PAE paging is used, thus it's impossible to
+        * get the CPU to treat the PDPTEs as encrypted.  Decrypt the page so
+        * that KVM's writes and the CPU's reads get along.  Note, this is
+        * only necessary when using shadow paging, as 64-bit NPT can get at
+        * the C-bit even when shadowing 32-bit NPT, and SME isn't supported
+        * by 32-bit kernels (when KVM itself uses 32-bit NPT).
+        */
+       if (!tdp_enabled)
+               set_memory_decrypted((unsigned long)mmu->pae_root, 1);
+       else
+               WARN_ON_ONCE(shadow_me_mask);
+
        for (i = 0; i < 4; ++i)
-               mmu->pae_root[i] = INVALID_PAGE;
+               mmu->pae_root[i] = INVALID_PAE_ROOT;
 
        return 0;
 }
@@ -5365,6 +5445,15 @@ static void kvm_mmu_zap_all_fast(struct kvm *kvm)
         */
        kvm->arch.mmu_valid_gen = kvm->arch.mmu_valid_gen ? 0 : 1;
 
+       /* In order to ensure all threads see this change when
+        * handling the MMU reload signal, this must happen in the
+        * same critical section as kvm_reload_remote_mmus, and
+        * before kvm_zap_obsolete_pages as kvm_zap_obsolete_pages
+        * could drop the MMU lock and yield.
+        */
+       if (is_tdp_mmu_enabled(kvm))
+               kvm_tdp_mmu_invalidate_all_roots(kvm);
+
        /*
         * Notify all vcpus to reload its shadow page table and flush TLB.
         * Then all vcpus will switch to new shadow page table with the new
@@ -5377,10 +5466,13 @@ static void kvm_mmu_zap_all_fast(struct kvm *kvm)
 
        kvm_zap_obsolete_pages(kvm);
 
-       if (is_tdp_mmu_enabled(kvm))
-               kvm_tdp_mmu_zap_all(kvm);
-
        write_unlock(&kvm->mmu_lock);
+
+       if (is_tdp_mmu_enabled(kvm)) {
+               read_lock(&kvm->mmu_lock);
+               kvm_tdp_mmu_zap_invalidated_roots(kvm);
+               read_unlock(&kvm->mmu_lock);
+       }
 }
 
 static bool kvm_has_zapped_obsolete_pages(struct kvm *kvm)
@@ -5420,7 +5512,7 @@ void kvm_zap_gfn_range(struct kvm *kvm, gfn_t gfn_start, gfn_t gfn_end)
        struct kvm_memslots *slots;
        struct kvm_memory_slot *memslot;
        int i;
-       bool flush;
+       bool flush = false;
 
        write_lock(&kvm->mmu_lock);
        for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++) {
@@ -5433,20 +5525,31 @@ void kvm_zap_gfn_range(struct kvm *kvm, gfn_t gfn_start, gfn_t gfn_end)
                        if (start >= end)
                                continue;
 
-                       slot_handle_level_range(kvm, memslot, kvm_zap_rmapp,
-                                               PG_LEVEL_4K,
-                                               KVM_MAX_HUGEPAGE_LEVEL,
-                                               start, end - 1, true);
+                       flush = slot_handle_level_range(kvm, memslot, kvm_zap_rmapp,
+                                                       PG_LEVEL_4K,
+                                                       KVM_MAX_HUGEPAGE_LEVEL,
+                                                       start, end - 1, true, flush);
                }
        }
 
+       if (flush)
+               kvm_flush_remote_tlbs_with_address(kvm, gfn_start, gfn_end);
+
+       write_unlock(&kvm->mmu_lock);
+
        if (is_tdp_mmu_enabled(kvm)) {
-               flush = kvm_tdp_mmu_zap_gfn_range(kvm, gfn_start, gfn_end);
+               flush = false;
+
+               read_lock(&kvm->mmu_lock);
+               for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++)
+                       flush = kvm_tdp_mmu_zap_gfn_range(kvm, i, gfn_start,
+                                                         gfn_end, flush, true);
                if (flush)
-                       kvm_flush_remote_tlbs(kvm);
-       }
+                       kvm_flush_remote_tlbs_with_address(kvm, gfn_start,
+                                                          gfn_end);
 
-       write_unlock(&kvm->mmu_lock);
+               read_unlock(&kvm->mmu_lock);
+       }
 }
 
 static bool slot_rmap_write_protect(struct kvm *kvm,
@@ -5465,10 +5568,14 @@ void kvm_mmu_slot_remove_write_access(struct kvm *kvm,
        write_lock(&kvm->mmu_lock);
        flush = slot_handle_level(kvm, memslot, slot_rmap_write_protect,
                                start_level, KVM_MAX_HUGEPAGE_LEVEL, false);
-       if (is_tdp_mmu_enabled(kvm))
-               flush |= kvm_tdp_mmu_wrprot_slot(kvm, memslot, PG_LEVEL_4K);
        write_unlock(&kvm->mmu_lock);
 
+       if (is_tdp_mmu_enabled(kvm)) {
+               read_lock(&kvm->mmu_lock);
+               flush |= kvm_tdp_mmu_wrprot_slot(kvm, memslot, start_level);
+               read_unlock(&kvm->mmu_lock);
+       }
+
        /*
         * We can flush all the TLBs out of the mmu lock without TLB
         * corruption since we just change the spte from writable to
@@ -5476,9 +5583,9 @@ void kvm_mmu_slot_remove_write_access(struct kvm *kvm,
         * spte from present to present (changing the spte from present
         * to nonpresent will flush all the TLBs immediately), in other
         * words, the only case we care is mmu_spte_update() where we
-        * have checked SPTE_HOST_WRITEABLE | SPTE_MMU_WRITEABLE
-        * instead of PT_WRITABLE_MASK, that means it does not depend
-        * on PT_WRITABLE_MASK anymore.
+        * have checked Host-writable | MMU-writable instead of
+        * PT_WRITABLE_MASK, that means it does not depend on PT_WRITABLE_MASK
+        * anymore.
         */
        if (flush)
                kvm_arch_flush_remote_tlbs_memslot(kvm, memslot);
@@ -5529,21 +5636,32 @@ void kvm_mmu_zap_collapsible_sptes(struct kvm *kvm,
 {
        /* FIXME: const-ify all uses of struct kvm_memory_slot.  */
        struct kvm_memory_slot *slot = (struct kvm_memory_slot *)memslot;
+       bool flush;
 
        write_lock(&kvm->mmu_lock);
-       slot_handle_leaf(kvm, slot, kvm_mmu_zap_collapsible_spte, true);
+       flush = slot_handle_leaf(kvm, slot, kvm_mmu_zap_collapsible_spte, true);
 
-       if (is_tdp_mmu_enabled(kvm))
-               kvm_tdp_mmu_zap_collapsible_sptes(kvm, slot);
+       if (flush)
+               kvm_arch_flush_remote_tlbs_memslot(kvm, slot);
        write_unlock(&kvm->mmu_lock);
+
+       if (is_tdp_mmu_enabled(kvm)) {
+               flush = false;
+
+               read_lock(&kvm->mmu_lock);
+               flush = kvm_tdp_mmu_zap_collapsible_sptes(kvm, slot, flush);
+               if (flush)
+                       kvm_arch_flush_remote_tlbs_memslot(kvm, slot);
+               read_unlock(&kvm->mmu_lock);
+       }
 }
 
 void kvm_arch_flush_remote_tlbs_memslot(struct kvm *kvm,
-                                       struct kvm_memory_slot *memslot)
+                                       const struct kvm_memory_slot *memslot)
 {
        /*
         * All current use cases for flushing the TLBs for a specific memslot
-        * are related to dirty logging, and do the TLB flush out of mmu_lock.
+        * related to dirty logging, and many do the TLB flush out of mmu_lock.
         * The interaction between the various operations on memslot must be
         * serialized by slots_locks to ensure the TLB flush from one operation
         * is observed by any other operation on the same memslot.
@@ -5560,10 +5678,14 @@ void kvm_mmu_slot_leaf_clear_dirty(struct kvm *kvm,
 
        write_lock(&kvm->mmu_lock);
        flush = slot_handle_leaf(kvm, memslot, __rmap_clear_dirty, false);
-       if (is_tdp_mmu_enabled(kvm))
-               flush |= kvm_tdp_mmu_clear_dirty_slot(kvm, memslot);
        write_unlock(&kvm->mmu_lock);
 
+       if (is_tdp_mmu_enabled(kvm)) {
+               read_lock(&kvm->mmu_lock);
+               flush |= kvm_tdp_mmu_clear_dirty_slot(kvm, memslot);
+               read_unlock(&kvm->mmu_lock);
+       }
+
        /*
         * It's also safe to flush TLBs out of mmu lock here as currently this
         * function is only used for dirty logging, in which case flushing TLB
@@ -5701,25 +5823,6 @@ static void mmu_destroy_caches(void)
        kmem_cache_destroy(mmu_page_header_cache);
 }
 
-static void kvm_set_mmio_spte_mask(void)
-{
-       u64 mask;
-
-       /*
-        * Set a reserved PA bit in MMIO SPTEs to generate page faults with
-        * PFEC.RSVD=1 on MMIO accesses.  64-bit PTEs (PAE, x86-64, and EPT
-        * paging) support a maximum of 52 bits of PA, i.e. if the CPU supports
-        * 52-bit physical addresses then there are no reserved PA bits in the
-        * PTEs and so the reserved PA approach must be disabled.
-        */
-       if (shadow_phys_bits < 52)
-               mask = BIT_ULL(51) | PT_PRESENT_MASK;
-       else
-               mask = 0;
-
-       kvm_mmu_set_mmio_spte_mask(mask, ACC_WRITE_MASK | ACC_USER_MASK);
-}
-
 static bool get_nx_auto_mode(void)
 {
        /* Return true when CPU has the bug, and mitigations are ON */
@@ -5785,8 +5888,6 @@ int kvm_mmu_module_init(void)
 
        kvm_mmu_reset_all_pte_masks();
 
-       kvm_set_mmio_spte_mask();
-
        pte_list_desc_cache = kmem_cache_create("pte_list_desc",
                                            sizeof(struct pte_list_desc),
                                            0, SLAB_ACCOUNT, NULL);