KVM: x86/mmu: Fix comment mentioning skip_4k
[linux-2.6-microblaze.git] / arch / x86 / kvm / mmu / tdp_mmu.c
index 34207b8..237317b 100644 (file)
@@ -27,6 +27,15 @@ void kvm_mmu_init_tdp_mmu(struct kvm *kvm)
        INIT_LIST_HEAD(&kvm->arch.tdp_mmu_pages);
 }
 
+static __always_inline void kvm_lockdep_assert_mmu_lock_held(struct kvm *kvm,
+                                                            bool shared)
+{
+       if (shared)
+               lockdep_assert_held_read(&kvm->mmu_lock);
+       else
+               lockdep_assert_held_write(&kvm->mmu_lock);
+}
+
 void kvm_mmu_uninit_tdp_mmu(struct kvm *kvm)
 {
        if (!kvm->arch.tdp_mmu_enabled)
@@ -41,32 +50,85 @@ void kvm_mmu_uninit_tdp_mmu(struct kvm *kvm)
        rcu_barrier();
 }
 
-static void tdp_mmu_put_root(struct kvm *kvm, struct kvm_mmu_page *root)
+static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
+                         gfn_t start, gfn_t end, bool can_yield, bool flush,
+                         bool shared);
+
+static void tdp_mmu_free_sp(struct kvm_mmu_page *sp)
+{
+       free_page((unsigned long)sp->spt);
+       kmem_cache_free(mmu_page_header_cache, sp);
+}
+
+/*
+ * This is called through call_rcu in order to free TDP page table memory
+ * safely with respect to other kernel threads that may be operating on
+ * the memory.
+ * By only accessing TDP MMU page table memory in an RCU read critical
+ * section, and freeing it after a grace period, lockless access to that
+ * memory won't use it after it is freed.
+ */
+static void tdp_mmu_free_sp_rcu_callback(struct rcu_head *head)
 {
-       if (kvm_mmu_put_root(kvm, root))
-               kvm_tdp_mmu_free_root(kvm, root);
+       struct kvm_mmu_page *sp = container_of(head, struct kvm_mmu_page,
+                                              rcu_head);
+
+       tdp_mmu_free_sp(sp);
 }
 
-static inline bool tdp_mmu_next_root_valid(struct kvm *kvm,
-                                          struct kvm_mmu_page *root)
+void kvm_tdp_mmu_put_root(struct kvm *kvm, struct kvm_mmu_page *root,
+                         bool shared)
 {
-       lockdep_assert_held_write(&kvm->mmu_lock);
+       gfn_t max_gfn = 1ULL << (shadow_phys_bits - PAGE_SHIFT);
 
-       if (list_entry_is_head(root, &kvm->arch.tdp_mmu_roots, link))
-               return false;
+       kvm_lockdep_assert_mmu_lock_held(kvm, shared);
 
-       kvm_mmu_get_root(kvm, root);
-       return true;
+       if (!refcount_dec_and_test(&root->tdp_mmu_root_count))
+               return;
+
+       WARN_ON(!root->tdp_mmu_page);
 
+       spin_lock(&kvm->arch.tdp_mmu_pages_lock);
+       list_del_rcu(&root->link);
+       spin_unlock(&kvm->arch.tdp_mmu_pages_lock);
+
+       zap_gfn_range(kvm, root, 0, max_gfn, false, false, shared);
+
+       call_rcu(&root->rcu_head, tdp_mmu_free_sp_rcu_callback);
 }
 
-static inline struct kvm_mmu_page *tdp_mmu_next_root(struct kvm *kvm,
-                                                    struct kvm_mmu_page *root)
+/*
+ * Finds the next valid root after root (or the first valid root if root
+ * is NULL), takes a reference on it, and returns that next root. If root
+ * is not NULL, this thread should have already taken a reference on it, and
+ * that reference will be dropped. If no valid root is found, this
+ * function will return NULL.
+ */
+static struct kvm_mmu_page *tdp_mmu_next_root(struct kvm *kvm,
+                                             struct kvm_mmu_page *prev_root,
+                                             bool shared)
 {
        struct kvm_mmu_page *next_root;
 
-       next_root = list_next_entry(root, link);
-       tdp_mmu_put_root(kvm, root);
+       rcu_read_lock();
+
+       if (prev_root)
+               next_root = list_next_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+                                                 &prev_root->link,
+                                                 typeof(*prev_root), link);
+       else
+               next_root = list_first_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+                                                  typeof(*next_root), link);
+
+       while (next_root && !kvm_tdp_mmu_get_root(kvm, next_root))
+               next_root = list_next_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+                               &next_root->link, typeof(*next_root), link);
+
+       rcu_read_unlock();
+
+       if (prev_root)
+               kvm_tdp_mmu_put_root(kvm, prev_root, shared);
+
        return next_root;
 }
 
