Linux 6.9-rc1
[linux-2.6-microblaze.git] / mm / ksm.c
index e8f8c1a..8c00181 100644 (file)
--- a/mm/ksm.c
+++ b/mm/ksm.c
@@ -21,6 +21,7 @@
 #include <linux/sched.h>
 #include <linux/sched/mm.h>
 #include <linux/sched/coredump.h>
+#include <linux/sched/cputime.h>
 #include <linux/rwsem.h>
 #include <linux/pagemap.h>
 #include <linux/rmap.h>
 #include <linux/freezer.h>
 #include <linux/oom.h>
 #include <linux/numa.h>
+#include <linux/pagewalk.h>
 
 #include <asm/tlbflush.h>
 #include "internal.h"
+#include "mm_slot.h"
+
+#define CREATE_TRACE_POINTS
+#include <trace/events/ksm.h>
 
 #ifdef CONFIG_NUMA
 #define NUMA(x)                (x)
@@ -51,6 +57,8 @@
 #define DO_NUMA(x)     do { } while (0)
 #endif
 
+typedef u8 rmap_age_t;
+
 /**
  * DOC: Overview
  *
@@ -82,7 +90,7 @@
  *   different KSM page copy of that content
  *
  * Internally, the regular nodes, "dups" and "chains" are represented
- * using the same struct stable_node structure.
+ * using the same struct ksm_stable_node structure.
  *
  * In addition to the stable tree, KSM uses a second data structure called the
  * unstable tree: this tree holds pointers to pages which have been found to
  */
 
 /**
- * struct mm_slot - ksm information per mm that is being scanned
- * @link: link to the mm_slots hash list
- * @mm_list: link into the mm_slots list, rooted in ksm_mm_head
+ * struct ksm_mm_slot - ksm information per mm that is being scanned
+ * @slot: hash lookup from mm to mm_slot
  * @rmap_list: head for this mm_slot's singly-linked list of rmap_items
- * @mm: the mm that this information is valid for
  */
