mm: memcontrol: do not miss MEMCG_MAX events for enforced allocations
[linux-2.6-microblaze.git] / mm / memcontrol.c
index 598fece..767f49a 100644 (file)
@@ -67,6 +67,7 @@
 #include <net/sock.h>
 #include <net/ip.h>
 #include "slab.h"
+#include "swap.h"
 
 #include <linux/uaccess.h>
 
@@ -89,7 +90,7 @@ static bool cgroup_memory_nokmem __ro_after_init;
 
 /* Whether the swap controller is active */
 #ifdef CONFIG_MEMCG_SWAP
-bool cgroup_memory_noswap __ro_after_init;
+static bool cgroup_memory_noswap __ro_after_init;
 #else
 #define cgroup_memory_noswap           1
 #endif
@@ -209,7 +210,6 @@ static struct move_charge_struct {
 enum res_type {
        _MEM,
        _MEMSWAP,
-       _OOM_TYPE,
        _KMEM,
        _TCP,
 };
@@ -217,8 +217,6 @@ enum res_type {
 #define MEMFILE_PRIVATE(x, val)        ((x) << 16 | (val))
 #define MEMFILE_TYPE(val)      ((val) >> 16 & 0xffff)
 #define MEMFILE_ATTR(val)      ((val) & 0xffff)
-/* Used for OOM notifier */
-#define OOM_CONTROL            (0)
 
 /*
  * Iteration constructs for visiting all cgroups (under a tree).  If
@@ -785,7 +783,7 @@ void __mod_lruvec_kmem_state(void *p, enum node_stat_item idx, int val)
        struct lruvec *lruvec;
 
        rcu_read_lock();
-       memcg = mem_cgroup_from_obj(p);
+       memcg = mem_cgroup_from_slab_obj(p);
 
        /*
         * Untracked pages have no memcg, no lruvec. Update only the
@@ -1013,9 +1011,6 @@ struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
        if (!root)
                root = root_mem_cgroup;
 
-       if (prev && !reclaim)
-               pos = prev;
-
        rcu_read_lock();
 
        if (reclaim) {
@@ -1024,7 +1019,13 @@ struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
                mz = root->nodeinfo[reclaim->pgdat->node_id];
                iter = &mz->iter;
 
-               if (prev && reclaim->generation != iter->generation)
+               /*
+                * On start, join the current reclaim iteration cycle.
+                * Exit when a concurrent walker completes it.
+                */
+               if (!prev)
+                       reclaim->generation = iter->generation;
+               else if (reclaim->generation != iter->generation)
                        goto out_unlock;
 
                while (1) {
@@ -1041,6 +1042,8 @@ struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
                         */
                        (void)cmpxchg(&iter->position, pos, NULL);
                }
+       } else if (prev) {
+               pos = prev;
        }
 
        if (pos)
@@ -1065,15 +1068,10 @@ struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
                 * is provided by the caller, so we know it's alive
                 * and kicking, and don't take an extra reference.
                 */
-               memcg = mem_cgroup_from_css(css);
-
-               if (css == &root->css)
+               if (css == &root->css || css_tryget(css)) {
+                       memcg = mem_cgroup_from_css(css);
                        break;
-
-               if (css_tryget(css))
-                       break;
-
-               memcg = NULL;
+               }
        }
 
        if (reclaim) {
@@ -1089,8 +1087,6 @@ struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
 
                if (!memcg)
                        iter->generation++;
-               else if (!prev)
-                       reclaim->generation = iter->generation;
        }
 
 out_unlock:
@@ -1402,6 +1398,10 @@ static const struct memory_stat memory_stats[] = {
        { "sock",                       MEMCG_SOCK                      },
        { "vmalloc",                    MEMCG_VMALLOC                   },
        { "shmem",                      NR_SHMEM                        },
+#if defined(CONFIG_MEMCG_KMEM) && defined(CONFIG_ZSWAP)
+       { "zswap",                      MEMCG_ZSWAP_B                   },
+       { "zswapped",                   MEMCG_ZSWAPPED                  },
+#endif
        { "file_mapped",                NR_FILE_MAPPED                  },
        { "file_dirty",                 NR_FILE_DIRTY                   },
        { "file_writeback",             NR_WRITEBACK                    },
@@ -1436,6 +1436,7 @@ static int memcg_page_state_unit(int item)
 {
        switch (item) {
        case MEMCG_PERCPU_B:
+       case MEMCG_ZSWAP_B:
        case NR_SLAB_RECLAIMABLE_B:
        case NR_SLAB_UNRECLAIMABLE_B:
        case WORKINGSET_REFAULT_ANON:
@@ -1459,6 +1460,29 @@ static inline unsigned long memcg_page_state_output(struct mem_cgroup *memcg,
        return memcg_page_state(memcg, item) * memcg_page_state_unit(item);
 }
 
+/* Subset of vm_event_item to report for memcg event stats */
+static const unsigned int memcg_vm_event_stat[] = {
+       PGSCAN_KSWAPD,
+       PGSCAN_DIRECT,
+       PGSTEAL_KSWAPD,
+       PGSTEAL_DIRECT,
+       PGFAULT,
+       PGMAJFAULT,
+       PGREFILL,
+       PGACTIVATE,
+       PGDEACTIVATE,
+       PGLAZYFREE,
+       PGLAZYFREED,
+#if defined(CONFIG_MEMCG_KMEM) && defined(CONFIG_ZSWAP)
+       ZSWPIN,
+       ZSWPOUT,
+#endif
+#ifdef CONFIG_TRANSPARENT_HUGEPAGE
+       THP_FAULT_ALLOC,
+       THP_COLLAPSE_ALLOC,
+#endif
+};
+
 static char *memory_stat_format(struct mem_cgroup *memcg)
 {
        struct seq_buf s;
@@ -1494,34 +1518,17 @@ static char *memory_stat_format(struct mem_cgroup *memcg)
        }
 
        /* Accumulated memory events */
-
-       seq_buf_printf(&s, "%s %lu\n", vm_event_name(PGFAULT),
-                      memcg_events(memcg, PGFAULT));
-       seq_buf_printf(&s, "%s %lu\n", vm_event_name(PGMAJFAULT),
-                      memcg_events(memcg, PGMAJFAULT));
-       seq_buf_printf(&s, "%s %lu\n",  vm_event_name(PGREFILL),
-                      memcg_events(memcg, PGREFILL));
        seq_buf_printf(&s, "pgscan %lu\n",
                       memcg_events(memcg, PGSCAN_KSWAPD) +
                       memcg_events(memcg, PGSCAN_DIRECT));
        seq_buf_printf(&s, "pgsteal %lu\n",
                       memcg_events(memcg, PGSTEAL_KSWAPD) +
                       memcg_events(memcg, PGSTEAL_DIRECT));
-       seq_buf_printf(&s, "%s %lu\n", vm_event_name(PGACTIVATE),
-                      memcg_events(memcg, PGACTIVATE));
-       seq_buf_printf(&s, "%s %lu\n", vm_event_name(PGDEACTIVATE),
-                      memcg_events(memcg, PGDEACTIVATE));
-       seq_buf_printf(&s, "%s %lu\n", vm_event_name(PGLAZYFREE),
-                      memcg_events(memcg, PGLAZYFREE));
-       seq_buf_printf(&s, "%s %lu\n", vm_event_name(PGLAZYFREED),
-                      memcg_events(memcg, PGLAZYFREED));
 
-#ifdef CONFIG_TRANSPARENT_HUGEPAGE
-       seq_buf_printf(&s, "%s %lu\n", vm_event_name(THP_FAULT_ALLOC),
-                      memcg_events(memcg, THP_FAULT_ALLOC));
-       seq_buf_printf(&s, "%s %lu\n", vm_event_name(THP_COLLAPSE_ALLOC),
-                      memcg_events(memcg, THP_COLLAPSE_ALLOC));
-#endif /* CONFIG_TRANSPARENT_HUGEPAGE */
+       for (i = 0; i < ARRAY_SIZE(memcg_vm_event_stat); i++)
+               seq_buf_printf(&s, "%s %lu\n",
+                              vm_event_name(memcg_vm_event_stat[i]),
+                              memcg_events(memcg, memcg_vm_event_stat[i]));
 
        /* The above should easily fit into one page */
        WARN_ON_ONCE(seq_buf_has_overflowed(&s));
@@ -2570,6 +2577,7 @@ static int try_charge_memcg(struct mem_cgroup *memcg, gfp_t gfp_mask,
        bool passed_oom = false;
        bool may_swap = true;
        bool drained = false;
+       bool raised_max_event = false;
        unsigned long pflags;
 
 retry:
@@ -2609,6 +2617,7 @@ retry:
                goto nomem;
 
        memcg_memory_event(mem_over_limit, MEMCG_MAX);
+       raised_max_event = true;
 
        psi_memstall_enter(&pflags);
        nr_reclaimed = try_to_free_mem_cgroup_pages(mem_over_limit, nr_pages,
@@ -2675,6 +2684,13 @@ nomem:
        if (!(gfp_mask & (__GFP_NOFAIL | __GFP_HIGH)))
                return -ENOMEM;
 force:
+       /*
+        * If the allocation has to be enforced, don't forget to raise
+        * a MEMCG_MAX event.
+        */
+       if (!raised_max_event)
+               memcg_memory_event(mem_over_limit, MEMCG_MAX);
+
        /*
         * The allocation either can't fail or will lead to more memory
         * being freed very soon.  Allow memory usage go over the limit
@@ -2834,27 +2850,9 @@ int memcg_alloc_slab_cgroups(struct slab *slab, struct kmem_cache *s,
        return 0;
 }
 
-/*
- * Returns a pointer to the memory cgroup to which the kernel object is charged.
- *
- * A passed kernel object can be a slab object or a generic kernel page, so
- * different mechanisms for getting the memory cgroup pointer should be used.
- * In certain cases (e.g. kernel stacks or large kmallocs with SLUB) the caller
- * can not know for sure how the kernel object is implemented.
- * mem_cgroup_from_obj() can be safely used in such cases.
- *
- * The caller must ensure the memcg lifetime, e.g. by taking rcu_read_lock(),
- * cgroup_mutex, etc.
- */
-struct mem_cgroup *mem_cgroup_from_obj(void *p)
+static __always_inline
+struct mem_cgroup *mem_cgroup_from_obj_folio(struct folio *folio, void *p)
 {
-       struct folio *folio;
-
-       if (mem_cgroup_disabled())
-               return NULL;
-
-       folio = virt_to_folio(p);
-
        /*
         * Slab objects are accounted individually, not per-page.
         * Memcg membership data for each individual object is saved in
@@ -2887,6 +2885,66 @@ struct mem_cgroup *mem_cgroup_from_obj(void *p)
        return page_memcg_check(folio_page(folio, 0));
 }
 
+/*
+ * Returns a pointer to the memory cgroup to which the kernel object is charged.
+ *
+ * A passed kernel object can be a slab object, vmalloc object or a generic
+ * kernel page, so different mechanisms for getting the memory cgroup pointer
+ * should be used.
+ *
+ * In certain cases (e.g. kernel stacks or large kmallocs with SLUB) the caller
+ * can not know for sure how the kernel object is implemented.
+ * mem_cgroup_from_obj() can be safely used in such cases.
+ *
+ * The caller must ensure the memcg lifetime, e.g. by taking rcu_read_lock(),
+ * cgroup_mutex, etc.
+ */
+struct mem_cgroup *mem_cgroup_from_obj(void *p)
+{
+       struct folio *folio;
+
+       if (mem_cgroup_disabled())
+               return NULL;
+
+       if (unlikely(is_vmalloc_addr(p)))
+               folio = page_folio(vmalloc_to_page(p));
+       else
+               folio = virt_to_folio(p);
+
+       return mem_cgroup_from_obj_folio(folio, p);
+}
+
+/*
+ * Returns a pointer to the memory cgroup to which the kernel object is charged.
+ * Similar to mem_cgroup_from_obj(), but faster and not suitable for objects,
+ * allocated using vmalloc().
+ *
+ * A passed kernel object must be a slab object or a generic kernel page.
+ *
+ * The caller must ensure the memcg lifetime, e.g. by taking rcu_read_lock(),
+ * cgroup_mutex, etc.
+ */
+struct mem_cgroup *mem_cgroup_from_slab_obj(void *p)
+{
+       if (mem_cgroup_disabled())
+               return NULL;
+
+       return mem_cgroup_from_obj_folio(virt_to_folio(p), p);
+}
+
+static struct obj_cgroup *__get_obj_cgroup_from_memcg(struct mem_cgroup *memcg)
+{
+       struct obj_cgroup *objcg = NULL;
+
+       for (; memcg != root_mem_cgroup; memcg = parent_mem_cgroup(memcg)) {
+               objcg = rcu_dereference(memcg->objcg);
+               if (objcg && obj_cgroup_tryget(objcg))
+                       break;
+               objcg = NULL;
+       }
+       return objcg;
+}
+
 __always_inline struct obj_cgroup *get_obj_cgroup_from_current(void)
 {
        struct obj_cgroup *objcg = NULL;
@@ -2900,15 +2958,32 @@ __always_inline struct obj_cgroup *get_obj_cgroup_from_current(void)
                memcg = active_memcg();
        else
                memcg = mem_cgroup_from_task(current);
-
-       for (; memcg != root_mem_cgroup; memcg = parent_mem_cgroup(memcg)) {
-               objcg = rcu_dereference(memcg->objcg);
-               if (objcg && obj_cgroup_tryget(objcg))
-                       break;
-               objcg = NULL;
-       }
+       objcg = __get_obj_cgroup_from_memcg(memcg);
        rcu_read_unlock();
+       return objcg;
+}
+
+struct obj_cgroup *get_obj_cgroup_from_page(struct page *page)
+{
+       struct obj_cgroup *objcg;
+
+       if (!memcg_kmem_enabled() || memcg_kmem_bypass())
+               return NULL;
+
+       if (PageMemcgKmem(page)) {
+               objcg = __folio_objcg(page_folio(page));
+               obj_cgroup_get(objcg);
+       } else {
+               struct mem_cgroup *memcg;
 
+               rcu_read_lock();
+               memcg = __folio_memcg(page_folio(page));
+               if (memcg)
+                       objcg = __get_obj_cgroup_from_memcg(memcg);
+               else
+                       objcg = NULL;
+               rcu_read_unlock();
+       }
        return objcg;
 }
 
@@ -3387,7 +3462,6 @@ unsigned long mem_cgroup_soft_limit_reclaim(pg_data_t *pgdat, int order,
        int loop = 0;
        struct mem_cgroup_tree_per_node *mctz;
        unsigned long excess;
-       unsigned long nr_scanned;
 
        if (order > 0)
                return 0;
@@ -3415,13 +3489,10 @@ unsigned long mem_cgroup_soft_limit_reclaim(pg_data_t *pgdat, int order,
                if (!mz)
                        break;
 
-               nr_scanned = 0;
                reclaimed = mem_cgroup_soft_reclaim(mz->memcg, pgdat,
-                                                   gfp_mask, &nr_scanned);
+                                                   gfp_mask, total_scanned);
                nr_reclaimed += reclaimed;
-               *total_scanned += nr_scanned;
                spin_lock_irq(&mctz->lock);
-               __mem_cgroup_remove_exceeded(mz, mctz);
 
                /*
                 * If we failed to reclaim anything from this memory cgroup
@@ -3591,7 +3662,7 @@ static int memcg_online_kmem(struct mem_cgroup *memcg)
 {
        struct obj_cgroup *objcg;
 
-       if (cgroup_memory_nokmem)
+       if (mem_cgroup_kmem_disabled())
                return 0;
 
        if (unlikely(mem_cgroup_is_root(memcg)))
@@ -3615,7 +3686,7 @@ static void memcg_offline_kmem(struct mem_cgroup *memcg)
 {
        struct mem_cgroup *parent;
 
-       if (cgroup_memory_nokmem)
+       if (mem_cgroup_kmem_disabled())
                return;
 
        if (unlikely(mem_cgroup_is_root(memcg)))
@@ -4825,7 +4896,7 @@ static int mem_cgroup_slab_show(struct seq_file *m, void *p)
 {
        /*
         * Deprecated.
-        * Please, take a look at tools/cgroup/slabinfo.py .
+        * Please, take a look at tools/cgroup/memcg_slabinfo.py .
         */
        return 0;
 }
@@ -4893,7 +4964,6 @@ static struct cftype mem_cgroup_legacy_files[] = {
                .name = "oom_control",
                .seq_show = mem_cgroup_oom_control_read,
                .write_u64 = mem_cgroup_oom_control_write,
-               .private = MEMFILE_PRIVATE(_OOM_TYPE, OOM_CONTROL),
        },
        {
                .name = "pressure_level",
@@ -5027,6 +5097,29 @@ struct mem_cgroup *mem_cgroup_from_id(unsigned short id)
        return idr_find(&mem_cgroup_idr, id);
 }
 
+#ifdef CONFIG_SHRINKER_DEBUG
+struct mem_cgroup *mem_cgroup_get_from_ino(unsigned long ino)
+{
+       struct cgroup *cgrp;
+       struct cgroup_subsys_state *css;
+       struct mem_cgroup *memcg;
+
+       cgrp = cgroup_get_from_id(ino);
+       if (!cgrp)
+               return ERR_PTR(-ENOENT);
+
+       css = cgroup_get_e_css(cgrp, &memory_cgrp_subsys);
+       if (css)
+               memcg = container_of(css, struct mem_cgroup, css);
+       else
+               memcg = ERR_PTR(-ENOENT);
+
+       cgroup_put(cgrp);
+
+       return memcg;
+}
+#endif
+
 static int alloc_mem_cgroup_per_node_info(struct mem_cgroup *memcg, int node)
 {
        struct mem_cgroup_per_node *pn;
@@ -5151,6 +5244,9 @@ mem_cgroup_css_alloc(struct cgroup_subsys_state *parent_css)
 
        page_counter_set_high(&memcg->memory, PAGE_COUNTER_MAX);
        memcg->soft_limit = PAGE_COUNTER_MAX;
+#if defined(CONFIG_MEMCG_KMEM) && defined(CONFIG_ZSWAP)
+       memcg->zswap_max = PAGE_COUNTER_MAX;
+#endif
        page_counter_set_high(&memcg->swap, PAGE_COUNTER_MAX);
        if (parent) {
                memcg->swappiness = mem_cgroup_swappiness(parent);
@@ -5629,8 +5725,8 @@ out:
  *   2(MC_TARGET_SWAP): if the swap entry corresponding to this pte is a
  *     target for charge migration. if @target is not NULL, the entry is stored
  *     in target->ent.
- *   3(MC_TARGET_DEVICE): like MC_TARGET_PAGE  but page is MEMORY_DEVICE_PRIVATE
- *     (so ZONE_DEVICE page and thus not on the lru).
+ *   3(MC_TARGET_DEVICE): like MC_TARGET_PAGE  but page is device memory and
+ *   thus not on the lru.
  *     For now we such page is charge like a regular page would be as for all
  *     intent and purposes it is just special memory taking the place of a
  *     regular page.
@@ -5649,10 +5745,14 @@ static enum mc_target_type get_mctgt_type(struct vm_area_struct *vma,
 
        if (pte_present(ptent))
                page = mc_handle_present_pte(vma, addr, ptent);
+       else if (pte_none_mostly(ptent))
+               /*
+                * PTE markers should be treated as a none pte here, separated
+                * from other swap handling below.
+                */
+               page = mc_handle_file_pte(vma, addr, ptent);
        else if (is_swap_pte(ptent))
                page = mc_handle_swap_pte(vma, ptent, &ent);
-       else if (pte_none(ptent))
-               page = mc_handle_file_pte(vma, addr, ptent);
 
        if (!page && !ent.val)
                return ret;
@@ -5664,7 +5764,8 @@ static enum mc_target_type get_mctgt_type(struct vm_area_struct *vma,
                 */
                if (page_memcg(page) == mc.from) {
                        ret = MC_TARGET_PAGE;
-                       if (is_device_private_page(page))
+                       if (is_device_private_page(page) ||
+                           is_device_coherent_page(page))
                                ret = MC_TARGET_DEVICE;
                        if (target)
                                target->page = page;
@@ -6108,6 +6209,14 @@ static u64 memory_current_read(struct cgroup_subsys_state *css,
        return (u64)page_counter_read(&memcg->memory) * PAGE_SIZE;
 }
 
+static u64 memory_peak_read(struct cgroup_subsys_state *css,
+                           struct cftype *cft)
+{
+       struct mem_cgroup *memcg = mem_cgroup_from_css(css);
+
+       return (u64)memcg->memory.watermark * PAGE_SIZE;
+}
+
 static int memory_min_show(struct seq_file *m, void *v)
 {
        return seq_puts_memcg_tunable(m,
@@ -6365,12 +6474,57 @@ static ssize_t memory_oom_group_write(struct kernfs_open_file *of,
        return nbytes;
 }
 
+static ssize_t memory_reclaim(struct kernfs_open_file *of, char *buf,
+                             size_t nbytes, loff_t off)
+{
+       struct mem_cgroup *memcg = mem_cgroup_from_css(of_css(of));
+       unsigned int nr_retries = MAX_RECLAIM_RETRIES;
+       unsigned long nr_to_reclaim, nr_reclaimed = 0;
+       int err;
+
+       buf = strstrip(buf);
+       err = page_counter_memparse(buf, "", &nr_to_reclaim);
+       if (err)
+               return err;
+
+       while (nr_reclaimed < nr_to_reclaim) {
+               unsigned long reclaimed;
+
+               if (signal_pending(current))
+                       return -EINTR;
+
+               /*
+                * This is the final attempt, drain percpu lru caches in the
+                * hope of introducing more evictable pages for
+                * try_to_free_mem_cgroup_pages().
+                */
+               if (!nr_retries)
+                       lru_add_drain_all();
+
+               reclaimed = try_to_free_mem_cgroup_pages(memcg,
+                                               nr_to_reclaim - nr_reclaimed,
+                                               GFP_KERNEL, true);
+
+               if (!reclaimed && !nr_retries--)
+                       return -EAGAIN;
+
+               nr_reclaimed += reclaimed;
+       }
+
+       return nbytes;
+}
+
 static struct cftype memory_files[] = {
        {
                .name = "current",
                .flags = CFTYPE_NOT_ON_ROOT,
                .read_u64 = memory_current_read,
        },
+       {
+               .name = "peak",
+               .flags = CFTYPE_NOT_ON_ROOT,
+               .read_u64 = memory_peak_read,
+       },
        {
                .name = "min",
                .flags = CFTYPE_NOT_ON_ROOT,
@@ -6423,6 +6577,11 @@ static struct cftype memory_files[] = {
                .seq_show = memory_oom_group_show,
                .write = memory_oom_group_write,
        },
+       {
+               .name = "reclaim",
+               .flags = CFTYPE_NS_DELEGATABLE,
+               .write = memory_reclaim,
+       },
        { }     /* terminate */
 };
 
@@ -6593,9 +6752,6 @@ void mem_cgroup_calculate_protection(struct mem_cgroup *root,
                return;
 
        parent = parent_mem_cgroup(memcg);
-       /* No parent means a non-hierarchical mode on v1 memcg */
-       if (!parent)
-               return;
 
        if (parent == root) {
                memcg->memory.emin = READ_ONCE(memcg->memory.min);
@@ -7125,17 +7281,17 @@ void mem_cgroup_swapout(struct folio *folio, swp_entry_t entry)
 }
 
 /**
- * __mem_cgroup_try_charge_swap - try charging swap space for a page
- * @page: page being added to swap
+ * __mem_cgroup_try_charge_swap - try charging swap space for a folio
+ * @folio: folio being added to swap
  * @entry: swap entry to charge
  *
- * Try to charge @page's memcg for the swap space at @entry.
+ * Try to charge @folio's memcg for the swap space at @entry.
  *
  * Returns 0 on success, -ENOMEM on failure.
  */
-int __mem_cgroup_try_charge_swap(struct page *page, swp_entry_t entry)
+int __mem_cgroup_try_charge_swap(struct folio *folio, swp_entry_t entry)
 {
-       unsigned int nr_pages = thp_nr_pages(page);
+       unsigned int nr_pages = folio_nr_pages(folio);
        struct page_counter *counter;
        struct mem_cgroup *memcg;
        unsigned short oldid;
@@ -7143,9 +7299,9 @@ int __mem_cgroup_try_charge_swap(struct page *page, swp_entry_t entry)
        if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
                return 0;
 
-       memcg = page_memcg(page);
+       memcg = folio_memcg(folio);
 
-       VM_WARN_ON_ONCE_PAGE(!memcg, page);
+       VM_WARN_ON_ONCE_FOLIO(!memcg, folio);
        if (!memcg)
                return 0;
 
@@ -7168,7 +7324,7 @@ int __mem_cgroup_try_charge_swap(struct page *page, swp_entry_t entry)
        if (nr_pages > 1)
                mem_cgroup_id_get_many(memcg, nr_pages - 1);
        oldid = swap_cgroup_record(entry, mem_cgroup_id(memcg), nr_pages);
-       VM_BUG_ON_PAGE(oldid, page);
+       VM_BUG_ON_FOLIO(oldid, folio);
        mod_memcg_state(memcg, MEMCG_SWAP, nr_pages);
 
        return 0;
@@ -7371,6 +7527,148 @@ static struct cftype memsw_files[] = {
        { },    /* terminate */
 };
 
+#if defined(CONFIG_MEMCG_KMEM) && defined(CONFIG_ZSWAP)
+/**
+ * obj_cgroup_may_zswap - check if this cgroup can zswap
+ * @objcg: the object cgroup
+ *
+ * Check if the hierarchical zswap limit has been reached.
+ *
+ * This doesn't check for specific headroom, and it is not atomic
+ * either. But with zswap, the size of the allocation is only known
+ * once compression has occured, and this optimistic pre-check avoids
+ * spending cycles on compression when there is already no room left
+ * or zswap is disabled altogether somewhere in the hierarchy.
+ */
+bool obj_cgroup_may_zswap(struct obj_cgroup *objcg)
+{
+       struct mem_cgroup *memcg, *original_memcg;
+       bool ret = true;
+
+       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
+               return true;
+
+       original_memcg = get_mem_cgroup_from_objcg(objcg);
+       for (memcg = original_memcg; memcg != root_mem_cgroup;
+            memcg = parent_mem_cgroup(memcg)) {
+               unsigned long max = READ_ONCE(memcg->zswap_max);
+               unsigned long pages;
+
+               if (max == PAGE_COUNTER_MAX)
+                       continue;
+               if (max == 0) {
+                       ret = false;
+                       break;
+               }
+
+               cgroup_rstat_flush(memcg->css.cgroup);
+               pages = memcg_page_state(memcg, MEMCG_ZSWAP_B) / PAGE_SIZE;
+               if (pages < max)
+                       continue;
+               ret = false;
+               break;
+       }
+       mem_cgroup_put(original_memcg);
+       return ret;
+}
+
+/**
+ * obj_cgroup_charge_zswap - charge compression backend memory
+ * @objcg: the object cgroup
+ * @size: size of compressed object
+ *
+ * This forces the charge after obj_cgroup_may_swap() allowed
+ * compression and storage in zwap for this cgroup to go ahead.
+ */
+void obj_cgroup_charge_zswap(struct obj_cgroup *objcg, size_t size)
+{
+       struct mem_cgroup *memcg;
+
+       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
+               return;
+
+       VM_WARN_ON_ONCE(!(current->flags & PF_MEMALLOC));
+
+       /* PF_MEMALLOC context, charging must succeed */
+       if (obj_cgroup_charge(objcg, GFP_KERNEL, size))
+               VM_WARN_ON_ONCE(1);
+
+       rcu_read_lock();
+       memcg = obj_cgroup_memcg(objcg);
+       mod_memcg_state(memcg, MEMCG_ZSWAP_B, size);
+       mod_memcg_state(memcg, MEMCG_ZSWAPPED, 1);
+       rcu_read_unlock();
+}
+
+/**
+ * obj_cgroup_uncharge_zswap - uncharge compression backend memory
+ * @objcg: the object cgroup
+ * @size: size of compressed object
+ *
+ * Uncharges zswap memory on page in.
+ */
+void obj_cgroup_uncharge_zswap(struct obj_cgroup *objcg, size_t size)
+{
+       struct mem_cgroup *memcg;
+
+       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
+               return;
+
+       obj_cgroup_uncharge(objcg, size);
+
+       rcu_read_lock();
+       memcg = obj_cgroup_memcg(objcg);
+       mod_memcg_state(memcg, MEMCG_ZSWAP_B, -size);
+       mod_memcg_state(memcg, MEMCG_ZSWAPPED, -1);
+       rcu_read_unlock();
+}
+
+static u64 zswap_current_read(struct cgroup_subsys_state *css,
+                             struct cftype *cft)
+{
+       cgroup_rstat_flush(css->cgroup);
+       return memcg_page_state(mem_cgroup_from_css(css), MEMCG_ZSWAP_B);
+}
+
+static int zswap_max_show(struct seq_file *m, void *v)
+{
+       return seq_puts_memcg_tunable(m,
+               READ_ONCE(mem_cgroup_from_seq(m)->zswap_max));
+}
+
+static ssize_t zswap_max_write(struct kernfs_open_file *of,
+                              char *buf, size_t nbytes, loff_t off)
+{
+       struct mem_cgroup *memcg = mem_cgroup_from_css(of_css(of));
+       unsigned long max;
+       int err;
+
+       buf = strstrip(buf);
+       err = page_counter_memparse(buf, "max", &max);
+       if (err)
+               return err;
+
+       xchg(&memcg->zswap_max, max);
+
+       return nbytes;
+}
+
+static struct cftype zswap_files[] = {
+       {
+               .name = "zswap.current",
+               .flags = CFTYPE_NOT_ON_ROOT,
+               .read_u64 = zswap_current_read,
+       },
+       {
+               .name = "zswap.max",
+               .flags = CFTYPE_NOT_ON_ROOT,
+               .seq_show = zswap_max_show,
+               .write = zswap_max_write,
+       },
+       { }     /* terminate */
+};
+#endif /* CONFIG_MEMCG_KMEM && CONFIG_ZSWAP */
+
 /*
  * If mem_cgroup_swap_init() is implemented as a subsys_initcall()
  * instead of a core_initcall(), this could mean cgroup_memory_noswap still
@@ -7389,7 +7687,9 @@ static int __init mem_cgroup_swap_init(void)
 
        WARN_ON(cgroup_add_dfl_cftypes(&memory_cgrp_subsys, swap_files));
        WARN_ON(cgroup_add_legacy_cftypes(&memory_cgrp_subsys, memsw_files));
-
+#if defined(CONFIG_MEMCG_KMEM) && defined(CONFIG_ZSWAP)
+       WARN_ON(cgroup_add_dfl_cftypes(&memory_cgrp_subsys, zswap_files));
+#endif
        return 0;
 }
 core_initcall(mem_cgroup_swap_init);