@@ -75,35 +137,24 @@ static inline struct kvm_mmu_page *tdp_mmu_next_root(struct kvm *kvm,
  * This makes it safe to release the MMU lock and yield within the loop, but
  * if exiting the loop early, the caller must drop the reference to the most
  * recent root. (Unless keeping a live reference is desirable.)
+ *
+ * If shared is set, this function is operating under the MMU lock in read
+ * mode. In the unlikely event that this thread must free a root, the lock
+ * will be temporarily dropped and reacquired in write mode.
  */
-#define for_each_tdp_mmu_root_yield_safe(_kvm, _root)                          \
-       for (_root = list_first_entry(&_kvm->arch.tdp_mmu_roots,        \
-                                     typeof(*_root), link);            \
-            tdp_mmu_next_root_valid(_kvm, _root);                      \
-            _root = tdp_mmu_next_root(_kvm, _root))
-
-#define for_each_tdp_mmu_root(_kvm, _root)                             \
-       list_for_each_entry(_root, &_kvm->arch.tdp_mmu_roots, link)
-
-static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
-                         gfn_t start, gfn_t end, bool can_yield, bool flush);
-
-void kvm_tdp_mmu_free_root(struct kvm *kvm, struct kvm_mmu_page *root)
-{
-       gfn_t max_gfn = 1ULL << (shadow_phys_bits - PAGE_SHIFT);
-
-       lockdep_assert_held_write(&kvm->mmu_lock);
-
-       WARN_ON(root->root_count);
-       WARN_ON(!root->tdp_mmu_page);
-
-       list_del(&root->link);
-
-       zap_gfn_range(kvm, root, 0, max_gfn, false, false);
-
-       free_page((unsigned long)root->spt);
-       kmem_cache_free(mmu_page_header_cache, root);
-}
+#define for_each_tdp_mmu_root_yield_safe(_kvm, _root, _as_id, _shared) \
+       for (_root = tdp_mmu_next_root(_kvm, NULL, _shared);            \
+            _root;                                                     \
+            _root = tdp_mmu_next_root(_kvm, _root, _shared))           \
+               if (kvm_mmu_page_as_id(_root) != _as_id) {              \
+               } else
+
+#define for_each_tdp_mmu_root(_kvm, _root, _as_id)                             \
+       list_for_each_entry_rcu(_root, &_kvm->arch.tdp_mmu_roots, link,         \
+                               lockdep_is_held_type(&kvm->mmu_lock, 0) ||      \
+                               lockdep_is_held(&kvm->arch.tdp_mmu_pages_lock)) \
+               if (kvm_mmu_page_as_id(_root) != _as_id) {              \
+               } else
 
 static union kvm_mmu_page_role page_role_for_level(struct kvm_vcpu *vcpu,
                                                   int level)
@@ -137,81 +188,46 @@ static struct kvm_mmu_page *alloc_tdp_mmu_page(struct kvm_vcpu *vcpu, gfn_t gfn,
        return sp;
 }
 
-static struct kvm_mmu_page *get_tdp_mmu_vcpu_root(struct kvm_vcpu *vcpu)
+hpa_t kvm_tdp_mmu_get_vcpu_root_hpa(struct kvm_vcpu *vcpu)
 {
        union kvm_mmu_page_role role;
        struct kvm *kvm = vcpu->kvm;
        struct kvm_mmu_page *root;
 
-       role = page_role_for_level(vcpu, vcpu->arch.mmu->shadow_root_level);
+       lockdep_assert_held_write(&kvm->mmu_lock);
 
-       write_lock(&kvm->mmu_lock);
+       role = page_role_for_level(vcpu, vcpu->arch.mmu->shadow_root_level);
 
        /* Check for an existing root before allocating a new one. */
-       for_each_tdp_mmu_root(kvm, root) {
-               if (root->role.word == role.word) {
-                       kvm_mmu_get_root(kvm, root);
-                       write_unlock(&kvm->mmu_lock);
-                       return root;
-               }
+       for_each_tdp_mmu_root(kvm, root, kvm_mmu_role_as_id(role)) {
+               if (root->role.word == role.word &&
+                   kvm_tdp_mmu_get_root(kvm, root))
+                       goto out;
        }
 
        root = alloc_tdp_mmu_page(vcpu, 0, vcpu->arch.mmu->shadow_root_level);
-       root->root_count = 1;
-
-       list_add(&root->link, &kvm->arch.tdp_mmu_roots);
+       refcount_set(&root->tdp_mmu_root_count, 1);
 
-       write_unlock(&kvm->mmu_lock);
-
-       return root;
-}
-
-hpa_t kvm_tdp_mmu_get_vcpu_root_hpa(struct kvm_vcpu *vcpu)
-{
-       struct kvm_mmu_page *root;
-
-       root = get_tdp_mmu_vcpu_root(vcpu);
-       if (!root)
-               return INVALID_PAGE;
+       spin_lock(&kvm->arch.tdp_mmu_pages_lock);
+       list_add_rcu(&root->link, &kvm->arch.tdp_mmu_roots);
+       spin_unlock(&kvm->arch.tdp_mmu_pages_lock);
 
+out:
        return __pa(root->spt);
 }
 
-static void tdp_mmu_free_sp(struct kvm_mmu_page *sp)
-{
-       free_page((unsigned long)sp->spt);
-       kmem_cache_free(mmu_page_header_cache, sp);
-}
-
-/*
- * This is called through call_rcu in order to free TDP page table memory
- * safely with respect to other kernel threads that may be operating on
- * the memory.
- * By only accessing TDP MMU page table memory in an RCU read critical
- * section, and freeing it after a grace period, lockless access to that
- * memory won't use it after it is freed.
- */
-static void tdp_mmu_free_sp_rcu_callback(struct rcu_head *head)
-{
-       struct kvm_mmu_page *sp = container_of(head, struct kvm_mmu_page,
-                                              rcu_head);
-
-       tdp_mmu_free_sp(sp);
-}
-
 static void handle_changed_spte(struct kvm *kvm, int as_id, gfn_t gfn,
                                u64 old_spte, u64 new_spte, int level,
                                bool shared);
 
 static void handle_changed_spte_acc_track(u64 old_spte, u64 new_spte, int level)
 {
-       bool pfn_changed = spte_to_pfn(old_spte) != spte_to_pfn(new_spte);
-
        if (!is_shadow_present_pte(old_spte) || !is_last_spte(old_spte, level))
                return;
 
        if (is_accessed_spte(old_spte) &&
-           (!is_accessed_spte(new_spte) || pfn_changed))
+           (!is_shadow_present_pte(new_spte) || !is_accessed_spte(new_spte) ||
+            spte_to_pfn(old_spte) != spte_to_pfn(new_spte)))
                kvm_set_pfn_accessed(spte_to_pfn(old_spte));
 }
 
