Linux 6.9-rc1
[linux-2.6-microblaze.git] / mm / memcontrol.c
index b69979c..fabce2b 100644 (file)
@@ -33,6 +33,7 @@
 #include <linux/shmem_fs.h>
 #include <linux/hugetlb.h>
 #include <linux/pagemap.h>
+#include <linux/pagevec.h>
 #include <linux/vm_event_item.h>
 #include <linux/smp.h>
 #include <linux/page-flags.h>
@@ -63,6 +64,8 @@
 #include <linux/resume_user_mode.h>
 #include <linux/psi.h>
 #include <linux/seq_buf.h>
+#include <linux/sched/isolation.h>
+#include <linux/kmemleak.h>
 #include "internal.h"
 #include <net/sock.h>
 #include <net/ip.h>
@@ -88,12 +91,8 @@ static bool cgroup_memory_nosocket __ro_after_init;
 /* Kernel memory accounting disabled? */
 static bool cgroup_memory_nokmem __ro_after_init;
 
-/* Whether the swap controller is active */
-#ifdef CONFIG_MEMCG_SWAP
-static bool cgroup_memory_noswap __ro_after_init;
-#else
-#define cgroup_memory_noswap           1
-#endif
+/* BPF memory accounting disabled? */
+static bool cgroup_memory_nobpf __ro_after_init;
 
 #ifdef CONFIG_CGROUP_WRITEBACK
 static DECLARE_WAIT_QUEUE_HEAD(memcg_cgwb_frn_waitq);
@@ -102,7 +101,7 @@ static DECLARE_WAIT_QUEUE_HEAD(memcg_cgwb_frn_waitq);
 /* Whether legacy memory+swap accounting is active */
 static bool do_memsw_account(void)
 {
-       return !cgroup_subsys_on_dfl(memory_cgrp_subsys) && !cgroup_memory_noswap;
+       return !cgroup_subsys_on_dfl(memory_cgrp_subsys);
 }
 
 #define THRESHOLDS_EVENTS_TARGET 128
@@ -200,7 +199,7 @@ static struct move_charge_struct {
 };
 
 /*
- * Maximum loops in mem_cgroup_hierarchical_reclaim(), used for soft
+ * Maximum loops in mem_cgroup_soft_reclaim(), used for soft
  * limit reclaim to prevent infinite loops, if they ever occur.
  */
 #define        MEM_CGROUP_MAX_RECLAIM_LOOPS            100
@@ -252,6 +251,9 @@ struct mem_cgroup *vmpressure_to_memcg(struct vmpressure *vmpr)
        return container_of(vmpr, struct mem_cgroup, vmpressure);
 }
 
+#define CURRENT_OBJCG_UPDATE_BIT 0
+#define CURRENT_OBJCG_UPDATE_FLAG (1UL << CURRENT_OBJCG_UPDATE_BIT)
+
 #ifdef CONFIG_MEMCG_KMEM
 static DEFINE_SPINLOCK(objcg_lock);
 
