Merge branch 'kvm-tdp-fix-rcu' into HEAD
[linux-2.6-microblaze.git] / arch / x86 / kvm / mmu / tdp_mmu.c
index 08667e3..fd50008 100644 (file)
@@ -190,11 +190,6 @@ static void handle_changed_spte(struct kvm *kvm, int as_id, gfn_t gfn,
                                u64 old_spte, u64 new_spte, int level,
                                bool shared);
 
-static int kvm_mmu_page_as_id(struct kvm_mmu_page *sp)
-{
-       return sp->role.smm ? 1 : 0;
-}
-
 static void handle_changed_spte_acc_track(u64 old_spte, u64 new_spte, int level)
 {
        if (!is_shadow_present_pte(old_spte) || !is_last_spte(old_spte, level))
@@ -287,11 +282,16 @@ static void tdp_mmu_unlink_page(struct kvm *kvm, struct kvm_mmu_page *sp,
  *
  * Given a page table that has been removed from the TDP paging structure,
  * iterates through the page table to clear SPTEs and free child page tables.
+ *
+ * Note that pt is passed in as a tdp_ptep_t, but it does not need RCU
+ * protection. Since this thread removed it from the paging structure,
+ * this thread will be responsible for ensuring the page is freed. Hence the
+ * early rcu_dereferences in the function.
  */
-static void handle_removed_tdp_mmu_page(struct kvm *kvm, u64 *pt,
+static void handle_removed_tdp_mmu_page(struct kvm *kvm, tdp_ptep_t pt,
                                        bool shared)
 {
-       struct kvm_mmu_page *sp = sptep_to_sp(pt);
+       struct kvm_mmu_page *sp = sptep_to_sp(rcu_dereference(pt));
        int level = sp->role.level;
        gfn_t base_gfn = sp->gfn;
        u64 old_child_spte;
@@ -304,7 +304,7 @@ static void handle_removed_tdp_mmu_page(struct kvm *kvm, u64 *pt,
        tdp_mmu_unlink_page(kvm, sp, shared);
 
        for (i = 0; i < PT64_ENT_PER_PAGE; i++) {
-               sptep = pt + i;
+               sptep = rcu_dereference(pt) + i;
                gfn = base_gfn + (i * KVM_PAGES_PER_HPAGE(level - 1));
 
                if (shared) {
@@ -478,10 +478,6 @@ static inline bool tdp_mmu_set_spte_atomic(struct kvm *kvm,
                                           struct tdp_iter *iter,
                                           u64 new_spte)
 {
-       u64 *root_pt = tdp_iter_root_pt(iter);
-       struct kvm_mmu_page *root = sptep_to_sp(root_pt);
-       int as_id = kvm_mmu_page_as_id(root);
-
        lockdep_assert_held_read(&kvm->mmu_lock);
 
        /*
@@ -495,8 +491,8 @@ static inline bool tdp_mmu_set_spte_atomic(struct kvm *kvm,
                      new_spte) != iter->old_spte)
                return false;
 
-       handle_changed_spte(kvm, as_id, iter->gfn, iter->old_spte, new_spte,
-                           iter->level, true);
+       handle_changed_spte(kvm, iter->as_id, iter->gfn, iter->old_spte,
+                           new_spte, iter->level, true);
 
        return true;
 }
@@ -524,7 +520,7 @@ static inline bool tdp_mmu_zap_spte_atomic(struct kvm *kvm,
         * here since the SPTE is going from non-present
         * to non-present.
         */
-       WRITE_ONCE(*iter->sptep, 0);
+       WRITE_ONCE(*rcu_dereference(iter->sptep), 0);
 
        return true;
 }
@@ -550,10 +546,6 @@ static inline void __tdp_mmu_set_spte(struct kvm *kvm, struct tdp_iter *iter,
                                      u64 new_spte, bool record_acc_track,
                                      bool record_dirty_log)
 {
-       tdp_ptep_t root_pt = tdp_iter_root_pt(iter);
-       struct kvm_mmu_page *root = sptep_to_sp(root_pt);
-       int as_id = kvm_mmu_page_as_id(root);
-
        lockdep_assert_held_write(&kvm->mmu_lock);
 
        /*
@@ -567,13 +559,13 @@ static inline void __tdp_mmu_set_spte(struct kvm *kvm, struct tdp_iter *iter,
 
        WRITE_ONCE(*rcu_dereference(iter->sptep), new_spte);
 
-       __handle_changed_spte(kvm, as_id, iter->gfn, iter->old_spte, new_spte,
-                             iter->level, false);
+       __handle_changed_spte(kvm, iter->as_id, iter->gfn, iter->old_spte,
+                             new_spte, iter->level, false);
        if (record_acc_track)
                handle_changed_spte_acc_track(iter->old_spte, new_spte,
                                              iter->level);
        if (record_dirty_log)
-               handle_changed_spte_dirty_log(kvm, as_id, iter->gfn,
+               handle_changed_spte_dirty_log(kvm, iter->as_id, iter->gfn,
                                              iter->old_spte, new_spte,
                                              iter->level);
 }
@@ -645,9 +637,7 @@ static inline bool tdp_mmu_iter_cond_resched(struct kvm *kvm,
 
                WARN_ON(iter->gfn > iter->next_last_level_gfn);
 
-               tdp_iter_start(iter, iter->pt_path[iter->root_level - 1],
-                              iter->root_level, iter->min_level,
-                              iter->next_last_level_gfn);
+               tdp_iter_restart(iter);
 
                return true;
        }