@@ -372,7 +388,7 @@ static void handle_removed_tdp_mmu_page(struct kvm *kvm, tdp_ptep_t pt,
 }
 
 /**
- * handle_changed_spte - handle bookkeeping associated with an SPTE change
+ * __handle_changed_spte - handle bookkeeping associated with an SPTE change
  * @kvm: kvm instance
  * @as_id: the address space of the paging structure the SPTE was a part of
  * @gfn: the base GFN that was mapped by the SPTE
@@ -428,6 +444,13 @@ static void __handle_changed_spte(struct kvm *kvm, int as_id, gfn_t gfn,
 
        trace_kvm_tdp_mmu_spte_changed(as_id, gfn, level, old_spte, new_spte);
 
+       if (is_large_pte(old_spte) != is_large_pte(new_spte)) {
+               if (is_large_pte(old_spte))
+                       atomic64_sub(1, (atomic64_t*)&kvm->stat.lpages);
+               else
+                       atomic64_add(1, (atomic64_t*)&kvm->stat.lpages);
+       }
+
        /*
         * The only times a SPTE should be changed from a non-present to
         * non-present state is when an MMIO entry is installed/modified/
@@ -455,7 +478,7 @@ static void __handle_changed_spte(struct kvm *kvm, int as_id, gfn_t gfn,
 
 
        if (was_leaf && is_dirty_spte(old_spte) &&
-           (!is_dirty_spte(new_spte) || pfn_changed))
+           (!is_present || !is_dirty_spte(new_spte) || pfn_changed))
                kvm_set_pfn_dirty(spte_to_pfn(old_spte));
 
        /*
@@ -479,8 +502,9 @@ static void handle_changed_spte(struct kvm *kvm, int as_id, gfn_t gfn,
 }
 
 /*
- * tdp_mmu_set_spte_atomic - Set a TDP MMU SPTE atomically and handle the
- * associated bookkeeping
+ * tdp_mmu_set_spte_atomic_no_dirty_log - Set a TDP MMU SPTE atomically
+ * and handle the associated bookkeeping, but do not mark the page dirty
+ * in KVM's dirty bitmaps.
  *
  * @kvm: kvm instance
  * @iter: a tdp_iter instance currently on the SPTE that should be set
@@ -488,9 +512,9 @@ static void handle_changed_spte(struct kvm *kvm, int as_id, gfn_t gfn,
  * Returns: true if the SPTE was set, false if it was not. If false is returned,
  *         this function will have no side-effects.
  */