-struct mm_slot {
-       struct hlist_node link;
-       struct list_head mm_list;
-       struct rmap_item *rmap_list;
-       struct mm_struct *mm;
+struct ksm_mm_slot {
+       struct mm_slot slot;
+       struct ksm_rmap_item *rmap_list;
 };
 
 /**
@@ -135,14 +139,14 @@ struct mm_slot {
  * There is only the one ksm_scan instance of this cursor structure.
  */
 struct ksm_scan {
-       struct mm_slot *mm_slot;
+       struct ksm_mm_slot *mm_slot;
        unsigned long address;
-       struct rmap_item **rmap_list;
+       struct ksm_rmap_item **rmap_list;
        unsigned long seqnr;
 };
 
 /**
- * struct stable_node - node of the stable rbtree
+ * struct ksm_stable_node - node of the stable rbtree
  * @node: rb node of this ksm page in the stable tree
  * @head: (overlaying parent) &migrate_nodes indicates temporarily on that list
  * @hlist_dup: linked into the stable_node->hlist with a stable_node chain
@@ -153,7 +157,7 @@ struct ksm_scan {
  * @rmap_hlist_len: number of rmap_item entries in hlist or STABLE_NODE_CHAIN
  * @nid: NUMA node id of stable tree in which linked (may not match kpfn)
  */
-struct stable_node {
+struct ksm_stable_node {
        union {
                struct rb_node node;    /* when node of stable tree */
                struct {                /* when listed for migration */
@@ -182,7 +186,7 @@ struct stable_node {
 };
 
 /**
- * struct rmap_item - reverse mapping item for virtual addresses
+ * struct ksm_rmap_item - reverse mapping item for virtual addresses
  * @rmap_list: next rmap_item in mm_slot's singly-linked rmap_list
  * @anon_vma: pointer to anon_vma for this mm,address, when in stable tree
  * @nid: NUMA node id of unstable tree in which linked (may not match page)
@@ -192,9 +196,11 @@ struct stable_node {
  * @node: rb node of this rmap_item in the unstable tree
  * @head: pointer to stable_node heading this list in the stable tree
  * @hlist: link into hlist of rmap_items hanging off that stable_node
+ * @age: number of scan iterations since creation
+ * @remaining_skips: how many scans to skip
  */
-struct rmap_item {
-       struct rmap_item *rmap_list;
+struct ksm_rmap_item {
+       struct ksm_rmap_item *rmap_list;
        union {
                struct anon_vma *anon_vma;      /* when stable */
 #ifdef CONFIG_NUMA
@@ -204,10 +210,12 @@ struct rmap_item {
        struct mm_struct *mm;
        unsigned long address;          /* + low bits used for flags below */
        unsigned int oldchecksum;       /* when unstable */
+       rmap_age_t age;
+       rmap_age_t remaining_skips;
        union {
                struct rb_node node;    /* when node of unstable tree */
                struct {                /* when listed from stable tree */
-                       struct stable_node *head;
+                       struct ksm_stable_node *head;
                        struct hlist_node hlist;
                };
        };
@@ -230,8 +238,8 @@ static LIST_HEAD(migrate_nodes);
 #define MM_SLOTS_HASH_BITS 10
 static DEFINE_HASHTABLE(mm_slots_hash, MM_SLOTS_HASH_BITS);
 
-static struct mm_slot ksm_mm_head = {
-       .mm_list = LIST_HEAD_INIT(ksm_mm_head.mm_list),
+static struct ksm_mm_slot ksm_mm_head = {
+       .slot.mm_node = LIST_HEAD_INIT(ksm_mm_head.slot.mm_node),
 };
 static struct ksm_scan ksm_scan = {
        .mm_slot = &ksm_mm_head,
@@ -241,6 +249,12 @@ static struct kmem_cache *rmap_item_cache;
 static struct kmem_cache *stable_node_cache;
 static struct kmem_cache *mm_slot_cache;
 
+/* Default number of pages to scan per batch */
+#define DEFAULT_PAGES_TO_SCAN 100
+
+/* The number of pages scanned */
+static unsigned long ksm_pages_scanned;
+
 /* The number of nodes in the stable tree */
 static unsigned long ksm_pages_shared;
 
@@ -266,7 +280,7 @@ static unsigned int ksm_stable_node_chains_prune_millisecs = 2000;
 static int ksm_max_page_sharing = 256;
 
 /* Number of pages ksmd should scan in one batch */
-static unsigned int ksm_thread_pages_to_scan = 100;
+static unsigned int ksm_thread_pages_to_scan = DEFAULT_PAGES_TO_SCAN;
 
 /* Milliseconds ksmd should sleep between batches */
 static unsigned int ksm_thread_sleep_millisecs = 20;
@@ -277,6 +291,182 @@ static unsigned int zero_checksum __read_mostly;
 /* Whether to merge empty (zeroed) pages with actual zero pages */
 static bool ksm_use_zero_pages __read_mostly;
 
+/* Skip pages that couldn't be de-duplicated previously */
+/* Default to true at least temporarily, for testing */
+static bool ksm_smart_scan = true;
+
+/* The number of zero pages which is placed by KSM */
+unsigned long ksm_zero_pages;
+
+/* The number of pages that have been skipped due to "smart scanning" */
+static unsigned long ksm_pages_skipped;
+
+/* Don't scan more than max pages per batch. */
+static unsigned long ksm_advisor_max_pages_to_scan = 30000;
+
+/* Min CPU for scanning pages per scan */
+#define KSM_ADVISOR_MIN_CPU 10
+
+/* Max CPU for scanning pages per scan */
+static unsigned int ksm_advisor_max_cpu =  70;
+
+/* Target scan time in seconds to analyze all KSM candidate pages. */
+static unsigned long ksm_advisor_target_scan_time = 200;
+
+/* Exponentially weighted moving average. */
+#define EWMA_WEIGHT 30
+
+/**
+ * struct advisor_ctx - metadata for KSM advisor
+ * @start_scan: start time of the current scan
+ * @scan_time: scan time of previous scan
+ * @change: change in percent to pages_to_scan parameter
+ * @cpu_time: cpu time consumed by the ksmd thread in the previous scan
+ */
+struct advisor_ctx {
+       ktime_t start_scan;
+       unsigned long scan_time;
+       unsigned long change;
+       unsigned long long cpu_time;
+};
+static struct advisor_ctx advisor_ctx;
+
+/* Define different advisor's */
+enum ksm_advisor_type {
+       KSM_ADVISOR_NONE,
+       KSM_ADVISOR_SCAN_TIME,
+};
+static enum ksm_advisor_type ksm_advisor;
+
+#ifdef CONFIG_SYSFS
+/*
+ * Only called through the sysfs control interface:
+ */
+
+/* At least scan this many pages per batch. */
+static unsigned long ksm_advisor_min_pages_to_scan = 500;
+
+static void set_advisor_defaults(void)
+{
+       if (ksm_advisor == KSM_ADVISOR_NONE) {
+               ksm_thread_pages_to_scan = DEFAULT_PAGES_TO_SCAN;
+       } else if (ksm_advisor == KSM_ADVISOR_SCAN_TIME) {
+               advisor_ctx = (const struct advisor_ctx){ 0 };
+               ksm_thread_pages_to_scan = ksm_advisor_min_pages_to_scan;
+       }
+}
+#endif /* CONFIG_SYSFS */
+
+static inline void advisor_start_scan(void)
+{
+       if (ksm_advisor == KSM_ADVISOR_SCAN_TIME)
+               advisor_ctx.start_scan = ktime_get();
+}
+
+/*
+ * Use previous scan time if available, otherwise use current scan time as an
+ * approximation for the previous scan time.
+ */
+static inline unsigned long prev_scan_time(struct advisor_ctx *ctx,
+                                          unsigned long scan_time)
+{
+       return ctx->scan_time ? ctx->scan_time : scan_time;
+}
+
+/* Calculate exponential weighted moving average */
+static unsigned long ewma(unsigned long prev, unsigned long curr)
+{
+       return ((100 - EWMA_WEIGHT) * prev + EWMA_WEIGHT * curr) / 100;
+}
+
+/*
+ * The scan time advisor is based on the current scan rate and the target
+ * scan rate.
+ *
+ *      new_pages_to_scan = pages_to_scan * (scan_time / target_scan_time)
+ *
+ * To avoid perturbations it calculates a change factor of previous changes.
+ * A new change factor is calculated for each iteration and it uses an
+ * exponentially weighted moving average. The new pages_to_scan value is
+ * multiplied with that change factor:
+ *
+ *      new_pages_to_scan *= change facor
+ *
+ * The new_pages_to_scan value is limited by the cpu min and max values. It
+ * calculates the cpu percent for the last scan and calculates the new
+ * estimated cpu percent cost for the next scan. That value is capped by the
+ * cpu min and max setting.
+ *
+ * In addition the new pages_to_scan value is capped by the max and min
+ * limits.
+ */
+static void scan_time_advisor(void)
+{
+       unsigned int cpu_percent;
+       unsigned long cpu_time;
+       unsigned long cpu_time_diff;
+       unsigned long cpu_time_diff_ms;
+       unsigned long pages;
+       unsigned long per_page_cost;
+       unsigned long factor;
+       unsigned long change;
+       unsigned long last_scan_time;
+       unsigned long scan_time;
+
+       /* Convert scan time to seconds */
+       scan_time = div_s64(ktime_ms_delta(ktime_get(), advisor_ctx.start_scan),
+                           MSEC_PER_SEC);
+       scan_time = scan_time ? scan_time : 1;
+
+       /* Calculate CPU consumption of ksmd background thread */
+       cpu_time = task_sched_runtime(current);
+       cpu_time_diff = cpu_time - advisor_ctx.cpu_time;
+       cpu_time_diff_ms = cpu_time_diff / 1000 / 1000;
+
+       cpu_percent = (cpu_time_diff_ms * 100) / (scan_time * 1000);
+       cpu_percent = cpu_percent ? cpu_percent : 1;
+       last_scan_time = prev_scan_time(&advisor_ctx, scan_time);
+
+       /* Calculate scan time as percentage of target scan time */
+       factor = ksm_advisor_target_scan_time * 100 / scan_time;
+       factor = factor ? factor : 1;
+
+       /*
+        * Calculate scan time as percentage of last scan time and use
+        * exponentially weighted average to smooth it
+        */
+       change = scan_time * 100 / last_scan_time;
+       change = change ? change : 1;
+       change = ewma(advisor_ctx.change, change);
+
+       /* Calculate new scan rate based on target scan rate. */
+       pages = ksm_thread_pages_to_scan * 100 / factor;
+       /* Update pages_to_scan by weighted change percentage. */
+       pages = pages * change / 100;
+
+       /* Cap new pages_to_scan value */
+       per_page_cost = ksm_thread_pages_to_scan / cpu_percent;
+       per_page_cost = per_page_cost ? per_page_cost : 1;
+
+       pages = min(pages, per_page_cost * ksm_advisor_max_cpu);
+       pages = max(pages, per_page_cost * KSM_ADVISOR_MIN_CPU);
+       pages = min(pages, ksm_advisor_max_pages_to_scan);
+
+       /* Update advisor context */
+       advisor_ctx.change = change;
+       advisor_ctx.scan_time = scan_time;
+       advisor_ctx.cpu_time = cpu_time;
+
+       ksm_thread_pages_to_scan = pages;
+       trace_ksm_advisor(scan_time, pages, cpu_percent);
+}
+
+static void advisor_stop_scan(void)
+{
+       if (ksm_advisor == KSM_ADVISOR_SCAN_TIME)
+               scan_time_advisor();
+}
+
 #ifdef CONFIG_NUMA
 /* Zeroed when merging across nodes is not allowed */
 static unsigned int ksm_merge_across_nodes = 1;
@@ -298,21 +488,21 @@ static DECLARE_WAIT_QUEUE_HEAD(ksm_iter_wait);
 static DEFINE_MUTEX(ksm_thread_mutex);
 static DEFINE_SPINLOCK(ksm_mmlist_lock);
 
-#define KSM_KMEM_CACHE(__struct, __flags) kmem_cache_create("ksm_"#__struct,\
+#define KSM_KMEM_CACHE(__struct, __flags) kmem_cache_create(#__struct,\
                sizeof(struct __struct), __alignof__(struct __struct),\
                (__flags), NULL)
 
 static int __init ksm_slab_init(void)
 {
-       rmap_item_cache = KSM_KMEM_CACHE(rmap_item, 0);
+       rmap_item_cache = KSM_KMEM_CACHE(ksm_rmap_item, 0);
        if (!rmap_item_cache)
                goto out;
 
-       stable_node_cache = KSM_KMEM_CACHE(stable_node, 0);
+       stable_node_cache = KSM_KMEM_CACHE(ksm_stable_node, 0);
        if (!stable_node_cache)
                goto out_free1;
 
-       mm_slot_cache = KSM_KMEM_CACHE(mm_slot, 0);
+       mm_slot_cache = KSM_KMEM_CACHE(ksm_mm_slot, 0);
        if (!mm_slot_cache)
                goto out_free2;
 
@@ -334,18 +524,18 @@ static void __init ksm_slab_free(void)
        mm_slot_cache = NULL;
 }
 
-static __always_inline bool is_stable_node_chain(struct stable_node *chain)
+static __always_inline bool is_stable_node_chain(struct ksm_stable_node *chain)
 {
        return chain->rmap_hlist_len == STABLE_NODE_CHAIN;
 }
 
-static __always_inline bool is_stable_node_dup(struct stable_node *dup)
+static __always_inline bool is_stable_node_dup(struct ksm_stable_node *dup)
 {
        return dup->head == STABLE_NODE_DUP_HEAD;
 }
 
-static inline void stable_node_chain_add_dup(struct stable_node *dup,
-                                            struct stable_node *chain)
+static inline void stable_node_chain_add_dup(struct ksm_stable_node *dup,
+                                            struct ksm_stable_node *chain)
 {
        VM_BUG_ON(is_stable_node_dup(dup));
        dup->head = STABLE_NODE_DUP_HEAD;
@@ -354,14 +544,14 @@ static inline void stable_node_chain_add_dup(struct stable_node *dup,
        ksm_stable_node_dups++;
 }
 
-static inline void __stable_node_dup_del(struct stable_node *dup)
+static inline void __stable_node_dup_del(struct ksm_stable_node *dup)
 {
        VM_BUG_ON(!is_stable_node_dup(dup));
        hlist_del(&dup->hlist_dup);
        ksm_stable_node_dups--;
 }
 
-static inline void stable_node_dup_del(struct stable_node *dup)
+static inline void stable_node_dup_del(struct ksm_stable_node *dup)
 {
        VM_BUG_ON(is_stable_node_chain(dup));
        if (is_stable_node_dup(dup))
@@ -373,9 +563,9 @@ static inline void stable_node_dup_del(struct stable_node *dup)
 #endif
 }
 
-static inline struct rmap_item *alloc_rmap_item(void)
+static inline struct ksm_rmap_item *alloc_rmap_item(void)
 {
-       struct rmap_item *rmap_item;
+       struct ksm_rmap_item *rmap_item;
 
        rmap_item = kmem_cache_zalloc(rmap_item_cache, GFP_KERNEL |
                                                __GFP_NORETRY | __GFP_NOWARN);
@@ -384,14 +574,15 @@ static inline struct rmap_item *alloc_rmap_item(void)
        return rmap_item;
 }
 
-static inline void free_rmap_item(struct rmap_item *rmap_item)
+static inline void free_rmap_item(struct ksm_rmap_item *rmap_item)
 {
        ksm_rmap_items--;
+       rmap_item->mm->ksm_rmap_items--;
        rmap_item->mm = NULL;   /* debug safety */
        kmem_cache_free(rmap_item_cache, rmap_item);
 }
 
-static inline struct stable_node *alloc_stable_node(void)
+static inline struct ksm_stable_node *alloc_stable_node(void)
 {
        /*
         * The allocation can take too long with GFP_KERNEL when memory is under
@@ -401,43 +592,13 @@ static inline struct stable_node *alloc_stable_node(void)
        return kmem_cache_alloc(stable_node_cache, GFP_KERNEL | __GFP_HIGH);
 }
 
-static inline void free_stable_node(struct stable_node *stable_node)
+static inline void free_stable_node(struct ksm_stable_node *stable_node)
 {
        VM_BUG_ON(stable_node->rmap_hlist_len &&
                  !is_stable_node_chain(stable_node));
        kmem_cache_free(stable_node_cache, stable_node);
 }
 
-static inline struct mm_slot *alloc_mm_slot(void)
-{
-       if (!mm_slot_cache)     /* initialization failed */
-               return NULL;
-       return kmem_cache_zalloc(mm_slot_cache, GFP_KERNEL);
-}
-
-static inline void free_mm_slot(struct mm_slot *mm_slot)
-{
-       kmem_cache_free(mm_slot_cache, mm_slot);
-}
-
-static struct mm_slot *get_mm_slot(struct mm_struct *mm)
-{
-       struct mm_slot *slot;
-
-       hash_for_each_possible(mm_slots_hash, slot, link, (unsigned long)mm)
-               if (slot->mm == mm)
-                       return slot;
-
-       return NULL;
-}
-
-static void insert_to_mm_slots_hash(struct mm_struct *mm,
-                                   struct mm_slot *mm_slot)
-{
-       mm_slot->mm = mm;
-       hash_add(mm_slots_hash, &mm_slot->link, (unsigned long)mm);
-}
-
 /*
  * ksmd, and unmerge_and_remove_all_rmap_items(), must not touch an mm's
  * page tables after it has passed through ksm_exit() - which, if necessary,
@@ -451,47 +612,83 @@ static inline bool ksm_test_exit(struct mm_struct *mm)
        return atomic_read(&mm->mm_users) == 0;
 }
 
+static int break_ksm_pmd_entry(pmd_t *pmd, unsigned long addr, unsigned long next,
+                       struct mm_walk *walk)
+{
+       struct page *page = NULL;
+       spinlock_t *ptl;
+       pte_t *pte;
+       pte_t ptent;
+       int ret;
+
+       pte = pte_offset_map_lock(walk->mm, pmd, addr, &ptl);
+       if (!pte)
+               return 0;
+       ptent = ptep_get(pte);
+       if (pte_present(ptent)) {
+               page = vm_normal_page(walk->vma, addr, ptent);
+       } else if (!pte_none(ptent)) {
+               swp_entry_t entry = pte_to_swp_entry(ptent);
+
+               /*
+                * As KSM pages remain KSM pages until freed, no need to wait
+                * here for migration to end.
+                */
+               if (is_migration_entry(entry))
+                       page = pfn_swap_entry_to_page(entry);
+       }
+       /* return 1 if the page is an normal ksm page or KSM-placed zero page */
+       ret = (page && PageKsm(page)) || is_ksm_zero_pte(ptent);
+       pte_unmap_unlock(pte, ptl);
+       return ret;
+}
+
+static const struct mm_walk_ops break_ksm_ops = {
+       .pmd_entry = break_ksm_pmd_entry,
+       .walk_lock = PGWALK_RDLOCK,
+};
+
+static const struct mm_walk_ops break_ksm_lock_vma_ops = {
+       .pmd_entry = break_ksm_pmd_entry,
+       .walk_lock = PGWALK_WRLOCK,
+};
+
 /*
- * We use break_ksm to break COW on a ksm page: it's a stripped down
- *
- *     if (get_user_pages(addr, 1, FOLL_WRITE, &page, NULL) == 1)
- *             put_page(page);
+ * We use break_ksm to break COW on a ksm page by triggering unsharing,
+ * such that the ksm page will get replaced by an exclusive anonymous page.
  *
- * but taking great care only to touch a ksm page, in a VM_MERGEABLE vma,
+ * We take great care only to touch a ksm page, in a VM_MERGEABLE vma,
  * in case the application has unmapped and remapped mm,addr meanwhile.
  * Could a ksm page appear anywhere else?  Actually yes, in a VM_PFNMAP
  * mmap of /dev/mem, where we would not want to touch it.
  *
- * FAULT_FLAG/FOLL_REMOTE are because we do this outside the context
+ * FAULT_FLAG_REMOTE/FOLL_REMOTE are because we do this outside the context
  * of the process that owns 'vma'.  We also do not want to enforce
  * protection keys here anyway.
  */
-static int break_ksm(struct vm_area_struct *vma, unsigned long addr)
+static int break_ksm(struct vm_area_struct *vma, unsigned long addr, bool lock_vma)
 {
-       struct page *page;
        vm_fault_t ret = 0;
+       const struct mm_walk_ops *ops = lock_vma ?
+                               &break_ksm_lock_vma_ops : &break_ksm_ops;
 
        do {
+               int ksm_page;
+
                cond_resched();
-               page = follow_page(vma, addr,
-                               FOLL_GET | FOLL_MIGRATION | FOLL_REMOTE);
-               if (IS_ERR_OR_NULL(page))
-                       break;
-               if (PageKsm(page))
-                       ret = handle_mm_fault(vma, addr,
-                                             FAULT_FLAG_WRITE | FAULT_FLAG_REMOTE,
-                                             NULL);
-               else
-                       ret = VM_FAULT_WRITE;
-               put_page(page);
-       } while (!(ret & (VM_FAULT_WRITE | VM_FAULT_SIGBUS | VM_FAULT_SIGSEGV | VM_FAULT_OOM)));
+               ksm_page = walk_page_range_vma(vma, addr, addr + 1, ops, NULL);
+               if (WARN_ON_ONCE(ksm_page < 0))
+                       return ksm_page;
+               if (!ksm_page)
+                       return 0;
+               ret = handle_mm_fault(vma, addr,
+                                     FAULT_FLAG_UNSHARE | FAULT_FLAG_REMOTE,
+                                     NULL);
+       } while (!(ret & (VM_FAULT_SIGBUS | VM_FAULT_SIGSEGV | VM_FAULT_OOM)));
        /*
-        * We must loop because handle_mm_fault() may back out if there's
-        * any difficulty e.g. if pte accessed bit gets updated concurrently.
-        *
-        * VM_FAULT_WRITE is what we have been hoping for: it indicates that
-        * COW has been broken, even if the vma does not permit VM_WRITE;
-        * but note that a concurrent fault might break PageKsm for us.
+        * We must loop until we no longer find a KSM page because
+        * handle_mm_fault() may back out if there's any difficulty e.g. if
+        * pte accessed bit gets updated concurrently.
         *
         * VM_FAULT_SIGBUS could occur if we race with truncation of the
         * backing file, which also invalidates anonymous pages: that's
@@ -516,6 +713,28 @@ static int break_ksm(struct vm_area_struct *vma, unsigned long addr)
        return (ret & VM_FAULT_OOM) ? -ENOMEM : 0;
 }
 
+static bool vma_ksm_compatible(struct vm_area_struct *vma)
+{
+       if (vma->vm_flags & (VM_SHARED  | VM_MAYSHARE   | VM_PFNMAP  |
+                            VM_IO      | VM_DONTEXPAND | VM_HUGETLB |
+                            VM_MIXEDMAP))
+               return false;           /* just ignore the advice */
+
+       if (vma_is_dax(vma))
+               return false;
+
+#ifdef VM_SAO
+       if (vma->vm_flags & VM_SAO)
+               return false;
+#endif
+#ifdef VM_SPARC_ADI
+       if (vma->vm_flags & VM_SPARC_ADI)
+               return false;
+#endif
+
+       return true;
+}
+
 static struct vm_area_struct *find_mergeable_vma(struct mm_struct *mm,
                unsigned long addr)
 {
@@ -528,7 +747,7 @@ static struct vm_area_struct *find_mergeable_vma(struct mm_struct *mm,
        return vma;
 }
 
-static void break_cow(struct rmap_item *rmap_item)
+static void break_cow(struct ksm_rmap_item *rmap_item)
 {
        struct mm_struct *mm = rmap_item->mm;
        unsigned long addr = rmap_item->address;
@@ -543,11 +762,11 @@ static void break_cow(struct rmap_item *rmap_item)
        mmap_read_lock(mm);
        vma = find_mergeable_vma(mm, addr);
        if (vma)
-               break_ksm(vma, addr);
+               break_ksm(vma, addr, false);
        mmap_read_unlock(mm);
 }
 
-static struct page *get_mergeable_page(struct rmap_item *rmap_item)
+static struct page *get_mergeable_page(struct ksm_rmap_item *rmap_item)
 {
        struct mm_struct *mm = rmap_item->mm;
        unsigned long addr = rmap_item->address;
@@ -562,10 +781,13 @@ static struct page *get_mergeable_page(struct rmap_item *rmap_item)
        page = follow_page(vma, addr, FOLL_GET);
        if (IS_ERR_OR_NULL(page))
                goto out;
+       if (is_zone_device_page(page))
+               goto out_putpage;
        if (PageAnon(page)) {
                flush_anon_page(vma, page, addr);
                flush_dcache_page(page);
        } else {
+out_putpage:
                put_page(page);
 out:
                page = NULL;
@@ -585,10 +807,10 @@ static inline int get_kpfn_nid(unsigned long kpfn)
        return ksm_merge_across_nodes ? 0 : NUMA(pfn_to_nid(kpfn));
 }
 
-static struct stable_node *alloc_stable_node_chain(struct stable_node *dup,
+static struct ksm_stable_node *alloc_stable_node_chain(struct ksm_stable_node *dup,
                                                   struct rb_root *root)
 {
-       struct stable_node *chain = alloc_stable_node();
+       struct ksm_stable_node *chain = alloc_stable_node();
        VM_BUG_ON(is_stable_node_chain(dup));
        if (likely(chain)) {
                INIT_HLIST_HEAD(&chain->hlist);
@@ -618,7 +840,7 @@ static struct stable_node *alloc_stable_node_chain(struct stable_node *dup,
        return chain;
 }
 
-static inline void free_stable_node_chain(struct stable_node *chain,
+static inline void free_stable_node_chain(struct ksm_stable_node *chain,
                                          struct rb_root *root)
 {
        rb_erase(&chain->node, root);
@@ -626,18 +848,20 @@ static inline void free_stable_node_chain(struct stable_node *chain,
        ksm_stable_node_chains--;
 }
 
-static void remove_node_from_stable_tree(struct stable_node *stable_node)
+static void remove_node_from_stable_tree(struct ksm_stable_node *stable_node)
 {
-       struct rmap_item *rmap_item;
+       struct ksm_rmap_item *rmap_item;
 
        /* check it's not STABLE_NODE_CHAIN or negative */
        BUG_ON(stable_node->rmap_hlist_len < 0);
 
        hlist_for_each_entry(rmap_item, &stable_node->hlist, hlist) {
-               if (rmap_item->hlist.next)
+               if (rmap_item->hlist.next) {
                        ksm_pages_sharing--;
-               else
+                       trace_ksm_remove_rmap_item(stable_node->kpfn, rmap_item, rmap_item->mm);
+               } else {
                        ksm_pages_shared--;
+               }
 
                rmap_item->mm->ksm_merging_pages--;
 
@@ -658,6 +882,7 @@ static void remove_node_from_stable_tree(struct stable_node *stable_node)
        BUILD_BUG_ON(STABLE_NODE_DUP_HEAD <= &migrate_nodes);
        BUILD_BUG_ON(STABLE_NODE_DUP_HEAD >= &migrate_nodes + 1);
 
+       trace_ksm_remove_ksm_page(stable_node->kpfn);
        if (stable_node->head == &migrate_nodes)
                list_del(&stable_node->list);
        else
@@ -690,7 +915,7 @@ enum get_ksm_page_flags {
  * a page to put something that might look like our key in page->mapping.
  * is on its way to being freed; but it is an anomaly to bear in mind.
  */
-static struct page *get_ksm_page(struct stable_node *stable_node,
+static struct page *get_ksm_page(struct ksm_stable_node *stable_node,
                                 enum get_ksm_page_flags flags)
 {
        struct page *page;
@@ -769,10 +994,10 @@ stale:
  * Removing rmap_item from stable or unstable tree.
  * This function will clean the information from the stable/unstable tree.
  */
-static void remove_rmap_item_from_tree(struct rmap_item *rmap_item)
+static void remove_rmap_item_from_tree(struct ksm_rmap_item *rmap_item)
 {
        if (rmap_item->address & STABLE_FLAG) {
-               struct stable_node *stable_node;
+               struct ksm_stable_node *stable_node;
                struct page *page;
 
                stable_node = rmap_item->head;
@@ -819,10 +1044,10 @@ out:
        cond_resched();         /* we're called from many long loops */
 }
 
-static void remove_trailing_rmap_items(struct rmap_item **rmap_list)
+static void remove_trailing_rmap_items(struct ksm_rmap_item **rmap_list)
 {
        while (*rmap_list) {
-               struct rmap_item *rmap_item = *rmap_list;
+               struct ksm_rmap_item *rmap_item = *rmap_list;
                *rmap_list = rmap_item->rmap_list;
                remove_rmap_item_from_tree(rmap_item);
                free_rmap_item(rmap_item);
@@ -843,7 +1068,7 @@ static void remove_trailing_rmap_items(struct rmap_item **rmap_list)
  * in cmp_and_merge_page on one of the rmap_items we would be removing.
  */
 static int unmerge_ksm_pages(struct vm_area_struct *vma,
-                            unsigned long start, unsigned long end)
+                            unsigned long start, unsigned long end, bool lock_vma)
 {
        unsigned long addr;
        int err = 0;
@@ -854,23 +1079,23 @@ static int unmerge_ksm_pages(struct vm_area_struct *vma,
                if (signal_pending(current))
                        err = -ERESTARTSYS;
                else
-                       err = break_ksm(vma, addr);
+                       err = break_ksm(vma, addr, lock_vma);
        }
        return err;
 }
 
-static inline struct stable_node *folio_stable_node(struct folio *folio)
+static inline struct ksm_stable_node *folio_stable_node(struct folio *folio)
 {
        return folio_test_ksm(folio) ? folio_raw_mapping(folio) : NULL;
 }
 
-static inline struct stable_node *page_stable_node(struct page *page)
+static inline struct ksm_stable_node *page_stable_node(struct page *page)
 {
        return folio_stable_node(page_folio(page));
 }
 
 static inline void set_page_stable_node(struct page *page,
-                                       struct stable_node *stable_node)
+                                       struct ksm_stable_node *stable_node)
 {
        VM_BUG_ON_PAGE(PageAnon(page) && PageAnonExclusive(page), page);
        page->mapping = (void *)((unsigned long)stable_node | PAGE_MAPPING_KSM);
@@ -880,7 +1105,7 @@ static inline void set_page_stable_node(struct page *page,
 /*
  * Only called through the sysfs control interface:
  */
-static int remove_stable_node(struct stable_node *stable_node)
+static int remove_stable_node(struct ksm_stable_node *stable_node)
 {
        struct page *page;
        int err;
@@ -904,7 +1129,7 @@ static int remove_stable_node(struct stable_node *stable_node)
                 * The stable node did not yet appear stale to get_ksm_page(),
                 * since that allows for an unmapped ksm page to be recognized
                 * right up until it is freed; but the node is safe to remove.
-                * This page might be in a pagevec waiting to be freed,
+                * This page might be in an LRU cache waiting to be freed,
                 * or it might be PageSwapCache (perhaps under writeback),
                 * or it might have been removed from swapcache a moment ago.
                 */
@@ -918,10 +1143,10 @@ static int remove_stable_node(struct stable_node *stable_node)
        return err;
 }
 
-static int remove_stable_node_chain(struct stable_node *stable_node,
+static int remove_stable_node_chain(struct ksm_stable_node *stable_node,
                                    struct rb_root *root)
 {
-       struct stable_node *dup;
+       struct ksm_stable_node *dup;
        struct hlist_node *hlist_safe;
 
        if (!is_stable_node_chain(stable_node)) {
@@ -945,14 +1170,14 @@ static int remove_stable_node_chain(struct stable_node *stable_node,
 
 static int remove_all_stable_nodes(void)
 {
-       struct stable_node *stable_node, *next;
+       struct ksm_stable_node *stable_node, *next;
        int nid;
        int err = 0;
 
        for (nid = 0; nid < ksm_nr_node_ids; nid++) {
                while (root_stable_tree[nid].rb_node) {
                        stable_node = rb_entry(root_stable_tree[nid].rb_node,
-                                               struct stable_node, node);
+                                               struct ksm_stable_node, node);
                        if (remove_stable_node_chain(stable_node,
                                                     root_stable_tree + nid)) {
                                err = -EBUSY;
@@ -971,44 +1196,57 @@ static int remove_all_stable_nodes(void)
 
 static int unmerge_and_remove_all_rmap_items(void)
 {
-       struct mm_slot *mm_slot;
+       struct ksm_mm_slot *mm_slot;
+       struct mm_slot *slot;
        struct mm_struct *mm;
        struct vm_area_struct *vma;
        int err = 0;
 
        spin_lock(&ksm_mmlist_lock);
-       ksm_scan.mm_slot = list_entry(ksm_mm_head.mm_list.next,
-                                               struct mm_slot, mm_list);
+       slot = list_entry(ksm_mm_head.slot.mm_node.next,
+                         struct mm_slot, mm_node);
+       ksm_scan.mm_slot = mm_slot_entry(slot, struct ksm_mm_slot, slot);
        spin_unlock(&ksm_mmlist_lock);
 
-       for (mm_slot = ksm_scan.mm_slot;
-                       mm_slot != &ksm_mm_head; mm_slot = ksm_scan.mm_slot) {
-               mm = mm_slot->mm;
+       for (mm_slot = ksm_scan.mm_slot; mm_slot != &ksm_mm_head;
+            mm_slot = ksm_scan.mm_slot) {
+               VMA_ITERATOR(vmi, mm_slot->slot.mm, 0);
+
+               mm = mm_slot->slot.mm;
                mmap_read_lock(mm);
-               for (vma = mm->mmap; vma; vma = vma->vm_next) {
-                       if (ksm_test_exit(mm))
-                               break;
+
+               /*
+                * Exit right away if mm is exiting to avoid lockdep issue in
+                * the maple tree
+                */
+               if (ksm_test_exit(mm))
+                       goto mm_exiting;
+
+               for_each_vma(vmi, vma) {
                        if (!(vma->vm_flags & VM_MERGEABLE) || !vma->anon_vma)
                                continue;
                        err = unmerge_ksm_pages(vma,
-                                               vma->vm_start, vma->vm_end);
+                                               vma->vm_start, vma->vm_end, false);
                        if (err)
                                goto error;
                }
 
+mm_exiting:
                remove_trailing_rmap_items(&mm_slot->rmap_list);
                mmap_read_unlock(mm);
 
                spin_lock(&ksm_mmlist_lock);
-               ksm_scan.mm_slot = list_entry(mm_slot->mm_list.next,
-                                               struct mm_slot, mm_list);
+               slot = list_entry(mm_slot->slot.mm_node.next,
+                                 struct mm_slot, mm_node);
+               ksm_scan.mm_slot = mm_slot_entry(slot, struct ksm_mm_slot, slot);
                if (ksm_test_exit(mm)) {
-                       hash_del(&mm_slot->link);
-                       list_del(&mm_slot->mm_list);
+                       hash_del(&mm_slot->slot.hash);
+                       list_del(&mm_slot->slot.mm_node);
                        spin_unlock(&ksm_mmlist_lock);
 
-                       free_mm_slot(mm_slot);
+                       mm_slot_free(mm_slot_cache, mm_slot);
                        clear_bit(MMF_VM_MERGEABLE, &mm->flags);
+                       clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
                        mmdrop(mm);
                } else
                        spin_unlock(&ksm_mmlist_lock);
@@ -1031,9 +1269,9 @@ error:
 static u32 calc_checksum(struct page *page)
 {
        u32 checksum;
-       void *addr = kmap_atomic(page);
+       void *addr = kmap_local_page(page);
        checksum = xxhash(addr, PAGE_SIZE, 0);
-       kunmap_atomic(addr);
+       kunmap_local(addr);
        return checksum;
 }
 
@@ -1046,6 +1284,7 @@ static int write_protect_page(struct vm_area_struct *vma, struct page *page,
        int err = -EFAULT;
        struct mmu_notifier_range range;
        bool anon_exclusive;
+       pte_t entry;
 
        pvmw.address = page_address_in_vma(page, vma);
        if (pvmw.address == -EFAULT)
@@ -1053,8 +1292,7 @@ static int write_protect_page(struct vm_area_struct *vma, struct page *page,
 
        BUG_ON(PageTransCompound(page));
 
-       mmu_notifier_range_init(&range, MMU_NOTIFY_CLEAR, 0, vma, mm,
-                               pvmw.address,
+       mmu_notifier_range_init(&range, MMU_NOTIFY_CLEAR, 0, mm, pvmw.address,
                                pvmw.address + PAGE_SIZE);
        mmu_notifier_invalidate_range_start(&range);
 
@@ -1064,11 +1302,9 @@ static int write_protect_page(struct vm_area_struct *vma, struct page *page,
                goto out_unlock;
 
        anon_exclusive = PageAnonExclusive(page);
-       if (pte_write(*pvmw.pte) || pte_dirty(*pvmw.pte) ||
-           (pte_protnone(*pvmw.pte) && pte_savedwrite(*pvmw.pte)) ||
+       entry = ptep_get(pvmw.pte);
+       if (pte_write(entry) || pte_dirty(entry) ||
            anon_exclusive || mm_tlb_flush_pending(mm)) {
-               pte_t entry;
-
                swapped = PageSwapCache(page);
                flush_cache_page(vma, pvmw.address, page_to_pfn(page));
                /*
@@ -1083,7 +1319,7 @@ static int write_protect_page(struct vm_area_struct *vma, struct page *page,
                 * No need to notify as we are downgrading page table to read
                 * only not changing it to point to a new page.
                 *
-                * See Documentation/vm/mmu_notifier.rst
+                * See Documentation/mm/mmu_notifier.rst
                 */
                entry = ptep_clear_flush(vma, pvmw.address, pvmw.pte);
                /*
@@ -1095,21 +1331,23 @@ static int write_protect_page(struct vm_area_struct *vma, struct page *page,
                        goto out_unlock;
                }
 
-               if (anon_exclusive && page_try_share_anon_rmap(page)) {
+               /* See folio_try_share_anon_rmap_pte(): clear PTE first. */
+               if (anon_exclusive &&
+                   folio_try_share_anon_rmap_pte(page_folio(page), page)) {
                        set_pte_at(mm, pvmw.address, pvmw.pte, entry);
                        goto out_unlock;
                }
 
                if (pte_dirty(entry))
                        set_page_dirty(page);
+               entry = pte_mkclean(entry);
+
+               if (pte_write(entry))
+                       entry = pte_wrprotect(entry);
 
-               if (pte_protnone(entry))
-                       entry = pte_mkclean(pte_clear_savedwrite(entry));
-               else
-                       entry = pte_mkclean(pte_wrprotect(entry));
                set_pte_at_notify(mm, pvmw.address, pvmw.pte, entry);
        }
-       *orig_pte = *pvmw.pte;
+       *orig_pte = entry;
        err = 0;
 
 out_unlock:
@@ -1132,8 +1370,11 @@ out:
 static int replace_page(struct vm_area_struct *vma, struct page *page,
                        struct page *kpage, pte_t orig_pte)
 {
+       struct folio *kfolio = page_folio(kpage);
        struct mm_struct *mm = vma->vm_mm;
+       struct folio *folio;
        pmd_t *pmd;
+       pmd_t pmde;
        pte_t *ptep;
        pte_t newpte;
        spinlock_t *ptl;
@@ -1148,30 +1389,47 @@ static int replace_page(struct vm_area_struct *vma, struct page *page,
        pmd = mm_find_pmd(mm, addr);
        if (!pmd)
                goto out;
+       /*
+        * Some THP functions use the sequence pmdp_huge_clear_flush(), set_pmd_at()
+        * without holding anon_vma lock for write.  So when looking for a
+        * genuine pmde (in which to find pte), test present and !THP together.
+        */
+       pmde = pmdp_get_lockless(pmd);
+       if (!pmd_present(pmde) || pmd_trans_huge(pmde))
+               goto out;
 
-       mmu_notifier_range_init(&range, MMU_NOTIFY_CLEAR, 0, vma, mm, addr,
+       mmu_notifier_range_init(&range, MMU_NOTIFY_CLEAR, 0, mm, addr,
                                addr + PAGE_SIZE);
        mmu_notifier_invalidate_range_start(&range);
 
        ptep = pte_offset_map_lock(mm, pmd, addr, &ptl);
-       if (!pte_same(*ptep, orig_pte)) {
+       if (!ptep)
+               goto out_mn;
+       if (!pte_same(ptep_get(ptep), orig_pte)) {
                pte_unmap_unlock(ptep, ptl);
                goto out_mn;
        }
        VM_BUG_ON_PAGE(PageAnonExclusive(page), page);
-       VM_BUG_ON_PAGE(PageAnon(kpage) && PageAnonExclusive(kpage), kpage);
+       VM_BUG_ON_FOLIO(folio_test_anon(kfolio) && PageAnonExclusive(kpage),
+                       kfolio);
 
        /*
         * No need to check ksm_use_zero_pages here: we can only have a
         * zero_page here if ksm_use_zero_pages was enabled already.
         */
        if (!is_zero_pfn(page_to_pfn(kpage))) {
-               get_page(kpage);
-               page_add_anon_rmap(kpage, vma, addr, RMAP_NONE);
+               folio_get(kfolio);
+               folio_add_anon_rmap_pte(kfolio, kpage, vma, addr, RMAP_NONE);
                newpte = mk_pte(kpage, vma->vm_page_prot);
        } else {
-               newpte = pte_mkspecial(pfn_pte(page_to_pfn(kpage),
-                                              vma->vm_page_prot));
+               /*
+                * Use pte_mkdirty to mark the zero page mapped by KSM, and then
+                * we can easily track all KSM-placed zero pages by checking if
+                * the dirty bit in zero page's PTE is set.
+                */
+               newpte = pte_mkdirty(pte_mkspecial(pfn_pte(page_to_pfn(kpage), vma->vm_page_prot)));
+               ksm_zero_pages++;
+               mm->ksm_zero_pages++;
                /*
                 * We're replacing an anonymous page with a zero page, which is
                 * not anonymous. We need to do proper accounting otherwise we
@@ -1181,20 +1439,21 @@ static int replace_page(struct vm_area_struct *vma, struct page *page,
                dec_mm_counter(mm, MM_ANONPAGES);
        }
 
-       flush_cache_page(vma, addr, pte_pfn(*ptep));
+       flush_cache_page(vma, addr, pte_pfn(ptep_get(ptep)));
        /*
         * No need to notify as we are replacing a read only page with another
         * read only page with the same content.
         *
-        * See Documentation/vm/mmu_notifier.rst
+        * See Documentation/mm/mmu_notifier.rst
         */
        ptep_clear_flush(vma, addr, ptep);
        set_pte_at_notify(mm, addr, ptep, newpte);
 
-       page_remove_rmap(page, vma, false);
-       if (!page_mapped(page))
-               try_to_free_swap(page);
-       put_page(page);
+       folio = page_folio(page);
+       folio_remove_rmap_pte(folio, page, vma);
+       if (!folio_mapped(folio))
+               folio_free_swap(folio);
+       folio_put(folio);
 
        pte_unmap_unlock(ptep, ptl);
        err = 0;
@@ -1278,7 +1537,7 @@ out:
  *
  * This function returns 0 if the pages were merged, -EFAULT otherwise.
  */
-static int try_to_merge_with_ksm_page(struct rmap_item *rmap_item,
+static int try_to_merge_with_ksm_page(struct ksm_rmap_item *rmap_item,
                                      struct page *page, struct page *kpage)
 {
        struct mm_struct *mm = rmap_item->mm;
@@ -1302,6 +1561,8 @@ static int try_to_merge_with_ksm_page(struct rmap_item *rmap_item,
        get_anon_vma(vma->anon_vma);
 out:
        mmap_read_unlock(mm);
+       trace_ksm_merge_with_ksm_page(kpage, page_to_pfn(kpage ? kpage : page),
+                               rmap_item, mm, err);
        return err;
 }
 
@@ -1315,9 +1576,9 @@ out:
  * Note that this function upgrades page to ksm page: if one of the pages
  * is already a ksm page, try_to_merge_with_ksm_page should be used.
  */
-static struct page *try_to_merge_two_pages(struct rmap_item *rmap_item,
+static struct page *try_to_merge_two_pages(struct ksm_rmap_item *rmap_item,
                                           struct page *page,
-                                          struct rmap_item *tree_rmap_item,
+                                          struct ksm_rmap_item *tree_rmap_item,
                                           struct page *tree_page)
 {
        int err;
@@ -1337,7 +1598,7 @@ static struct page *try_to_merge_two_pages(struct rmap_item *rmap_item,
 }
 
 static __always_inline
-bool __is_page_sharing_candidate(struct stable_node *stable_node, int offset)
+bool __is_page_sharing_candidate(struct ksm_stable_node *stable_node, int offset)
 {
        VM_BUG_ON(stable_node->rmap_hlist_len < 0);
        /*
@@ -1351,17 +1612,17 @@ bool __is_page_sharing_candidate(struct stable_node *stable_node, int offset)
 }
 
 static __always_inline
-bool is_page_sharing_candidate(struct stable_node *stable_node)
+bool is_page_sharing_candidate(struct ksm_stable_node *stable_node)
 {
        return __is_page_sharing_candidate(stable_node, 0);
 }
 
-static struct page *stable_node_dup(struct stable_node **_stable_node_dup,
-                                   struct stable_node **_stable_node,
+static struct page *stable_node_dup(struct ksm_stable_node **_stable_node_dup,
+                                   struct ksm_stable_node **_stable_node,
                                    struct rb_root *root,
                                    bool prune_stale_stable_nodes)
 {
-       struct stable_node *dup, *found = NULL, *stable_node = *_stable_node;
+       struct ksm_stable_node *dup, *found = NULL, *stable_node = *_stable_node;
        struct hlist_node *hlist_safe;
        struct page *_tree_page, *tree_page = NULL;
        int nr = 0;
@@ -1475,7 +1736,7 @@ static struct page *stable_node_dup(struct stable_node **_stable_node_dup,
        return tree_page;
 }
 
-static struct stable_node *stable_node_dup_any(struct stable_node *stable_node,
+static struct ksm_stable_node *stable_node_dup_any(struct ksm_stable_node *stable_node,
                                               struct rb_root *root)
 {
        if (!is_stable_node_chain(stable_node))
@@ -1502,12 +1763,12 @@ static struct stable_node *stable_node_dup_any(struct stable_node *stable_node,
  * function and will be overwritten in all cases, the caller doesn't
  * need to initialize it.
  */
-static struct page *__stable_node_chain(struct stable_node **_stable_node_dup,
-                                       struct stable_node **_stable_node,
+static struct page *__stable_node_chain(struct ksm_stable_node **_stable_node_dup,
+                                       struct ksm_stable_node **_stable_node,
                                        struct rb_root *root,
                                        bool prune_stale_stable_nodes)
 {
-       struct stable_node *stable_node = *_stable_node;
+       struct ksm_stable_node *stable_node = *_stable_node;
        if (!is_stable_node_chain(stable_node)) {
                if (is_page_sharing_candidate(stable_node)) {
                        *_stable_node_dup = stable_node;
@@ -1524,18 +1785,18 @@ static struct page *__stable_node_chain(struct stable_node **_stable_node_dup,
                               prune_stale_stable_nodes);
 }
 
-static __always_inline struct page *chain_prune(struct stable_node **s_n_d,
-                                               struct stable_node **s_n,
+static __always_inline struct page *chain_prune(struct ksm_stable_node **s_n_d,
+                                               struct ksm_stable_node **s_n,
                                                struct rb_root *root)
 {
        return __stable_node_chain(s_n_d, s_n, root, true);
 }
 
-static __always_inline struct page *chain(struct stable_node **s_n_d,
-                                         struct stable_node *s_n,
+static __always_inline struct page *chain(struct ksm_stable_node **s_n_d,
+                                         struct ksm_stable_node *s_n,
                                          struct rb_root *root)
 {
-       struct stable_node *old_stable_node = s_n;
+       struct ksm_stable_node *old_stable_node = s_n;
        struct page *tree_page;
 
        tree_page = __stable_node_chain(s_n_d, &s_n, root, false);
@@ -1559,8 +1820,8 @@ static struct page *stable_tree_search(struct page *page)
        struct rb_root *root;
        struct rb_node **new;
        struct rb_node *parent;
-       struct stable_node *stable_node, *stable_node_dup, *stable_node_any;
-       struct stable_node *page_node;
+       struct ksm_stable_node *stable_node, *stable_node_dup, *stable_node_any;
+       struct ksm_stable_node *page_node;
 
        page_node = page_stable_node(page);
        if (page_node && page_node->head != &migrate_nodes) {
@@ -1580,7 +1841,7 @@ again:
                int ret;
 
                cond_resched();
-               stable_node = rb_entry(*new, struct stable_node, node);
+               stable_node = rb_entry(*new, struct ksm_stable_node, node);
                stable_node_any = NULL;
                tree_page = chain_prune(&stable_node_dup, &stable_node, root);
                /*
@@ -1803,14 +2064,14 @@ chain_append:
  * This function returns the stable tree node just allocated on success,
  * NULL otherwise.
  */
-static struct stable_node *stable_tree_insert(struct page *kpage)
+static struct ksm_stable_node *stable_tree_insert(struct page *kpage)
 {
        int nid;
        unsigned long kpfn;
        struct rb_root *root;
        struct rb_node **new;
        struct rb_node *parent;
-       struct stable_node *stable_node, *stable_node_dup, *stable_node_any;
+       struct ksm_stable_node *stable_node, *stable_node_dup, *stable_node_any;
        bool need_chain = false;
 
        kpfn = page_to_pfn(kpage);
@@ -1825,7 +2086,7 @@ again:
                int ret;
 
                cond_resched();
-               stable_node = rb_entry(*new, struct stable_node, node);
+               stable_node = rb_entry(*new, struct ksm_stable_node, node);
                stable_node_any = NULL;
                tree_page = chain(&stable_node_dup, stable_node, root);
                if (!stable_node_dup) {
@@ -1894,7 +2155,7 @@ again:
                rb_insert_color(&stable_node_dup->node, root);
        } else {
                if (!is_stable_node_chain(stable_node)) {
-                       struct stable_node *orig = stable_node;
+                       struct ksm_stable_node *orig = stable_node;
                        /* chain is missing so create it */
                        stable_node = alloc_stable_node_chain(orig, root);
                        if (!stable_node) {
@@ -1923,7 +2184,7 @@ again:
  * the same walking algorithm in an rbtree.
  */
 static
-struct rmap_item *unstable_tree_search_insert(struct rmap_item *rmap_item,
+struct ksm_rmap_item *unstable_tree_search_insert(struct ksm_rmap_item *rmap_item,
                                              struct page *page,
                                              struct page **tree_pagep)
 {
@@ -1937,12 +2198,12 @@ struct rmap_item *unstable_tree_search_insert(struct rmap_item *rmap_item,
        new = &root->rb_node;
 
        while (*new) {
-               struct rmap_item *tree_rmap_item;
+               struct ksm_rmap_item *tree_rmap_item;
                struct page *tree_page;
                int ret;
 
                cond_resched();
-               tree_rmap_item = rb_entry(*new, struct rmap_item, node);
+               tree_rmap_item = rb_entry(*new, struct ksm_rmap_item, node);
                tree_page = get_mergeable_page(tree_rmap_item);
                if (!tree_page)
                        return NULL;
@@ -1994,8 +2255,8 @@ struct rmap_item *unstable_tree_search_insert(struct rmap_item *rmap_item,
  * rmap_items hanging off a given node of the stable tree, all sharing
  * the same ksm page.
  */
-static void stable_tree_append(struct rmap_item *rmap_item,
-                              struct stable_node *stable_node,
+static void stable_tree_append(struct ksm_rmap_item *rmap_item,
+                              struct ksm_stable_node *stable_node,
                               bool max_page_sharing_bypass)
 {
        /*
@@ -2037,12 +2298,12 @@ static void stable_tree_append(struct rmap_item *rmap_item,
  * @page: the page that we are searching identical page to.
  * @rmap_item: the reverse mapping into the virtual address of this page
  */
-static void cmp_and_merge_page(struct page *page, struct rmap_item *rmap_item)
+static void cmp_and_merge_page(struct page *page, struct ksm_rmap_item *rmap_item)
 {
        struct mm_struct *mm = rmap_item->mm;
-       struct rmap_item *tree_rmap_item;
+       struct ksm_rmap_item *tree_rmap_item;
        struct page *tree_page = NULL;
-       struct stable_node *stable_node;
+       struct ksm_stable_node *stable_node;
        struct page *kpage;
        unsigned int checksum;
        int err;
@@ -2120,6 +2381,9 @@ static void cmp_and_merge_page(struct page *page, struct rmap_item *rmap_item)
                if (vma) {
                        err = try_to_merge_one_page(vma, page,
                                        ZERO_PAGE(rmap_item->address));
+                       trace_ksm_merge_one_page(
+                               page_to_pfn(ZERO_PAGE(rmap_item->address)),
+                               rmap_item, mm, err);
                } else {
                        /*
                         * If the vma is out of date, we do not need to
@@ -2198,11 +2462,11 @@ static void cmp_and_merge_page(struct page *page, struct rmap_item *rmap_item)
        }
 }
 
-static struct rmap_item *get_next_rmap_item(struct mm_slot *mm_slot,
-                                           struct rmap_item **rmap_list,
+static struct ksm_rmap_item *get_next_rmap_item(struct ksm_mm_slot *mm_slot,
+                                           struct ksm_rmap_item **rmap_list,
                                            unsigned long addr)
 {
-       struct rmap_item *rmap_item;
+       struct ksm_rmap_item *rmap_item;
 
        while (*rmap_list) {
                rmap_item = *rmap_list;
@@ -2218,7 +2482,8 @@ static struct rmap_item *get_next_rmap_item(struct mm_slot *mm_slot,
        rmap_item = alloc_rmap_item();
        if (rmap_item) {
                /* It has already been zeroed */
-               rmap_item->mm = mm_slot->mm;
+               rmap_item->mm = mm_slot->slot.mm;
+               rmap_item->mm->ksm_rmap_items++;
                rmap_item->address = addr;
                rmap_item->rmap_list = *rmap_list;
                *rmap_list = rmap_item;
@@ -2226,22 +2491,95 @@ static struct rmap_item *get_next_rmap_item(struct mm_slot *mm_slot,
        return rmap_item;
 }
 
-static struct rmap_item *scan_get_next_rmap_item(struct page **page)
+/*
+ * Calculate skip age for the ksm page age. The age determines how often
+ * de-duplicating has already been tried unsuccessfully. If the age is
+ * smaller, the scanning of this page is skipped for less scans.
+ *
+ * @age: rmap_item age of page
+ */
+static unsigned int skip_age(rmap_age_t age)
+{
+       if (age <= 3)
+               return 1;
+       if (age <= 5)
+               return 2;
+       if (age <= 8)
+               return 4;
+
+       return 8;
+}
+
+/*
+ * Determines if a page should be skipped for the current scan.
+ *
+ * @page: page to check
+ * @rmap_item: associated rmap_item of page
+ */
+static bool should_skip_rmap_item(struct page *page,
+                                 struct ksm_rmap_item *rmap_item)
+{
+       rmap_age_t age;
+
+       if (!ksm_smart_scan)
+               return false;
+
+       /*
+        * Never skip pages that are already KSM; pages cmp_and_merge_page()
+        * will essentially ignore them, but we still have to process them
+        * properly.
+        */
+       if (PageKsm(page))
+               return false;
+
+       age = rmap_item->age;
+       if (age != U8_MAX)
+               rmap_item->age++;
+
+       /*
+        * Smaller ages are not skipped, they need to get a chance to go
+        * through the different phases of the KSM merging.
+        */
+       if (age < 3)
+               return false;
+
+       /*
+        * Are we still allowed to skip? If not, then don't skip it
+        * and determine how much more often we are allowed to skip next.
+        */
+       if (!rmap_item->remaining_skips) {
+               rmap_item->remaining_skips = skip_age(age);
+               return false;
+       }
+
+       /* Skip this page */
+       ksm_pages_skipped++;
+       rmap_item->remaining_skips--;
+       remove_rmap_item_from_tree(rmap_item);
+       return true;
+}
+
+static struct ksm_rmap_item *scan_get_next_rmap_item(struct page **page)
 {
        struct mm_struct *mm;
+       struct ksm_mm_slot *mm_slot;
        struct mm_slot *slot;
        struct vm_area_struct *vma;
-       struct rmap_item *rmap_item;
+       struct ksm_rmap_item *rmap_item;
+       struct vma_iterator vmi;
        int nid;
 
-       if (list_empty(&ksm_mm_head.mm_list))
+       if (list_empty(&ksm_mm_head.slot.mm_node))
                return NULL;
 
-       slot = ksm_scan.mm_slot;
-       if (slot == &ksm_mm_head) {
+       mm_slot = ksm_scan.mm_slot;
+       if (mm_slot == &ksm_mm_head) {
+               advisor_start_scan();
+               trace_ksm_start_scan(ksm_scan.seqnr, ksm_rmap_items);
+
                /*
-                * A number of pages can hang around indefinitely on per-cpu
-                * pagevecs, raised page count preventing write_protect_page
+                * A number of pages can hang around indefinitely in per-cpu
+                * LRU cache, raised page count preventing write_protect_page
                 * from merging them.  Though it doesn't really matter much,
                 * it is puzzling to see some stuck in pages_volatile until
                 * other activity jostles them out, and they also prevented
@@ -2258,7 +2596,7 @@ static struct rmap_item *scan_get_next_rmap_item(struct page **page)
                 * so prune them once before each full scan.
                 */
                if (!ksm_merge_across_nodes) {
-                       struct stable_node *stable_node, *next;
+                       struct ksm_stable_node *stable_node, *next;
                        struct page *page;
 
                        list_for_each_entry_safe(stable_node, next,
@@ -2275,28 +2613,31 @@ static struct rmap_item *scan_get_next_rmap_item(struct page **page)
                        root_unstable_tree[nid] = RB_ROOT;
 
                spin_lock(&ksm_mmlist_lock);
-               slot = list_entry(slot->mm_list.next, struct mm_slot, mm_list);
-               ksm_scan.mm_slot = slot;
+               slot = list_entry(mm_slot->slot.mm_node.next,
+                                 struct mm_slot, mm_node);
+               mm_slot = mm_slot_entry(slot, struct ksm_mm_slot, slot);
+               ksm_scan.mm_slot = mm_slot;
                spin_unlock(&ksm_mmlist_lock);
                /*
                 * Although we tested list_empty() above, a racing __ksm_exit
                 * of the last mm on the list may have removed it since then.
                 */
-               if (slot == &ksm_mm_head)
+               if (mm_slot == &ksm_mm_head)
                        return NULL;
 next_mm:
                ksm_scan.address = 0;
-               ksm_scan.rmap_list = &slot->rmap_list;
+               ksm_scan.rmap_list = &mm_slot->rmap_list;
        }
 
+       slot = &mm_slot->slot;
        mm = slot->mm;
+       vma_iter_init(&vmi, mm, ksm_scan.address);
+
        mmap_read_lock(mm);
        if (ksm_test_exit(mm))
-               vma = NULL;
-       else
-               vma = find_vma(mm, ksm_scan.address);
+               goto no_vmas;
 
-       for (; vma; vma = vma->vm_next) {
+       for_each_vma(vmi, vma) {
                if (!(vma->vm_flags & VM_MERGEABLE))
                        continue;
                if (ksm_scan.address < vma->vm_start)
@@ -2313,20 +2654,27 @@ next_mm:
                                cond_resched();
                                continue;
                        }
+                       if (is_zone_device_page(*page))
+                               goto next_page;
                        if (PageAnon(*page)) {
                                flush_anon_page(vma, *page, ksm_scan.address);
                                flush_dcache_page(*page);
-                               rmap_item = get_next_rmap_item(slot,
+                               rmap_item = get_next_rmap_item(mm_slot,
                                        ksm_scan.rmap_list, ksm_scan.address);
                                if (rmap_item) {
                                        ksm_scan.rmap_list =
                                                        &rmap_item->rmap_list;
+
+                                       if (should_skip_rmap_item(*page, rmap_item))
+                                               goto next_page;
+
                                        ksm_scan.address += PAGE_SIZE;
                                } else
                                        put_page(*page);
                                mmap_read_unlock(mm);
                                return rmap_item;
                        }
+next_page:
                        put_page(*page);
                        ksm_scan.address += PAGE_SIZE;
                        cond_resched();
@@ -2334,8 +2682,9 @@ next_mm:
        }
 
        if (ksm_test_exit(mm)) {
+no_vmas:
                ksm_scan.address = 0;
-               ksm_scan.rmap_list = &slot->rmap_list;
+               ksm_scan.rmap_list = &mm_slot->rmap_list;
        }
        /*
         * Nuke all the rmap_items that are above this current rmap:
@@ -2344,8 +2693,9 @@ next_mm:
        remove_trailing_rmap_items(ksm_scan.rmap_list);
 
        spin_lock(&ksm_mmlist_lock);
-       ksm_scan.mm_slot = list_entry(slot->mm_list.next,
-                                               struct mm_slot, mm_list);
+       slot = list_entry(mm_slot->slot.mm_node.next,
+                         struct mm_slot, mm_node);
+       ksm_scan.mm_slot = mm_slot_entry(slot, struct ksm_mm_slot, slot);
        if (ksm_scan.address == 0) {
                /*
                 * We've completed a full scan of all vmas, holding mmap_lock
@@ -2356,12 +2706,13 @@ next_mm:
                 * or when all VM_MERGEABLE areas have been unmapped (and
                 * mmap_lock then protects against race with MADV_MERGEABLE).
                 */
-               hash_del(&slot->link);
-               list_del(&slot->mm_list);
+               hash_del(&mm_slot->slot.hash);
+               list_del(&mm_slot->slot.mm_node);
                spin_unlock(&ksm_mmlist_lock);
 
-               free_mm_slot(slot);
+               mm_slot_free(mm_slot_cache, mm_slot);
                clear_bit(MMF_VM_MERGEABLE, &mm->flags);
+               clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
                mmap_read_unlock(mm);
                mmdrop(mm);
        } else {
@@ -2377,10 +2728,13 @@ next_mm:
        }
 
        /* Repeat until we've completed scanning the whole list */
-       slot = ksm_scan.mm_slot;
-       if (slot != &ksm_mm_head)
+       mm_slot = ksm_scan.mm_slot;
+       if (mm_slot != &ksm_mm_head)
                goto next_mm;
 
+       advisor_stop_scan();
+
+       trace_ksm_stop_scan(ksm_scan.seqnr, ksm_rmap_items);
        ksm_scan.seqnr++;
        return NULL;
 }
@@ -2391,10 +2745,11 @@ next_mm:
  */
 static void ksm_do_scan(unsigned int scan_npages)
 {
-       struct rmap_item *rmap_item;
+       struct ksm_rmap_item *rmap_item;
        struct page *page;
+       unsigned int npages = scan_npages;
 
-       while (scan_npages-- && likely(!freezing(current))) {
+       while (npages-- && likely(!freezing(current))) {
                cond_resched();
                rmap_item = scan_get_next_rmap_item(&page);
                if (!rmap_item)
@@ -2402,11 +2757,13 @@ static void ksm_do_scan(unsigned int scan_npages)
                cmp_and_merge_page(page, rmap_item);
                put_page(page);
        }
+
+       ksm_pages_scanned += scan_npages - npages;
 }
 
 static int ksmd_should_run(void)
 {
-       return (ksm_run & KSM_RUN_MERGE) && !list_empty(&ksm_mm_head.mm_list);
+       return (ksm_run & KSM_RUN_MERGE) && !list_empty(&ksm_mm_head.slot.mm_node);
 }
 
 static int ksm_scan_thread(void *nothing)
@@ -2423,11 +2780,9 @@ static int ksm_scan_thread(void *nothing)
                        ksm_do_scan(ksm_thread_pages_to_scan);
                mutex_unlock(&ksm_thread_mutex);
 
-               try_to_freeze();
-
                if (ksmd_should_run()) {
                        sleep_ms = READ_ONCE(ksm_thread_sleep_millisecs);
-                       wait_event_interruptible_timeout(ksm_iter_wait,
+                       wait_event_freezable_timeout(ksm_iter_wait,
                                sleep_ms != READ_ONCE(ksm_thread_sleep_millisecs),
                                msecs_to_jiffies(sleep_ms));
                } else {
@@ -2438,6 +2793,136 @@ static int ksm_scan_thread(void *nothing)
        return 0;
 }
 
+static void __ksm_add_vma(struct vm_area_struct *vma)
+{
+       unsigned long vm_flags = vma->vm_flags;
+
+       if (vm_flags & VM_MERGEABLE)
+               return;
+
+       if (vma_ksm_compatible(vma))
+               vm_flags_set(vma, VM_MERGEABLE);
+}
+
+static int __ksm_del_vma(struct vm_area_struct *vma)
+{
+       int err;
+
+       if (!(vma->vm_flags & VM_MERGEABLE))
+               return 0;
+
+       if (vma->anon_vma) {
+               err = unmerge_ksm_pages(vma, vma->vm_start, vma->vm_end, true);
+               if (err)
+                       return err;
+       }
+
+       vm_flags_clear(vma, VM_MERGEABLE);
+       return 0;
+}
+/**
+ * ksm_add_vma - Mark vma as mergeable if compatible
+ *
+ * @vma:  Pointer to vma
+ */
+void ksm_add_vma(struct vm_area_struct *vma)
+{
+       struct mm_struct *mm = vma->vm_mm;
+
+       if (test_bit(MMF_VM_MERGE_ANY, &mm->flags))
+               __ksm_add_vma(vma);
+}
+
+static void ksm_add_vmas(struct mm_struct *mm)
+{
+       struct vm_area_struct *vma;
+
+       VMA_ITERATOR(vmi, mm, 0);
+       for_each_vma(vmi, vma)
+               __ksm_add_vma(vma);
+}
+
+static int ksm_del_vmas(struct mm_struct *mm)
+{
+       struct vm_area_struct *vma;
+       int err;
+
+       VMA_ITERATOR(vmi, mm, 0);
+       for_each_vma(vmi, vma) {
+               err = __ksm_del_vma(vma);
+               if (err)
+                       return err;
+       }
+       return 0;
+}
+
+/**
+ * ksm_enable_merge_any - Add mm to mm ksm list and enable merging on all
+ *                        compatible VMA's
+ *
+ * @mm:  Pointer to mm
+ *
+ * Returns 0 on success, otherwise error code
+ */
+int ksm_enable_merge_any(struct mm_struct *mm)
+{
+       int err;
+
+       if (test_bit(MMF_VM_MERGE_ANY, &mm->flags))
+               return 0;
+
+       if (!test_bit(MMF_VM_MERGEABLE, &mm->flags)) {
+               err = __ksm_enter(mm);
+               if (err)
+                       return err;
+       }
+
+       set_bit(MMF_VM_MERGE_ANY, &mm->flags);
+       ksm_add_vmas(mm);
+
+       return 0;
+}
+
+/**
+ * ksm_disable_merge_any - Disable merging on all compatible VMA's of the mm,
+ *                        previously enabled via ksm_enable_merge_any().
+ *
+ * Disabling merging implies unmerging any merged pages, like setting
+ * MADV_UNMERGEABLE would. If unmerging fails, the whole operation fails and
+ * merging on all compatible VMA's remains enabled.
+ *
+ * @mm: Pointer to mm
+ *
+ * Returns 0 on success, otherwise error code
+ */
+int ksm_disable_merge_any(struct mm_struct *mm)
+{
+       int err;
+
+       if (!test_bit(MMF_VM_MERGE_ANY, &mm->flags))
+               return 0;
+
+       err = ksm_del_vmas(mm);
+       if (err) {
+               ksm_add_vmas(mm);
+               return err;
+       }
+
+       clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
+       return 0;
+}
+
+int ksm_disable(struct mm_struct *mm)
+{
+       mmap_assert_write_locked(mm);
+
+       if (!test_bit(MMF_VM_MERGEABLE, &mm->flags))
+               return 0;
+       if (test_bit(MMF_VM_MERGE_ANY, &mm->flags))
+               return ksm_disable_merge_any(mm);
+       return ksm_del_vmas(mm);
+}
+
 int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
                unsigned long end, int advice, unsigned long *vm_flags)
 {
@@ -2446,25 +2931,10 @@ int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
 
        switch (advice) {
        case MADV_MERGEABLE:
-               /*
-                * Be somewhat over-protective for now!
-                */
-               if (*vm_flags & (VM_MERGEABLE | VM_SHARED  | VM_MAYSHARE   |
-                                VM_PFNMAP    | VM_IO      | VM_DONTEXPAND |
-                                VM_HUGETLB | VM_MIXEDMAP))
-                       return 0;               /* just ignore the advice */
-
-               if (vma_is_dax(vma))
+               if (vma->vm_flags & VM_MERGEABLE)
                        return 0;
-
-#ifdef VM_SAO
-               if (*vm_flags & VM_SAO)
+               if (!vma_ksm_compatible(vma))
                        return 0;
-#endif
-#ifdef VM_SPARC_ADI
-               if (*vm_flags & VM_SPARC_ADI)
-                       return 0;
-#endif
 
                if (!test_bit(MMF_VM_MERGEABLE, &mm->flags)) {
                        err = __ksm_enter(mm);
@@ -2480,7 +2950,7 @@ int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
                        return 0;               /* just ignore the advice */
 
                if (vma->anon_vma) {
-                       err = unmerge_ksm_pages(vma, start, end);
+                       err = unmerge_ksm_pages(vma, start, end, true);
                        if (err)
                                return err;
                }
@@ -2495,18 +2965,21 @@ EXPORT_SYMBOL_GPL(ksm_madvise);
 
 int __ksm_enter(struct mm_struct *mm)
 {
-       struct mm_slot *mm_slot;
+       struct ksm_mm_slot *mm_slot;
+       struct mm_slot *slot;
        int needs_wakeup;
 
-       mm_slot = alloc_mm_slot();
+       mm_slot = mm_slot_alloc(mm_slot_cache);
        if (!mm_slot)
                return -ENOMEM;
 
+       slot = &mm_slot->slot;
+
        /* Check ksm_run too?  Would need tighter locking */
-       needs_wakeup = list_empty(&ksm_mm_head.mm_list);
+       needs_wakeup = list_empty(&ksm_mm_head.slot.mm_node);
 
        spin_lock(&ksm_mmlist_lock);
-       insert_to_mm_slots_hash(mm, mm_slot);
+       mm_slot_insert(mm_slots_hash, mm, slot);
        /*
         * When KSM_RUN_MERGE (or KSM_RUN_STOP),
         * insert just behind the scanning cursor, to let the area settle
@@ -2518,9 +2991,9 @@ int __ksm_enter(struct mm_struct *mm)
         * missed: then we might as well insert at the end of the list.
         */
        if (ksm_run & KSM_RUN_UNMERGE)
-               list_add_tail(&mm_slot->mm_list, &ksm_mm_head.mm_list);
+               list_add_tail(&slot->mm_node, &ksm_mm_head.slot.mm_node);
        else
-               list_add_tail(&mm_slot->mm_list, &ksm_scan.mm_slot->mm_list);
+               list_add_tail(&slot->mm_node, &ksm_scan.mm_slot->slot.mm_node);
        spin_unlock(&ksm_mmlist_lock);
 
        set_bit(MMF_VM_MERGEABLE, &mm->flags);
@@ -2529,12 +3002,14 @@ int __ksm_enter(struct mm_struct *mm)
        if (needs_wakeup)
                wake_up_interruptible(&ksm_thread_wait);
 
+       trace_ksm_enter(mm);
        return 0;
 }
 
 void __ksm_exit(struct mm_struct *mm)
 {
-       struct mm_slot *mm_slot;
+       struct ksm_mm_slot *mm_slot;
+       struct mm_slot *slot;
        int easy_to_free = 0;
 
        /*
@@ -2547,73 +3022,86 @@ void __ksm_exit(struct mm_struct *mm)
         */
 
        spin_lock(&ksm_mmlist_lock);
-       mm_slot = get_mm_slot(mm);
+       slot = mm_slot_lookup(mm_slots_hash, mm);
+       mm_slot = mm_slot_entry(slot, struct ksm_mm_slot, slot);
        if (mm_slot && ksm_scan.mm_slot != mm_slot) {
                if (!mm_slot->rmap_list) {
-                       hash_del(&mm_slot->link);
-                       list_del(&mm_slot->mm_list);
+                       hash_del(&slot->hash);
+                       list_del(&slot->mm_node);
                        easy_to_free = 1;
                } else {
-                       list_move(&mm_slot->mm_list,
-                                 &ksm_scan.mm_slot->mm_list);
+                       list_move(&slot->mm_node,
+                                 &ksm_scan.mm_slot->slot.mm_node);
                }
        }
        spin_unlock(&ksm_mmlist_lock);
 
        if (easy_to_free) {
-               free_mm_slot(mm_slot);
+               mm_slot_free(mm_slot_cache, mm_slot);
+               clear_bit(MMF_VM_MERGE_ANY, &mm->flags);
                clear_bit(MMF_VM_MERGEABLE, &mm->flags);
                mmdrop(mm);
        } else if (mm_slot) {
                mmap_write_lock(mm);
                mmap_write_unlock(mm);
        }
+
+       trace_ksm_exit(mm);
 }
 
-struct page *ksm_might_need_to_copy(struct page *page,
-                       struct vm_area_struct *vma, unsigned long address)
+struct folio *ksm_might_need_to_copy(struct folio *folio,
+                       struct vm_area_struct *vma, unsigned long addr)
 {
-       struct folio *folio = page_folio(page);
+       struct page *page = folio_page(folio, 0);
        struct anon_vma *anon_vma = folio_anon_vma(folio);
-       struct page *new_page;
+       struct folio *new_folio;
 
-       if (PageKsm(page)) {
-               if (page_stable_node(page) &&
+       if (folio_test_large(folio))
+               return folio;
+
+       if (folio_test_ksm(folio)) {
+               if (folio_stable_node(folio) &&
                    !(ksm_run & KSM_RUN_UNMERGE))
-                       return page;    /* no need to copy it */
+                       return folio;   /* no need to copy it */
        } else if (!anon_vma) {
-               return page;            /* no need to copy it */
-       } else if (page->index == linear_page_index(vma, address) &&
+               return folio;           /* no need to copy it */
+       } else if (folio->index == linear_page_index(vma, addr) &&
                        anon_vma->root == vma->anon_vma->root) {
-               return page;            /* still no need to copy it */
-       }
-       if (!PageUptodate(page))
-               return page;            /* let do_swap_page report the error */
-
-       new_page = alloc_page_vma(GFP_HIGHUSER_MOVABLE, vma, address);
-       if (new_page &&
-           mem_cgroup_charge(page_folio(new_page), vma->vm_mm, GFP_KERNEL)) {
-               put_page(new_page);
-               new_page = NULL;
-       }
-       if (new_page) {
-               copy_user_highpage(new_page, page, address, vma);
-
-               SetPageDirty(new_page);
-               __SetPageUptodate(new_page);
-               __SetPageLocked(new_page);
+               return folio;           /* still no need to copy it */
+       }
+       if (PageHWPoison(page))
+               return ERR_PTR(-EHWPOISON);
+       if (!folio_test_uptodate(folio))
+               return folio;           /* let do_swap_page report the error */
+
+       new_folio = vma_alloc_folio(GFP_HIGHUSER_MOVABLE, 0, vma, addr, false);
+       if (new_folio &&
+           mem_cgroup_charge(new_folio, vma->vm_mm, GFP_KERNEL)) {
+               folio_put(new_folio);
+               new_folio = NULL;
+       }
+       if (new_folio) {
+               if (copy_mc_user_highpage(folio_page(new_folio, 0), page,
+                                                               addr, vma)) {
+                       folio_put(new_folio);
+                       memory_failure_queue(folio_pfn(folio), 0);
+                       return ERR_PTR(-EHWPOISON);
+               }
+               folio_set_dirty(new_folio);
+               __folio_mark_uptodate(new_folio);
+               __folio_set_locked(new_folio);
 #ifdef CONFIG_SWAP
                count_vm_event(KSM_SWPIN_COPY);
 #endif
        }
 
-       return new_page;
+       return new_folio;
 }
 
 void rmap_walk_ksm(struct folio *folio, struct rmap_walk_control *rwc)
 {
-       struct stable_node *stable_node;
-       struct rmap_item *rmap_item;
+       struct ksm_stable_node *stable_node;
+       struct ksm_rmap_item *rmap_item;
        int search_new_forks = 0;
 
        VM_BUG_ON_FOLIO(!folio_test_ksm(folio), folio);
@@ -2680,10 +3168,55 @@ again:
                goto again;
 }
 
+#ifdef CONFIG_MEMORY_FAILURE
+/*
+ * Collect processes when the error hit an ksm page.
+ */
+void collect_procs_ksm(struct page *page, struct list_head *to_kill,
+                      int force_early)
+{
+       struct ksm_stable_node *stable_node;
+       struct ksm_rmap_item *rmap_item;
+       struct folio *folio = page_folio(page);
+       struct vm_area_struct *vma;
+       struct task_struct *tsk;
+
+       stable_node = folio_stable_node(folio);
+       if (!stable_node)
+               return;
+       hlist_for_each_entry(rmap_item, &stable_node->hlist, hlist) {
+               struct anon_vma *av = rmap_item->anon_vma;
+
+               anon_vma_lock_read(av);
+               rcu_read_lock();
+               for_each_process(tsk) {
+                       struct anon_vma_chain *vmac;
+                       unsigned long addr;
+                       struct task_struct *t =
+                               task_early_kill(tsk, force_early);
+                       if (!t)
+                               continue;
+                       anon_vma_interval_tree_foreach(vmac, &av->rb_root, 0,
+                                                      ULONG_MAX)
+                       {
+                               vma = vmac->vma;
+                               if (vma->vm_mm == t->mm) {
+                                       addr = rmap_item->address & PAGE_MASK;
+                                       add_to_kill_ksm(t, page, vma, to_kill,
+                                                       addr);
+                               }
+                       }
+               }
+               rcu_read_unlock();
+               anon_vma_unlock_read(av);
+       }
+}
+#endif
+
 #ifdef CONFIG_MIGRATION
 void folio_migrate_ksm(struct folio *newfolio, struct folio *folio)
 {
-       struct stable_node *stable_node;
+       struct ksm_stable_node *stable_node;
 
        VM_BUG_ON_FOLIO(!folio_test_locked(folio), folio);
        VM_BUG_ON_FOLIO(!folio_test_locked(newfolio), newfolio);
@@ -2716,7 +3249,7 @@ static void wait_while_offlining(void)
        }
 }
 
-static bool stable_node_dup_remove_range(struct stable_node *stable_node,
+static bool stable_node_dup_remove_range(struct ksm_stable_node *stable_node,
                                         unsigned long start_pfn,
                                         unsigned long end_pfn)
 {
@@ -2732,12 +3265,12 @@ static bool stable_node_dup_remove_range(struct stable_node *stable_node,
        return false;
 }
 
-static bool stable_node_chain_remove_range(struct stable_node *stable_node,
+static bool stable_node_chain_remove_range(struct ksm_stable_node *stable_node,
                                           unsigned long start_pfn,
                                           unsigned long end_pfn,
                                           struct rb_root *root)
 {
-       struct stable_node *dup;
+       struct ksm_stable_node *dup;
        struct hlist_node *hlist_safe;
 
        if (!is_stable_node_chain(stable_node)) {
@@ -2761,14 +3294,14 @@ static bool stable_node_chain_remove_range(struct stable_node *stable_node,
 static void ksm_check_stable_tree(unsigned long start_pfn,
                                  unsigned long end_pfn)
 {
-       struct stable_node *stable_node, *next;
+       struct ksm_stable_node *stable_node, *next;
        struct rb_node *node;
        int nid;
 
        for (nid = 0; nid < ksm_nr_node_ids; nid++) {
                node = rb_first(root_stable_tree + nid);
                while (node) {
-                       stable_node = rb_entry(node, struct stable_node, node);
+                       stable_node = rb_entry(node, struct ksm_stable_node, node);
                        if (stable_node_chain_remove_range(stable_node,
                                                           start_pfn, end_pfn,
                                                           root_stable_tree +
@@ -2834,6 +3367,14 @@ static void wait_while_offlining(void)
 }
 #endif /* CONFIG_MEMORY_HOTREMOVE */
 
+#ifdef CONFIG_PROC_FS
+long ksm_process_profit(struct mm_struct *mm)
+{
+       return (long)(mm->ksm_merging_pages + mm->ksm_zero_pages) * PAGE_SIZE -
+               mm->ksm_rmap_items * sizeof(struct ksm_rmap_item);
+}
+#endif /* CONFIG_PROC_FS */
+
 #ifdef CONFIG_SYSFS
 /*
  * This all compiles without CONFIG_SYSFS, but is a waste of space.
@@ -2881,6 +3422,9 @@ static ssize_t pages_to_scan_store(struct kobject *kobj,
        unsigned int nr_pages;
        int err;
 
+       if (ksm_advisor != KSM_ADVISOR_NONE)
+               return -EINVAL;
+
        err = kstrtouint(buf, 10, &nr_pages);
        if (err)
                return -EINVAL;
@@ -3060,6 +3604,13 @@ static ssize_t max_page_sharing_store(struct kobject *kobj,
 }
 KSM_ATTR(max_page_sharing);
 
+static ssize_t pages_scanned_show(struct kobject *kobj,
+                                 struct kobj_attribute *attr, char *buf)
+{
+       return sysfs_emit(buf, "%lu\n", ksm_pages_scanned);
+}
+KSM_ATTR_RO(pages_scanned);
+
 static ssize_t pages_shared_show(struct kobject *kobj,
                                 struct kobj_attribute *attr, char *buf)
 {
@@ -3098,6 +3649,32 @@ static ssize_t pages_volatile_show(struct kobject *kobj,
 }
 KSM_ATTR_RO(pages_volatile);
 
+static ssize_t pages_skipped_show(struct kobject *kobj,
+                                 struct kobj_attribute *attr, char *buf)
+{
+       return sysfs_emit(buf, "%lu\n", ksm_pages_skipped);
+}
+KSM_ATTR_RO(pages_skipped);
+
+static ssize_t ksm_zero_pages_show(struct kobject *kobj,
+                               struct kobj_attribute *attr, char *buf)
+{
+       return sysfs_emit(buf, "%ld\n", ksm_zero_pages);
+}
+KSM_ATTR_RO(ksm_zero_pages);
+
+static ssize_t general_profit_show(struct kobject *kobj,
+                                  struct kobj_attribute *attr, char *buf)
+{
+       long general_profit;
+
+       general_profit = (ksm_pages_sharing + ksm_zero_pages) * PAGE_SIZE -
+                               ksm_rmap_items * sizeof(struct ksm_rmap_item);
+
+       return sysfs_emit(buf, "%ld\n", general_profit);
+}
+KSM_ATTR_RO(general_profit);
+
 static ssize_t stable_node_dups_show(struct kobject *kobj,
                                     struct kobj_attribute *attr, char *buf)
 {
@@ -3145,14 +3722,163 @@ static ssize_t full_scans_show(struct kobject *kobj,
 }
 KSM_ATTR_RO(full_scans);
 
+static ssize_t smart_scan_show(struct kobject *kobj,
+                              struct kobj_attribute *attr, char *buf)
+{
+       return sysfs_emit(buf, "%u\n", ksm_smart_scan);
+}
+
+static ssize_t smart_scan_store(struct kobject *kobj,
+                               struct kobj_attribute *attr,
+                               const char *buf, size_t count)
+{
+       int err;
+       bool value;
+
+       err = kstrtobool(buf, &value);
+       if (err)
+               return -EINVAL;
+
+       ksm_smart_scan = value;
+       return count;
+}
+KSM_ATTR(smart_scan);
+
+static ssize_t advisor_mode_show(struct kobject *kobj,
+                                struct kobj_attribute *attr, char *buf)
+{
+       const char *output;
+
+       if (ksm_advisor == KSM_ADVISOR_NONE)
+               output = "[none] scan-time";
+       else if (ksm_advisor == KSM_ADVISOR_SCAN_TIME)
+               output = "none [scan-time]";
+
+       return sysfs_emit(buf, "%s\n", output);
+}
+
+static ssize_t advisor_mode_store(struct kobject *kobj,
+                                 struct kobj_attribute *attr, const char *buf,
+                                 size_t count)
+{
+       enum ksm_advisor_type curr_advisor = ksm_advisor;
+
+       if (sysfs_streq("scan-time", buf))
+               ksm_advisor = KSM_ADVISOR_SCAN_TIME;
+       else if (sysfs_streq("none", buf))
+               ksm_advisor = KSM_ADVISOR_NONE;
+       else
+               return -EINVAL;
+
+       /* Set advisor default values */
+       if (curr_advisor != ksm_advisor)
+               set_advisor_defaults();
+
+       return count;
+}
+KSM_ATTR(advisor_mode);
+
+static ssize_t advisor_max_cpu_show(struct kobject *kobj,
+                                   struct kobj_attribute *attr, char *buf)
+{
+       return sysfs_emit(buf, "%u\n", ksm_advisor_max_cpu);
+}
+
+static ssize_t advisor_max_cpu_store(struct kobject *kobj,
+                                    struct kobj_attribute *attr,
+                                    const char *buf, size_t count)
+{
+       int err;
+       unsigned long value;
+
+       err = kstrtoul(buf, 10, &value);
+       if (err)
+               return -EINVAL;
+
+       ksm_advisor_max_cpu = value;
+       return count;
+}
+KSM_ATTR(advisor_max_cpu);
+
+static ssize_t advisor_min_pages_to_scan_show(struct kobject *kobj,
+                                       struct kobj_attribute *attr, char *buf)
+{
+       return sysfs_emit(buf, "%lu\n", ksm_advisor_min_pages_to_scan);
+}
+
+static ssize_t advisor_min_pages_to_scan_store(struct kobject *kobj,
+                                       struct kobj_attribute *attr,
+                                       const char *buf, size_t count)
+{
+       int err;
+       unsigned long value;
+
+       err = kstrtoul(buf, 10, &value);
+       if (err)
+               return -EINVAL;
+
+       ksm_advisor_min_pages_to_scan = value;
+       return count;
+}
+KSM_ATTR(advisor_min_pages_to_scan);
+
+static ssize_t advisor_max_pages_to_scan_show(struct kobject *kobj,
+                                       struct kobj_attribute *attr, char *buf)
+{
+       return sysfs_emit(buf, "%lu\n", ksm_advisor_max_pages_to_scan);
+}
+
+static ssize_t advisor_max_pages_to_scan_store(struct kobject *kobj,
+                                       struct kobj_attribute *attr,
+                                       const char *buf, size_t count)
+{
+       int err;
+       unsigned long value;
+
+       err = kstrtoul(buf, 10, &value);
+       if (err)
+               return -EINVAL;
+
+       ksm_advisor_max_pages_to_scan = value;
+       return count;
+}
+KSM_ATTR(advisor_max_pages_to_scan);
+
+static ssize_t advisor_target_scan_time_show(struct kobject *kobj,
+                                            struct kobj_attribute *attr, char *buf)
+{
+       return sysfs_emit(buf, "%lu\n", ksm_advisor_target_scan_time);
+}
+
+static ssize_t advisor_target_scan_time_store(struct kobject *kobj,
+                                             struct kobj_attribute *attr,
+                                             const char *buf, size_t count)
+{
+       int err;
+       unsigned long value;
+
+       err = kstrtoul(buf, 10, &value);
+       if (err)
+               return -EINVAL;
+       if (value < 1)
+               return -EINVAL;
+
+       ksm_advisor_target_scan_time = value;
+       return count;
+}
+KSM_ATTR(advisor_target_scan_time);
+
 static struct attribute *ksm_attrs[] = {
        &sleep_millisecs_attr.attr,
        &pages_to_scan_attr.attr,
        &run_attr.attr,
+       &pages_scanned_attr.attr,
        &pages_shared_attr.attr,
        &pages_sharing_attr.attr,
        &pages_unshared_attr.attr,
        &pages_volatile_attr.attr,
+       &pages_skipped_attr.attr,
+       &ksm_zero_pages_attr.attr,
        &full_scans_attr.attr,
 #ifdef CONFIG_NUMA
        &merge_across_nodes_attr.attr,
@@ -3162,6 +3888,13 @@ static struct attribute *ksm_attrs[] = {
        &stable_node_dups_attr.attr,
        &stable_node_chains_prune_millisecs_attr.attr,
        &use_zero_pages_attr.attr,
+       &general_profit_attr.attr,
+       &smart_scan_attr.attr,
+       &advisor_mode_attr.attr,
+       &advisor_max_cpu_attr.attr,
+       &advisor_min_pages_to_scan_attr.attr,
+       &advisor_max_pages_to_scan_attr.attr,
+       &advisor_target_scan_time_attr.attr,
        NULL,
 };
 
@@ -3206,7 +3939,7 @@ static int __init ksm_init(void)
 
 #ifdef CONFIG_MEMORY_HOTREMOVE
        /* There is no significance to this priority 100 */
-       hotplug_memory_notifier(ksm_memory_callback, 100);
+       hotplug_memory_notifier(ksm_memory_callback, KSM_CALLBACK_PRI);
 #endif
        return 0;