@@ -352,26 +354,27 @@ static void memcg_reparent_objcgs(struct mem_cgroup *memcg,
  * conditional to this static branch, we'll have to allow modules that does
  * kmem_cache_alloc and the such to see this symbol as well
  */
-DEFINE_STATIC_KEY_FALSE(memcg_kmem_enabled_key);
-EXPORT_SYMBOL(memcg_kmem_enabled_key);
+DEFINE_STATIC_KEY_FALSE(memcg_kmem_online_key);
+EXPORT_SYMBOL(memcg_kmem_online_key);
+
+DEFINE_STATIC_KEY_FALSE(memcg_bpf_enabled_key);
+EXPORT_SYMBOL(memcg_bpf_enabled_key);
 #endif
 
 /**
- * mem_cgroup_css_from_page - css of the memcg associated with a page
- * @page: page of interest
+ * mem_cgroup_css_from_folio - css of the memcg associated with a folio
+ * @folio: folio of interest
  *
  * If memcg is bound to the default hierarchy, css of the memcg associated
- * with @page is returned.  The returned css remains associated with @page
+ * with @folio is returned.  The returned css remains associated with @folio
  * until it is released.
  *
  * If memcg is bound to a traditional hierarchy, the css of root_mem_cgroup
  * is returned.
  */
-struct cgroup_subsys_state *mem_cgroup_css_from_page(struct page *page)
+struct cgroup_subsys_state *mem_cgroup_css_from_folio(struct folio *folio)
 {
-       struct mem_cgroup *memcg;
-
-       memcg = page_memcg(page);
+       struct mem_cgroup *memcg = folio_memcg(folio);
 
        if (!memcg || !cgroup_subsys_on_dfl(memory_cgrp_subsys))
                memcg = root_mem_cgroup;
@@ -398,7 +401,8 @@ ino_t page_cgroup_ino(struct page *page)
        unsigned long ino = 0;
 
        rcu_read_lock();
-       memcg = page_memcg_check(page);
+       /* page_folio() is racy here, but the entire function is racy anyway */
+       memcg = folio_memcg_check(page_folio(page));
 
        while (memcg && !(memcg->css.flags & CSS_ONLINE))
                memcg = parent_mem_cgroup(memcg);
@@ -484,6 +488,12 @@ static void mem_cgroup_update_tree(struct mem_cgroup *memcg, int nid)
        struct mem_cgroup_per_node *mz;
        struct mem_cgroup_tree_per_node *mctz;
 
+       if (lru_gen_enabled()) {
+               if (soft_limit_excess(memcg))
+                       lru_gen_soft_reclaim(memcg, nid);
+               return;
+       }
+
        mctz = soft_limit_tree.rb_tree_per_node[nid];
        if (!mctz)
                return;
@@ -565,6 +575,92 @@ mem_cgroup_largest_soft_limit_node(struct mem_cgroup_tree_per_node *mctz)
        return mz;
 }
 
+/* Subset of vm_event_item to report for memcg event stats */
+static const unsigned int memcg_vm_event_stat[] = {
+       PGPGIN,
+       PGPGOUT,
+       PGSCAN_KSWAPD,
+       PGSCAN_DIRECT,
+       PGSCAN_KHUGEPAGED,
+       PGSTEAL_KSWAPD,
+       PGSTEAL_DIRECT,
+       PGSTEAL_KHUGEPAGED,
+       PGFAULT,
+       PGMAJFAULT,
+       PGREFILL,
+       PGACTIVATE,
+       PGDEACTIVATE,
+       PGLAZYFREE,
+       PGLAZYFREED,
+#if defined(CONFIG_MEMCG_KMEM) && defined(CONFIG_ZSWAP)
+       ZSWPIN,
+       ZSWPOUT,
+       ZSWPWB,
+#endif
+#ifdef CONFIG_TRANSPARENT_HUGEPAGE
+       THP_FAULT_ALLOC,
+       THP_COLLAPSE_ALLOC,
+       THP_SWPOUT,
+       THP_SWPOUT_FALLBACK,
+#endif
+};
+
+#define NR_MEMCG_EVENTS ARRAY_SIZE(memcg_vm_event_stat)
+static int mem_cgroup_events_index[NR_VM_EVENT_ITEMS] __read_mostly;
+
+static void init_memcg_events(void)
+{
+       int i;
+
+       for (i = 0; i < NR_MEMCG_EVENTS; ++i)
+               mem_cgroup_events_index[memcg_vm_event_stat[i]] = i + 1;
+}
+
+static inline int memcg_events_index(enum vm_event_item idx)
+{
+       return mem_cgroup_events_index[idx] - 1;
+}
+
+struct memcg_vmstats_percpu {
+       /* Stats updates since the last flush */
+       unsigned int                    stats_updates;
+
+       /* Cached pointers for fast iteration in memcg_rstat_updated() */
+       struct memcg_vmstats_percpu     *parent;
+       struct memcg_vmstats            *vmstats;
+
+       /* The above should fit a single cacheline for memcg_rstat_updated() */
+
+       /* Local (CPU and cgroup) page state & events */
+       long                    state[MEMCG_NR_STAT];
+       unsigned long           events[NR_MEMCG_EVENTS];
+
+       /* Delta calculation for lockless upward propagation */
+       long                    state_prev[MEMCG_NR_STAT];
+       unsigned long           events_prev[NR_MEMCG_EVENTS];
+
+       /* Cgroup1: threshold notifications & softlimit tree updates */
+       unsigned long           nr_page_events;
+       unsigned long           targets[MEM_CGROUP_NTARGETS];
+} ____cacheline_aligned;
+
+struct memcg_vmstats {
+       /* Aggregated (CPU and subtree) page state & events */
+       long                    state[MEMCG_NR_STAT];
+       unsigned long           events[NR_MEMCG_EVENTS];
+
+       /* Non-hierarchical (CPU aggregated) page state & events */
+       long                    state_local[MEMCG_NR_STAT];
+       unsigned long           events_local[NR_MEMCG_EVENTS];
+
+       /* Pending child counts during tree propagation */
+       long                    state_pending[MEMCG_NR_STAT];
+       unsigned long           events_pending[NR_MEMCG_EVENTS];
+
+       /* Stats updates since the last flush */
+       atomic64_t              stats_updates;
+};
+
 /*
  * memcg and lruvec stats flushing
  *
@@ -582,10 +678,7 @@ mem_cgroup_largest_soft_limit_node(struct mem_cgroup_tree_per_node *mctz)
  */
 static void flush_memcg_stats_dwork(struct work_struct *w);
 static DECLARE_DEFERRABLE_WORK(stats_flush_dwork, flush_memcg_stats_dwork);
-static DEFINE_SPINLOCK(stats_flush_lock);
-static DEFINE_PER_CPU(unsigned int, stats_updates);
-static atomic_t stats_flush_threshold = ATOMIC_INIT(0);
-static u64 flush_next_time;
+static u64 flush_last_time;
 
 #define FLUSH_TIME (2UL*HZ)
 
@@ -597,78 +690,125 @@ static u64 flush_next_time;
  */
 static void memcg_stats_lock(void)
 {
-#ifdef CONFIG_PREEMPT_RT
-      preempt_disable();
-#else
-      VM_BUG_ON(!irqs_disabled());
-#endif
+       preempt_disable_nested();
+       VM_WARN_ON_IRQS_ENABLED();
 }
 
 static void __memcg_stats_lock(void)
 {
-#ifdef CONFIG_PREEMPT_RT
-      preempt_disable();
-#endif
+       preempt_disable_nested();
 }
 
 static void memcg_stats_unlock(void)
 {
-#ifdef CONFIG_PREEMPT_RT
-      preempt_enable();
-#endif
+       preempt_enable_nested();
+}
+
+
+static bool memcg_vmstats_needs_flush(struct memcg_vmstats *vmstats)
+{
+       return atomic64_read(&vmstats->stats_updates) >
+               MEMCG_CHARGE_BATCH * num_online_cpus();
 }
 
 static inline void memcg_rstat_updated(struct mem_cgroup *memcg, int val)
 {
-       unsigned int x;
+       struct memcg_vmstats_percpu *statc;
+       int cpu = smp_processor_id();
+
+       if (!val)
+               return;
 
-       cgroup_rstat_updated(memcg->css.cgroup, smp_processor_id());
+       cgroup_rstat_updated(memcg->css.cgroup, cpu);
+       statc = this_cpu_ptr(memcg->vmstats_percpu);
+       for (; statc; statc = statc->parent) {
+               statc->stats_updates += abs(val);
+               if (statc->stats_updates < MEMCG_CHARGE_BATCH)
+                       continue;
 
-       x = __this_cpu_add_return(stats_updates, abs(val));
-       if (x > MEMCG_CHARGE_BATCH) {
                /*
-                * If stats_flush_threshold exceeds the threshold
-                * (>num_online_cpus()), cgroup stats update will be triggered
-                * in __mem_cgroup_flush_stats(). Increasing this var further
-                * is redundant and simply adds overhead in atomic update.
+                * If @memcg is already flush-able, increasing stats_updates is
+                * redundant. Avoid the overhead of the atomic update.
                 */
-               if (atomic_read(&stats_flush_threshold) <= num_online_cpus())
-                       atomic_add(x / MEMCG_CHARGE_BATCH, &stats_flush_threshold);
-               __this_cpu_write(stats_updates, 0);
+               if (!memcg_vmstats_needs_flush(statc->vmstats))
+                       atomic64_add(statc->stats_updates,
+                                    &statc->vmstats->stats_updates);
+               statc->stats_updates = 0;
        }
 }
 
-static void __mem_cgroup_flush_stats(void)
+static void do_flush_stats(struct mem_cgroup *memcg)
 {
-       unsigned long flag;
-
-       if (!spin_trylock_irqsave(&stats_flush_lock, flag))
-               return;
+       if (mem_cgroup_is_root(memcg))
+               WRITE_ONCE(flush_last_time, jiffies_64);
 
-       flush_next_time = jiffies_64 + 2*FLUSH_TIME;
-       cgroup_rstat_flush_irqsafe(root_mem_cgroup->css.cgroup);
-       atomic_set(&stats_flush_threshold, 0);
-       spin_unlock_irqrestore(&stats_flush_lock, flag);
+       cgroup_rstat_flush(memcg->css.cgroup);
 }
 
-void mem_cgroup_flush_stats(void)
+/*
+ * mem_cgroup_flush_stats - flush the stats of a memory cgroup subtree
+ * @memcg: root of the subtree to flush
+ *
+ * Flushing is serialized by the underlying global rstat lock. There is also a
+ * minimum amount of work to be done even if there are no stat updates to flush.
+ * Hence, we only flush the stats if the updates delta exceeds a threshold. This
+ * avoids unnecessary work and contention on the underlying lock.
+ */
+void mem_cgroup_flush_stats(struct mem_cgroup *memcg)
 {
-       if (atomic_read(&stats_flush_threshold) > num_online_cpus())
-               __mem_cgroup_flush_stats();
+       if (mem_cgroup_disabled())
+               return;
+
+       if (!memcg)
+               memcg = root_mem_cgroup;
+
+       if (memcg_vmstats_needs_flush(memcg->vmstats))
+               do_flush_stats(memcg);
 }
 
-void mem_cgroup_flush_stats_delayed(void)
+void mem_cgroup_flush_stats_ratelimited(struct mem_cgroup *memcg)
 {
-       if (time_after64(jiffies_64, flush_next_time))
-               mem_cgroup_flush_stats();
+       /* Only flush if the periodic flusher is one full cycle late */
+       if (time_after64(jiffies_64, READ_ONCE(flush_last_time) + 2*FLUSH_TIME))
+               mem_cgroup_flush_stats(memcg);
 }
 
 static void flush_memcg_stats_dwork(struct work_struct *w)
 {
-       __mem_cgroup_flush_stats();
+       /*
+        * Deliberately ignore memcg_vmstats_needs_flush() here so that flushing
+        * in latency-sensitive paths is as cheap as possible.
+        */
+       do_flush_stats(root_mem_cgroup);
        queue_delayed_work(system_unbound_wq, &stats_flush_dwork, FLUSH_TIME);
 }
 
+unsigned long memcg_page_state(struct mem_cgroup *memcg, int idx)
+{
+       long x = READ_ONCE(memcg->vmstats->state[idx]);
+#ifdef CONFIG_SMP
+       if (x < 0)
+               x = 0;
+#endif
+       return x;
+}
+
+static int memcg_page_state_unit(int item);
+
+/*
+ * Normalize the value passed into memcg_rstat_updated() to be in pages. Round
+ * up non-zero sub-page updates to 1 page as zero page updates are ignored.
+ */
+static int memcg_state_val_in_pages(int idx, int val)
+{
+       int unit = memcg_page_state_unit(idx);
+
+       if (!val || unit == PAGE_SIZE)
+               return val;
+       else
+               return max(val * unit / PAGE_SIZE, 1UL);
+}
+
 /**
  * __mod_memcg_state - update cgroup memory statistics
  * @memcg: the memory cgroup
@@ -681,17 +821,14 @@ void __mod_memcg_state(struct mem_cgroup *memcg, int idx, int val)
                return;
 
        __this_cpu_add(memcg->vmstats_percpu->state[idx], val);
-       memcg_rstat_updated(memcg, val);
+       memcg_rstat_updated(memcg, memcg_state_val_in_pages(idx, val));
 }
 
 /* idx can be of type enum memcg_stat_item or node_stat_item. */
 static unsigned long memcg_page_state_local(struct mem_cgroup *memcg, int idx)
 {
-       long x = 0;
-       int cpu;
+       long x = READ_ONCE(memcg->vmstats->state_local[idx]);
 
-       for_each_possible_cpu(cpu)
-               x += per_cpu(memcg->vmstats_percpu->state[idx], cpu);
 #ifdef CONFIG_SMP
        if (x < 0)
                x = 0;
@@ -709,13 +846,13 @@ void __mod_memcg_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
        memcg = pn->memcg;
 
        /*
-        * The caller from rmap relay on disabled preemption becase they never
+        * The caller from rmap relies on disabled preemption because they never
         * update their counter from in-interrupt context. For these two
         * counters we check that the update is never performed from an
         * interrupt context while other caller need to have disabled interrupt.
         */
        __memcg_stats_lock();
-       if (IS_ENABLED(CONFIG_DEBUG_VM) && !IS_ENABLED(CONFIG_PREEMPT_RT)) {
+       if (IS_ENABLED(CONFIG_DEBUG_VM)) {
                switch (idx) {
                case NR_ANON_MAPPED:
                case NR_FILE_MAPPED:
@@ -725,7 +862,7 @@ void __mod_memcg_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
                        WARN_ON_ONCE(!in_task());
                        break;
                default:
-                       WARN_ON_ONCE(!irqs_disabled());
+                       VM_WARN_ON_IRQS_ENABLED();
                }
        }
 
@@ -735,7 +872,7 @@ void __mod_memcg_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
        /* Update lruvec */
        __this_cpu_add(pn->lruvec_stats_percpu->state[idx], val);
 
-       memcg_rstat_updated(memcg, val);
+       memcg_rstat_updated(memcg, memcg_state_val_in_pages(idx, val));
        memcg_stats_unlock();
 }
 
@@ -760,16 +897,15 @@ void __mod_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
                __mod_memcg_lruvec_state(lruvec, idx, val);
 }
 
-void __mod_lruvec_page_state(struct page *page, enum node_stat_item idx,
+void __lruvec_stat_mod_folio(struct folio *folio, enum node_stat_item idx,
                             int val)
 {
-       struct page *head = compound_head(page); /* rmap on tail pages */
        struct mem_cgroup *memcg;
-       pg_data_t *pgdat = page_pgdat(page);
+       pg_data_t *pgdat = folio_pgdat(folio);
        struct lruvec *lruvec;
 
        rcu_read_lock();
-       memcg = page_memcg(head);
+       memcg = folio_memcg(folio);
        /* Untracked pages have no memcg, no lruvec. Update only the node */
        if (!memcg) {
                rcu_read_unlock();
@@ -781,7 +917,7 @@ void __mod_lruvec_page_state(struct page *page, enum node_stat_item idx,
        __mod_lruvec_state(lruvec, idx, val);
        rcu_read_unlock();
 }
-EXPORT_SYMBOL(__mod_lruvec_page_state);
+EXPORT_SYMBOL(__lruvec_stat_mod_folio);
 
 void __mod_lruvec_kmem_state(void *p, enum node_stat_item idx, int val)
 {
@@ -816,28 +952,34 @@ void __mod_lruvec_kmem_state(void *p, enum node_stat_item idx, int val)
 void __count_memcg_events(struct mem_cgroup *memcg, enum vm_event_item idx,
                          unsigned long count)
 {
-       if (mem_cgroup_disabled())
+       int index = memcg_events_index(idx);
+
+       if (mem_cgroup_disabled() || index < 0)
                return;
 
        memcg_stats_lock();
-       __this_cpu_add(memcg->vmstats_percpu->events[idx], count);
+       __this_cpu_add(memcg->vmstats_percpu->events[index], count);
        memcg_rstat_updated(memcg, count);
        memcg_stats_unlock();
 }
 
 static unsigned long memcg_events(struct mem_cgroup *memcg, int event)
 {
-       return READ_ONCE(memcg->vmstats.events[event]);
+       int index = memcg_events_index(event);
+
+       if (index < 0)
+               return 0;
+       return READ_ONCE(memcg->vmstats->events[index]);
 }
 
 static unsigned long memcg_events_local(struct mem_cgroup *memcg, int event)
 {
-       long x = 0;
-       int cpu;
+       int index = memcg_events_index(event);
 
-       for_each_possible_cpu(cpu)
-               x += per_cpu(memcg->vmstats_percpu->events[event], cpu);
-       return x;
+       if (index < 0)
+               return 0;
+
+       return READ_ONCE(memcg->vmstats->events_local[index]);
 }
 
 static void mem_cgroup_charge_statistics(struct mem_cgroup *memcg,
@@ -973,17 +1115,25 @@ struct mem_cgroup *get_mem_cgroup_from_mm(struct mm_struct *mm)
 }
 EXPORT_SYMBOL(get_mem_cgroup_from_mm);
 
-static __always_inline bool memcg_kmem_bypass(void)
+/**
+ * get_mem_cgroup_from_current - Obtain a reference on current task's memcg.
+ */
+struct mem_cgroup *get_mem_cgroup_from_current(void)
 {
-       /* Allow remote memcg charging from any context. */
-       if (unlikely(active_memcg()))
-               return false;
+       struct mem_cgroup *memcg;
 
-       /* Memcg to charge can't be determined. */
-       if (!in_task() || !current->mm || (current->flags & PF_KTHREAD))
-               return true;
+       if (mem_cgroup_disabled())
+               return NULL;
 
-       return false;
+again:
+       rcu_read_lock();
+       memcg = mem_cgroup_from_task(current);
+       if (!css_tryget(&memcg->css)) {
+               rcu_read_unlock();
+               goto again;
+       }
+       rcu_read_unlock();
+       return memcg;
 }
 
 /**
@@ -1143,12 +1293,12 @@ static void invalidate_reclaim_iterators(struct mem_cgroup *dead_memcg)
        } while ((memcg = parent_mem_cgroup(memcg)));
 
        /*
-        * When cgruop1 non-hierarchy mode is used,
+        * When cgroup1 non-hierarchy mode is used,
         * parent_mem_cgroup() does not walk all the way up to the
         * cgroup root (root_mem_cgroup). So we have to handle
         * dead_memcg from cgroup root separately.
         */
-       if (last != root_mem_cgroup)
+       if (!mem_cgroup_is_root(last))
                __invalidate_reclaim_iterators(root_mem_cgroup,
                                                dead_memcg);
 }
@@ -1161,18 +1311,18 @@ static void invalidate_reclaim_iterators(struct mem_cgroup *dead_memcg)
  *
  * This function iterates over tasks attached to @memcg or to any of its
  * descendants and calls @fn for each task. If @fn returns a non-zero
- * value, the function breaks the iteration loop and returns the value.
- * Otherwise, it will iterate over all tasks and return 0.
+ * value, the function breaks the iteration loop. Otherwise, it will iterate
+ * over all tasks and return 0.
  *
  * This function must not be called for the root memory cgroup.
  */
-int mem_cgroup_scan_tasks(struct mem_cgroup *memcg,
-                         int (*fn)(struct task_struct *, void *), void *arg)
+void mem_cgroup_scan_tasks(struct mem_cgroup *memcg,
+                          int (*fn)(struct task_struct *, void *), void *arg)
 {
        struct mem_cgroup *iter;
        int ret = 0;
 
-       BUG_ON(memcg == root_mem_cgroup);
+       BUG_ON(mem_cgroup_is_root(memcg));
 
        for_each_mem_cgroup_tree(iter, memcg) {
                struct css_task_iter it;
@@ -1187,7 +1337,6 @@ int mem_cgroup_scan_tasks(struct mem_cgroup *memcg,
                        break;
                }
        }
-       return ret;
 }
 
 #ifdef CONFIG_DEBUG_VM
@@ -1201,7 +1350,7 @@ void lruvec_memcg_debug(struct lruvec *lruvec, struct folio *folio)
        memcg = folio_memcg(folio);
 
        if (!memcg)
-               VM_BUG_ON_FOLIO(lruvec_memcg(lruvec) != root_mem_cgroup, folio);
+               VM_BUG_ON_FOLIO(!mem_cgroup_is_root(lruvec_memcg(lruvec)), folio);
        else
                VM_BUG_ON_FOLIO(lruvec_memcg(lruvec) != memcg, folio);
 }
@@ -1401,6 +1550,7 @@ static const struct memory_stat memory_stats[] = {
        { "kernel",                     MEMCG_KMEM                      },
        { "kernel_stack",               NR_KERNEL_STACK_KB              },
        { "pagetables",                 NR_PAGETABLE                    },
+       { "sec_pagetables",             NR_SECONDARY_PAGETABLE          },
        { "percpu",                     MEMCG_PERCPU_B                  },
        { "sock",                       MEMCG_SOCK                      },
        { "vmalloc",                    MEMCG_VMALLOC                   },
@@ -1438,7 +1588,7 @@ static const struct memory_stat memory_stats[] = {
        { "workingset_nodereclaim",     WORKINGSET_NODERECLAIM          },
 };
 
-/* Translate stat items to the correct unit for memory.stat output */
+/* The actual unit of the state item, not the same as the output unit */
 static int memcg_page_state_unit(int item)
 {
        switch (item) {
@@ -1446,6 +1596,22 @@ static int memcg_page_state_unit(int item)
        case MEMCG_ZSWAP_B:
        case NR_SLAB_RECLAIMABLE_B:
        case NR_SLAB_UNRECLAIMABLE_B:
+               return 1;
+       case NR_KERNEL_STACK_KB:
+               return SZ_1K;
+       default:
+               return PAGE_SIZE;
+       }
+}
+
+/* Translate stat items to the correct unit for memory.stat output */
+static int memcg_page_state_output_unit(int item)
+{
+       /*
+        * Workingset state is actually in pages, but we export it to userspace
+        * as a scalar count of events, so special case it here.
+        */
+       switch (item) {
        case WORKINGSET_REFAULT_ANON:
        case WORKINGSET_REFAULT_FILE:
        case WORKINGSET_ACTIVATE_ANON:
@@ -1454,49 +1620,29 @@ static int memcg_page_state_unit(int item)
        case WORKINGSET_RESTORE_FILE:
        case WORKINGSET_NODERECLAIM:
                return 1;
-       case NR_KERNEL_STACK_KB:
-               return SZ_1K;
        default:
-               return PAGE_SIZE;
+               return memcg_page_state_unit(item);
        }
 }
 
 static inline unsigned long memcg_page_state_output(struct mem_cgroup *memcg,
                                                    int item)
 {
-       return memcg_page_state(memcg, item) * memcg_page_state_unit(item);
+       return memcg_page_state(memcg, item) *
+               memcg_page_state_output_unit(item);
 }
 
-/* Subset of vm_event_item to report for memcg event stats */
-static const unsigned int memcg_vm_event_stat[] = {
-       PGSCAN_KSWAPD,
-       PGSCAN_DIRECT,
-       PGSTEAL_KSWAPD,
-       PGSTEAL_DIRECT,
-       PGFAULT,
-       PGMAJFAULT,
-       PGREFILL,
-       PGACTIVATE,
-       PGDEACTIVATE,
-       PGLAZYFREE,
-       PGLAZYFREED,
-#if defined(CONFIG_MEMCG_KMEM) && defined(CONFIG_ZSWAP)
-       ZSWPIN,
-       ZSWPOUT,
-#endif
-#ifdef CONFIG_TRANSPARENT_HUGEPAGE
-       THP_FAULT_ALLOC,
-       THP_COLLAPSE_ALLOC,
-#endif
-};
+static inline unsigned long memcg_page_state_local_output(
+               struct mem_cgroup *memcg, int item)
+{
+       return memcg_page_state_local(memcg, item) *
+               memcg_page_state_output_unit(item);
+}
 
-static void memory_stat_format(struct mem_cgroup *memcg, char *buf, int bufsize)
+static void memcg_stat_format(struct mem_cgroup *memcg, struct seq_buf *s)
 {
-       struct seq_buf s;
        int i;
 
-       seq_buf_init(&s, buf, bufsize);
-
        /*
         * Provide statistics on the state of the memory subsystem as
         * well as cumulative event counters that show past behavior.
@@ -1507,39 +1653,56 @@ static void memory_stat_format(struct mem_cgroup *memcg, char *buf, int bufsize)
         *
         * Current memory state:
         */
-       mem_cgroup_flush_stats();
+       mem_cgroup_flush_stats(memcg);
 
        for (i = 0; i < ARRAY_SIZE(memory_stats); i++) {
                u64 size;
 
                size = memcg_page_state_output(memcg, memory_stats[i].idx);
-               seq_buf_printf(&s, "%s %llu\n", memory_stats[i].name, size);
+               seq_buf_printf(s, "%s %llu\n", memory_stats[i].name, size);
 
                if (unlikely(memory_stats[i].idx == NR_SLAB_UNRECLAIMABLE_B)) {
                        size += memcg_page_state_output(memcg,
                                                        NR_SLAB_RECLAIMABLE_B);
-                       seq_buf_printf(&s, "slab %llu\n", size);
+                       seq_buf_printf(s, "slab %llu\n", size);
                }
        }
 
        /* Accumulated memory events */
-       seq_buf_printf(&s, "pgscan %lu\n",
+       seq_buf_printf(s, "pgscan %lu\n",
                       memcg_events(memcg, PGSCAN_KSWAPD) +
-                      memcg_events(memcg, PGSCAN_DIRECT));
-       seq_buf_printf(&s, "pgsteal %lu\n",
+                      memcg_events(memcg, PGSCAN_DIRECT) +
+                      memcg_events(memcg, PGSCAN_KHUGEPAGED));
+       seq_buf_printf(s, "pgsteal %lu\n",
                       memcg_events(memcg, PGSTEAL_KSWAPD) +
-                      memcg_events(memcg, PGSTEAL_DIRECT));
+                      memcg_events(memcg, PGSTEAL_DIRECT) +
+                      memcg_events(memcg, PGSTEAL_KHUGEPAGED));
 
-       for (i = 0; i < ARRAY_SIZE(memcg_vm_event_stat); i++)
-               seq_buf_printf(&s, "%s %lu\n",
+       for (i = 0; i < ARRAY_SIZE(memcg_vm_event_stat); i++) {
+               if (memcg_vm_event_stat[i] == PGPGIN ||
+                   memcg_vm_event_stat[i] == PGPGOUT)
+                       continue;
+
+               seq_buf_printf(s, "%s %lu\n",
                               vm_event_name(memcg_vm_event_stat[i]),
                               memcg_events(memcg, memcg_vm_event_stat[i]));
+       }
 
        /* The above should easily fit into one page */
-       WARN_ON_ONCE(seq_buf_has_overflowed(&s));
+       WARN_ON_ONCE(seq_buf_has_overflowed(s));
+}
+
+static void memcg1_stat_format(struct mem_cgroup *memcg, struct seq_buf *s);
+
+static void memory_stat_format(struct mem_cgroup *memcg, struct seq_buf *s)
+{
+       if (cgroup_subsys_on_dfl(memory_cgrp_subsys))
+               memcg_stat_format(memcg, s);
+       else
+               memcg1_stat_format(memcg, s);
+       WARN_ON_ONCE(seq_buf_has_overflowed(s));
 }
 
-#define K(x) ((x) << (PAGE_SHIFT-10))
 /**
  * mem_cgroup_print_oom_context: Print OOM information relevant to
  * memory controller.
@@ -1574,6 +1737,7 @@ void mem_cgroup_print_oom_meminfo(struct mem_cgroup *memcg)
 {
        /* Use static buffer, for the caller is holding oom_lock. */
        static char buf[PAGE_SIZE];
+       struct seq_buf s;
 
        lockdep_assert_held(&oom_lock);
 
@@ -1596,8 +1760,9 @@ void mem_cgroup_print_oom_meminfo(struct mem_cgroup *memcg)
        pr_info("Memory cgroup stats for ");
        pr_cont_cgroup_path(memcg->css.cgroup);
        pr_cont(":");
-       memory_stat_format(memcg, buf, sizeof(buf));
-       pr_info("%s", buf);
+       seq_buf_init(&s, buf, sizeof(buf));
+       memory_stat_format(memcg, &s);
+       seq_buf_do_printk(&s, KERN_INFO);
 }
 
 /*
@@ -1607,17 +1772,17 @@ unsigned long mem_cgroup_get_max(struct mem_cgroup *memcg)
 {
        unsigned long max = READ_ONCE(memcg->memory.max);
 
-       if (cgroup_subsys_on_dfl(memory_cgrp_subsys)) {
-               if (mem_cgroup_swappiness(memcg))
-                       max += min(READ_ONCE(memcg->swap.max),
-                                  (unsigned long)total_swap_pages);
-       } else { /* v1 */
+       if (do_memsw_account()) {
                if (mem_cgroup_swappiness(memcg)) {
                        /* Calculate swap excess capacity from memsw limit */
                        unsigned long swap = READ_ONCE(memcg->memsw.max) - max;
 
                        max += min(swap, (unsigned long)total_swap_pages);
                }
+       } else {
+               if (mem_cgroup_swappiness(memcg))
+                       max += min(READ_ONCE(memcg->swap.max),
+                                  (unsigned long)total_swap_pages);
        }
        return max;
 }
@@ -1861,7 +2026,7 @@ static bool mem_cgroup_oom(struct mem_cgroup *memcg, gfp_t mask, int order)
         * Please note that mem_cgroup_out_of_memory might fail to find a
         * victim and then we have to bail out from the charge path.
         */
-       if (memcg->oom_kill_disable) {
+       if (READ_ONCE(memcg->oom_kill_disable)) {
                if (current->in_user_fault) {
                        css_get(&memcg->css);
                        current->memcg_in_oom = memcg;
@@ -1931,26 +2096,12 @@ bool mem_cgroup_oom_synchronize(bool handle)
        if (locked)
                mem_cgroup_oom_notify(memcg);
 
-       if (locked && !memcg->oom_kill_disable) {
-               mem_cgroup_unmark_under_oom(memcg);
-               finish_wait(&memcg_oom_waitq, &owait.wait);
-               mem_cgroup_out_of_memory(memcg, current->memcg_oom_gfp_mask,
-                                        current->memcg_oom_order);
-       } else {
-               schedule();
-               mem_cgroup_unmark_under_oom(memcg);
-               finish_wait(&memcg_oom_waitq, &owait.wait);
-       }
+       schedule();
+       mem_cgroup_unmark_under_oom(memcg);
+       finish_wait(&memcg_oom_waitq, &owait.wait);
 
-       if (locked) {
+       if (locked)
                mem_cgroup_oom_unlock(memcg);
-               /*
-                * There is no guarantee that an OOM-lock contender
-                * sees the wakeups triggered by the OOM kill
-                * uncharges.  Wake any sleepers explicitly.
-                */
-               memcg_oom_recover(memcg);
-       }
 cleanup:
        current->memcg_in_oom = NULL;
        css_put(&memcg->css);
@@ -1982,7 +2133,7 @@ struct mem_cgroup *mem_cgroup_get_oom_group(struct task_struct *victim,
        rcu_read_lock();
 
        memcg = mem_cgroup_from_task(victim);
-       if (memcg == root_mem_cgroup)
+       if (mem_cgroup_is_root(memcg))
                goto out;
 
        /*
@@ -1999,7 +2150,7 @@ struct mem_cgroup *mem_cgroup_get_oom_group(struct task_struct *victim,
         * highest-level memory cgroup with oom.group set.
         */
        for (; memcg; memcg = parent_mem_cgroup(memcg)) {
-               if (memcg->oom_group)
+               if (READ_ONCE(memcg->oom_group))
                        oom_group = memcg;
 
                if (memcg == oom_domain)
@@ -2069,17 +2220,12 @@ again:
         * When charge migration first begins, we can have multiple
         * critical sections holding the fast-path RCU lock and one
         * holding the slowpath move_lock. Track the task who has the
-        * move_lock for unlock_page_memcg().
+        * move_lock for folio_memcg_unlock().
         */
        memcg->move_lock_task = current;
        memcg->move_lock_flags = flags;
 }
 
-void lock_page_memcg(struct page *page)
-{
-       folio_memcg_lock(page_folio(page));
-}
-
 static void __folio_memcg_unlock(struct mem_cgroup *memcg)
 {
        if (memcg && memcg->move_lock_task == current) {
@@ -2107,11 +2253,6 @@ void folio_memcg_unlock(struct folio *folio)
        __folio_memcg_unlock(folio_memcg(folio));
 }
 
-void unlock_page_memcg(struct page *page)
-{
-       folio_memcg_unlock(page_folio(page));
-}
-
 struct memcg_stock_pcp {
        local_lock_t stock_lock;
        struct mem_cgroup *cached; /* this never be root cgroup */
@@ -2178,7 +2319,7 @@ static bool consume_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
        local_lock_irqsave(&memcg_stock.stock_lock, flags);
 
        stock = this_cpu_ptr(&memcg_stock);
-       if (memcg == stock->cached && stock->nr_pages >= nr_pages) {
+       if (memcg == READ_ONCE(stock->cached) && stock->nr_pages >= nr_pages) {
                stock->nr_pages -= nr_pages;
                ret = true;
        }
@@ -2193,7 +2334,7 @@ static bool consume_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
  */
 static void drain_stock(struct memcg_stock_pcp *stock)
 {
-       struct mem_cgroup *old = stock->cached;
+       struct mem_cgroup *old = READ_ONCE(stock->cached);
 
        if (!old)
                return;
@@ -2206,7 +2347,7 @@ static void drain_stock(struct memcg_stock_pcp *stock)
        }
 
        css_put(&old->css);
-       stock->cached = NULL;
+       WRITE_ONCE(stock->cached, NULL);
 }
 
 static void drain_local_stock(struct work_struct *dummy)
@@ -2241,10 +2382,10 @@ static void __refill_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
        struct memcg_stock_pcp *stock;
 
        stock = this_cpu_ptr(&memcg_stock);
-       if (stock->cached != memcg) { /* reset if necessary */
+       if (READ_ONCE(stock->cached) != memcg) { /* reset if necessary */
                drain_stock(stock);
                css_get(&memcg->css);
-               stock->cached = memcg;
+               WRITE_ONCE(stock->cached, memcg);
        }
        stock->nr_pages += nr_pages;
 
@@ -2286,7 +2427,7 @@ static void drain_all_stock(struct mem_cgroup *root_memcg)
                bool flush = false;
 
                rcu_read_lock();
-               memcg = stock->cached;
+               memcg = READ_ONCE(stock->cached);
                if (memcg && stock->nr_pages &&
                    mem_cgroup_is_descendant(memcg, root_memcg))
                        flush = true;
@@ -2298,7 +2439,7 @@ static void drain_all_stock(struct mem_cgroup *root_memcg)
                    !test_and_set_bit(FLUSHING_CACHED_CHARGE, &stock->flags)) {
                        if (cpu == curcpu)
                                drain_local_stock(&stock->work);
-                       else
+                       else if (!cpu_is_isolated(cpu))
                                schedule_work_on(cpu, &stock->work);
                }
        }
@@ -2488,10 +2629,11 @@ static unsigned long calculate_high_delay(struct mem_cgroup *memcg,
 }
 
 /*
- * Scheduled by try_charge() to be executed from the userland return path
- * and reclaims memory over the high limit.
+ * Reclaims memory over the high limit. Called directly from
+ * try_charge() (context permitting), as well as from the userland
+ * return path where reclaim is always able to block.
  */
-void mem_cgroup_handle_over_high(void)
+void mem_cgroup_handle_over_high(gfp_t gfp_mask)
 {
        unsigned long penalty_jiffies;
        unsigned long pflags;
@@ -2508,6 +2650,17 @@ void mem_cgroup_handle_over_high(void)
        current->memcg_nr_pages_over_high = 0;
 
 retry_reclaim:
+       /*
+        * Bail if the task is already exiting. Unlike memory.max,
+        * memory.high enforcement isn't as strict, and there is no
+        * OOM killer involved, which means the excess could already
+        * be much bigger (and still growing) than it could for
+        * memory.max; the dying task could get stuck in fruitless
+        * reclaim for a long time, which isn't desirable.
+        */
+       if (task_is_dying())
+               goto out;
+
        /*
         * The allocating task should reclaim at least the batch size, but for
         * subsequent retries we only want to do what's necessary to prevent oom
@@ -2519,7 +2672,7 @@ retry_reclaim:
         */
        nr_reclaimed = reclaim_high(memcg,
                                    in_retry ? SWAP_CLUSTER_MAX : nr_pages,
-                                   GFP_KERNEL);
+                                   gfp_mask);
 
        /*
         * memory.high is breached and reclaim is unable to keep up. Throttle
@@ -2558,6 +2711,9 @@ retry_reclaim:
        }
 
        /*
+        * Reclaim didn't manage to push usage below the limit, slow
+        * this allocating task down.
+        *
         * If we exit early, we're guaranteed to die (since
         * schedule_timeout_killable sets TASK_KILLABLE). This means we don't
         * need to account for any ill-begotten jiffies to pay them off later.
@@ -2752,11 +2908,17 @@ done_restock:
                }
        } while ((memcg = parent_mem_cgroup(memcg)));
 
+       /*
+        * Reclaim is set up above to be called from the userland
+        * return path. But also attempt synchronous reclaim to avoid
+        * excessive overrun while the task is still inside the
+        * kernel. If this is successful, the return path will see it
+        * when it rechecks the overage and simply bail out.
+        */
        if (current->memcg_nr_pages_over_high > MEMCG_CHARGE_BATCH &&
            !(current->flags & PF_MEMALLOC) &&
-           gfpflags_allow_blocking(gfp_mask)) {
-               mem_cgroup_handle_over_high();
-       }
+           gfpflags_allow_blocking(gfp_mask))
+               mem_cgroup_handle_over_high(gfp_mask);
        return 0;
 }
 
@@ -2769,7 +2931,12 @@ static inline int try_charge(struct mem_cgroup *memcg, gfp_t gfp_mask,
        return try_charge_memcg(memcg, gfp_mask, nr_pages);
 }
 
-static inline void cancel_charge(struct mem_cgroup *memcg, unsigned int nr_pages)
+/**
+ * mem_cgroup_cancel_charge() - cancel an uncommitted try_charge() call.
+ * @memcg: memcg previously charged.
+ * @nr_pages: number of pages previously charged.
+ */
+void mem_cgroup_cancel_charge(struct mem_cgroup *memcg, unsigned int nr_pages)
 {
        if (mem_cgroup_is_root(memcg))
                return;
@@ -2787,19 +2954,37 @@ static void commit_charge(struct folio *folio, struct mem_cgroup *memcg)
         *
         * - the page lock
         * - LRU isolation
-        * - lock_page_memcg()
+        * - folio_memcg_lock()
         * - exclusive reference
+        * - mem_cgroup_trylock_pages()
         */
        folio->memcg_data = (unsigned long)memcg;
 }
 
+/**
+ * mem_cgroup_commit_charge - commit a previously successful try_charge().
+ * @folio: folio to commit the charge to.
+ * @memcg: memcg previously charged.
+ */
+void mem_cgroup_commit_charge(struct folio *folio, struct mem_cgroup *memcg)
+{
+       css_get(&memcg->css);
+       commit_charge(folio, memcg);
+
+       local_irq_disable();
+       mem_cgroup_charge_statistics(memcg, folio_nr_pages(folio));
+       memcg_check_events(memcg, folio_nid(folio));
+       local_irq_enable();
+}
+
 #ifdef CONFIG_MEMCG_KMEM
 /*
  * The allocated objcg pointers array is not accounted directly.
  * Moreover, it should not come from DMA buffer and is not readily
  * reclaimable. So those GFP bits should be masked off.
  */
-#define OBJCGS_CLEAR_MASK      (__GFP_DMA | __GFP_RECLAIMABLE | __GFP_ACCOUNT)
+#define OBJCGS_CLEAR_MASK      (__GFP_DMA | __GFP_RECLAIMABLE | \
+                                __GFP_ACCOUNT | __GFP_NOFAIL)
 
 /*
  * mod_objcg_mlstate() may be called with irq enabled, so
@@ -2880,13 +3065,13 @@ struct mem_cgroup *mem_cgroup_from_obj_folio(struct folio *folio, void *p)
        }
 
        /*
-        * page_memcg_check() is used here, because in theory we can encounter
+        * folio_memcg_check() is used here, because in theory we can encounter
         * a folio where the slab flag has been cleared already, but
         * slab->memcg_data has not been freed yet
-        * page_memcg_check(page) will guarantee that a proper memory
+        * folio_memcg_check() will guarantee that a proper memory
         * cgroup pointer or NULL will be returned.
         */
-       return page_memcg_check(folio_page(folio, 0));
+       return folio_memcg_check(folio);
 }
 
 /*
@@ -2940,48 +3125,125 @@ static struct obj_cgroup *__get_obj_cgroup_from_memcg(struct mem_cgroup *memcg)
 {
        struct obj_cgroup *objcg = NULL;
 
-       for (; memcg != root_mem_cgroup; memcg = parent_mem_cgroup(memcg)) {
+       for (; !mem_cgroup_is_root(memcg); memcg = parent_mem_cgroup(memcg)) {
                objcg = rcu_dereference(memcg->objcg);
-               if (objcg && obj_cgroup_tryget(objcg))
+               if (likely(objcg && obj_cgroup_tryget(objcg)))
                        break;
                objcg = NULL;
        }
        return objcg;
 }
 
-__always_inline struct obj_cgroup *get_obj_cgroup_from_current(void)
+static struct obj_cgroup *current_objcg_update(void)
 {
-       struct obj_cgroup *objcg = NULL;
        struct mem_cgroup *memcg;
+       struct obj_cgroup *old, *objcg = NULL;
 
-       if (memcg_kmem_bypass())
-               return NULL;
+       do {
+               /* Atomically drop the update bit. */
+               old = xchg(&current->objcg, NULL);
+               if (old) {
+                       old = (struct obj_cgroup *)
+                               ((unsigned long)old & ~CURRENT_OBJCG_UPDATE_FLAG);
+                       if (old)
+                               obj_cgroup_put(old);
+
+                       old = NULL;
+               }
 
-       rcu_read_lock();
-       if (unlikely(active_memcg()))
-               memcg = active_memcg();
-       else
+               /* If new objcg is NULL, no reason for the second atomic update. */
+               if (!current->mm || (current->flags & PF_KTHREAD))
+                       return NULL;
+
+               /*
+                * Release the objcg pointer from the previous iteration,
+                * if try_cmpxcg() below fails.
+                */
+               if (unlikely(objcg)) {
+                       obj_cgroup_put(objcg);
+                       objcg = NULL;
+               }
+
+               /*
+                * Obtain the new objcg pointer. The current task can be
+                * asynchronously moved to another memcg and the previous
+                * memcg can be offlined. So let's get the memcg pointer
+                * and try get a reference to objcg under a rcu read lock.
+                */
+
+               rcu_read_lock();
                memcg = mem_cgroup_from_task(current);
-       objcg = __get_obj_cgroup_from_memcg(memcg);
-       rcu_read_unlock();
+               objcg = __get_obj_cgroup_from_memcg(memcg);
+               rcu_read_unlock();
+
+               /*
+                * Try set up a new objcg pointer atomically. If it
+                * fails, it means the update flag was set concurrently, so
+                * the whole procedure should be repeated.
+                */
+       } while (!try_cmpxchg(&current->objcg, &old, objcg));
+
+       return objcg;
+}
+
+__always_inline struct obj_cgroup *current_obj_cgroup(void)
+{
+       struct mem_cgroup *memcg;
+       struct obj_cgroup *objcg;
+
+       if (in_task()) {
+               memcg = current->active_memcg;
+               if (unlikely(memcg))
+                       goto from_memcg;
+
+               objcg = READ_ONCE(current->objcg);
+               if (unlikely((unsigned long)objcg & CURRENT_OBJCG_UPDATE_FLAG))
+                       objcg = current_objcg_update();
+               /*
+                * Objcg reference is kept by the task, so it's safe
+                * to use the objcg by the current task.
+                */
+               return objcg;
+       }
+
+       memcg = this_cpu_read(int_active_memcg);
+       if (unlikely(memcg))
+               goto from_memcg;
+
+       return NULL;
+
+from_memcg:
+       objcg = NULL;
+       for (; !mem_cgroup_is_root(memcg); memcg = parent_mem_cgroup(memcg)) {
+               /*
+                * Memcg pointer is protected by scope (see set_active_memcg())
+                * and is pinning the corresponding objcg, so objcg can't go
+                * away and can be used within the scope without any additional
+                * protection.
+                */
+               objcg = rcu_dereference_check(memcg->objcg, 1);
+               if (likely(objcg))
+                       break;
+       }
+
        return objcg;
 }
 
-struct obj_cgroup *get_obj_cgroup_from_page(struct page *page)
+struct obj_cgroup *get_obj_cgroup_from_folio(struct folio *folio)
 {
        struct obj_cgroup *objcg;
 
-       if (!memcg_kmem_enabled() || memcg_kmem_bypass())
+       if (!memcg_kmem_online())
                return NULL;
 
-       if (PageMemcgKmem(page)) {
-               objcg = __folio_objcg(page_folio(page));
+       if (folio_memcg_kmem(folio)) {
+               objcg = __folio_objcg(folio);
                obj_cgroup_get(objcg);
        } else {
                struct mem_cgroup *memcg;
 
                rcu_read_lock();
-               memcg = __folio_memcg(page_folio(page));
+               memcg = __folio_memcg(folio);
                if (memcg)
                        objcg = __get_obj_cgroup_from_memcg(memcg);
                else
@@ -3061,15 +3323,15 @@ int __memcg_kmem_charge_page(struct page *page, gfp_t gfp, int order)
        struct obj_cgroup *objcg;
        int ret = 0;
 
-       objcg = get_obj_cgroup_from_current();
+       objcg = current_obj_cgroup();
        if (objcg) {
                ret = obj_cgroup_charge_pages(objcg, gfp, 1 << order);
                if (!ret) {
+                       obj_cgroup_get(objcg);
                        page->memcg_data = (unsigned long)objcg |
                                MEMCG_DATA_KMEM;
                        return 0;
                }
-               obj_cgroup_put(objcg);
        }
        return ret;
 }
@@ -3110,12 +3372,12 @@ void mod_objcg_state(struct obj_cgroup *objcg, struct pglist_data *pgdat,
         * accumulating over a page of vmstat data or when pgdat or idx
         * changes.
         */
-       if (stock->cached_objcg != objcg) {
+       if (READ_ONCE(stock->cached_objcg) != objcg) {
                old = drain_obj_stock(stock);
                obj_cgroup_get(objcg);
                stock->nr_bytes = atomic_read(&objcg->nr_charged_bytes)
                                ? atomic_xchg(&objcg->nr_charged_bytes, 0) : 0;
-               stock->cached_objcg = objcg;
+               WRITE_ONCE(stock->cached_objcg, objcg);
                stock->cached_pgdat = pgdat;
        } else if (stock->cached_pgdat != pgdat) {
                /* Flush the existing cached vmstat data */
@@ -3169,7 +3431,7 @@ static bool consume_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes)
        local_lock_irqsave(&memcg_stock.stock_lock, flags);
 
        stock = this_cpu_ptr(&memcg_stock);
-       if (objcg == stock->cached_objcg && stock->nr_bytes >= nr_bytes) {
+       if (objcg == READ_ONCE(stock->cached_objcg) && stock->nr_bytes >= nr_bytes) {
                stock->nr_bytes -= nr_bytes;
                ret = true;
        }
@@ -3181,7 +3443,7 @@ static bool consume_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes)
 
 static struct obj_cgroup *drain_obj_stock(struct memcg_stock_pcp *stock)
 {
-       struct obj_cgroup *old = stock->cached_objcg;
+       struct obj_cgroup *old = READ_ONCE(stock->cached_objcg);
 
        if (!old)
                return NULL;
@@ -3234,7 +3496,7 @@ static struct obj_cgroup *drain_obj_stock(struct memcg_stock_pcp *stock)
                stock->cached_pgdat = NULL;
        }
 
-       stock->cached_objcg = NULL;
+       WRITE_ONCE(stock->cached_objcg, NULL);
        /*
         * The `old' objects needs to be released by the caller via
         * obj_cgroup_put() outside of memcg_stock_pcp::stock_lock.
@@ -3245,10 +3507,11 @@ static struct obj_cgroup *drain_obj_stock(struct memcg_stock_pcp *stock)
 static bool obj_stock_flush_required(struct memcg_stock_pcp *stock,
                                     struct mem_cgroup *root_memcg)
 {
+       struct obj_cgroup *objcg = READ_ONCE(stock->cached_objcg);
        struct mem_cgroup *memcg;
 
-       if (stock->cached_objcg) {
-               memcg = obj_cgroup_memcg(stock->cached_objcg);
+       if (objcg) {
+               memcg = obj_cgroup_memcg(objcg);
                if (memcg && mem_cgroup_is_descendant(memcg, root_memcg))
                        return true;
        }
@@ -3267,10 +3530,10 @@ static void refill_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes,
        local_lock_irqsave(&memcg_stock.stock_lock, flags);
 
        stock = this_cpu_ptr(&memcg_stock);
-       if (stock->cached_objcg != objcg) { /* reset if necessary */
+       if (READ_ONCE(stock->cached_objcg) != objcg) { /* reset if necessary */
                old = drain_obj_stock(stock);
                obj_cgroup_get(objcg);
-               stock->cached_objcg = objcg;
+               WRITE_ONCE(stock->cached_objcg, objcg);
                stock->nr_bytes = atomic_read(&objcg->nr_charged_bytes)
                                ? atomic_xchg(&objcg->nr_charged_bytes, 0) : 0;
                allow_uncharge = true;  /* Allow uncharge when objcg changes */
@@ -3344,25 +3607,27 @@ void obj_cgroup_uncharge(struct obj_cgroup *objcg, size_t size)
 /*
  * Because page_memcg(head) is not set on tails, set it now.
  */
-void split_page_memcg(struct page *head, unsigned int nr)
+void split_page_memcg(struct page *head, int old_order, int new_order)
 {
        struct folio *folio = page_folio(head);
        struct mem_cgroup *memcg = folio_memcg(folio);
        int i;
+       unsigned int old_nr = 1 << old_order;
+       unsigned int new_nr = 1 << new_order;
 
        if (mem_cgroup_disabled() || !memcg)
                return;
 
-       for (i = 1; i < nr; i++)
+       for (i = new_nr; i < old_nr; i += new_nr)
                folio_page(folio, i)->memcg_data = folio->memcg_data;
 
        if (folio_memcg_kmem(folio))
-               obj_cgroup_get_many(__folio_objcg(folio), nr - 1);
+               obj_cgroup_get_many(__folio_objcg(folio), old_nr / new_nr - 1);
        else
-               css_get_many(&memcg->css, nr - 1);
+               css_get_many(&memcg->css, old_nr / new_nr - 1);
 }
 
-#ifdef CONFIG_MEMCG_SWAP
+#ifdef CONFIG_SWAP
 /**
  * mem_cgroup_move_swap_account - move swap charge and swap_cgroup's record.
  * @entry: swap entry to be moved
@@ -3467,6 +3732,9 @@ unsigned long mem_cgroup_soft_limit_reclaim(pg_data_t *pgdat, int order,
        struct mem_cgroup_tree_per_node *mctz;
        unsigned long excess;
 
+       if (lru_gen_enabled())
+               return 0;
+
        if (order > 0)
                return 0;
 
@@ -3597,11 +3865,14 @@ static unsigned long mem_cgroup_usage(struct mem_cgroup *memcg, bool swap)
        unsigned long val;
 
        if (mem_cgroup_is_root(memcg)) {
-               mem_cgroup_flush_stats();
-               val = memcg_page_state(memcg, NR_FILE_PAGES) +
-                       memcg_page_state(memcg, NR_ANON_MAPPED);
+               /*
+                * Approximate root's usage from global state. This isn't
+                * perfect, but the root usage was always an approximation.
+                */
+               val = global_node_page_state(NR_FILE_PAGES) +
+                       global_node_page_state(NR_ANON_MAPPED);
                if (swap)
-                       val += memcg_page_state(memcg, MEMCG_SWAP);
+                       val += total_swap_pages - get_nr_swap_pages();
        } else {
                if (!swap)
                        val = page_counter_read(&memcg->memory);
@@ -3656,12 +3927,22 @@ static u64 mem_cgroup_read_u64(struct cgroup_subsys_state *css,
        case RES_FAILCNT:
                return counter->failcnt;
        case RES_SOFT_LIMIT:
-               return (u64)memcg->soft_limit * PAGE_SIZE;
+               return (u64)READ_ONCE(memcg->soft_limit) * PAGE_SIZE;
        default:
                BUG();
        }
 }
 
+/*
+ * This function doesn't do anything useful. Its only job is to provide a read
+ * handler for a file so that cgroup_file_mode() will add read permissions.
+ */
+static int mem_cgroup_dummy_seq_show(__always_unused struct seq_file *m,
+                                    __always_unused void *v)
+{
+       return -EINVAL;
+}
+
 #ifdef CONFIG_MEMCG_KMEM
 static int memcg_online_kmem(struct mem_cgroup *memcg)
 {
@@ -3679,8 +3960,10 @@ static int memcg_online_kmem(struct mem_cgroup *memcg)
 
        objcg->memcg = memcg;
        rcu_assign_pointer(memcg->objcg, objcg);
+       obj_cgroup_get(objcg);
+       memcg->orig_objcg = objcg;
 
-       static_branch_enable(&memcg_kmem_enabled_key);
+       static_branch_enable(&memcg_kmem_online_key);
 
        memcg->kmemcg_id = memcg->id.id;
 
@@ -3786,8 +4069,11 @@ static ssize_t mem_cgroup_write(struct kernfs_open_file *of,
                        ret = mem_cgroup_resize_max(memcg, nr_pages, true);
                        break;
                case _KMEM:
-                       /* kmem.limit_in_bytes is deprecated. */
-                       ret = -EOPNOTSUPP;
+                       pr_warn_once("kmem.limit_in_bytes is deprecated and will be removed. "
+                                    "Writing any value to this file has no effect. "
+                                    "Please report your usecase to linux-mm@kvack.org if you "
+                                    "depend on this functionality.\n");
+                       ret = 0;
                        break;
                case _TCP:
                        ret = memcg_update_tcp_max(memcg, nr_pages);
@@ -3798,7 +4084,7 @@ static ssize_t mem_cgroup_write(struct kernfs_open_file *of,
                if (IS_ENABLED(CONFIG_PREEMPT_RT)) {
                        ret = -EOPNOTSUPP;
                } else {
-                       memcg->soft_limit = nr_pages;
+                       WRITE_ONCE(memcg->soft_limit, nr_pages);
                        ret = 0;
                }
                break;
@@ -3855,6 +4141,10 @@ static int mem_cgroup_move_charge_write(struct cgroup_subsys_state *css,
 {
        struct mem_cgroup *memcg = mem_cgroup_from_css(css);
 
+       pr_warn_once("Cgroup memory moving (move_charge_at_immigrate) is deprecated. "
+                    "Please report your usecase to linux-mm@kvack.org if you "
+                    "depend on this functionality.\n");
+
        if (val & ~MOVE_MASK)
                return -EINVAL;
 
@@ -3936,7 +4226,7 @@ static int memcg_numa_stat_show(struct seq_file *m, void *v)
        int nid;
        struct mem_cgroup *memcg = mem_cgroup_from_seq(m);
 
-       mem_cgroup_flush_stats();
+       mem_cgroup_flush_stats(memcg);
 
        for (stat = stats; stat < stats + ARRAY_SIZE(stats); stat++) {
                seq_printf(m, "%s=%lu", stat->name,
@@ -3975,7 +4265,12 @@ static const unsigned int memcg1_stats[] = {
        NR_FILE_MAPPED,
        NR_FILE_DIRTY,
        NR_WRITEBACK,
+       WORKINGSET_REFAULT_ANON,
+       WORKINGSET_REFAULT_FILE,
+#ifdef CONFIG_SWAP
        MEMCG_SWAP,
+       NR_SWAPCACHE,
+#endif
 };
 
 static const char *const memcg1_stat_names[] = {
@@ -3988,7 +4283,12 @@ static const char *const memcg1_stat_names[] = {
        "mapped_file",
        "dirty",
        "writeback",
+       "workingset_refault_anon",
+       "workingset_refault_file",
+#ifdef CONFIG_SWAP
        "swap",
+       "swapcached",
+#endif
 };
 
 /* Universal VM events cgroup1 shows, original sort order */
@@ -3999,34 +4299,31 @@ static const unsigned int memcg1_events[] = {
        PGMAJFAULT,
 };
 
-static int memcg_stat_show(struct seq_file *m, void *v)
+static void memcg1_stat_format(struct mem_cgroup *memcg, struct seq_buf *s)
 {
-       struct mem_cgroup *memcg = mem_cgroup_from_seq(m);
        unsigned long memory, memsw;
        struct mem_cgroup *mi;
        unsigned int i;
 
        BUILD_BUG_ON(ARRAY_SIZE(memcg1_stat_names) != ARRAY_SIZE(memcg1_stats));
 
-       mem_cgroup_flush_stats();
+       mem_cgroup_flush_stats(memcg);
 
        for (i = 0; i < ARRAY_SIZE(memcg1_stats); i++) {
                unsigned long nr;
 
-               if (memcg1_stats[i] == MEMCG_SWAP && !do_memsw_account())
-                       continue;
-               nr = memcg_page_state_local(memcg, memcg1_stats[i]);
-               seq_printf(m, "%s %lu\n", memcg1_stat_names[i], nr * PAGE_SIZE);
+               nr = memcg_page_state_local_output(memcg, memcg1_stats[i]);
+               seq_buf_printf(s, "%s %lu\n", memcg1_stat_names[i], nr);
        }
 
        for (i = 0; i < ARRAY_SIZE(memcg1_events); i++)
-               seq_printf(m, "%s %lu\n", vm_event_name(memcg1_events[i]),
-                          memcg_events_local(memcg, memcg1_events[i]));
+               seq_buf_printf(s, "%s %lu\n", vm_event_name(memcg1_events[i]),
+                              memcg_events_local(memcg, memcg1_events[i]));
 
        for (i = 0; i < NR_LRU_LISTS; i++)
-               seq_printf(m, "%s %lu\n", lru_list_name(i),
-                          memcg_page_state_local(memcg, NR_LRU_BASE + i) *
-                          PAGE_SIZE);
+               seq_buf_printf(s, "%s %lu\n", lru_list_name(i),
+                              memcg_page_state_local(memcg, NR_LRU_BASE + i) *
+                              PAGE_SIZE);
 
        /* Hierarchical information */
        memory = memsw = PAGE_COUNTER_MAX;
@@ -4034,31 +4331,28 @@ static int memcg_stat_show(struct seq_file *m, void *v)
                memory = min(memory, READ_ONCE(mi->memory.max));
                memsw = min(memsw, READ_ONCE(mi->memsw.max));
        }
-       seq_printf(m, "hierarchical_memory_limit %llu\n",
-                  (u64)memory * PAGE_SIZE);
-       if (do_memsw_account())
-               seq_printf(m, "hierarchical_memsw_limit %llu\n",
-                          (u64)memsw * PAGE_SIZE);
+       seq_buf_printf(s, "hierarchical_memory_limit %llu\n",
+                      (u64)memory * PAGE_SIZE);
+       seq_buf_printf(s, "hierarchical_memsw_limit %llu\n",
+                      (u64)memsw * PAGE_SIZE);
 
        for (i = 0; i < ARRAY_SIZE(memcg1_stats); i++) {
                unsigned long nr;
 
-               if (memcg1_stats[i] == MEMCG_SWAP && !do_memsw_account())
-                       continue;
-               nr = memcg_page_state(memcg, memcg1_stats[i]);
-               seq_printf(m, "total_%s %llu\n", memcg1_stat_names[i],
-                                               (u64)nr * PAGE_SIZE);
+               nr = memcg_page_state_output(memcg, memcg1_stats[i]);
+               seq_buf_printf(s, "total_%s %llu\n", memcg1_stat_names[i],
+                              (u64)nr);
        }
 
        for (i = 0; i < ARRAY_SIZE(memcg1_events); i++)
-               seq_printf(m, "total_%s %llu\n",
-                          vm_event_name(memcg1_events[i]),
-                          (u64)memcg_events(memcg, memcg1_events[i]));
+               seq_buf_printf(s, "total_%s %llu\n",
+                              vm_event_name(memcg1_events[i]),
+                              (u64)memcg_events(memcg, memcg1_events[i]));
 
        for (i = 0; i < NR_LRU_LISTS; i++)
-               seq_printf(m, "total_%s %llu\n", lru_list_name(i),
-                          (u64)memcg_page_state(memcg, NR_LRU_BASE + i) *
-                          PAGE_SIZE);
+               seq_buf_printf(s, "total_%s %llu\n", lru_list_name(i),
+                              (u64)memcg_page_state(memcg, NR_LRU_BASE + i) *
+                              PAGE_SIZE);
 
 #ifdef CONFIG_DEBUG_VM
        {
@@ -4073,12 +4367,10 @@ static int memcg_stat_show(struct seq_file *m, void *v)
                        anon_cost += mz->lruvec.anon_cost;
                        file_cost += mz->lruvec.file_cost;
                }
-               seq_printf(m, "anon_cost %lu\n", anon_cost);
-               seq_printf(m, "file_cost %lu\n", file_cost);
+               seq_buf_printf(s, "anon_cost %lu\n", anon_cost);
+               seq_buf_printf(s, "file_cost %lu\n", file_cost);
        }
 #endif
-
-       return 0;
 }
 
 static u64 mem_cgroup_swappiness_read(struct cgroup_subsys_state *css,
@@ -4098,9 +4390,9 @@ static int mem_cgroup_swappiness_write(struct cgroup_subsys_state *css,
                return -EINVAL;
 
        if (!mem_cgroup_is_root(memcg))
-               memcg->swappiness = val;
+               WRITE_ONCE(memcg->swappiness, val);
        else
-               vm_swappiness = val;
+               WRITE_ONCE(vm_swappiness, val);
 
        return 0;
 }
@@ -4136,7 +4428,7 @@ static void __mem_cgroup_threshold(struct mem_cgroup *memcg, bool swap)
         * only one element of the array here.
         */
        for (; i >= 0 && unlikely(t->entries[i].threshold > usage); i--)
-               eventfd_signal(t->entries[i].eventfd, 1);
+               eventfd_signal(t->entries[i].eventfd);
 
        /* i = current_threshold + 1 */
        i++;
@@ -4148,7 +4440,7 @@ static void __mem_cgroup_threshold(struct mem_cgroup *memcg, bool swap)
         * only one element of the array here.
         */
        for (; i < t->size && unlikely(t->entries[i].threshold <= usage); i++)
-               eventfd_signal(t->entries[i].eventfd, 1);
+               eventfd_signal(t->entries[i].eventfd);
 
        /* Update current_threshold */
        t->current_threshold = i - 1;
@@ -4188,7 +4480,7 @@ static int mem_cgroup_oom_notify_cb(struct mem_cgroup *memcg)
        spin_lock(&memcg_oom_lock);
 
        list_for_each_entry(ev, &memcg->oom_notify, list)
-               eventfd_signal(ev->eventfd, 1);
+               eventfd_signal(ev->eventfd);
 
        spin_unlock(&memcg_oom_lock);
        return 0;
@@ -4407,7 +4699,7 @@ static int mem_cgroup_oom_register_event(struct mem_cgroup *memcg,
 
        /* already in OOM ? */
        if (memcg->under_oom)
-               eventfd_signal(eventfd, 1);
+               eventfd_signal(eventfd);
        spin_unlock(&memcg_oom_lock);
 
        return 0;
@@ -4434,7 +4726,7 @@ static int mem_cgroup_oom_control_read(struct seq_file *sf, void *v)
 {
        struct mem_cgroup *memcg = mem_cgroup_from_seq(sf);
 
-       seq_printf(sf, "oom_kill_disable %d\n", memcg->oom_kill_disable);
+       seq_printf(sf, "oom_kill_disable %d\n", READ_ONCE(memcg->oom_kill_disable));
        seq_printf(sf, "under_oom %d\n", (bool)memcg->under_oom);
        seq_printf(sf, "oom_kill %lu\n",
                   atomic_long_read(&memcg->memory_events[MEMCG_OOM_KILL]));
@@ -4450,7 +4742,7 @@ static int mem_cgroup_oom_control_write(struct cgroup_subsys_state *css,
        if (mem_cgroup_is_root(memcg) || !((val == 0) || (val == 1)))
                return -EINVAL;
 
-       memcg->oom_kill_disable = val;
+       WRITE_ONCE(memcg->oom_kill_disable, val);
        if (!val)
                memcg_oom_recover(memcg);
 
@@ -4511,7 +4803,7 @@ void mem_cgroup_wb_stats(struct bdi_writeback *wb, unsigned long *pfilepages,
        struct mem_cgroup *memcg = mem_cgroup_from_css(wb->memcg_css);
        struct mem_cgroup *parent;
 
-       mem_cgroup_flush_stats();
+       mem_cgroup_flush_stats_ratelimited(memcg);
 
        *pdirty = memcg_page_state(memcg, NR_FILE_DIRTY);
        *pwriteback = memcg_page_state(memcg, NR_WRITEBACK);
@@ -4699,7 +4991,7 @@ static void memcg_event_remove(struct work_struct *work)
        event->unregister_event(memcg, event->eventfd);
 
        /* Notify userspace the event is going away. */
-       eventfd_signal(event->eventfd, 1);
+       eventfd_signal(event->eventfd);
 
        eventfd_ctx_put(event->eventfd);
        kfree(event);
@@ -4772,6 +5064,7 @@ static ssize_t memcg_write_event_control(struct kernfs_open_file *of,
        unsigned int efd, cfd;
        struct fd efile;
        struct fd cfile;
+       struct dentry *cdentry;
        const char *name;
        char *endp;
        int ret;
@@ -4825,6 +5118,16 @@ static ssize_t memcg_write_event_control(struct kernfs_open_file *of,
        if (ret < 0)
                goto out_put_cfile;
 
+       /*
+        * The control file must be a regular cgroup1 file. As a regular cgroup
+        * file can't be renamed, it's safe to access its name afterwards.
+        */
+       cdentry = cfile.file->f_path.dentry;
+       if (cdentry->d_sb->s_type != &cgroup_fs_type || !d_is_reg(cdentry)) {
+               ret = -EINVAL;
+               goto out_put_cfile;
+       }
+
        /*
         * Determine the event callbacks and set them in @event.  This used
         * to be done via struct cftype but cgroup core no longer knows
@@ -4833,7 +5136,7 @@ static ssize_t memcg_write_event_control(struct kernfs_open_file *of,
         *
         * DO NOT ADD NEW FILES.
         */
-       name = cfile.file->f_path.dentry->d_name.name;
+       name = cdentry->d_name.name;
 
        if (!strcmp(name, "memory.usage_in_bytes")) {
                event->register_event = mem_cgroup_usage_register_event;
@@ -4857,7 +5160,7 @@ static ssize_t memcg_write_event_control(struct kernfs_open_file *of,
         * automatically removed on cgroup destruction but the removal is
         * asynchronous, so take an extra ref on @css.
         */
-       cfile_css = css_tryget_online_from_dir(cfile.file->f_path.dentry->d_parent,
+       cfile_css = css_tryget_online_from_dir(cdentry->d_parent,
                                               &memory_cgrp_subsys);
        ret = -EINVAL;
        if (IS_ERR(cfile_css))
@@ -4896,7 +5199,7 @@ out_kfree:
        return ret;
 }
 
-#if defined(CONFIG_MEMCG_KMEM) && (defined(CONFIG_SLAB) || defined(CONFIG_SLUB_DEBUG))
+#if defined(CONFIG_MEMCG_KMEM) && defined(CONFIG_SLUB_DEBUG)
 static int mem_cgroup_slab_show(struct seq_file *m, void *p)
 {
        /*
@@ -4907,6 +5210,8 @@ static int mem_cgroup_slab_show(struct seq_file *m, void *p)
 }
 #endif
 
+static int memory_stat_show(struct seq_file *m, void *v);
+
 static struct cftype mem_cgroup_legacy_files[] = {
        {
                .name = "usage_in_bytes",
@@ -4939,7 +5244,7 @@ static struct cftype mem_cgroup_legacy_files[] = {
        },
        {
                .name = "stat",
-               .seq_show = memcg_stat_show,
+               .seq_show = memory_stat_show,
        },
        {
                .name = "force_empty",
@@ -4972,6 +5277,7 @@ static struct cftype mem_cgroup_legacy_files[] = {
        },
        {
                .name = "pressure_level",
+               .seq_show = mem_cgroup_dummy_seq_show,
        },
 #ifdef CONFIG_NUMA
        {
@@ -5002,8 +5308,7 @@ static struct cftype mem_cgroup_legacy_files[] = {
                .write = mem_cgroup_reset,
                .read_u64 = mem_cgroup_read_u64,
        },
-#if defined(CONFIG_MEMCG_KMEM) && \
-       (defined(CONFIG_SLAB) || defined(CONFIG_SLUB_DEBUG))
+#if defined(CONFIG_MEMCG_KMEM) && defined(CONFIG_SLUB_DEBUG)
        {
                .name = "kmem.slabinfo",
                .seq_show = mem_cgroup_slab_show,
@@ -5059,6 +5364,7 @@ static struct cftype mem_cgroup_legacy_files[] = {
  * those references are manageable from userspace.
  */
 
+#define MEM_CGROUP_ID_MAX      ((1UL << MEM_CGROUP_ID_SHIFT) - 1)
 static DEFINE_IDR(mem_cgroup_idr);
 
 static void mem_cgroup_id_remove(struct mem_cgroup *memcg)
@@ -5110,8 +5416,8 @@ struct mem_cgroup *mem_cgroup_get_from_ino(unsigned long ino)
        struct mem_cgroup *memcg;
 
        cgrp = cgroup_get_from_id(ino);
-       if (!cgrp)
-               return ERR_PTR(-ENOENT);
+       if (IS_ERR(cgrp))
+               return ERR_CAST(cgrp);
 
        css = cgroup_get_e_css(cgrp, &memory_cgrp_subsys);
        if (css)
@@ -5162,22 +5468,28 @@ static void __mem_cgroup_free(struct mem_cgroup *memcg)
 {
        int node;
 
+       if (memcg->orig_objcg)
+               obj_cgroup_put(memcg->orig_objcg);
+
        for_each_node(node)
                free_mem_cgroup_per_node_info(memcg, node);
+       kfree(memcg->vmstats);
        free_percpu(memcg->vmstats_percpu);
        kfree(memcg);
 }
 
 static void mem_cgroup_free(struct mem_cgroup *memcg)
 {
+       lru_gen_exit_memcg(memcg);
        memcg_wb_domain_exit(memcg);
        __mem_cgroup_free(memcg);
 }
 
-static struct mem_cgroup *mem_cgroup_alloc(void)
+static struct mem_cgroup *mem_cgroup_alloc(struct mem_cgroup *parent)
 {
+       struct memcg_vmstats_percpu *statc, *pstatc;
        struct mem_cgroup *memcg;
-       int node;
+       int node, cpu;
        int __maybe_unused i;
        long error = -ENOMEM;
 
@@ -5192,11 +5504,23 @@ static struct mem_cgroup *mem_cgroup_alloc(void)
                goto fail;
        }
 
+       memcg->vmstats = kzalloc(sizeof(struct memcg_vmstats), GFP_KERNEL);
+       if (!memcg->vmstats)
+               goto fail;
+
        memcg->vmstats_percpu = alloc_percpu_gfp(struct memcg_vmstats_percpu,
                                                 GFP_KERNEL_ACCOUNT);
        if (!memcg->vmstats_percpu)
                goto fail;
 
+       for_each_possible_cpu(cpu) {
+               if (parent)
+                       pstatc = per_cpu_ptr(parent->vmstats_percpu, cpu);
+               statc = per_cpu_ptr(memcg->vmstats_percpu, cpu);
+               statc->parent = parent ? pstatc : NULL;
+               statc->vmstats = memcg->vmstats;
+       }
+
        for_each_node(node)
                if (alloc_mem_cgroup_per_node_info(memcg, node))
                        goto fail;
@@ -5227,7 +5551,7 @@ static struct mem_cgroup *mem_cgroup_alloc(void)
        INIT_LIST_HEAD(&memcg->deferred_split_queue.split_queue);
        memcg->deferred_split_queue.split_queue_len = 0;
 #endif
-       idr_replace(&mem_cgroup_idr, memcg, memcg->id.id);
+       lru_gen_init_memcg(memcg);
        return memcg;
 fail:
        mem_cgroup_id_remove(memcg);
@@ -5242,26 +5566,29 @@ mem_cgroup_css_alloc(struct cgroup_subsys_state *parent_css)
        struct mem_cgroup *memcg, *old_memcg;
 
        old_memcg = set_active_memcg(parent);
-       memcg = mem_cgroup_alloc();
+       memcg = mem_cgroup_alloc(parent);
        set_active_memcg(old_memcg);
        if (IS_ERR(memcg))
                return ERR_CAST(memcg);
 
        page_counter_set_high(&memcg->memory, PAGE_COUNTER_MAX);
-       memcg->soft_limit = PAGE_COUNTER_MAX;
+       WRITE_ONCE(memcg->soft_limit, PAGE_COUNTER_MAX);
 #if defined(CONFIG_MEMCG_KMEM) && defined(CONFIG_ZSWAP)
        memcg->zswap_max = PAGE_COUNTER_MAX;
+       WRITE_ONCE(memcg->zswap_writeback,
+               !parent || READ_ONCE(parent->zswap_writeback));
 #endif
        page_counter_set_high(&memcg->swap, PAGE_COUNTER_MAX);
        if (parent) {
-               memcg->swappiness = mem_cgroup_swappiness(parent);
-               memcg->oom_kill_disable = parent->oom_kill_disable;
+               WRITE_ONCE(memcg->swappiness, mem_cgroup_swappiness(parent));
+               WRITE_ONCE(memcg->oom_kill_disable, READ_ONCE(parent->oom_kill_disable));
 
                page_counter_init(&memcg->memory, &parent->memory);
                page_counter_init(&memcg->swap, &parent->swap);
                page_counter_init(&memcg->kmem, &parent->kmem);
                page_counter_init(&memcg->tcpmem, &parent->tcpmem);
        } else {
+               init_memcg_events();
                page_counter_init(&memcg->memory, NULL);
                page_counter_init(&memcg->swap, NULL);
                page_counter_init(&memcg->kmem, NULL);
@@ -5274,6 +5601,11 @@ mem_cgroup_css_alloc(struct cgroup_subsys_state *parent_css)
        if (cgroup_subsys_on_dfl(memory_cgrp_subsys) && !cgroup_memory_nosocket)
                static_branch_inc(&memcg_sockets_enabled_key);
 
+#if defined(CONFIG_MEMCG_KMEM)
+       if (!cgroup_memory_nobpf)
+               static_branch_inc(&memcg_bpf_enabled_key);
+#endif
+
        return &memcg->css;
 }
 
@@ -5292,13 +5624,27 @@ static int mem_cgroup_css_online(struct cgroup_subsys_state *css)
        if (alloc_shrinker_info(memcg))
                goto offline_kmem;
 
+       if (unlikely(mem_cgroup_is_root(memcg)) && !mem_cgroup_disabled())
+               queue_delayed_work(system_unbound_wq, &stats_flush_dwork,
+                                  FLUSH_TIME);
+       lru_gen_online_memcg(memcg);
+
        /* Online state pins memcg ID, memcg ID pins CSS */
        refcount_set(&memcg->id.ref, 1);
        css_get(css);
 
-       if (unlikely(mem_cgroup_is_root(memcg)))
-               queue_delayed_work(system_unbound_wq, &stats_flush_dwork,
-                                  2UL*HZ);
+       /*
+        * Ensure mem_cgroup_from_id() works once we're fully online.
+        *
+        * We could do this earlier and require callers to filter with
+        * css_tryget_online(). But right now there are no users that
+        * need earlier access, and the workingset code relies on the
+        * cgroup tree linkage (mem_cgroup_get_nr_swap_pages()). So
+        * publish it here at the end of onlining. This matches the
+        * regular ID destruction during offlining.
+        */
+       idr_replace(&mem_cgroup_idr, memcg, memcg->id.id);
+
        return 0;
 offline_kmem:
        memcg_offline_kmem(memcg);
@@ -5327,9 +5673,12 @@ static void mem_cgroup_css_offline(struct cgroup_subsys_state *css)
        page_counter_set_min(&memcg->memory, 0);
        page_counter_set_low(&memcg->memory, 0);
 
+       zswap_memcg_offline_cleanup(memcg);
+
        memcg_offline_kmem(memcg);
        reparent_shrinker_deferred(memcg);
        wb_memcg_offline(memcg);
+       lru_gen_offline_memcg(memcg);
 
        drain_all_stock(memcg);
 
@@ -5341,6 +5690,7 @@ static void mem_cgroup_css_released(struct cgroup_subsys_state *css)
        struct mem_cgroup *memcg = mem_cgroup_from_css(css);
 
        invalidate_reclaim_iterators(memcg);
+       lru_gen_release_memcg(memcg);
 }
 
 static void mem_cgroup_css_free(struct cgroup_subsys_state *css)
@@ -5358,6 +5708,11 @@ static void mem_cgroup_css_free(struct cgroup_subsys_state *css)
        if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) && memcg->tcpmem_active)
                static_branch_dec(&memcg_sockets_enabled_key);
 
+#if defined(CONFIG_MEMCG_KMEM)
+       if (!cgroup_memory_nobpf)
+               static_branch_dec(&memcg_bpf_enabled_key);
+#endif
+
        vmpressure_cleanup(&memcg->vmpressure);
        cancel_work_sync(&memcg->high_work);
        mem_cgroup_remove_from_trees(memcg);
@@ -5389,7 +5744,7 @@ static void mem_cgroup_css_reset(struct cgroup_subsys_state *css)
        page_counter_set_min(&memcg->memory, 0);
        page_counter_set_low(&memcg->memory, 0);
        page_counter_set_high(&memcg->memory, PAGE_COUNTER_MAX);
-       memcg->soft_limit = PAGE_COUNTER_MAX;
+       WRITE_ONCE(memcg->soft_limit, PAGE_COUNTER_MAX);
        page_counter_set_high(&memcg->swap, PAGE_COUNTER_MAX);
        memcg_wb_domain_size_changed(memcg);
 }
@@ -5399,7 +5754,7 @@ static void mem_cgroup_css_rstat_flush(struct cgroup_subsys_state *css, int cpu)
        struct mem_cgroup *memcg = mem_cgroup_from_css(css);
        struct mem_cgroup *parent = parent_mem_cgroup(memcg);
        struct memcg_vmstats_percpu *statc;
-       long delta, v;
+       long delta, delta_cpu, v;
        int i, nid;
 
        statc = per_cpu_ptr(memcg->vmstats_percpu, cpu);
@@ -5410,43 +5765,51 @@ static void mem_cgroup_css_rstat_flush(struct cgroup_subsys_state *css, int cpu)
                 * below us. We're in a per-cpu loop here and this is
                 * a global counter, so the first cycle will get them.
                 */
-               delta = memcg->vmstats.state_pending[i];
+               delta = memcg->vmstats->state_pending[i];
                if (delta)
-                       memcg->vmstats.state_pending[i] = 0;
+                       memcg->vmstats->state_pending[i] = 0;
 
                /* Add CPU changes on this level since the last flush */
+               delta_cpu = 0;
                v = READ_ONCE(statc->state[i]);
                if (v != statc->state_prev[i]) {
-                       delta += v - statc->state_prev[i];
+                       delta_cpu = v - statc->state_prev[i];
+                       delta += delta_cpu;
                        statc->state_prev[i] = v;
                }
 
-               if (!delta)
-                       continue;
-
                /* Aggregate counts on this level and propagate upwards */
-               memcg->vmstats.state[i] += delta;
-               if (parent)
-                       parent->vmstats.state_pending[i] += delta;
+               if (delta_cpu)
+                       memcg->vmstats->state_local[i] += delta_cpu;
+
+               if (delta) {
+                       memcg->vmstats->state[i] += delta;
+                       if (parent)
+                               parent->vmstats->state_pending[i] += delta;
+               }
        }
 
-       for (i = 0; i < NR_VM_EVENT_ITEMS; i++) {
-               delta = memcg->vmstats.events_pending[i];
+       for (i = 0; i < NR_MEMCG_EVENTS; i++) {
+               delta = memcg->vmstats->events_pending[i];
                if (delta)
-                       memcg->vmstats.events_pending[i] = 0;
+                       memcg->vmstats->events_pending[i] = 0;
 
+               delta_cpu = 0;
                v = READ_ONCE(statc->events[i]);
                if (v != statc->events_prev[i]) {
-                       delta += v - statc->events_prev[i];
+                       delta_cpu = v - statc->events_prev[i];
+                       delta += delta_cpu;
                        statc->events_prev[i] = v;
                }
 
-               if (!delta)
-                       continue;
+               if (delta_cpu)
+                       memcg->vmstats->events_local[i] += delta_cpu;
 
-               memcg->vmstats.events[i] += delta;
-               if (parent)
-                       parent->vmstats.events_pending[i] += delta;
+               if (delta) {
+                       memcg->vmstats->events[i] += delta;
+                       if (parent)
+                               parent->vmstats->events_pending[i] += delta;
+               }
        }
 
        for_each_node_state(nid, N_MEMORY) {
@@ -5464,20 +5827,28 @@ static void mem_cgroup_css_rstat_flush(struct cgroup_subsys_state *css, int cpu)
                        if (delta)
                                pn->lruvec_stats.state_pending[i] = 0;
 
+                       delta_cpu = 0;
                        v = READ_ONCE(lstatc->state[i]);
                        if (v != lstatc->state_prev[i]) {
-                               delta += v - lstatc->state_prev[i];
+                               delta_cpu = v - lstatc->state_prev[i];
+                               delta += delta_cpu;
                                lstatc->state_prev[i] = v;
                        }
 
-                       if (!delta)
-                               continue;
+                       if (delta_cpu)
+                               pn->lruvec_stats.state_local[i] += delta_cpu;
 
-                       pn->lruvec_stats.state[i] += delta;
-                       if (ppn)
-                               ppn->lruvec_stats.state_pending[i] += delta;
+                       if (delta) {
+                               pn->lruvec_stats.state[i] += delta;
+                               if (ppn)
+                                       ppn->lruvec_stats.state_pending[i] += delta;
+                       }
                }
        }
+       statc->stats_updates = 0;
+       /* We are in a per-cpu loop here, only do the atomic write once */
+       if (atomic64_read(&memcg->vmstats->stats_updates))
+               atomic64_set(&memcg->vmstats->stats_updates, 0);
 }
 
 #ifdef CONFIG_MMU
@@ -5505,7 +5876,7 @@ static int mem_cgroup_do_precharge(unsigned long count)
 }
 
 union mc_target {
-       struct page     *page;
+       struct folio    *folio;
        swp_entry_t     ent;
 };
 
@@ -5521,7 +5892,7 @@ static struct page *mc_handle_present_pte(struct vm_area_struct *vma,
 {
        struct page *page = vm_normal_page(vma, addr, ptent);
 
-       if (!page || !page_mapped(page))
+       if (!page)
                return NULL;
        if (PageAnon(page)) {
                if (!(mc.flags & MOVE_ANON))
@@ -5530,8 +5901,7 @@ static struct page *mc_handle_present_pte(struct vm_area_struct *vma,
                if (!(mc.flags & MOVE_FILE))
                        return NULL;
        }
-       if (!get_page_unless_zero(page))
-               return NULL;
+       get_page(page);
 
        return page;
 }
@@ -5561,7 +5931,7 @@ static struct page *mc_handle_swap_pte(struct vm_area_struct *vma,
                return NULL;
 
        /*
-        * Because lookup_swap_cache() updates some statistics counter,
+        * Because swap_cache_get_folio() updates some statistics counter,
         * we call find_get_page() with swapper_space directly.
         */
        page = find_get_page(swap_address_space(ent), swp_offset(ent));
@@ -5580,55 +5950,53 @@ static struct page *mc_handle_swap_pte(struct vm_area_struct *vma,
 static struct page *mc_handle_file_pte(struct vm_area_struct *vma,
                        unsigned long addr, pte_t ptent)
 {
+       unsigned long index;
+       struct folio *folio;
+
        if (!vma->vm_file) /* anonymous vma */
                return NULL;
        if (!(mc.flags & MOVE_FILE))
                return NULL;
 
-       /* page is moved even if it's not RSS of this task(page-faulted). */
+       /* folio is moved even if it's not RSS of this task(page-faulted). */
        /* shmem/tmpfs may report page out on swap: account for that too. */
-       return find_get_incore_page(vma->vm_file->f_mapping,
-                       linear_page_index(vma, addr));
+       index = linear_page_index(vma, addr);
+       folio = filemap_get_incore_folio(vma->vm_file->f_mapping, index);
+       if (IS_ERR(folio))
+               return NULL;
+       return folio_file_page(folio, index);
 }
 
 /**
- * mem_cgroup_move_account - move account of the page
- * @page: the page
+ * mem_cgroup_move_account - move account of the folio
+ * @folio: The folio.
  * @compound: charge the page as compound or small page
- * @from: mem_cgroup which the page is moved from.
- * @to:        mem_cgroup which the page is moved to. @from != @to.
+ * @from: mem_cgroup which the folio is moved from.
+ * @to:        mem_cgroup which the folio is moved to. @from != @to.
  *
- * The caller must make sure the page is not on LRU (isolate_page() is useful.)
+ * The folio must be locked and not on the LRU.
  *
  * This function doesn't do "charge" to new cgroup and doesn't do "uncharge"
  * from old cgroup.
  */
-static int mem_cgroup_move_account(struct page *page,
+static int mem_cgroup_move_account(struct folio *folio,
                                   bool compound,
                                   struct mem_cgroup *from,
                                   struct mem_cgroup *to)
 {
-       struct folio *folio = page_folio(page);
        struct lruvec *from_vec, *to_vec;
        struct pglist_data *pgdat;
        unsigned int nr_pages = compound ? folio_nr_pages(folio) : 1;
        int nid, ret;
 
        VM_BUG_ON(from == to);
+       VM_BUG_ON_FOLIO(!folio_test_locked(folio), folio);
        VM_BUG_ON_FOLIO(folio_test_lru(folio), folio);
        VM_BUG_ON(compound && !folio_test_large(folio));
 
-       /*
-        * Prevent mem_cgroup_migrate() from looking at
-        * page's memory cgroup of its source page while we change it.
-        */
-       ret = -EBUSY;
-       if (!folio_trylock(folio))
-               goto out;
-
        ret = -EINVAL;
        if (folio_memcg(folio) != from)
-               goto out_unlock;
+               goto out;
 
        pgdat = folio_pgdat(folio);
        from_vec = mem_cgroup_lruvec(from, pgdat);
@@ -5640,7 +6008,7 @@ static int mem_cgroup_move_account(struct page *page,
                if (folio_mapped(folio)) {
                        __mod_lruvec_state(from_vec, NR_ANON_MAPPED, -nr_pages);
                        __mod_lruvec_state(to_vec, NR_ANON_MAPPED, nr_pages);
-                       if (folio_test_transhuge(folio)) {
+                       if (folio_test_pmd_mappable(folio)) {
                                __mod_lruvec_state(from_vec, NR_ANON_THPS,
                                                   -nr_pages);
                                __mod_lruvec_state(to_vec, NR_ANON_THPS,
@@ -5673,6 +6041,12 @@ static int mem_cgroup_move_account(struct page *page,
                }
        }
 
+#ifdef CONFIG_SWAP
+       if (folio_test_swapcache(folio)) {
+               __mod_lruvec_state(from_vec, NR_SWAPCACHE, -nr_pages);
+               __mod_lruvec_state(to_vec, NR_SWAPCACHE, nr_pages);
+       }
+#endif
        if (folio_test_writeback(folio)) {
                __mod_lruvec_state(from_vec, NR_WRITEBACK, -nr_pages);
                __mod_lruvec_state(to_vec, NR_WRITEBACK, nr_pages);
@@ -5686,7 +6060,7 @@ static int mem_cgroup_move_account(struct page *page,
         * with (un)charging, migration, LRU putback, or anything else
         * that would rely on a stable page's memory cgroup.
         *
-        * Note that lock_page_memcg is a memcg lock, not a page lock,
+        * Note that folio_memcg_lock is a memcg lock, not a page lock,
         * to save space. As soon as we switch page's memory cgroup to a
         * new memcg that isn't locked, the above state can change
         * concurrently again. Make sure we're truly done with it.
@@ -5709,8 +6083,6 @@ static int mem_cgroup_move_account(struct page *page,
        mem_cgroup_charge_statistics(from, -nr_pages);
        memcg_check_events(from, nid);
        local_irq_enable();
-out_unlock:
-       folio_unlock(folio);
 out:
        return ret;
 }
@@ -5722,29 +6094,25 @@ out:
  * @ptent: the pte to be checked
  * @target: the pointer the target page or swap ent will be stored(can be NULL)
  *
- * Returns
- *   0(MC_TARGET_NONE): if the pte is not a target for move charge.
- *   1(MC_TARGET_PAGE): if the page corresponding to this pte is a target for
- *     move charge. if @target is not NULL, the page is stored in target->page
- *     with extra refcnt got(Callers should handle it).
- *   2(MC_TARGET_SWAP): if the swap entry corresponding to this pte is a
- *     target for charge migration. if @target is not NULL, the entry is stored
- *     in target->ent.
- *   3(MC_TARGET_DEVICE): like MC_TARGET_PAGE  but page is device memory and
- *   thus not on the lru.
- *     For now we such page is charge like a regular page would be as for all
- *     intent and purposes it is just special memory taking the place of a
- *     regular page.
- *
- *     See Documentations/vm/hmm.txt and include/linux/hmm.h
- *
- * Called with pte lock held.
+ * Context: Called with pte lock held.
+ * Return:
+ * * MC_TARGET_NONE - If the pte is not a target for move charge.
+ * * MC_TARGET_PAGE - If the page corresponding to this pte is a target for
+ *   move charge. If @target is not NULL, the folio is stored in target->folio
+ *   with extra refcnt taken (Caller should release it).
+ * * MC_TARGET_SWAP - If the swap entry corresponding to this pte is a
+ *   target for charge migration.  If @target is not NULL, the entry is
+ *   stored in target->ent.
+ * * MC_TARGET_DEVICE - Like MC_TARGET_PAGE but page is device memory and
+ *   thus not on the lru.  For now such page is charged like a regular page
+ *   would be as it is just special memory taking the place of a regular page.
+ *   See Documentations/vm/hmm.txt and include/linux/hmm.h
  */
-
 static enum mc_target_type get_mctgt_type(struct vm_area_struct *vma,
                unsigned long addr, pte_t ptent, union mc_target *target)
 {
        struct page *page = NULL;
+       struct folio *folio;
        enum mc_target_type ret = MC_TARGET_NONE;
        swp_entry_t ent = { .val = 0 };
 
@@ -5759,6 +6127,31 @@ static enum mc_target_type get_mctgt_type(struct vm_area_struct *vma,
        else if (is_swap_pte(ptent))
                page = mc_handle_swap_pte(vma, ptent, &ent);
 
+       if (page)
+               folio = page_folio(page);
+       if (target && page) {
+               if (!folio_trylock(folio)) {
+                       folio_put(folio);
+                       return ret;
+               }
+               /*
+                * page_mapped() must be stable during the move. This
+                * pte is locked, so if it's present, the page cannot
+                * become unmapped. If it isn't, we have only partial
+                * control over the mapped state: the page lock will
+                * prevent new faults against pagecache and swapcache,
+                * so an unmapped page cannot become mapped. However,
+                * if the page is already mapped elsewhere, it can
+                * unmap, and there is nothing we can do about it.
+                * Alas, skip moving the page in this case.
+                */
+               if (!pte_present(ptent) && page_mapped(page)) {
+                       folio_unlock(folio);
+                       folio_put(folio);
+                       return ret;
+               }
+       }
+
        if (!page && !ent.val)
                return ret;
        if (page) {
@@ -5767,16 +6160,19 @@ static enum mc_target_type get_mctgt_type(struct vm_area_struct *vma,
                 * mem_cgroup_move_account() checks the page is valid or
                 * not under LRU exclusion.
                 */
-               if (page_memcg(page) == mc.from) {
+               if (folio_memcg(folio) == mc.from) {
                        ret = MC_TARGET_PAGE;
-                       if (is_device_private_page(page) ||
-                           is_device_coherent_page(page))
+                       if (folio_is_device_private(folio) ||
+                           folio_is_device_coherent(folio))
                                ret = MC_TARGET_DEVICE;
                        if (target)
-                               target->page = page;
+                               target->folio = folio;
+               }
+               if (!ret || !target) {
+                       if (target)
+                               folio_unlock(folio);
+                       folio_put(folio);
                }
-               if (!ret || !target)
-                       put_page(page);
        }
        /*
         * There is a swap entry and a page doesn't exist or isn't charged.
@@ -5801,6 +6197,7 @@ static enum mc_target_type get_mctgt_type_thp(struct vm_area_struct *vma,
                unsigned long addr, pmd_t pmd, union mc_target *target)
 {
        struct page *page = NULL;
+       struct folio *folio;
        enum mc_target_type ret = MC_TARGET_NONE;
 
        if (unlikely(is_swap_pmd(pmd))) {
@@ -5810,13 +6207,18 @@ static enum mc_target_type get_mctgt_type_thp(struct vm_area_struct *vma,
        }
        page = pmd_page(pmd);
        VM_BUG_ON_PAGE(!page || !PageHead(page), page);
+       folio = page_folio(page);
        if (!(mc.flags & MOVE_ANON))
                return ret;
-       if (page_memcg(page) == mc.from) {
+       if (folio_memcg(folio) == mc.from) {
                ret = MC_TARGET_PAGE;
                if (target) {
-                       get_page(page);
-                       target->page = page;
+                       folio_get(folio);
+                       if (!folio_trylock(folio)) {
+                               folio_put(folio);
+                               return MC_TARGET_NONE;
+                       }
+                       target->folio = folio;
                }
        }
        return ret;
@@ -5850,11 +6252,11 @@ static int mem_cgroup_count_precharge_pte_range(pmd_t *pmd,
                return 0;
        }
 
-       if (pmd_trans_unstable(pmd))
-               return 0;
        pte = pte_offset_map_lock(vma->vm_mm, pmd, addr, &ptl);
+       if (!pte)
+               return 0;
        for (; addr != end; pte++, addr += PAGE_SIZE)
-               if (get_mctgt_type(vma, addr, *pte, NULL))
+               if (get_mctgt_type(vma, addr, ptep_get(pte), NULL))
                        mc.precharge++; /* increment precharge temporarily */
        pte_unmap_unlock(pte - 1, ptl);
        cond_resched();
@@ -5864,6 +6266,7 @@ static int mem_cgroup_count_precharge_pte_range(pmd_t *pmd,
 
 static const struct mm_walk_ops precharge_walk_ops = {
        .pmd_entry      = mem_cgroup_count_precharge_pte_range,
+       .walk_lock      = PGWALK_RDLOCK,
 };
 
 static unsigned long mem_cgroup_count_precharge(struct mm_struct *mm)
@@ -5871,7 +6274,7 @@ static unsigned long mem_cgroup_count_precharge(struct mm_struct *mm)
        unsigned long precharge;
 
        mmap_read_lock(mm);
-       walk_page_range(mm, 0, mm->highest_vm_end, &precharge_walk_ops, NULL);
+       walk_page_range(mm, 0, ULONG_MAX, &precharge_walk_ops, NULL);
        mmap_read_unlock(mm);
 
        precharge = mc.precharge;
@@ -5897,7 +6300,7 @@ static void __mem_cgroup_clear_mc(void)
 
        /* we must uncharge all the leftover precharges from mc.to */
        if (mc.precharge) {
-               cancel_charge(mc.to, mc.precharge);
+               mem_cgroup_cancel_charge(mc.to, mc.precharge);
                mc.precharge = 0;
        }
        /*
@@ -5905,7 +6308,7 @@ static void __mem_cgroup_clear_mc(void)
         * we must uncharge here.
         */
        if (mc.moved_charge) {
-               cancel_charge(mc.from, mc.moved_charge);
+               mem_cgroup_cancel_charge(mc.from, mc.moved_charge);
                mc.moved_charge = 0;
        }
        /* we must fixup refcnts and charges */
@@ -6035,7 +6438,7 @@ static int mem_cgroup_move_charge_pte_range(pmd_t *pmd,
        spinlock_t *ptl;
        enum mc_target_type target_type;
        union mc_target target;
-       struct page *page;
+       struct folio *folio;
 
        ptl = pmd_trans_huge_lock(pmd, vma);
        if (ptl) {
@@ -6045,35 +6448,37 @@ static int mem_cgroup_move_charge_pte_range(pmd_t *pmd,
                }
                target_type = get_mctgt_type_thp(vma, addr, *pmd, &target);
                if (target_type == MC_TARGET_PAGE) {
-                       page = target.page;
-                       if (!isolate_lru_page(page)) {
-                               if (!mem_cgroup_move_account(page, true,
+                       folio = target.folio;
+                       if (folio_isolate_lru(folio)) {
+                               if (!mem_cgroup_move_account(folio, true,
                                                             mc.from, mc.to)) {
                                        mc.precharge -= HPAGE_PMD_NR;
                                        mc.moved_charge += HPAGE_PMD_NR;
                                }
-                               putback_lru_page(page);
+                               folio_putback_lru(folio);
                        }
-                       put_page(page);
+                       folio_unlock(folio);
+                       folio_put(folio);
                } else if (target_type == MC_TARGET_DEVICE) {
-                       page = target.page;
-                       if (!mem_cgroup_move_account(page, true,
+                       folio = target.folio;
+                       if (!mem_cgroup_move_account(folio, true,
                                                     mc.from, mc.to)) {
                                mc.precharge -= HPAGE_PMD_NR;
                                mc.moved_charge += HPAGE_PMD_NR;
                        }
-                       put_page(page);
+                       folio_unlock(folio);
+                       folio_put(folio);
                }
                spin_unlock(ptl);
                return 0;
        }
 
-       if (pmd_trans_unstable(pmd))
-               return 0;
 retry:
        pte = pte_offset_map_lock(vma->vm_mm, pmd, addr, &ptl);
+       if (!pte)
+               return 0;
        for (; addr != end; addr += PAGE_SIZE) {
-               pte_t ptent = *(pte++);
+               pte_t ptent = ptep_get(pte++);
                bool device = false;
                swp_entry_t ent;
 
@@ -6085,27 +6490,28 @@ retry:
                        device = true;
                        fallthrough;
                case MC_TARGET_PAGE:
-                       page = target.page;
+                       folio = target.folio;
                        /*
                         * We can have a part of the split pmd here. Moving it
                         * can be done but it would be too convoluted so simply
                         * ignore such a partial THP and keep it in original
                         * memcg. There should be somebody mapping the head.
                         */
-                       if (PageTransCompound(page))
+                       if (folio_test_large(folio))
                                goto put;
-                       if (!device && isolate_lru_page(page))
+                       if (!device && !folio_isolate_lru(folio))
                                goto put;
-                       if (!mem_cgroup_move_account(page, false,
+                       if (!mem_cgroup_move_account(folio, false,
                                                mc.from, mc.to)) {
                                mc.precharge--;
                                /* we uncharge from mc.from later. */
                                mc.moved_charge++;
                        }
                        if (!device)
-                               putback_lru_page(page);
-put:                   /* get_mctgt_type() gets the page */
-                       put_page(page);
+                               folio_putback_lru(folio);
+put:                   /* get_mctgt_type() gets & locks the page */
+                       folio_unlock(folio);
+                       folio_put(folio);
                        break;
                case MC_TARGET_SWAP:
                        ent = target.ent;
@@ -6140,13 +6546,14 @@ put:                    /* get_mctgt_type() gets the page */
 
 static const struct mm_walk_ops charge_walk_ops = {
        .pmd_entry      = mem_cgroup_move_charge_pte_range,
+       .walk_lock      = PGWALK_RDLOCK,
 };
 
 static void mem_cgroup_move_charge(void)
 {
        lru_add_drain_all();
        /*
-        * Signal lock_page_memcg() to take the memcg's move_lock
+        * Signal folio_memcg_lock() to take the memcg's move_lock
         * while we're moving its pages to another memcg. Then wait
         * for already started RCU-only updates to finish.
         */
@@ -6169,9 +6576,7 @@ retry:
         * When we have consumed all precharges and failed in doing
         * additional charge, the page walk just aborts.
         */
-       walk_page_range(mc.mm, 0, mc.mm->highest_vm_end, &charge_walk_ops,
-                       NULL);
-
+       walk_page_range(mc.mm, 0, ULONG_MAX, &charge_walk_ops, NULL);
        mmap_read_unlock(mc.mm);
        atomic_dec(&mc.from->moving_account);
 }
@@ -6183,6 +6588,7 @@ static void mem_cgroup_move_task(void)
                mem_cgroup_clear_mc();
        }
 }
+
 #else  /* !CONFIG_MMU */
 static int mem_cgroup_can_attach(struct cgroup_taskset *tset)
 {
@@ -6196,6 +6602,82 @@ static void mem_cgroup_move_task(void)
 }
 #endif
 
+#ifdef CONFIG_MEMCG_KMEM
+static void mem_cgroup_fork(struct task_struct *task)
+{
+       /*
+        * Set the update flag to cause task->objcg to be initialized lazily
+        * on the first allocation. It can be done without any synchronization
+        * because it's always performed on the current task, so does
+        * current_objcg_update().
+        */
+       task->objcg = (struct obj_cgroup *)CURRENT_OBJCG_UPDATE_FLAG;
+}
+
+static void mem_cgroup_exit(struct task_struct *task)
+{
+       struct obj_cgroup *objcg = task->objcg;
+
+       objcg = (struct obj_cgroup *)
+               ((unsigned long)objcg & ~CURRENT_OBJCG_UPDATE_FLAG);
+       if (objcg)
+               obj_cgroup_put(objcg);
+
+       /*
+        * Some kernel allocations can happen after this point,
+        * but let's ignore them. It can be done without any synchronization
+        * because it's always performed on the current task, so does
+        * current_objcg_update().
+        */
+       task->objcg = NULL;
+}
+#endif
+
+#ifdef CONFIG_LRU_GEN
+static void mem_cgroup_lru_gen_attach(struct cgroup_taskset *tset)
+{
+       struct task_struct *task;
+       struct cgroup_subsys_state *css;
+
+       /* find the first leader if there is any */
+       cgroup_taskset_for_each_leader(task, css, tset)
+               break;
+
+       if (!task)
+               return;
+
+       task_lock(task);
+       if (task->mm && READ_ONCE(task->mm->owner) == task)
+               lru_gen_migrate_mm(task->mm);
+       task_unlock(task);
+}
+#else
+static void mem_cgroup_lru_gen_attach(struct cgroup_taskset *tset) {}
+#endif /* CONFIG_LRU_GEN */
+
+#ifdef CONFIG_MEMCG_KMEM
+static void mem_cgroup_kmem_attach(struct cgroup_taskset *tset)
+{
+       struct task_struct *task;
+       struct cgroup_subsys_state *css;
+
+       cgroup_taskset_for_each(task, css, tset) {
+               /* atomically set the update bit */
+               set_bit(CURRENT_OBJCG_UPDATE_BIT, (unsigned long *)&task->objcg);
+       }
+}
+#else
+static void mem_cgroup_kmem_attach(struct cgroup_taskset *tset) {}
+#endif /* CONFIG_MEMCG_KMEM */
+
+#if defined(CONFIG_LRU_GEN) || defined(CONFIG_MEMCG_KMEM)
+static void mem_cgroup_attach(struct cgroup_taskset *tset)
+{
+       mem_cgroup_lru_gen_attach(tset);
+       mem_cgroup_kmem_attach(tset);
+}
+#endif
+
 static int seq_puts_memcg_tunable(struct seq_file *m, unsigned long value)
 {
        if (value == PAGE_COUNTER_MAX)
@@ -6370,6 +6852,10 @@ static ssize_t memory_max_write(struct kernfs_open_file *of,
        return nbytes;
 }
 
+/*
+ * Note: don't forget to update the 'samples/cgroup/memcg_event_listener'
+ * if any new events become available.
+ */
 static void __memory_events_show(struct seq_file *m, atomic_long_t *events)
 {
        seq_printf(m, "low %lu\n", atomic_long_read(&events[MEMCG_LOW]));
@@ -6402,10 +6888,12 @@ static int memory_stat_show(struct seq_file *m, void *v)
 {
        struct mem_cgroup *memcg = mem_cgroup_from_seq(m);
        char *buf = kmalloc(PAGE_SIZE, GFP_KERNEL);
+       struct seq_buf s;
 
        if (!buf)
                return -ENOMEM;
-       memory_stat_format(memcg, buf, PAGE_SIZE);
+       seq_buf_init(&s, buf, PAGE_SIZE);
+       memory_stat_format(memcg, &s);
        seq_puts(m, buf);
        kfree(buf);
        return 0;
@@ -6415,7 +6903,8 @@ static int memory_stat_show(struct seq_file *m, void *v)
 static inline unsigned long lruvec_page_state_output(struct lruvec *lruvec,
                                                     int item)
 {
-       return lruvec_page_state(lruvec, item) * memcg_page_state_unit(item);
+       return lruvec_page_state(lruvec, item) *
+               memcg_page_state_output_unit(item);
 }
 
 static int memory_numa_stat_show(struct seq_file *m, void *v)
@@ -6423,7 +6912,7 @@ static int memory_numa_stat_show(struct seq_file *m, void *v)
        int i;
        struct mem_cgroup *memcg = mem_cgroup_from_seq(m);
 
-       mem_cgroup_flush_stats();
+       mem_cgroup_flush_stats(memcg);
 
        for (i = 0; i < ARRAY_SIZE(memory_stats); i++) {
                int nid;
@@ -6452,7 +6941,7 @@ static int memory_oom_group_show(struct seq_file *m, void *v)
 {
        struct mem_cgroup *memcg = mem_cgroup_from_seq(m);
 
-       seq_printf(m, "%d\n", memcg->oom_group);
+       seq_printf(m, "%d\n", READ_ONCE(memcg->oom_group));
 
        return 0;
 }
@@ -6474,7 +6963,7 @@ static ssize_t memory_oom_group_write(struct kernfs_open_file *of,
        if (oom_group != 0 && oom_group != 1)
                return -EINVAL;
 
-       memcg->oom_group = oom_group;
+       WRITE_ONCE(memcg->oom_group, oom_group);
 
        return nbytes;
 }
@@ -6495,6 +6984,8 @@ static ssize_t memory_reclaim(struct kernfs_open_file *of, char *buf,
 
        reclaim_options = MEMCG_RECLAIM_MAY_SWAP | MEMCG_RECLAIM_PROACTIVE;
        while (nr_reclaimed < nr_to_reclaim) {
+               /* Will converge on zero, but reclaim enforces a minimum */
+               unsigned long batch_size = (nr_to_reclaim - nr_reclaimed) / 4;
                unsigned long reclaimed;
 
                if (signal_pending(current))
@@ -6509,8 +7000,7 @@ static ssize_t memory_reclaim(struct kernfs_open_file *of, char *buf,
                        lru_add_drain_all();
 
                reclaimed = try_to_free_mem_cgroup_pages(memcg,
-                                               nr_to_reclaim - nr_reclaimed,
-                                               GFP_KERNEL, reclaim_options);
+                                       batch_size, GFP_KERNEL, reclaim_options);
 
                if (!reclaimed && !nr_retries--)
                        return -EAGAIN;
@@ -6601,8 +7091,15 @@ struct cgroup_subsys memory_cgrp_subsys = {
        .css_reset = mem_cgroup_css_reset,
        .css_rstat_flush = mem_cgroup_css_rstat_flush,
        .can_attach = mem_cgroup_can_attach,
+#if defined(CONFIG_LRU_GEN) || defined(CONFIG_MEMCG_KMEM)
+       .attach = mem_cgroup_attach,
+#endif
        .cancel_attach = mem_cgroup_cancel_attach,
        .post_attach = mem_cgroup_move_task,
+#ifdef CONFIG_MEMCG_KMEM
+       .fork = mem_cgroup_fork,
+       .exit = mem_cgroup_exit,
+#endif
        .dfl_cftypes = memory_files,
        .legacy_cftypes = mem_cgroup_legacy_files,
        .early_init = 0,
@@ -6663,7 +7160,7 @@ static unsigned long effective_protection(unsigned long usage,
        protected = min(usage, setting);
        /*
         * If all cgroups at this level combined claim and use more
-        * protection then what the parent affords them, distribute
+        * protection than what the parent affords them, distribute
         * shares in proportion to utilization.
         *
         * We are using actual utilization rather than the statically
@@ -6782,20 +7279,13 @@ void mem_cgroup_calculate_protection(struct mem_cgroup *root,
 static int charge_memcg(struct folio *folio, struct mem_cgroup *memcg,
                        gfp_t gfp)
 {
-       long nr_pages = folio_nr_pages(folio);
        int ret;
 
-       ret = try_charge(memcg, gfp, nr_pages);
+       ret = try_charge(memcg, gfp, folio_nr_pages(folio));
        if (ret)
                goto out;
 
-       css_get(&memcg->css);
-       commit_charge(folio, memcg);
-
-       local_irq_disable();
-       mem_cgroup_charge_statistics(memcg, nr_pages);
-       memcg_check_events(memcg, folio_nid(folio));
-       local_irq_enable();
+       mem_cgroup_commit_charge(folio, memcg);
 out:
        return ret;
 }
@@ -6813,21 +7303,55 @@ int __mem_cgroup_charge(struct folio *folio, struct mm_struct *mm, gfp_t gfp)
 }
 
 /**
- * mem_cgroup_swapin_charge_page - charge a newly allocated page for swapin
- * @page: page to charge
+ * mem_cgroup_hugetlb_try_charge - try to charge the memcg for a hugetlb folio
+ * @memcg: memcg to charge.
+ * @gfp: reclaim mode.
+ * @nr_pages: number of pages to charge.
+ *
+ * This function is called when allocating a huge page folio to determine if
+ * the memcg has the capacity for it. It does not commit the charge yet,
+ * as the hugetlb folio itself has not been obtained from the hugetlb pool.
+ *
+ * Once we have obtained the hugetlb folio, we can call
+ * mem_cgroup_commit_charge() to commit the charge. If we fail to obtain the
+ * folio, we should instead call mem_cgroup_cancel_charge() to undo the effect
+ * of try_charge().
+ *
+ * Returns 0 on success. Otherwise, an error code is returned.
+ */
+int mem_cgroup_hugetlb_try_charge(struct mem_cgroup *memcg, gfp_t gfp,
+                       long nr_pages)
+{
+       /*
+        * If hugetlb memcg charging is not enabled, do not fail hugetlb allocation,
+        * but do not attempt to commit charge later (or cancel on error) either.
+        */
+       if (mem_cgroup_disabled() || !memcg ||
+               !cgroup_subsys_on_dfl(memory_cgrp_subsys) ||
+               !(cgrp_dfl_root.flags & CGRP_ROOT_MEMORY_HUGETLB_ACCOUNTING))
+               return -EOPNOTSUPP;
+
+       if (try_charge(memcg, gfp, nr_pages))
+               return -ENOMEM;
+
+       return 0;
+}
+
+/**
+ * mem_cgroup_swapin_charge_folio - Charge a newly allocated folio for swapin.
+ * @folio: folio to charge.
  * @mm: mm context of the victim
  * @gfp: reclaim mode
- * @entry: swap entry for which the page is allocated
+ * @entry: swap entry for which the folio is allocated
  *
- * This function charges a page allocated for swapin. Please call this before
- * adding the page to the swapcache.
+ * This function charges a folio allocated for swapin. Please call this before
+ * adding the folio to the swapcache.
  *
  * Returns 0 on success. Otherwise, an error code is returned.
  */
-int mem_cgroup_swapin_charge_page(struct page *page, struct mm_struct *mm,
+int mem_cgroup_swapin_charge_folio(struct folio *folio, struct mm_struct *mm,
                                  gfp_t gfp, swp_entry_t entry)
 {
-       struct folio *folio = page_folio(page);
        struct mem_cgroup *memcg;
        unsigned short id;
        int ret;
@@ -6989,36 +7513,30 @@ void __mem_cgroup_uncharge(struct folio *folio)
        uncharge_batch(&ug);
 }
 
-/**
- * __mem_cgroup_uncharge_list - uncharge a list of page
- * @page_list: list of pages to uncharge
- *
- * Uncharge a list of pages previously charged with
- * __mem_cgroup_charge().
- */
-void __mem_cgroup_uncharge_list(struct list_head *page_list)
+void __mem_cgroup_uncharge_folios(struct folio_batch *folios)
 {
        struct uncharge_gather ug;
-       struct folio *folio;
+       unsigned int i;
 
        uncharge_gather_clear(&ug);
-       list_for_each_entry(folio, page_list, lru)
-               uncharge_folio(folio, &ug);
+       for (i = 0; i < folios->nr; i++)
+               uncharge_folio(folios->folios[i], &ug);
        if (ug.memcg)
                uncharge_batch(&ug);
 }
 
 /**
- * mem_cgroup_migrate - Charge a folio's replacement.
+ * mem_cgroup_replace_folio - Charge a folio's replacement.
  * @old: Currently circulating folio.
  * @new: Replacement folio.
  *
  * Charge @new as a replacement folio for @old. @old will
- * be uncharged upon free.
+ * be uncharged upon free. This is only used by the page cache
+ * (in replace_page_cache_folio()).
  *
  * Both folios must be locked, @new->mapping must be set up.
  */
-void mem_cgroup_migrate(struct folio *old, struct folio *new)
+void mem_cgroup_replace_folio(struct folio *old, struct folio *new)
 {
        struct mem_cgroup *memcg;
        long nr_pages = folio_nr_pages(new);
@@ -7057,6 +7575,55 @@ void mem_cgroup_migrate(struct folio *old, struct folio *new)
        local_irq_restore(flags);
 }
 
+/**
+ * mem_cgroup_migrate - Transfer the memcg data from the old to the new folio.
+ * @old: Currently circulating folio.
+ * @new: Replacement folio.
+ *
+ * Transfer the memcg data from the old folio to the new folio for migration.
+ * The old folio's data info will be cleared. Note that the memory counters
+ * will remain unchanged throughout the process.
+ *
+ * Both folios must be locked, @new->mapping must be set up.
+ */
+void mem_cgroup_migrate(struct folio *old, struct folio *new)
+{
+       struct mem_cgroup *memcg;
+
+       VM_BUG_ON_FOLIO(!folio_test_locked(old), old);
+       VM_BUG_ON_FOLIO(!folio_test_locked(new), new);
+       VM_BUG_ON_FOLIO(folio_test_anon(old) != folio_test_anon(new), new);
+       VM_BUG_ON_FOLIO(folio_nr_pages(old) != folio_nr_pages(new), new);
+
+       if (mem_cgroup_disabled())
+               return;
+
+       memcg = folio_memcg(old);
+       /*
+        * Note that it is normal to see !memcg for a hugetlb folio.
+        * For e.g, itt could have been allocated when memory_hugetlb_accounting
+        * was not selected.
+        */
+       VM_WARN_ON_ONCE_FOLIO(!folio_test_hugetlb(old) && !memcg, old);
+       if (!memcg)
+               return;
+
+       /* Transfer the charge and the css ref */
+       commit_charge(new, memcg);
+       /*
+        * If the old folio is a large folio and is in the split queue, it needs
+        * to be removed from the split queue now, in case getting an incorrect
+        * split queue in destroy_large_folio() after the memcg of the old folio
+        * is cleared.
+        *
+        * In addition, the old folio is about to be freed after migration, so
+        * removing from the split queue a bit earlier seems reasonable.
+        */
+       if (folio_test_large(old) && folio_test_large_rmappable(old))
+               folio_undo_large_rmappable(old);
+       old->memcg_data = 0;
+}
+
 DEFINE_STATIC_KEY_FALSE(memcg_sockets_enabled_key);
 EXPORT_SYMBOL(memcg_sockets_enabled_key);
 
@@ -7073,7 +7640,7 @@ void mem_cgroup_sk_alloc(struct sock *sk)
 
        rcu_read_lock();
        memcg = mem_cgroup_from_task(current);
-       if (memcg == root_mem_cgroup)
+       if (mem_cgroup_is_root(memcg))
                goto out;
        if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) && !memcg->tcpmem_active)
                goto out;
@@ -7152,6 +7719,8 @@ static int __init cgroup_memory(char *s)
                        cgroup_memory_nosocket = true;
                if (!strcmp(token, "nokmem"))
                        cgroup_memory_nokmem = true;
+               if (!strcmp(token, "nobpf"))
+                       cgroup_memory_nobpf = true;
        }
        return 1;
 }
@@ -7187,8 +7756,7 @@ static int __init mem_cgroup_init(void)
        for_each_node(node) {
                struct mem_cgroup_tree_per_node *rtpn;
 
-               rtpn = kzalloc_node(sizeof(*rtpn), GFP_KERNEL,
-                                   node_online(node) ? node : NUMA_NO_NODE);
+               rtpn = kzalloc_node(sizeof(*rtpn), GFP_KERNEL, node);
 
                rtpn->rb_root = RB_ROOT;
                rtpn->rb_rightmost = NULL;
@@ -7200,7 +7768,7 @@ static int __init mem_cgroup_init(void)
 }
 subsys_initcall(mem_cgroup_init);
 
-#ifdef CONFIG_MEMCG_SWAP
+#ifdef CONFIG_SWAP
 static struct mem_cgroup *mem_cgroup_id_get_online(struct mem_cgroup *memcg)
 {
        while (!refcount_inc_not_zero(&memcg->id.ref)) {
@@ -7208,7 +7776,7 @@ static struct mem_cgroup *mem_cgroup_id_get_online(struct mem_cgroup *memcg)
                 * The root cgroup cannot be destroyed, so it's refcount must
                 * always be >= 1.
                 */
-               if (WARN_ON_ONCE(memcg == root_mem_cgroup)) {
+               if (WARN_ON_ONCE(mem_cgroup_is_root(memcg))) {
                        VM_BUG_ON(1);
                        break;
                }
@@ -7238,7 +7806,7 @@ void mem_cgroup_swapout(struct folio *folio, swp_entry_t entry)
        if (mem_cgroup_disabled())
                return;
 
-       if (cgroup_subsys_on_dfl(memory_cgrp_subsys))
+       if (!do_memsw_account())
                return;
 
        memcg = folio_memcg(folio);
@@ -7267,7 +7835,7 @@ void mem_cgroup_swapout(struct folio *folio, swp_entry_t entry)
        if (!mem_cgroup_is_root(memcg))
                page_counter_uncharge(&memcg->memory, nr_entries);
 
-       if (!cgroup_memory_noswap && memcg != swap_memcg) {
+       if (memcg != swap_memcg) {
                if (!mem_cgroup_is_root(swap_memcg))
                        page_counter_charge(&swap_memcg->memsw, nr_entries);
                page_counter_uncharge(&memcg->memsw, nr_entries);
@@ -7303,7 +7871,7 @@ int __mem_cgroup_try_charge_swap(struct folio *folio, swp_entry_t entry)
        struct mem_cgroup *memcg;
        unsigned short oldid;
 
-       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
+       if (do_memsw_account())
                return 0;
 
        memcg = folio_memcg(folio);
@@ -7319,7 +7887,7 @@ int __mem_cgroup_try_charge_swap(struct folio *folio, swp_entry_t entry)
 
        memcg = mem_cgroup_id_get_online(memcg);
 
-       if (!cgroup_memory_noswap && !mem_cgroup_is_root(memcg) &&
+       if (!mem_cgroup_is_root(memcg) &&
            !page_counter_try_charge(&memcg->swap, nr_pages, &counter)) {
                memcg_memory_event(memcg, MEMCG_SWAP_MAX);
                memcg_memory_event(memcg, MEMCG_SWAP_FAIL);
@@ -7351,11 +7919,11 @@ void __mem_cgroup_uncharge_swap(swp_entry_t entry, unsigned int nr_pages)
        rcu_read_lock();
        memcg = mem_cgroup_from_id(id);
        if (memcg) {
-               if (!cgroup_memory_noswap && !mem_cgroup_is_root(memcg)) {
-                       if (cgroup_subsys_on_dfl(memory_cgrp_subsys))
-                               page_counter_uncharge(&memcg->swap, nr_pages);
-                       else
+               if (!mem_cgroup_is_root(memcg)) {
+                       if (do_memsw_account())
                                page_counter_uncharge(&memcg->memsw, nr_pages);
+                       else
+                               page_counter_uncharge(&memcg->swap, nr_pages);
                }
                mod_memcg_state(memcg, MEMCG_SWAP, -nr_pages);
                mem_cgroup_id_put_many(memcg, nr_pages);
@@ -7367,31 +7935,31 @@ long mem_cgroup_get_nr_swap_pages(struct mem_cgroup *memcg)
 {
        long nr_swap_pages = get_nr_swap_pages();
 
-       if (cgroup_memory_noswap || !cgroup_subsys_on_dfl(memory_cgrp_subsys))
+       if (mem_cgroup_disabled() || do_memsw_account())
                return nr_swap_pages;
-       for (; memcg != root_mem_cgroup; memcg = parent_mem_cgroup(memcg))
+       for (; !mem_cgroup_is_root(memcg); memcg = parent_mem_cgroup(memcg))
                nr_swap_pages = min_t(long, nr_swap_pages,
                                      READ_ONCE(memcg->swap.max) -
                                      page_counter_read(&memcg->swap));
        return nr_swap_pages;
 }
 
-bool mem_cgroup_swap_full(struct page *page)
+bool mem_cgroup_swap_full(struct folio *folio)
 {
        struct mem_cgroup *memcg;
 
-       VM_BUG_ON_PAGE(!PageLocked(page), page);
+       VM_BUG_ON_FOLIO(!folio_test_locked(folio), folio);
 
        if (vm_swap_full())
                return true;
-       if (cgroup_memory_noswap || !cgroup_subsys_on_dfl(memory_cgrp_subsys))
+       if (do_memsw_account())
                return false;
 
-       memcg = page_memcg(page);
+       memcg = folio_memcg(folio);
        if (!memcg)
                return false;
 
-       for (; memcg != root_mem_cgroup; memcg = parent_mem_cgroup(memcg)) {
+       for (; !mem_cgroup_is_root(memcg); memcg = parent_mem_cgroup(memcg)) {
                unsigned long usage = page_counter_read(&memcg->swap);
 
                if (usage * 2 >= READ_ONCE(memcg->swap.high) ||
@@ -7404,10 +7972,13 @@ bool mem_cgroup_swap_full(struct page *page)
 
 static int __init setup_swap_account(char *s)
 {
-       if (!strcmp(s, "1"))
-               cgroup_memory_noswap = false;
-       else if (!strcmp(s, "0"))
-               cgroup_memory_noswap = true;
+       bool res;
+
+       if (!kstrtobool(s, &res) && !res)
+               pr_warn_once("The swapaccount=0 commandline option is deprecated "
+                            "in favor of configuring swap control via cgroupfs. "
+                            "Please report your usecase to linux-mm@kvack.org if you "
+                            "depend on this functionality.\n");
        return 1;
 }
 __setup("swapaccount=", setup_swap_account);
@@ -7420,6 +7991,14 @@ static u64 swap_current_read(struct cgroup_subsys_state *css,
        return (u64)page_counter_read(&memcg->swap) * PAGE_SIZE;
 }
 
+static u64 swap_peak_read(struct cgroup_subsys_state *css,
+                         struct cftype *cft)
+{
+       struct mem_cgroup *memcg = mem_cgroup_from_css(css);
+
+       return (u64)memcg->swap.watermark * PAGE_SIZE;
+}
+
 static int swap_high_show(struct seq_file *m, void *v)
 {
        return seq_puts_memcg_tunable(m,
@@ -7498,6 +8077,11 @@ static struct cftype swap_files[] = {
                .seq_show = swap_max_show,
                .write = swap_max_write,
        },
+       {
+               .name = "swap.peak",
+               .flags = CFTYPE_NOT_ON_ROOT,
+               .read_u64 = swap_peak_read,
+       },
        {
                .name = "swap.events",
                .flags = CFTYPE_NOT_ON_ROOT,
@@ -7543,7 +8127,7 @@ static struct cftype memsw_files[] = {
  *
  * This doesn't check for specific headroom, and it is not atomic
  * either. But with zswap, the size of the allocation is only known
- * once compression has occured, and this optimistic pre-check avoids
+ * once compression has occurred, and this optimistic pre-check avoids
  * spending cycles on compression when there is already no room left
  * or zswap is disabled altogether somewhere in the hierarchy.
  */
@@ -7556,7 +8140,7 @@ bool obj_cgroup_may_zswap(struct obj_cgroup *objcg)
                return true;
 
        original_memcg = get_mem_cgroup_from_objcg(objcg);
-       for (memcg = original_memcg; memcg != root_mem_cgroup;
+       for (memcg = original_memcg; !mem_cgroup_is_root(memcg);
             memcg = parent_mem_cgroup(memcg)) {
                unsigned long max = READ_ONCE(memcg->zswap_max);
                unsigned long pages;
@@ -7568,7 +8152,11 @@ bool obj_cgroup_may_zswap(struct obj_cgroup *objcg)
                        break;
                }
 
-               cgroup_rstat_flush(memcg->css.cgroup);
+               /*
+                * mem_cgroup_flush_stats() ignores small changes. Use
+                * do_flush_stats() directly to get accurate stats for charging.
+                */
+               do_flush_stats(memcg);
                pages = memcg_page_state(memcg, MEMCG_ZSWAP_B) / PAGE_SIZE;
                if (pages < max)
                        continue;
@@ -7584,7 +8172,7 @@ bool obj_cgroup_may_zswap(struct obj_cgroup *objcg)
  * @objcg: the object cgroup
  * @size: size of compressed object
  *
- * This forces the charge after obj_cgroup_may_swap() allowed
+ * This forces the charge after obj_cgroup_may_zswap() allowed
  * compression and storage in zwap for this cgroup to go ahead.
  */
 void obj_cgroup_charge_zswap(struct obj_cgroup *objcg, size_t size)
@@ -7630,11 +8218,19 @@ void obj_cgroup_uncharge_zswap(struct obj_cgroup *objcg, size_t size)
        rcu_read_unlock();
 }
 
+bool mem_cgroup_zswap_writeback_enabled(struct mem_cgroup *memcg)
+{
+       /* if zswap is disabled, do not block pages going to the swapping device */
+       return !is_zswap_enabled() || !memcg || READ_ONCE(memcg->zswap_writeback);
+}
+
 static u64 zswap_current_read(struct cgroup_subsys_state *css,
                              struct cftype *cft)
 {
-       cgroup_rstat_flush(css->cgroup);
-       return memcg_page_state(mem_cgroup_from_css(css), MEMCG_ZSWAP_B);
+       struct mem_cgroup *memcg = mem_cgroup_from_css(css);
+
+       mem_cgroup_flush_stats(memcg);
+       return memcg_page_state(memcg, MEMCG_ZSWAP_B);
 }
 
 static int zswap_max_show(struct seq_file *m, void *v)
@@ -7660,6 +8256,31 @@ static ssize_t zswap_max_write(struct kernfs_open_file *of,
        return nbytes;
 }
 
+static int zswap_writeback_show(struct seq_file *m, void *v)
+{
+       struct mem_cgroup *memcg = mem_cgroup_from_seq(m);
+
+       seq_printf(m, "%d\n", READ_ONCE(memcg->zswap_writeback));
+       return 0;
+}
+
+static ssize_t zswap_writeback_write(struct kernfs_open_file *of,
+                               char *buf, size_t nbytes, loff_t off)
+{
+       struct mem_cgroup *memcg = mem_cgroup_from_css(of_css(of));
+       int zswap_writeback;
+       ssize_t parse_ret = kstrtoint(strstrip(buf), 0, &zswap_writeback);
+
+       if (parse_ret)
+               return parse_ret;
+
+       if (zswap_writeback != 0 && zswap_writeback != 1)
+               return -EINVAL;
+
+       WRITE_ONCE(memcg->zswap_writeback, zswap_writeback);
+       return nbytes;
+}
+
 static struct cftype zswap_files[] = {
        {
                .name = "zswap.current",
@@ -7672,24 +8293,18 @@ static struct cftype zswap_files[] = {
                .seq_show = zswap_max_show,
                .write = zswap_max_write,
        },
+       {
+               .name = "zswap.writeback",
+               .seq_show = zswap_writeback_show,
+               .write = zswap_writeback_write,
+       },
        { }     /* terminate */
 };
 #endif /* CONFIG_MEMCG_KMEM && CONFIG_ZSWAP */
 
-/*
- * If mem_cgroup_swap_init() is implemented as a subsys_initcall()
- * instead of a core_initcall(), this could mean cgroup_memory_noswap still
- * remains set to false even when memcg is disabled via "cgroup_disable=memory"
- * boot parameter. This may result in premature OOPS inside
- * mem_cgroup_get_nr_swap_pages() function in corner cases.
- */
 static int __init mem_cgroup_swap_init(void)
 {
-       /* No memory control -> no swap control */
        if (mem_cgroup_disabled())
-               cgroup_memory_noswap = true;
-
-       if (cgroup_memory_noswap)
                return 0;
 
        WARN_ON(cgroup_add_dfl_cftypes(&memory_cgrp_subsys, swap_files));
@@ -7699,6 +8314,6 @@ static int __init mem_cgroup_swap_init(void)
 #endif
        return 0;
 }
-core_initcall(mem_cgroup_swap_init);
+subsys_initcall(mem_cgroup_swap_init);
 
-#endif /* CONFIG_MEMCG_SWAP */
+#endif /* CONFIG_SWAP */