-static inline bool tdp_mmu_set_spte_atomic(struct kvm *kvm,
-                                          struct tdp_iter *iter,
-                                          u64 new_spte)
+static inline bool tdp_mmu_set_spte_atomic_no_dirty_log(struct kvm *kvm,
+                                                       struct tdp_iter *iter,
+                                                       u64 new_spte)
 {
        lockdep_assert_held_read(&kvm->mmu_lock);
 
@@ -498,16 +522,29 @@ static inline bool tdp_mmu_set_spte_atomic(struct kvm *kvm,
         * Do not change removed SPTEs. Only the thread that froze the SPTE
         * may modify it.
         */
-       if (iter->old_spte == REMOVED_SPTE)
+       if (is_removed_spte(iter->old_spte))
                return false;
 
        if (cmpxchg64(rcu_dereference(iter->sptep), iter->old_spte,
                      new_spte) != iter->old_spte)
                return false;
 
-       handle_changed_spte(kvm, iter->as_id, iter->gfn, iter->old_spte,
-                           new_spte, iter->level, true);
+       __handle_changed_spte(kvm, iter->as_id, iter->gfn, iter->old_spte,
+                             new_spte, iter->level, true);
+       handle_changed_spte_acc_track(iter->old_spte, new_spte, iter->level);
+
+       return true;
+}
+
+static inline bool tdp_mmu_set_spte_atomic(struct kvm *kvm,
+                                          struct tdp_iter *iter,
+                                          u64 new_spte)
+{
+       if (!tdp_mmu_set_spte_atomic_no_dirty_log(kvm, iter, new_spte))
+               return false;
 
+       handle_changed_spte_dirty_log(kvm, iter->as_id, iter->gfn,
+                                     iter->old_spte, new_spte, iter->level);
        return true;
 }
 
@@ -569,7 +606,7 @@ static inline void __tdp_mmu_set_spte(struct kvm *kvm, struct tdp_iter *iter,
         * should be used. If operating under the MMU lock in write mode, the
         * use of the removed SPTE should not be necessary.
         */
-       WARN_ON(iter->old_spte == REMOVED_SPTE);
+       WARN_ON(is_removed_spte(iter->old_spte));
 
        WRITE_ONCE(*rcu_dereference(iter->sptep), new_spte);
 
@@ -634,7 +671,8 @@ static inline void tdp_mmu_set_spte_no_dirty_log(struct kvm *kvm,
  * Return false if a yield was not needed.
  */
 static inline bool tdp_mmu_iter_cond_resched(struct kvm *kvm,
-                                            struct tdp_iter *iter, bool flush)
+                                            struct tdp_iter *iter, bool flush,
+                                            bool shared)
 {
        /* Ensure forward progress has been made before yielding. */
        if (iter->next_last_level_gfn == iter->yielded_gfn)
@@ -646,7 +684,11 @@ static inline bool tdp_mmu_iter_cond_resched(struct kvm *kvm,
                if (flush)
                        kvm_flush_remote_tlbs(kvm);
 
-               cond_resched_rwlock_write(&kvm->mmu_lock);
+               if (shared)
+                       cond_resched_rwlock_read(&kvm->mmu_lock);
+               else
+                       cond_resched_rwlock_write(&kvm->mmu_lock);
+
                rcu_read_lock();
 
                WARN_ON(iter->gfn > iter->next_last_level_gfn);
@@ -664,24 +706,32 @@ static inline bool tdp_mmu_iter_cond_resched(struct kvm *kvm,
  * non-root pages mapping GFNs strictly within that range. Returns true if
  * SPTEs have been cleared and a TLB flush is needed before releasing the
  * MMU lock.
+ *
  * If can_yield is true, will release the MMU lock and reschedule if the
  * scheduler needs the CPU or there is contention on the MMU lock. If this
  * function cannot yield, it will not release the MMU lock or reschedule and
  * the caller must ensure it does not supply too large a GFN range, or the
- * operation can cause a soft lockup.  Note, in some use cases a flush may be
- * required by prior actions.  Ensure the pending flush is performed prior to
- * yielding.
+ * operation can cause a soft lockup.
+ *
+ * If shared is true, this thread holds the MMU lock in read mode and must
+ * account for the possibility that other threads are modifying the paging
+ * structures concurrently. If shared is false, this thread should hold the
+ * MMU lock in write mode.
  */
 static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
-                         gfn_t start, gfn_t end, bool can_yield, bool flush)
+                         gfn_t start, gfn_t end, bool can_yield, bool flush,
+                         bool shared)
 {
        struct tdp_iter iter;
 
+       kvm_lockdep_assert_mmu_lock_held(kvm, shared);
+
        rcu_read_lock();
 
        tdp_root_for_each_pte(iter, root, start, end) {
+retry:
                if (can_yield &&
-                   tdp_mmu_iter_cond_resched(kvm, &iter, flush)) {
+                   tdp_mmu_iter_cond_resched(kvm, &iter, flush, shared)) {
                        flush = false;
                        continue;
                }
@@ -699,8 +749,17 @@ static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
                    !is_last_spte(iter.old_spte, iter.level))
                        continue;
 
-               tdp_mmu_set_spte(kvm, &iter, 0);
-               flush = true;
+               if (!shared) {
+                       tdp_mmu_set_spte(kvm, &iter, 0);
+                       flush = true;
+               } else if (!tdp_mmu_zap_spte_atomic(kvm, &iter)) {
+                       /*
+                        * The iter must explicitly re-read the SPTE because
+                        * the atomic cmpxchg failed.
+                        */
+                       iter.old_spte = READ_ONCE(*rcu_dereference(iter.sptep));
+                       goto retry;
+               }
        }
 
        rcu_read_unlock();
@@ -712,15 +771,21 @@ static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
  * non-root pages mapping GFNs strictly within that range. Returns true if
  * SPTEs have been cleared and a TLB flush is needed before releasing the
  * MMU lock.
+ *
+ * If shared is true, this thread holds the MMU lock in read mode and must
+ * account for the possibility that other threads are modifying the paging
+ * structures concurrently. If shared is false, this thread should hold the
+ * MMU in write mode.
  */
-bool __kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, gfn_t start, gfn_t end,
-                                bool can_yield)
+bool __kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, int as_id, gfn_t start,
+                                gfn_t end, bool can_yield, bool flush,
+                                bool shared)
 {
        struct kvm_mmu_page *root;
-       bool flush = false;
 
-       for_each_tdp_mmu_root_yield_safe(kvm, root)
-               flush = zap_gfn_range(kvm, root, start, end, can_yield, flush);
+       for_each_tdp_mmu_root_yield_safe(kvm, root, as_id, shared)
+               flush = zap_gfn_range(kvm, root, start, end, can_yield, flush,
+                                     shared);
 
        return flush;
 }
@@ -728,13 +793,115 @@ bool __kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, gfn_t start, gfn_t end,
 void kvm_tdp_mmu_zap_all(struct kvm *kvm)
 {
        gfn_t max_gfn = 1ULL << (shadow_phys_bits - PAGE_SHIFT);
-       bool flush;
+       bool flush = false;
+       int i;
+
+       for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++)
+               flush = kvm_tdp_mmu_zap_gfn_range(kvm, i, 0, max_gfn,
+                                                 flush, false);
 
-       flush = kvm_tdp_mmu_zap_gfn_range(kvm, 0, max_gfn);
        if (flush)
                kvm_flush_remote_tlbs(kvm);
 }
 
+static struct kvm_mmu_page *next_invalidated_root(struct kvm *kvm,
+                                                 struct kvm_mmu_page *prev_root)
+{
+       struct kvm_mmu_page *next_root;
+
+       if (prev_root)
+               next_root = list_next_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+                                                 &prev_root->link,
+                                                 typeof(*prev_root), link);
+       else
+               next_root = list_first_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+                                                  typeof(*next_root), link);
+
+       while (next_root && !(next_root->role.invalid &&
+                             refcount_read(&next_root->tdp_mmu_root_count)))
+               next_root = list_next_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+                                                 &next_root->link,
+                                                 typeof(*next_root), link);
+
+       return next_root;
+}
+
+/*
+ * Since kvm_tdp_mmu_zap_all_fast has acquired a reference to each
+ * invalidated root, they will not be freed until this function drops the
+ * reference. Before dropping that reference, tear down the paging
+ * structure so that whichever thread does drop the last reference
+ * only has to do a trivial amount of work. Since the roots are invalid,
+ * no new SPTEs should be created under them.
+ */
+void kvm_tdp_mmu_zap_invalidated_roots(struct kvm *kvm)
+{
+       gfn_t max_gfn = 1ULL << (shadow_phys_bits - PAGE_SHIFT);
+       struct kvm_mmu_page *next_root;
+       struct kvm_mmu_page *root;
+       bool flush = false;
+
+       lockdep_assert_held_read(&kvm->mmu_lock);
+
+       rcu_read_lock();
+
+       root = next_invalidated_root(kvm, NULL);
+
+       while (root) {
+               next_root = next_invalidated_root(kvm, root);
+
+               rcu_read_unlock();
+
+               flush = zap_gfn_range(kvm, root, 0, max_gfn, true, flush,
+                                     true);
+
+               /*
+                * Put the reference acquired in
+                * kvm_tdp_mmu_invalidate_roots
+                */
+               kvm_tdp_mmu_put_root(kvm, root, true);
+
+               root = next_root;
+
+               rcu_read_lock();
+       }
+
+       rcu_read_unlock();
+
+       if (flush)
+               kvm_flush_remote_tlbs(kvm);
+}
+
+/*
+ * Mark each TDP MMU root as invalid so that other threads
+ * will drop their references and allow the root count to
+ * go to 0.
+ *
+ * Also take a reference on all roots so that this thread
+ * can do the bulk of the work required to free the roots
+ * once they are invalidated. Without this reference, a
+ * vCPU thread might drop the last reference to a root and
+ * get stuck with tearing down the entire paging structure.
+ *
+ * Roots which have a zero refcount should be skipped as
+ * they're already being torn down.
+ * Already invalid roots should be referenced again so that
+ * they aren't freed before kvm_tdp_mmu_zap_all_fast is
+ * done with them.
+ *
+ * This has essentially the same effect for the TDP MMU
+ * as updating mmu_valid_gen does for the shadow MMU.
+ */
+void kvm_tdp_mmu_invalidate_all_roots(struct kvm *kvm)
+{
+       struct kvm_mmu_page *root;
+
+       lockdep_assert_held_write(&kvm->mmu_lock);
+       list_for_each_entry(root, &kvm->arch.tdp_mmu_roots, link)
+               if (refcount_inc_not_zero(&root->tdp_mmu_root_count))
+                       root->role.invalid = true;
+}
+
 /*
  * Installs a last-level SPTE to handle a TDP page fault.
  * (NPT/EPT violation/misconfiguration)
@@ -777,12 +944,11 @@ static int tdp_mmu_map_handle_target_level(struct kvm_vcpu *vcpu, int write,
                trace_mark_mmio_spte(rcu_dereference(iter->sptep), iter->gfn,
                                     new_spte);
                ret = RET_PF_EMULATE;
-       } else
+       } else {
                trace_kvm_mmu_set_spte(iter->level, iter->gfn,
                                       rcu_dereference(iter->sptep));
+       }
 
-       trace_kvm_mmu_set_spte(iter->level, iter->gfn,
-                              rcu_dereference(iter->sptep));
        if (!prefault)
                vcpu->stat.pf_fixed++;
 
@@ -850,6 +1016,14 @@ int kvm_tdp_mmu_map(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
                }
 
                if (!is_shadow_present_pte(iter.old_spte)) {
+                       /*
+                        * If SPTE has been forzen by another thread, just
+                        * give up and retry, avoiding unnecessary page table
+                        * allocation and free.
+                        */
+                       if (is_removed_spte(iter.old_spte))
+                               break;
+
                        sp = alloc_tdp_mmu_page(vcpu, iter.gfn, iter.level);
                        child_pt = sp->spt;
 
@@ -882,205 +1056,145 @@ int kvm_tdp_mmu_map(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
        return ret;
 }
 
-static __always_inline int
-kvm_tdp_mmu_handle_hva_range(struct kvm *kvm,
-                            unsigned long start,
-                            unsigned long end,
-                            unsigned long data,
-                            int (*handler)(struct kvm *kvm,
-                                           struct kvm_memory_slot *slot,
-                                           struct kvm_mmu_page *root,
-                                           gfn_t start,
-                                           gfn_t end,
-                                           unsigned long data))
+bool kvm_tdp_mmu_unmap_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range,
+                                bool flush)
 {
-       struct kvm_memslots *slots;
-       struct kvm_memory_slot *memslot;
        struct kvm_mmu_page *root;
-       int ret = 0;
-       int as_id;
-
-       for_each_tdp_mmu_root_yield_safe(kvm, root) {
-               as_id = kvm_mmu_page_as_id(root);
-               slots = __kvm_memslots(kvm, as_id);
-               kvm_for_each_memslot(memslot, slots) {
-                       unsigned long hva_start, hva_end;
-                       gfn_t gfn_start, gfn_end;
-
-                       hva_start = max(start, memslot->userspace_addr);
-                       hva_end = min(end, memslot->userspace_addr +
-                                     (memslot->npages << PAGE_SHIFT));
-                       if (hva_start >= hva_end)
-                               continue;
-                       /*
-                        * {gfn(page) | page intersects with [hva_start, hva_end)} =
-                        * {gfn_start, gfn_start+1, ..., gfn_end-1}.
-                        */
-                       gfn_start = hva_to_gfn_memslot(hva_start, memslot);
-                       gfn_end = hva_to_gfn_memslot(hva_end + PAGE_SIZE - 1, memslot);
 
-                       ret |= handler(kvm, memslot, root, gfn_start,
-                                      gfn_end, data);
-               }
-       }
+       for_each_tdp_mmu_root(kvm, root, range->slot->as_id)
+               flush |= zap_gfn_range(kvm, root, range->start, range->end,
+                                      range->may_block, flush, false);
 
-       return ret;
+       return flush;
 }
 
-static int zap_gfn_range_hva_wrapper(struct kvm *kvm,
-                                    struct kvm_memory_slot *slot,
-                                    struct kvm_mmu_page *root, gfn_t start,
-                                    gfn_t end, unsigned long unused)
-{
-       return zap_gfn_range(kvm, root, start, end, false, false);
-}
+typedef bool (*tdp_handler_t)(struct kvm *kvm, struct tdp_iter *iter,
+                             struct kvm_gfn_range *range);
 
-int kvm_tdp_mmu_zap_hva_range(struct kvm *kvm, unsigned long start,
-                             unsigned long end)
+static __always_inline bool kvm_tdp_mmu_handle_gfn(struct kvm *kvm,
+                                                  struct kvm_gfn_range *range,
+                                                  tdp_handler_t handler)
 {
-       return kvm_tdp_mmu_handle_hva_range(kvm, start, end, 0,
-                                           zap_gfn_range_hva_wrapper);
+       struct kvm_mmu_page *root;
+       struct tdp_iter iter;
+       bool ret = false;
+
+       rcu_read_lock();
+
+       /*
+        * Don't support rescheduling, none of the MMU notifiers that funnel
+        * into this helper allow blocking; it'd be dead, wasteful code.
+        */
+       for_each_tdp_mmu_root(kvm, root, range->slot->as_id) {
+               tdp_root_for_each_leaf_pte(iter, root, range->start, range->end)
+                       ret |= handler(kvm, &iter, range);
+       }
+
+       rcu_read_unlock();
+
+       return ret;
 }
 
 /*
  * Mark the SPTEs range of GFNs [start, end) unaccessed and return non-zero
  * if any of the GFNs in the range have been accessed.
  */
-static int age_gfn_range(struct kvm *kvm, struct kvm_memory_slot *slot,
-                        struct kvm_mmu_page *root, gfn_t start, gfn_t end,
-                        unsigned long unused)
+static bool age_gfn_range(struct kvm *kvm, struct tdp_iter *iter,
+                         struct kvm_gfn_range *range)
 {
-       struct tdp_iter iter;
-       int young = 0;
        u64 new_spte = 0;
 
-       rcu_read_lock();
+       /* If we have a non-accessed entry we don't need to change the pte. */
+       if (!is_accessed_spte(iter->old_spte))
+               return false;
 
-       tdp_root_for_each_leaf_pte(iter, root, start, end) {
+       new_spte = iter->old_spte;
+
+       if (spte_ad_enabled(new_spte)) {
+               new_spte &= ~shadow_accessed_mask;
+       } else {
                /*
-                * If we have a non-accessed entry we don't need to change the
-                * pte.
+                * Capture the dirty status of the page, so that it doesn't get
+                * lost when the SPTE is marked for access tracking.
                 */
-               if (!is_accessed_spte(iter.old_spte))
-                       continue;
-
-               new_spte = iter.old_spte;
-
-               if (spte_ad_enabled(new_spte)) {
-                       clear_bit((ffs(shadow_accessed_mask) - 1),
-                                 (unsigned long *)&new_spte);
-               } else {
-                       /*
-                        * Capture the dirty status of the page, so that it doesn't get
-                        * lost when the SPTE is marked for access tracking.
-                        */
-                       if (is_writable_pte(new_spte))
-                               kvm_set_pfn_dirty(spte_to_pfn(new_spte));
+               if (is_writable_pte(new_spte))
+                       kvm_set_pfn_dirty(spte_to_pfn(new_spte));
 
-                       new_spte = mark_spte_for_access_track(new_spte);
-               }
-               new_spte &= ~shadow_dirty_mask;
-
-               tdp_mmu_set_spte_no_acc_track(kvm, &iter, new_spte);
-               young = 1;
-
-               trace_kvm_age_page(iter.gfn, iter.level, slot, young);
+               new_spte = mark_spte_for_access_track(new_spte);
        }
 
-       rcu_read_unlock();
+       tdp_mmu_set_spte_no_acc_track(kvm, iter, new_spte);
 
-       return young;
+       return true;
 }
 
-int kvm_tdp_mmu_age_hva_range(struct kvm *kvm, unsigned long start,
-                             unsigned long end)
+bool kvm_tdp_mmu_age_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-       return kvm_tdp_mmu_handle_hva_range(kvm, start, end, 0,
-                                           age_gfn_range);
+       return kvm_tdp_mmu_handle_gfn(kvm, range, age_gfn_range);
 }
 
-static int test_age_gfn(struct kvm *kvm, struct kvm_memory_slot *slot,
-                       struct kvm_mmu_page *root, gfn_t gfn, gfn_t unused,
-                       unsigned long unused2)
+static bool test_age_gfn(struct kvm *kvm, struct tdp_iter *iter,
+                        struct kvm_gfn_range *range)
 {
-       struct tdp_iter iter;
-
-       tdp_root_for_each_leaf_pte(iter, root, gfn, gfn + 1)
-               if (is_accessed_spte(iter.old_spte))
-                       return 1;
-
-       return 0;
+       return is_accessed_spte(iter->old_spte);
 }
 
-int kvm_tdp_mmu_test_age_hva(struct kvm *kvm, unsigned long hva)
+bool kvm_tdp_mmu_test_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-       return kvm_tdp_mmu_handle_hva_range(kvm, hva, hva + 1, 0,
-                                           test_age_gfn);
+       return kvm_tdp_mmu_handle_gfn(kvm, range, test_age_gfn);
 }
 
-/*
- * Handle the changed_pte MMU notifier for the TDP MMU.
- * data is a pointer to the new pte_t mapping the HVA specified by the MMU
- * notifier.
- * Returns non-zero if a flush is needed before releasing the MMU lock.
- */
-static int set_tdp_spte(struct kvm *kvm, struct kvm_memory_slot *slot,
-                       struct kvm_mmu_page *root, gfn_t gfn, gfn_t unused,
-                       unsigned long data)
+static bool set_spte_gfn(struct kvm *kvm, struct tdp_iter *iter,
+                        struct kvm_gfn_range *range)
 {
-       struct tdp_iter iter;
-       pte_t *ptep = (pte_t *)data;
-       kvm_pfn_t new_pfn;
        u64 new_spte;
-       int need_flush = 0;
-
-       rcu_read_lock();
 
-       WARN_ON(pte_huge(*ptep));
+       /* Huge pages aren't expected to be modified without first being zapped. */
+       WARN_ON(pte_huge(range->pte) || range->start + 1 != range->end);
 
-       new_pfn = pte_pfn(*ptep);
-
-       tdp_root_for_each_pte(iter, root, gfn, gfn + 1) {
-               if (iter.level != PG_LEVEL_4K)
-                       continue;
-
-               if (!is_shadow_present_pte(iter.old_spte))
-                       break;
-
-               tdp_mmu_set_spte(kvm, &iter, 0);
-
-               kvm_flush_remote_tlbs_with_address(kvm, iter.gfn, 1);
+       if (iter->level != PG_LEVEL_4K ||
+           !is_shadow_present_pte(iter->old_spte))
+               return false;
 
-               if (!pte_write(*ptep)) {
-                       new_spte = kvm_mmu_changed_pte_notifier_make_spte(
-                                       iter.old_spte, new_pfn);
+       /*
+        * Note, when changing a read-only SPTE, it's not strictly necessary to
+        * zero the SPTE before setting the new PFN, but doing so preserves the
+        * invariant that the PFN of a present * leaf SPTE can never change.
+        * See __handle_changed_spte().
+        */
+       tdp_mmu_set_spte(kvm, iter, 0);
 
-                       tdp_mmu_set_spte(kvm, &iter, new_spte);
-               }
+       if (!pte_write(range->pte)) {
+               new_spte = kvm_mmu_changed_pte_notifier_make_spte(iter->old_spte,
+                                                                 pte_pfn(range->pte));
 
-               need_flush = 1;
+               tdp_mmu_set_spte(kvm, iter, new_spte);
        }
 
-       if (need_flush)
-               kvm_flush_remote_tlbs_with_address(kvm, gfn, 1);
-
-       rcu_read_unlock();
-
-       return 0;
+       return true;
 }
 
-int kvm_tdp_mmu_set_spte_hva(struct kvm *kvm, unsigned long address,
-                            pte_t *host_ptep)
+/*
+ * Handle the changed_pte MMU notifier for the TDP MMU.
+ * data is a pointer to the new pte_t mapping the HVA specified by the MMU
+ * notifier.
+ * Returns non-zero if a flush is needed before releasing the MMU lock.
+ */
+bool kvm_tdp_mmu_set_spte_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-       return kvm_tdp_mmu_handle_hva_range(kvm, address, address + 1,
-                                           (unsigned long)host_ptep,
-                                           set_tdp_spte);
+       bool flush = kvm_tdp_mmu_handle_gfn(kvm, range, set_spte_gfn);
+
+       /* FIXME: return 'flush' instead of flushing here. */
+       if (flush)
+               kvm_flush_remote_tlbs_with_address(kvm, range->start, 1);
+
+       return false;
 }
 
 /*
- * Remove write access from all the SPTEs mapping GFNs [start, end). If
- * skip_4k is set, SPTEs that map 4k pages, will not be write-protected.
- * Returns true if an SPTE has been changed and the TLBs need to be flushed.
+ * Remove write access from all SPTEs at or above min_level that map GFNs
+ * [start, end). Returns true if an SPTE has been changed and the TLBs need to
+ * be flushed.
  */
 static bool wrprot_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
                             gfn_t start, gfn_t end, int min_level)
@@ -1095,7 +1209,8 @@ static bool wrprot_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
 
        for_each_tdp_pte_min_level(iter, root->spt, root->role.level,
                                   min_level, start, end) {
-               if (tdp_mmu_iter_cond_resched(kvm, &iter, false))
+retry:
+               if (tdp_mmu_iter_cond_resched(kvm, &iter, false, true))
                        continue;
 
                if (!is_shadow_present_pte(iter.old_spte) ||
@@ -1105,7 +1220,15 @@ static bool wrprot_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
 
                new_spte = iter.old_spte & ~PT_WRITABLE_MASK;
 
-               tdp_mmu_set_spte_no_dirty_log(kvm, &iter, new_spte);
+               if (!tdp_mmu_set_spte_atomic_no_dirty_log(kvm, &iter,
+                                                         new_spte)) {
+                       /*
+                        * The iter must explicitly re-read the SPTE because
+                        * the atomic cmpxchg failed.
+                        */
+                       iter.old_spte = READ_ONCE(*rcu_dereference(iter.sptep));
+                       goto retry;
+               }
                spte_set = true;
        }
 
@@ -1122,17 +1245,13 @@ bool kvm_tdp_mmu_wrprot_slot(struct kvm *kvm, struct kvm_memory_slot *slot,
                             int min_level)
 {
        struct kvm_mmu_page *root;
-       int root_as_id;
        bool spte_set = false;
 
-       for_each_tdp_mmu_root_yield_safe(kvm, root) {
-               root_as_id = kvm_mmu_page_as_id(root);
-               if (root_as_id != slot->as_id)
-                       continue;
+       lockdep_assert_held_read(&kvm->mmu_lock);
 
+       for_each_tdp_mmu_root_yield_safe(kvm, root, slot->as_id, true)
                spte_set |= wrprot_gfn_range(kvm, root, slot->base_gfn,
                             slot->base_gfn + slot->npages, min_level);
-       }
 
        return spte_set;
 }
@@ -1154,7 +1273,8 @@ static bool clear_dirty_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
        rcu_read_lock();
 
        tdp_root_for_each_leaf_pte(iter, root, start, end) {
-               if (tdp_mmu_iter_cond_resched(kvm, &iter, false))
+retry:
+               if (tdp_mmu_iter_cond_resched(kvm, &iter, false, true))
                        continue;
 
                if (spte_ad_need_write_protect(iter.old_spte)) {
@@ -1169,7 +1289,15 @@ static bool clear_dirty_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
                                continue;
                }
 
-               tdp_mmu_set_spte_no_dirty_log(kvm, &iter, new_spte);
+               if (!tdp_mmu_set_spte_atomic_no_dirty_log(kvm, &iter,
+                                                         new_spte)) {
+                       /*
+                        * The iter must explicitly re-read the SPTE because
+                        * the atomic cmpxchg failed.
+                        */
+                       iter.old_spte = READ_ONCE(*rcu_dereference(iter.sptep));
+                       goto retry;
+               }
                spte_set = true;
        }
 
@@ -1187,17 +1315,13 @@ static bool clear_dirty_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
 bool kvm_tdp_mmu_clear_dirty_slot(struct kvm *kvm, struct kvm_memory_slot *slot)
 {
        struct kvm_mmu_page *root;
-       int root_as_id;
        bool spte_set = false;
 
-       for_each_tdp_mmu_root_yield_safe(kvm, root) {
-               root_as_id = kvm_mmu_page_as_id(root);
-               if (root_as_id != slot->as_id)
-                       continue;
+       lockdep_assert_held_read(&kvm->mmu_lock);
 
+       for_each_tdp_mmu_root_yield_safe(kvm, root, slot->as_id, true)
                spte_set |= clear_dirty_gfn_range(kvm, root, slot->base_gfn,
                                slot->base_gfn + slot->npages);
-       }
 
        return spte_set;
 }
@@ -1259,37 +1383,32 @@ void kvm_tdp_mmu_clear_dirty_pt_masked(struct kvm *kvm,
                                       bool wrprot)
 {
        struct kvm_mmu_page *root;
-       int root_as_id;
 
        lockdep_assert_held_write(&kvm->mmu_lock);
-       for_each_tdp_mmu_root(kvm, root) {
-               root_as_id = kvm_mmu_page_as_id(root);
-               if (root_as_id != slot->as_id)
-                       continue;
-
+       for_each_tdp_mmu_root(kvm, root, slot->as_id)
                clear_dirty_pt_masked(kvm, root, gfn, mask, wrprot);
-       }
 }
 
 /*
  * Clear leaf entries which could be replaced by large mappings, for
  * GFNs within the slot.
  */
-static void zap_collapsible_spte_range(struct kvm *kvm,
+static bool zap_collapsible_spte_range(struct kvm *kvm,
                                       struct kvm_mmu_page *root,
-                                      struct kvm_memory_slot *slot)
+                                      const struct kvm_memory_slot *slot,
+                                      bool flush)
 {
        gfn_t start = slot->base_gfn;
        gfn_t end = start + slot->npages;
        struct tdp_iter iter;
        kvm_pfn_t pfn;
-       bool spte_set = false;
 
        rcu_read_lock();
 
        tdp_root_for_each_pte(iter, root, start, end) {
-               if (tdp_mmu_iter_cond_resched(kvm, &iter, spte_set)) {
-                       spte_set = false;
+retry:
+               if (tdp_mmu_iter_cond_resched(kvm, &iter, flush, true)) {
+                       flush = false;
                        continue;
                }
 
@@ -1303,38 +1422,43 @@ static void zap_collapsible_spte_range(struct kvm *kvm,
                                                            pfn, PG_LEVEL_NUM))
                        continue;
 
-               tdp_mmu_set_spte(kvm, &iter, 0);
-
-               spte_set = true;
+               if (!tdp_mmu_zap_spte_atomic(kvm, &iter)) {
+                       /*
+                        * The iter must explicitly re-read the SPTE because
+                        * the atomic cmpxchg failed.
+                        */
+                       iter.old_spte = READ_ONCE(*rcu_dereference(iter.sptep));
+                       goto retry;
+               }
+               flush = true;
        }
 
        rcu_read_unlock();
-       if (spte_set)
-               kvm_flush_remote_tlbs(kvm);
+
+       return flush;
 }
 
 /*
  * Clear non-leaf entries (and free associated page tables) which could
  * be replaced by large mappings, for GFNs within the slot.
  */
-void kvm_tdp_mmu_zap_collapsible_sptes(struct kvm *kvm,
-                                      struct kvm_memory_slot *slot)
+bool kvm_tdp_mmu_zap_collapsible_sptes(struct kvm *kvm,
+                                      const struct kvm_memory_slot *slot,
+                                      bool flush)
 {
        struct kvm_mmu_page *root;
-       int root_as_id;
 
-       for_each_tdp_mmu_root_yield_safe(kvm, root) {
-               root_as_id = kvm_mmu_page_as_id(root);
-               if (root_as_id != slot->as_id)
-                       continue;
+       lockdep_assert_held_read(&kvm->mmu_lock);
 
-               zap_collapsible_spte_range(kvm, root, slot);
-       }
+       for_each_tdp_mmu_root_yield_safe(kvm, root, slot->as_id, true)
+               flush = zap_collapsible_spte_range(kvm, root, slot, flush);
+
+       return flush;
 }
 
 /*
  * Removes write access on the last level SPTE mapping this GFN and unsets the
- * SPTE_MMU_WRITABLE bit to ensure future writes continue to be intercepted.
+ * MMU-writable bit to ensure future writes continue to be intercepted.
  * Returns true if an SPTE was set and a TLB flush is needed.
  */
 static bool write_protect_gfn(struct kvm *kvm, struct kvm_mmu_page *root,
@@ -1351,7 +1475,7 @@ static bool write_protect_gfn(struct kvm *kvm, struct kvm_mmu_page *root,
                        break;
 
                new_spte = iter.old_spte &
-                       ~(PT_WRITABLE_MASK | SPTE_MMU_WRITEABLE);
+                       ~(PT_WRITABLE_MASK | shadow_mmu_writable_mask);
 
                tdp_mmu_set_spte(kvm, &iter, new_spte);
                spte_set = true;
@@ -1364,24 +1488,19 @@ static bool write_protect_gfn(struct kvm *kvm, struct kvm_mmu_page *root,
 
 /*
  * Removes write access on the last level SPTE mapping this GFN and unsets the
- * SPTE_MMU_WRITABLE bit to ensure future writes continue to be intercepted.
+ * MMU-writable bit to ensure future writes continue to be intercepted.
  * Returns true if an SPTE was set and a TLB flush is needed.
  */
 bool kvm_tdp_mmu_write_protect_gfn(struct kvm *kvm,
                                   struct kvm_memory_slot *slot, gfn_t gfn)
 {
        struct kvm_mmu_page *root;
-       int root_as_id;
        bool spte_set = false;
 
        lockdep_assert_held_write(&kvm->mmu_lock);
-       for_each_tdp_mmu_root(kvm, root) {
-               root_as_id = kvm_mmu_page_as_id(root);
-               if (root_as_id != slot->as_id)
-                       continue;
-
+       for_each_tdp_mmu_root(kvm, root, slot->as_id)
                spte_set |= write_protect_gfn(kvm, root, gfn);
-       }
+
        return spte_set;
 }