Merge tag 'amd-drm-fixes-5.9-2020-08-20' of git://people.freedesktop.org/~agd5f/linux...
[linux-2.6-microblaze.git] / mm / memcontrol.c
index a3b97f1..b807952 100644 (file)
@@ -73,8 +73,6 @@ EXPORT_SYMBOL(memory_cgrp_subsys);
 
 struct mem_cgroup *root_mem_cgroup __read_mostly;
 
-#define MEM_CGROUP_RECLAIM_RETRIES     5
-
 /* Socket memory accounting disabled? */
 static bool cgroup_memory_nosocket;
 
@@ -83,9 +81,9 @@ static bool cgroup_memory_nokmem;
 
 /* Whether the swap controller is active */
 #ifdef CONFIG_MEMCG_SWAP
-int do_swap_account __read_mostly;
+bool cgroup_memory_noswap __read_mostly;
 #else
-#define do_swap_account                0
+#define cgroup_memory_noswap           1
 #endif
 
 #ifdef CONFIG_CGROUP_WRITEBACK
@@ -95,7 +93,7 @@ static DECLARE_WAIT_QUEUE_HEAD(memcg_cgwb_frn_waitq);
 /* Whether legacy memory+swap accounting is active */
 static bool do_memsw_account(void)
 {
-       return !cgroup_subsys_on_dfl(memory_cgrp_subsys) && do_swap_account;
+       return !cgroup_subsys_on_dfl(memory_cgrp_subsys) && !cgroup_memory_noswap;
 }
 
 #define THRESHOLDS_EVENTS_TARGET 128
@@ -257,8 +255,100 @@ struct cgroup_subsys_state *vmpressure_to_css(struct vmpressure *vmpr)
 }
 
 #ifdef CONFIG_MEMCG_KMEM
+extern spinlock_t css_set_lock;
+
+static void obj_cgroup_release(struct percpu_ref *ref)
+{
+       struct obj_cgroup *objcg = container_of(ref, struct obj_cgroup, refcnt);
+       struct mem_cgroup *memcg;
+       unsigned int nr_bytes;
+       unsigned int nr_pages;
+       unsigned long flags;
+
+       /*
+        * At this point all allocated objects are freed, and
+        * objcg->nr_charged_bytes can't have an arbitrary byte value.
+        * However, it can be PAGE_SIZE or (x * PAGE_SIZE).
+        *
+        * The following sequence can lead to it:
+        * 1) CPU0: objcg == stock->cached_objcg
+        * 2) CPU1: we do a small allocation (e.g. 92 bytes),
+        *          PAGE_SIZE bytes are charged
+        * 3) CPU1: a process from another memcg is allocating something,
+        *          the stock if flushed,
+        *          objcg->nr_charged_bytes = PAGE_SIZE - 92
+        * 5) CPU0: we do release this object,
+        *          92 bytes are added to stock->nr_bytes
+        * 6) CPU0: stock is flushed,
+        *          92 bytes are added to objcg->nr_charged_bytes
+        *
+        * In the result, nr_charged_bytes == PAGE_SIZE.
+        * This page will be uncharged in obj_cgroup_release().
+        */
+       nr_bytes = atomic_read(&objcg->nr_charged_bytes);
+       WARN_ON_ONCE(nr_bytes & (PAGE_SIZE - 1));
+       nr_pages = nr_bytes >> PAGE_SHIFT;
+
+       spin_lock_irqsave(&css_set_lock, flags);
+       memcg = obj_cgroup_memcg(objcg);
+       if (nr_pages)
+               __memcg_kmem_uncharge(memcg, nr_pages);
+       list_del(&objcg->list);
+       mem_cgroup_put(memcg);
+       spin_unlock_irqrestore(&css_set_lock, flags);
+
+       percpu_ref_exit(ref);
+       kfree_rcu(objcg, rcu);
+}
+
+static struct obj_cgroup *obj_cgroup_alloc(void)
+{
+       struct obj_cgroup *objcg;
+       int ret;
+
+       objcg = kzalloc(sizeof(struct obj_cgroup), GFP_KERNEL);
+       if (!objcg)
+               return NULL;
+
+       ret = percpu_ref_init(&objcg->refcnt, obj_cgroup_release, 0,
+                             GFP_KERNEL);
+       if (ret) {
+               kfree(objcg);
+               return NULL;
+       }
+       INIT_LIST_HEAD(&objcg->list);
+       return objcg;
+}
+
+static void memcg_reparent_objcgs(struct mem_cgroup *memcg,
+                                 struct mem_cgroup *parent)
+{
+       struct obj_cgroup *objcg, *iter;
+
+       objcg = rcu_replace_pointer(memcg->objcg, NULL, true);
+
+       spin_lock_irq(&css_set_lock);
+
+       /* Move active objcg to the parent's list */
+       xchg(&objcg->memcg, parent);
+       css_get(&parent->css);
+       list_add(&objcg->list, &parent->objcg_list);
+
+       /* Move already reparented objcgs to the parent's list */
+       list_for_each_entry(iter, &memcg->objcg_list, list) {
+               css_get(&parent->css);
+               xchg(&iter->memcg, parent);
+               css_put(&memcg->css);
+       }
+       list_splice(&memcg->objcg_list, &parent->objcg_list);
+
+       spin_unlock_irq(&css_set_lock);
+
+       percpu_ref_kill(&objcg->refcnt);
+}
+
 /*
- * This will be the memcg's index in each cache's ->memcg_params.memcg_caches.
+ * This will be used as a shrinker list's index.
  * The main reason for not using cgroup id for this:
  *  this works better in sparse environments, where we have a lot of memcgs,
  *  but only a few kmem-limited. Or also, if we have, for instance, 200
@@ -301,14 +391,12 @@ void memcg_put_cache_ids(void)
 
 /*
  * A lot of the calls to the cache allocation functions are expected to be
- * inlined by the compiler. Since the calls to memcg_kmem_get_cache are
+ * inlined by the compiler. Since the calls to memcg_slab_pre_alloc_hook() are
  * conditional to this static branch, we'll have to allow modules that does
  * kmem_cache_alloc and the such to see this symbol as well
  */
 DEFINE_STATIC_KEY_FALSE(memcg_kmem_enabled_key);
 EXPORT_SYMBOL(memcg_kmem_enabled_key);
-
-struct workqueue_struct *memcg_kmem_cache_wq;
 #endif
 
 static int memcg_shrinker_map_size;
@@ -477,10 +565,17 @@ ino_t page_cgroup_ino(struct page *page)
        unsigned long ino = 0;
 
        rcu_read_lock();
-       if (PageSlab(page) && !PageTail(page))
-               memcg = memcg_from_slab_page(page);
-       else
-               memcg = READ_ONCE(page->mem_cgroup);
+       memcg = page->mem_cgroup;
+
+       /*
+        * The lowest bit set means that memcg isn't a valid
+        * memcg pointer, but a obj_cgroups pointer.
+        * In this case the page is shared and doesn't belong
+        * to any specific memory cgroup.
+        */
+       if ((unsigned long) memcg & 0x1UL)
+               memcg = NULL;
+
        while (memcg && !(memcg->css.flags & CSS_ONLINE))
                memcg = parent_mem_cgroup(memcg);
        if (memcg)
@@ -681,13 +776,16 @@ mem_cgroup_largest_soft_limit_node(struct mem_cgroup_tree_per_node *mctz)
  */
 void __mod_memcg_state(struct mem_cgroup *memcg, int idx, int val)
 {
-       long x;
+       long x, threshold = MEMCG_CHARGE_BATCH;
 
        if (mem_cgroup_disabled())
                return;
 
+       if (memcg_stat_item_in_bytes(idx))
+               threshold <<= PAGE_SHIFT;
+
        x = val + __this_cpu_read(memcg->vmstats_percpu->stat[idx]);
-       if (unlikely(abs(x) > MEMCG_CHARGE_BATCH)) {
+       if (unlikely(abs(x) > threshold)) {
                struct mem_cgroup *mi;
 
                /*
@@ -713,29 +811,12 @@ parent_nodeinfo(struct mem_cgroup_per_node *pn, int nid)
        return mem_cgroup_nodeinfo(parent, nid);
 }
 
-/**
- * __mod_lruvec_state - update lruvec memory statistics
- * @lruvec: the lruvec
- * @idx: the stat item
- * @val: delta to add to the counter, can be negative
- *
- * The lruvec is the intersection of the NUMA node and a cgroup. This
- * function updates the all three counters that are affected by a
- * change of state at this level: per-node, per-cgroup, per-lruvec.
- */
-void __mod_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
-                       int val)
+void __mod_memcg_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
+                             int val)
 {
-       pg_data_t *pgdat = lruvec_pgdat(lruvec);
        struct mem_cgroup_per_node *pn;
        struct mem_cgroup *memcg;
-       long x;
-
-       /* Update node */
-       __mod_node_page_state(pgdat, idx, val);
-
-       if (mem_cgroup_disabled())
-               return;
+       long x, threshold = MEMCG_CHARGE_BATCH;
 
        pn = container_of(lruvec, struct mem_cgroup_per_node, lruvec);
        memcg = pn->memcg;
@@ -746,8 +827,12 @@ void __mod_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
        /* Update lruvec */
        __this_cpu_add(pn->lruvec_stat_local->count[idx], val);
 
+       if (vmstat_item_in_bytes(idx))
+               threshold <<= PAGE_SHIFT;
+
        x = val + __this_cpu_read(pn->lruvec_stat_cpu->count[idx]);
-       if (unlikely(abs(x) > MEMCG_CHARGE_BATCH)) {
+       if (unlikely(abs(x) > threshold)) {
+               pg_data_t *pgdat = lruvec_pgdat(lruvec);
                struct mem_cgroup_per_node *pi;
 
                for (pi = pn; pi; pi = parent_nodeinfo(pi, pgdat->node_id))
@@ -757,6 +842,27 @@ void __mod_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
        __this_cpu_write(pn->lruvec_stat_cpu->count[idx], x);
 }
 
+/**
+ * __mod_lruvec_state - update lruvec memory statistics
+ * @lruvec: the lruvec
+ * @idx: the stat item
+ * @val: delta to add to the counter, can be negative
+ *
+ * The lruvec is the intersection of the NUMA node and a cgroup. This
+ * function updates the all three counters that are affected by a
+ * change of state at this level: per-node, per-cgroup, per-lruvec.
+ */
+void __mod_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
+                       int val)
+{
+       /* Update node */
+       __mod_node_page_state(lruvec_pgdat(lruvec), idx, val);
+
+       /* Update memcg and lruvec */
+       if (!mem_cgroup_disabled())
+               __mod_memcg_lruvec_state(lruvec, idx, val);
+}
+
 void __mod_lruvec_slab_state(void *p, enum node_stat_item idx, int val)
 {
        pg_data_t *pgdat = page_pgdat(virt_to_page(p));
@@ -834,25 +940,8 @@ static unsigned long memcg_events_local(struct mem_cgroup *memcg, int event)
 
 static void mem_cgroup_charge_statistics(struct mem_cgroup *memcg,
                                         struct page *page,
-                                        bool compound, int nr_pages)
+                                        int nr_pages)
 {
-       /*
-        * Here, RSS means 'mapped anon' and anon's SwapCache. Shmem/tmpfs is
-        * counted as CACHE even if it's on ANON LRU.
-        */
-       if (PageAnon(page))
-               __mod_memcg_state(memcg, MEMCG_RSS, nr_pages);
-       else {
-               __mod_memcg_state(memcg, MEMCG_CACHE, nr_pages);
-               if (PageSwapBacked(page))
-                       __mod_memcg_state(memcg, NR_SHMEM, nr_pages);
-       }
-
-       if (compound) {
-               VM_BUG_ON_PAGE(!PageTransHuge(page), page);
-               __mod_memcg_state(memcg, MEMCG_RSS_HUGE, nr_pages);
-       }
-
        /* pagein of a big page is an event. So, ignore page size */
        if (nr_pages > 0)
                __count_memcg_events(memcg, PGPGIN, 1);
@@ -1021,7 +1110,7 @@ struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
                                   struct mem_cgroup *prev,
                                   struct mem_cgroup_reclaim_cookie *reclaim)
 {
-       struct mem_cgroup_reclaim_iter *uninitialized_var(iter);
+       struct mem_cgroup_reclaim_iter *iter;
        struct cgroup_subsys_state *css = NULL;
        struct mem_cgroup *memcg = NULL;
        struct mem_cgroup *pos = NULL;
@@ -1218,9 +1307,8 @@ int mem_cgroup_scan_tasks(struct mem_cgroup *memcg,
  * @page: the page
  * @pgdat: pgdat of the page
  *
- * This function is only safe when following the LRU page isolation
- * and putback protocol: the LRU lock must be held, and the page must
- * either be PageLRU() or the caller must have isolated/allocated it.
+ * This function relies on page->mem_cgroup being stable - see the
+ * access rules in commit_charge().
  */
 struct lruvec *mem_cgroup_page_lruvec(struct page *page, struct pglist_data *pgdat)
 {
@@ -1314,7 +1402,7 @@ static unsigned long mem_cgroup_margin(struct mem_cgroup *memcg)
        if (do_memsw_account()) {
                count = page_counter_read(&memcg->memsw);
                limit = READ_ONCE(memcg->memsw.max);
-               if (count <= limit)
+               if (count < limit)
                        margin = min(margin, limit - count);
                else
                        margin = 0;
@@ -1389,18 +1477,19 @@ static char *memory_stat_format(struct mem_cgroup *memcg)
         */
 
        seq_buf_printf(&s, "anon %llu\n",
-                      (u64)memcg_page_state(memcg, MEMCG_RSS) *
+                      (u64)memcg_page_state(memcg, NR_ANON_MAPPED) *
                       PAGE_SIZE);
        seq_buf_printf(&s, "file %llu\n",
-                      (u64)memcg_page_state(memcg, MEMCG_CACHE) *
+                      (u64)memcg_page_state(memcg, NR_FILE_PAGES) *
                       PAGE_SIZE);
        seq_buf_printf(&s, "kernel_stack %llu\n",
-                      (u64)memcg_page_state(memcg, MEMCG_KERNEL_STACK_KB) *
+                      (u64)memcg_page_state(memcg, NR_KERNEL_STACK_KB) *
                       1024);
        seq_buf_printf(&s, "slab %llu\n",
-                      (u64)(memcg_page_state(memcg, NR_SLAB_RECLAIMABLE) +
-                            memcg_page_state(memcg, NR_SLAB_UNRECLAIMABLE)) *
-                      PAGE_SIZE);
+                      (u64)(memcg_page_state(memcg, NR_SLAB_RECLAIMABLE_B) +
+                            memcg_page_state(memcg, NR_SLAB_UNRECLAIMABLE_B)));
+       seq_buf_printf(&s, "percpu %llu\n",
+                      (u64)memcg_page_state(memcg, MEMCG_PERCPU_B));
        seq_buf_printf(&s, "sock %llu\n",
                       (u64)memcg_page_state(memcg, MEMCG_SOCK) *
                       PAGE_SIZE);
@@ -1418,15 +1507,11 @@ static char *memory_stat_format(struct mem_cgroup *memcg)
                       (u64)memcg_page_state(memcg, NR_WRITEBACK) *
                       PAGE_SIZE);
 
-       /*
-        * TODO: We should eventually replace our own MEMCG_RSS_HUGE counter
-        * with the NR_ANON_THP vm counter, but right now it's a pain in the
-        * arse because it requires migrating the work out of rmap to a place
-        * where the page->mem_cgroup is set up and stable.
-        */
+#ifdef CONFIG_TRANSPARENT_HUGEPAGE
        seq_buf_printf(&s, "anon_thp %llu\n",
-                      (u64)memcg_page_state(memcg, MEMCG_RSS_HUGE) *
-                      PAGE_SIZE);
+                      (u64)memcg_page_state(memcg, NR_ANON_THPS) *
+                      HPAGE_PMD_SIZE);
+#endif
 
        for (i = 0; i < NR_LRU_LISTS; i++)
                seq_buf_printf(&s, "%s %llu\n", lru_list_name(i),
@@ -1434,11 +1519,9 @@ static char *memory_stat_format(struct mem_cgroup *memcg)
                               PAGE_SIZE);
 
        seq_buf_printf(&s, "slab_reclaimable %llu\n",
-                      (u64)memcg_page_state(memcg, NR_SLAB_RECLAIMABLE) *
-                      PAGE_SIZE);
+                      (u64)memcg_page_state(memcg, NR_SLAB_RECLAIMABLE_B));
        seq_buf_printf(&s, "slab_unreclaimable %llu\n",
-                      (u64)memcg_page_state(memcg, NR_SLAB_UNRECLAIMABLE) *
-                      PAGE_SIZE);
+                      (u64)memcg_page_state(memcg, NR_SLAB_UNRECLAIMABLE_B));
 
        /* Accumulated memory events */
 
@@ -1447,10 +1530,18 @@ static char *memory_stat_format(struct mem_cgroup *memcg)
        seq_buf_printf(&s, "%s %lu\n", vm_event_name(PGMAJFAULT),
                       memcg_events(memcg, PGMAJFAULT));
 
-       seq_buf_printf(&s, "workingset_refault %lu\n",
-                      memcg_page_state(memcg, WORKINGSET_REFAULT));
-       seq_buf_printf(&s, "workingset_activate %lu\n",
-                      memcg_page_state(memcg, WORKINGSET_ACTIVATE));
+       seq_buf_printf(&s, "workingset_refault_anon %lu\n",
+                      memcg_page_state(memcg, WORKINGSET_REFAULT_ANON));
+       seq_buf_printf(&s, "workingset_refault_file %lu\n",
+                      memcg_page_state(memcg, WORKINGSET_REFAULT_FILE));
+       seq_buf_printf(&s, "workingset_activate_anon %lu\n",
+                      memcg_page_state(memcg, WORKINGSET_ACTIVATE_ANON));
+       seq_buf_printf(&s, "workingset_activate_file %lu\n",
+                      memcg_page_state(memcg, WORKINGSET_ACTIVATE_FILE));
+       seq_buf_printf(&s, "workingset_restore %lu\n",
+                      memcg_page_state(memcg, WORKINGSET_RESTORE_ANON));
+       seq_buf_printf(&s, "workingset_restore %lu\n",
+                      memcg_page_state(memcg, WORKINGSET_RESTORE_FILE));
        seq_buf_printf(&s, "workingset_nodereclaim %lu\n",
                       memcg_page_state(memcg, WORKINGSET_NODERECLAIM));
 
@@ -1580,15 +1671,21 @@ static bool mem_cgroup_out_of_memory(struct mem_cgroup *memcg, gfp_t gfp_mask,
                .gfp_mask = gfp_mask,
                .order = order,
        };
-       bool ret;
+       bool ret = true;
 
        if (mutex_lock_killable(&oom_lock))
                return true;
+
+       if (mem_cgroup_margin(memcg) >= (1 << order))
+               goto unlock;
+
        /*
         * A few threads which were not waiting at mutex_lock_killable() can
         * fail to bail out. Therefore, check again after holding oom_lock.
         */
        ret = should_force_charge() || out_of_memory(&oc);
+
+unlock:
        mutex_unlock(&oom_lock);
        return ret;
 }
@@ -1979,6 +2076,7 @@ void mem_cgroup_print_oom_group(struct mem_cgroup *memcg)
  */
 struct mem_cgroup *lock_page_memcg(struct page *page)
 {
+       struct page *head = compound_head(page); /* rmap on tail pages */
        struct mem_cgroup *memcg;
        unsigned long flags;
 
@@ -1998,7 +2096,7 @@ struct mem_cgroup *lock_page_memcg(struct page *page)
        if (mem_cgroup_disabled())
                return NULL;
 again:
-       memcg = page->mem_cgroup;
+       memcg = head->mem_cgroup;
        if (unlikely(!memcg))
                return NULL;
 
@@ -2006,7 +2104,7 @@ again:
                return memcg;
 
        spin_lock_irqsave(&memcg->move_lock, flags);
-       if (memcg != page->mem_cgroup) {
+       if (memcg != head->mem_cgroup) {
                spin_unlock_irqrestore(&memcg->move_lock, flags);
                goto again;
        }
@@ -2049,13 +2147,21 @@ void __unlock_page_memcg(struct mem_cgroup *memcg)
  */
 void unlock_page_memcg(struct page *page)
 {
-       __unlock_page_memcg(page->mem_cgroup);
+       struct page *head = compound_head(page);
+
+       __unlock_page_memcg(head->mem_cgroup);
 }
 EXPORT_SYMBOL(unlock_page_memcg);
 
 struct memcg_stock_pcp {
        struct mem_cgroup *cached; /* this never be root cgroup */
        unsigned int nr_pages;
+
+#ifdef CONFIG_MEMCG_KMEM
+       struct obj_cgroup *cached_objcg;
+       unsigned int nr_bytes;
+#endif
+
        struct work_struct work;
        unsigned long flags;
 #define FLUSHING_CACHED_CHARGE 0
@@ -2063,6 +2169,22 @@ struct memcg_stock_pcp {
 static DEFINE_PER_CPU(struct memcg_stock_pcp, memcg_stock);
 static DEFINE_MUTEX(percpu_charge_mutex);
 
+#ifdef CONFIG_MEMCG_KMEM
+static void drain_obj_stock(struct memcg_stock_pcp *stock);
+static bool obj_stock_flush_required(struct memcg_stock_pcp *stock,
+                                    struct mem_cgroup *root_memcg);
+
+#else
+static inline void drain_obj_stock(struct memcg_stock_pcp *stock)
+{
+}
+static bool obj_stock_flush_required(struct memcg_stock_pcp *stock,
+                                    struct mem_cgroup *root_memcg)
+{
+       return false;
+}
+#endif
+
 /**
  * consume_stock: Try to consume stocked charge on this cpu.
  * @memcg: memcg to consume from.
@@ -2103,13 +2225,17 @@ static void drain_stock(struct memcg_stock_pcp *stock)
 {
        struct mem_cgroup *old = stock->cached;
 
+       if (!old)
+               return;
+
        if (stock->nr_pages) {
                page_counter_uncharge(&old->memory, stock->nr_pages);
                if (do_memsw_account())
                        page_counter_uncharge(&old->memsw, stock->nr_pages);
-               css_put_many(&old->css, stock->nr_pages);
                stock->nr_pages = 0;
        }
+
+       css_put(&old->css);
        stock->cached = NULL;
 }
 
@@ -2125,6 +2251,7 @@ static void drain_local_stock(struct work_struct *dummy)
        local_irq_save(flags);
 
        stock = this_cpu_ptr(&memcg_stock);
+       drain_obj_stock(stock);
        drain_stock(stock);
        clear_bit(FLUSHING_CACHED_CHARGE, &stock->flags);
 
@@ -2145,6 +2272,7 @@ static void refill_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
        stock = this_cpu_ptr(&memcg_stock);
        if (stock->cached != memcg) { /* reset if necessary */
                drain_stock(stock);
+               css_get(&memcg->css);
                stock->cached = memcg;
        }
        stock->nr_pages += nr_pages;
@@ -2183,6 +2311,8 @@ static void drain_all_stock(struct mem_cgroup *root_memcg)
                if (memcg && stock->nr_pages &&
                    mem_cgroup_is_descendant(memcg, root_memcg))
                        flush = true;
+               if (obj_stock_flush_required(stock, root_memcg))
+                       flush = true;
                rcu_read_unlock();
 
                if (flush &&
@@ -2245,17 +2375,29 @@ static int memcg_hotplug_cpu_dead(unsigned int cpu)
        return 0;
 }
 
-static void reclaim_high(struct mem_cgroup *memcg,
-                        unsigned int nr_pages,
-                        gfp_t gfp_mask)
+static unsigned long reclaim_high(struct mem_cgroup *memcg,
+                                 unsigned int nr_pages,
+                                 gfp_t gfp_mask)
 {
+       unsigned long nr_reclaimed = 0;
+
        do {
-               if (page_counter_read(&memcg->memory) <= READ_ONCE(memcg->high))
+               unsigned long pflags;
+
+               if (page_counter_read(&memcg->memory) <=
+                   READ_ONCE(memcg->memory.high))
                        continue;
+
                memcg_memory_event(memcg, MEMCG_HIGH);
-               try_to_free_mem_cgroup_pages(memcg, nr_pages, gfp_mask, true);
+
+               psi_memstall_enter(&pflags);
+               nr_reclaimed += try_to_free_mem_cgroup_pages(memcg, nr_pages,
+                                                            gfp_mask, true);
+               psi_memstall_leave(&pflags);
        } while ((memcg = parent_mem_cgroup(memcg)) &&
                 !mem_cgroup_is_root(memcg));
+
+       return nr_reclaimed;
 }
 
 static void high_work_func(struct work_struct *work)
@@ -2280,7 +2422,7 @@ static void high_work_func(struct work_struct *work)
  *
  * - MEMCG_DELAY_PRECISION_SHIFT: Extra precision bits while translating the
  *   overage ratio to a delay.
- * - MEMCG_DELAY_SCALING_SHIFT: The number of bits to scale down down the
+ * - MEMCG_DELAY_SCALING_SHIFT: The number of bits to scale down the
  *   proposed penalty in order to reduce to a reasonable number of jiffies, and
  *   to produce a reasonable delay curve.
  *
@@ -2319,41 +2461,64 @@ static void high_work_func(struct work_struct *work)
  #define MEMCG_DELAY_PRECISION_SHIFT 20
  #define MEMCG_DELAY_SCALING_SHIFT 14
 
-/*
- * Get the number of jiffies that we should penalise a mischievous cgroup which
- * is exceeding its memory.high by checking both it and its ancestors.
- */
-static unsigned long calculate_high_delay(struct mem_cgroup *memcg,
-                                         unsigned int nr_pages)
+static u64 calculate_overage(unsigned long usage, unsigned long high)
 {
-       unsigned long penalty_jiffies;
-       u64 max_overage = 0;
+       u64 overage;
 
-       do {
-               unsigned long usage, high;
-               u64 overage;
+       if (usage <= high)
+               return 0;
 
-               usage = page_counter_read(&memcg->memory);
-               high = READ_ONCE(memcg->high);
+       /*
+        * Prevent division by 0 in overage calculation by acting as if
+        * it was a threshold of 1 page
+        */
+       high = max(high, 1UL);
 
-               if (usage <= high)
-                       continue;
+       overage = usage - high;
+       overage <<= MEMCG_DELAY_PRECISION_SHIFT;
+       return div64_u64(overage, high);
+}
 
-               /*
-                * Prevent division by 0 in overage calculation by acting as if
-                * it was a threshold of 1 page
-                */
-               high = max(high, 1UL);
+static u64 mem_find_max_overage(struct mem_cgroup *memcg)
+{
+       u64 overage, max_overage = 0;
+
+       do {
+               overage = calculate_overage(page_counter_read(&memcg->memory),
+                                           READ_ONCE(memcg->memory.high));
+               max_overage = max(overage, max_overage);
+       } while ((memcg = parent_mem_cgroup(memcg)) &&
+                !mem_cgroup_is_root(memcg));
 
-               overage = usage - high;
-               overage <<= MEMCG_DELAY_PRECISION_SHIFT;
-               overage = div64_u64(overage, high);
+       return max_overage;
+}
 
-               if (overage > max_overage)
-                       max_overage = overage;
+static u64 swap_find_max_overage(struct mem_cgroup *memcg)
+{
+       u64 overage, max_overage = 0;
+
+       do {
+               overage = calculate_overage(page_counter_read(&memcg->swap),
+                                           READ_ONCE(memcg->swap.high));
+               if (overage)
+                       memcg_memory_event(memcg, MEMCG_SWAP_HIGH);
+               max_overage = max(overage, max_overage);
        } while ((memcg = parent_mem_cgroup(memcg)) &&
                 !mem_cgroup_is_root(memcg));
 
+       return max_overage;
+}
+
+/*
+ * Get the number of jiffies that we should penalise a mischievous cgroup which
+ * is exceeding its memory.high by checking both it and its ancestors.
+ */
+static unsigned long calculate_high_delay(struct mem_cgroup *memcg,
+                                         unsigned int nr_pages,
+                                         u64 max_overage)
+{
+       unsigned long penalty_jiffies;
+
        if (!max_overage)
                return 0;
 
@@ -2377,14 +2542,7 @@ static unsigned long calculate_high_delay(struct mem_cgroup *memcg,
         * MEMCG_CHARGE_BATCH pages is nominal, so work out how much smaller or
         * larger the current charge patch is than that.
         */
-       penalty_jiffies = penalty_jiffies * nr_pages / MEMCG_CHARGE_BATCH;
-
-       /*
-        * Clamp the max delay per usermode return so as to still keep the
-        * application moving forwards and also permit diagnostics, albeit
-        * extremely slowly.
-        */
-       return min(penalty_jiffies, MEMCG_MAX_HIGH_DELAY_JIFFIES);
+       return penalty_jiffies * nr_pages / MEMCG_CHARGE_BATCH;
 }
 
 /*
@@ -2395,21 +2553,48 @@ void mem_cgroup_handle_over_high(void)
 {
        unsigned long penalty_jiffies;
        unsigned long pflags;
+       unsigned long nr_reclaimed;
        unsigned int nr_pages = current->memcg_nr_pages_over_high;
+       int nr_retries = MAX_RECLAIM_RETRIES;
        struct mem_cgroup *memcg;
+       bool in_retry = false;
 
        if (likely(!nr_pages))
                return;
 
        memcg = get_mem_cgroup_from_mm(current->mm);
-       reclaim_high(memcg, nr_pages, GFP_KERNEL);
        current->memcg_nr_pages_over_high = 0;
 
+retry_reclaim:
+       /*
+        * The allocating task should reclaim at least the batch size, but for
+        * subsequent retries we only want to do what's necessary to prevent oom
+        * or breaching resource isolation.
+        *
+        * This is distinct from memory.max or page allocator behaviour because
+        * memory.high is currently batched, whereas memory.max and the page
+        * allocator run every time an allocation is made.
+        */
+       nr_reclaimed = reclaim_high(memcg,
+                                   in_retry ? SWAP_CLUSTER_MAX : nr_pages,
+                                   GFP_KERNEL);
+
        /*
         * memory.high is breached and reclaim is unable to keep up. Throttle
         * allocators proactively to slow down excessive growth.
         */
-       penalty_jiffies = calculate_high_delay(memcg, nr_pages);
+       penalty_jiffies = calculate_high_delay(memcg, nr_pages,
+                                              mem_find_max_overage(memcg));
+
+       penalty_jiffies += calculate_high_delay(memcg, nr_pages,
+                                               swap_find_max_overage(memcg));
+
+       /*
+        * Clamp the max delay per usermode return so as to still keep the
+        * application moving forwards and also permit diagnostics, albeit
+        * extremely slowly.
+        */
+       penalty_jiffies = min(penalty_jiffies, MEMCG_MAX_HIGH_DELAY_JIFFIES);
 
        /*
         * Don't sleep if the amount of jiffies this memcg owes us is so low
@@ -2420,6 +2605,16 @@ void mem_cgroup_handle_over_high(void)
        if (penalty_jiffies <= HZ / 100)
                goto out;
 
+       /*
+        * If reclaim is making forward progress but we're still over
+        * memory.high, we want to encourage that rather than doing allocator
+        * throttling.
+        */
+       if (nr_reclaimed || nr_retries--) {
+               in_retry = true;
+               goto retry_reclaim;
+       }
+
        /*
         * If we exit early, we're guaranteed to die (since
         * schedule_timeout_killable sets TASK_KILLABLE). This means we don't
@@ -2437,13 +2632,14 @@ static int try_charge(struct mem_cgroup *memcg, gfp_t gfp_mask,
                      unsigned int nr_pages)
 {
        unsigned int batch = max(MEMCG_CHARGE_BATCH, nr_pages);
-       int nr_retries = MEM_CGROUP_RECLAIM_RETRIES;
+       int nr_retries = MAX_RECLAIM_RETRIES;
        struct mem_cgroup *mem_over_limit;
        struct page_counter *counter;
+       enum oom_status oom_status;
        unsigned long nr_reclaimed;
        bool may_swap = true;
        bool drained = false;
-       enum oom_status oom_status;
+       unsigned long pflags;
 
        if (mem_cgroup_is_root(memcg))
                return 0;
@@ -2503,8 +2699,10 @@ retry:
 
        memcg_memory_event(mem_over_limit, MEMCG_MAX);
 
+       psi_memstall_enter(&pflags);
        nr_reclaimed = try_to_free_mem_cgroup_pages(mem_over_limit, nr_pages,
                                                    gfp_mask, may_swap);
+       psi_memstall_leave(&pflags);
 
        if (mem_cgroup_margin(mem_over_limit) >= nr_pages)
                goto retry;
@@ -2556,7 +2754,7 @@ retry:
                       get_order(nr_pages * PAGE_SIZE));
        switch (oom_status) {
        case OOM_SUCCESS:
-               nr_retries = MEM_CGROUP_RECLAIM_RETRIES;
+               nr_retries = MAX_RECLAIM_RETRIES;
                goto retry;
        case OOM_FAILED:
                goto force;
@@ -2575,12 +2773,10 @@ force:
        page_counter_charge(&memcg->memory, nr_pages);
        if (do_memsw_account())
                page_counter_charge(&memcg->memsw, nr_pages);
-       css_get_many(&memcg->css, nr_pages);
 
        return 0;
 
 done_restock:
-       css_get_many(&memcg->css, batch);
        if (batch > nr_pages)
                refill_stock(memcg, batch - nr_pages);
 
@@ -2594,12 +2790,32 @@ done_restock:
         * reclaim, the cost of mismatch is negligible.
         */
        do {
-               if (page_counter_read(&memcg->memory) > READ_ONCE(memcg->high)) {
-                       /* Don't bother a random interrupted task */
-                       if (in_interrupt()) {
+               bool mem_high, swap_high;
+
+               mem_high = page_counter_read(&memcg->memory) >
+                       READ_ONCE(memcg->memory.high);
+               swap_high = page_counter_read(&memcg->swap) >
+                       READ_ONCE(memcg->swap.high);
+
+               /* Don't bother a random interrupted task */
+               if (in_interrupt()) {
+                       if (mem_high) {
                                schedule_work(&memcg->high_work);
                                break;
                        }
+                       continue;
+               }
+
+               if (mem_high || swap_high) {
+                       /*
+                        * The allocating tasks in this cgroup will need to do
+                        * reclaim or be throttled to prevent further growth
+                        * of the memory or swap footprints.
+                        *
+                        * Target some best-effort fairness between the tasks,
+                        * and distribute reclaim work and delay penalties
+                        * based on how much each task is actually allocating.
+                        */
                        current->memcg_nr_pages_over_high += batch;
                        set_notify_resume(current);
                        break;
@@ -2609,6 +2825,7 @@ done_restock:
        return 0;
 }
 
+#if defined(CONFIG_MEMCG_KMEM) || defined(CONFIG_MMU)
 static void cancel_charge(struct mem_cgroup *memcg, unsigned int nr_pages)
 {
        if (mem_cgroup_is_root(memcg))
@@ -2617,76 +2834,44 @@ static void cancel_charge(struct mem_cgroup *memcg, unsigned int nr_pages)
        page_counter_uncharge(&memcg->memory, nr_pages);
        if (do_memsw_account())
                page_counter_uncharge(&memcg->memsw, nr_pages);
-
-       css_put_many(&memcg->css, nr_pages);
-}
-
-static void lock_page_lru(struct page *page, int *isolated)
-{
-       pg_data_t *pgdat = page_pgdat(page);
-
-       spin_lock_irq(&pgdat->lru_lock);
-       if (PageLRU(page)) {
-               struct lruvec *lruvec;
-
-               lruvec = mem_cgroup_page_lruvec(page, pgdat);
-               ClearPageLRU(page);
-               del_page_from_lru_list(page, lruvec, page_lru(page));
-               *isolated = 1;
-       } else
-               *isolated = 0;
-}
-
-static void unlock_page_lru(struct page *page, int isolated)
-{
-       pg_data_t *pgdat = page_pgdat(page);
-
-       if (isolated) {
-               struct lruvec *lruvec;
-
-               lruvec = mem_cgroup_page_lruvec(page, pgdat);
-               VM_BUG_ON_PAGE(PageLRU(page), page);
-               SetPageLRU(page);
-               add_page_to_lru_list(page, lruvec, page_lru(page));
-       }
-       spin_unlock_irq(&pgdat->lru_lock);
 }
+#endif
 
-static void commit_charge(struct page *page, struct mem_cgroup *memcg,
-                         bool lrucare)
+static void commit_charge(struct page *page, struct mem_cgroup *memcg)
 {
-       int isolated;
-
        VM_BUG_ON_PAGE(page->mem_cgroup, page);
-
        /*
-        * In some cases, SwapCache and FUSE(splice_buf->radixtree), the page
-        * may already be on some other mem_cgroup's LRU.  Take care of it.
-        */
-       if (lrucare)
-               lock_page_lru(page, &isolated);
-
-       /*
-        * Nobody should be changing or seriously looking at
-        * page->mem_cgroup at this point:
-        *
-        * - the page is uncharged
+        * Any of the following ensures page->mem_cgroup stability:
         *
-        * - the page is off-LRU
-        *
-        * - an anonymous fault has exclusive page access, except for
-        *   a locked page table
-        *
-        * - a page cache insertion, a swapin fault, or a migration
-        *   have the page locked
+        * - the page lock
+        * - LRU isolation
+        * - lock_page_memcg()
+        * - exclusive reference
         */
        page->mem_cgroup = memcg;
-
-       if (lrucare)
-               unlock_page_lru(page, isolated);
 }
 
 #ifdef CONFIG_MEMCG_KMEM
+int memcg_alloc_page_obj_cgroups(struct page *page, struct kmem_cache *s,
+                                gfp_t gfp)
+{
+       unsigned int objects = objs_per_slab_page(s, page);
+       void *vec;
+
+       vec = kcalloc_node(objects, sizeof(struct obj_cgroup *), gfp,
+                          page_to_nid(page));
+       if (!vec)
+               return -ENOMEM;
+
+       if (cmpxchg(&page->obj_cgroups, NULL,
+                   (struct obj_cgroup **) ((unsigned long)vec | 0x1UL)))
+               kfree(vec);
+       else
+               kmemleak_not_leak(vec);
+
+       return 0;
+}
+
 /*
  * Returns a pointer to the memory cgroup to which the kernel object is charged.
  *
@@ -2703,17 +2888,50 @@ struct mem_cgroup *mem_cgroup_from_obj(void *p)
        page = virt_to_head_page(p);
 
        /*
-        * Slab pages don't have page->mem_cgroup set because corresponding
-        * kmem caches can be reparented during the lifetime. That's why
-        * memcg_from_slab_page() should be used instead.
+        * Slab objects are accounted individually, not per-page.
+        * Memcg membership data for each individual object is saved in
+        * the page->obj_cgroups.
         */
-       if (PageSlab(page))
-               return memcg_from_slab_page(page);
+       if (page_has_obj_cgroups(page)) {
+               struct obj_cgroup *objcg;
+               unsigned int off;
+
+               off = obj_to_index(page->slab_cache, page, p);
+               objcg = page_obj_cgroups(page)[off];
+               if (objcg)
+                       return obj_cgroup_memcg(objcg);
+
+               return NULL;
+       }
 
        /* All other pages use page->mem_cgroup */
        return page->mem_cgroup;
 }
 
+__always_inline struct obj_cgroup *get_obj_cgroup_from_current(void)
+{
+       struct obj_cgroup *objcg = NULL;
+       struct mem_cgroup *memcg;
+
+       if (unlikely(!current->mm && !current->active_memcg))
+               return NULL;
+
+       rcu_read_lock();
+       if (unlikely(current->active_memcg))
+               memcg = rcu_dereference(current->active_memcg);
+       else
+               memcg = mem_cgroup_from_task(current);
+
+       for (; memcg != root_mem_cgroup; memcg = parent_mem_cgroup(memcg)) {
+               objcg = rcu_dereference(memcg->objcg);
+               if (objcg && obj_cgroup_tryget(objcg))
+                       break;
+       }
+       rcu_read_unlock();
+
+       return objcg;
+}
+
 static int memcg_alloc_cache_id(void)
 {
        int id, size;
@@ -2739,9 +2957,7 @@ static int memcg_alloc_cache_id(void)
        else if (size > MEMCG_CACHES_MAX_SIZE)
                size = MEMCG_CACHES_MAX_SIZE;
 
-       err = memcg_update_all_caches(size);
-       if (!err)
-               err = memcg_update_all_list_lrus(size);
+       err = memcg_update_all_list_lrus(size);
        if (!err)
                memcg_nr_cache_ids = size;
 
@@ -2759,245 +2975,242 @@ static void memcg_free_cache_id(int id)
        ida_simple_remove(&memcg_cache_ida, id);
 }
 
-struct memcg_kmem_cache_create_work {
-       struct mem_cgroup *memcg;
-       struct kmem_cache *cachep;
-       struct work_struct work;
-};
-
-static void memcg_kmem_cache_create_func(struct work_struct *w)
-{
-       struct memcg_kmem_cache_create_work *cw =
-               container_of(w, struct memcg_kmem_cache_create_work, work);
-       struct mem_cgroup *memcg = cw->memcg;
-       struct kmem_cache *cachep = cw->cachep;
-
-       memcg_create_kmem_cache(memcg, cachep);
-
-       css_put(&memcg->css);
-       kfree(cw);
-}
-
-/*
- * Enqueue the creation of a per-memcg kmem_cache.
+/**
+ * __memcg_kmem_charge: charge a number of kernel pages to a memcg
+ * @memcg: memory cgroup to charge
+ * @gfp: reclaim mode
+ * @nr_pages: number of pages to charge
+ *
+ * Returns 0 on success, an error code on failure.
  */
-static void memcg_schedule_kmem_cache_create(struct mem_cgroup *memcg,
-                                              struct kmem_cache *cachep)
+int __memcg_kmem_charge(struct mem_cgroup *memcg, gfp_t gfp,
+                       unsigned int nr_pages)
 {
-       struct memcg_kmem_cache_create_work *cw;
-
-       if (!css_tryget_online(&memcg->css))
-               return;
+       struct page_counter *counter;
+       int ret;
 
-       cw = kmalloc(sizeof(*cw), GFP_NOWAIT | __GFP_NOWARN);
-       if (!cw)
-               return;
+       ret = try_charge(memcg, gfp, nr_pages);
+       if (ret)
+               return ret;
 
-       cw->memcg = memcg;
-       cw->cachep = cachep;
-       INIT_WORK(&cw->work, memcg_kmem_cache_create_func);
+       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) &&
+           !page_counter_try_charge(&memcg->kmem, nr_pages, &counter)) {
 
-       queue_work(memcg_kmem_cache_wq, &cw->work);
+               /*
+                * Enforce __GFP_NOFAIL allocation because callers are not
+                * prepared to see failures and likely do not have any failure
+                * handling code.
+                */
+               if (gfp & __GFP_NOFAIL) {
+                       page_counter_charge(&memcg->kmem, nr_pages);
+                       return 0;
+               }
+               cancel_charge(memcg, nr_pages);
+               return -ENOMEM;
+       }
+       return 0;
 }
 
-static inline bool memcg_kmem_bypass(void)
+/**
+ * __memcg_kmem_uncharge: uncharge a number of kernel pages from a memcg
+ * @memcg: memcg to uncharge
+ * @nr_pages: number of pages to uncharge
+ */
+void __memcg_kmem_uncharge(struct mem_cgroup *memcg, unsigned int nr_pages)
 {
-       if (in_interrupt() || !current->mm || (current->flags & PF_KTHREAD))
-               return true;
-       return false;
+       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
+               page_counter_uncharge(&memcg->kmem, nr_pages);
+
+       page_counter_uncharge(&memcg->memory, nr_pages);
+       if (do_memsw_account())
+               page_counter_uncharge(&memcg->memsw, nr_pages);
 }
 
 /**
- * memcg_kmem_get_cache: select the correct per-memcg cache for allocation
- * @cachep: the original global kmem cache
- *
- * Return the kmem_cache we're supposed to use for a slab allocation.
- * We try to use the current memcg's version of the cache.
- *
- * If the cache does not exist yet, if we are the first user of it, we
- * create it asynchronously in a workqueue and let the current allocation
- * go through with the original cache.
+ * __memcg_kmem_charge_page: charge a kmem page to the current memory cgroup
+ * @page: page to charge
+ * @gfp: reclaim mode
+ * @order: allocation order
  *
- * This function takes a reference to the cache it returns to assure it
- * won't get destroyed while we are working with it. Once the caller is
- * done with it, memcg_kmem_put_cache() must be called to release the
- * reference.
+ * Returns 0 on success, an error code on failure.
  */
-struct kmem_cache *memcg_kmem_get_cache(struct kmem_cache *cachep)
+int __memcg_kmem_charge_page(struct page *page, gfp_t gfp, int order)
 {
        struct mem_cgroup *memcg;
-       struct kmem_cache *memcg_cachep;
-       struct memcg_cache_array *arr;
-       int kmemcg_id;
-
-       VM_BUG_ON(!is_root_cache(cachep));
+       int ret = 0;
 
        if (memcg_kmem_bypass())
-               return cachep;
-
-       rcu_read_lock();
-
-       if (unlikely(current->active_memcg))
-               memcg = current->active_memcg;
-       else
-               memcg = mem_cgroup_from_task(current);
+               return 0;
 
-       if (!memcg || memcg == root_mem_cgroup)
-               goto out_unlock;
+       memcg = get_mem_cgroup_from_current();
+       if (!mem_cgroup_is_root(memcg)) {
+               ret = __memcg_kmem_charge(memcg, gfp, 1 << order);
+               if (!ret) {
+                       page->mem_cgroup = memcg;
+                       __SetPageKmemcg(page);
+                       return 0;
+               }
+       }
+       css_put(&memcg->css);
+       return ret;
+}
 
-       kmemcg_id = READ_ONCE(memcg->kmemcg_id);
-       if (kmemcg_id < 0)
-               goto out_unlock;
+/**
+ * __memcg_kmem_uncharge_page: uncharge a kmem page
+ * @page: page to uncharge
+ * @order: allocation order
+ */
+void __memcg_kmem_uncharge_page(struct page *page, int order)
+{
+       struct mem_cgroup *memcg = page->mem_cgroup;
+       unsigned int nr_pages = 1 << order;
 
-       arr = rcu_dereference(cachep->memcg_params.memcg_caches);
+       if (!memcg)
+               return;
 
-       /*
-        * Make sure we will access the up-to-date value. The code updating
-        * memcg_caches issues a write barrier to match the data dependency
-        * barrier inside READ_ONCE() (see memcg_create_kmem_cache()).
-        */
-       memcg_cachep = READ_ONCE(arr->entries[kmemcg_id]);
+       VM_BUG_ON_PAGE(mem_cgroup_is_root(memcg), page);
+       __memcg_kmem_uncharge(memcg, nr_pages);
+       page->mem_cgroup = NULL;
+       css_put(&memcg->css);
 
-       /*
-        * If we are in a safe context (can wait, and not in interrupt
-        * context), we could be be predictable and return right away.
-        * This would guarantee that the allocation being performed
-        * already belongs in the new cache.
-        *
-        * However, there are some clashes that can arrive from locking.
-        * For instance, because we acquire the slab_mutex while doing
-        * memcg_create_kmem_cache, this means no further allocation
-        * could happen with the slab_mutex held. So it's better to
-        * defer everything.
-        *
-        * If the memcg is dying or memcg_cache is about to be released,
-        * don't bother creating new kmem_caches. Because memcg_cachep
-        * is ZEROed as the fist step of kmem offlining, we don't need
-        * percpu_ref_tryget_live() here. css_tryget_online() check in
-        * memcg_schedule_kmem_cache_create() will prevent us from
-        * creation of a new kmem_cache.
-        */
-       if (unlikely(!memcg_cachep))
-               memcg_schedule_kmem_cache_create(memcg, cachep);
-       else if (percpu_ref_tryget(&memcg_cachep->memcg_params.refcnt))
-               cachep = memcg_cachep;
-out_unlock:
-       rcu_read_unlock();
-       return cachep;
+       /* slab pages do not have PageKmemcg flag set */
+       if (PageKmemcg(page))
+               __ClearPageKmemcg(page);
 }
 
-/**
- * memcg_kmem_put_cache: drop reference taken by memcg_kmem_get_cache
- * @cachep: the cache returned by memcg_kmem_get_cache
- */
-void memcg_kmem_put_cache(struct kmem_cache *cachep)
+static bool consume_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes)
 {
-       if (!is_root_cache(cachep))
-               percpu_ref_put(&cachep->memcg_params.refcnt);
+       struct memcg_stock_pcp *stock;
+       unsigned long flags;
+       bool ret = false;
+
+       local_irq_save(flags);
+
+       stock = this_cpu_ptr(&memcg_stock);
+       if (objcg == stock->cached_objcg && stock->nr_bytes >= nr_bytes) {
+               stock->nr_bytes -= nr_bytes;
+               ret = true;
+       }
+
+       local_irq_restore(flags);
+
+       return ret;
 }
 
-/**
- * __memcg_kmem_charge: charge a number of kernel pages to a memcg
- * @memcg: memory cgroup to charge
- * @gfp: reclaim mode
- * @nr_pages: number of pages to charge
- *
- * Returns 0 on success, an error code on failure.
- */
-int __memcg_kmem_charge(struct mem_cgroup *memcg, gfp_t gfp,
-                       unsigned int nr_pages)
+static void drain_obj_stock(struct memcg_stock_pcp *stock)
 {
-       struct page_counter *counter;
-       int ret;
+       struct obj_cgroup *old = stock->cached_objcg;
 
-       ret = try_charge(memcg, gfp, nr_pages);
-       if (ret)
-               return ret;
+       if (!old)
+               return;
 
-       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) &&
-           !page_counter_try_charge(&memcg->kmem, nr_pages, &counter)) {
+       if (stock->nr_bytes) {
+               unsigned int nr_pages = stock->nr_bytes >> PAGE_SHIFT;
+               unsigned int nr_bytes = stock->nr_bytes & (PAGE_SIZE - 1);
+
+               if (nr_pages) {
+                       rcu_read_lock();
+                       __memcg_kmem_uncharge(obj_cgroup_memcg(old), nr_pages);
+                       rcu_read_unlock();
+               }
 
                /*
-                * Enforce __GFP_NOFAIL allocation because callers are not
-                * prepared to see failures and likely do not have any failure
-                * handling code.
+                * The leftover is flushed to the centralized per-memcg value.
+                * On the next attempt to refill obj stock it will be moved
+                * to a per-cpu stock (probably, on an other CPU), see
+                * refill_obj_stock().
+                *
+                * How often it's flushed is a trade-off between the memory
+                * limit enforcement accuracy and potential CPU contention,
+                * so it might be changed in the future.
                 */
-               if (gfp & __GFP_NOFAIL) {
-                       page_counter_charge(&memcg->kmem, nr_pages);
-                       return 0;
-               }
-               cancel_charge(memcg, nr_pages);
-               return -ENOMEM;
+               atomic_add(nr_bytes, &old->nr_charged_bytes);
+               stock->nr_bytes = 0;
+       }
+
+       obj_cgroup_put(old);
+       stock->cached_objcg = NULL;
+}
+
+static bool obj_stock_flush_required(struct memcg_stock_pcp *stock,
+                                    struct mem_cgroup *root_memcg)
+{
+       struct mem_cgroup *memcg;
+
+       if (stock->cached_objcg) {
+               memcg = obj_cgroup_memcg(stock->cached_objcg);
+               if (memcg && mem_cgroup_is_descendant(memcg, root_memcg))
+                       return true;
        }
-       return 0;
+
+       return false;
 }
 
-/**
- * __memcg_kmem_uncharge: uncharge a number of kernel pages from a memcg
- * @memcg: memcg to uncharge
- * @nr_pages: number of pages to uncharge
- */
-void __memcg_kmem_uncharge(struct mem_cgroup *memcg, unsigned int nr_pages)
+static void refill_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes)
 {
-       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
-               page_counter_uncharge(&memcg->kmem, nr_pages);
+       struct memcg_stock_pcp *stock;
+       unsigned long flags;
 
-       page_counter_uncharge(&memcg->memory, nr_pages);
-       if (do_memsw_account())
-               page_counter_uncharge(&memcg->memsw, nr_pages);
+       local_irq_save(flags);
+
+       stock = this_cpu_ptr(&memcg_stock);
+       if (stock->cached_objcg != objcg) { /* reset if necessary */
+               drain_obj_stock(stock);
+               obj_cgroup_get(objcg);
+               stock->cached_objcg = objcg;
+               stock->nr_bytes = atomic_xchg(&objcg->nr_charged_bytes, 0);
+       }
+       stock->nr_bytes += nr_bytes;
+
+       if (stock->nr_bytes > PAGE_SIZE)
+               drain_obj_stock(stock);
+
+       local_irq_restore(flags);
 }
 
-/**
- * __memcg_kmem_charge_page: charge a kmem page to the current memory cgroup
- * @page: page to charge
- * @gfp: reclaim mode
- * @order: allocation order
- *
- * Returns 0 on success, an error code on failure.
- */
-int __memcg_kmem_charge_page(struct page *page, gfp_t gfp, int order)
+int obj_cgroup_charge(struct obj_cgroup *objcg, gfp_t gfp, size_t size)
 {
        struct mem_cgroup *memcg;
-       int ret = 0;
+       unsigned int nr_pages, nr_bytes;
+       int ret;
 
-       if (memcg_kmem_bypass())
+       if (consume_obj_stock(objcg, size))
                return 0;
 
-       memcg = get_mem_cgroup_from_current();
-       if (!mem_cgroup_is_root(memcg)) {
-               ret = __memcg_kmem_charge(memcg, gfp, 1 << order);
-               if (!ret) {
-                       page->mem_cgroup = memcg;
-                       __SetPageKmemcg(page);
-               }
-       }
-       css_put(&memcg->css);
-       return ret;
-}
+       /*
+        * In theory, memcg->nr_charged_bytes can have enough
+        * pre-charged bytes to satisfy the allocation. However,
+        * flushing memcg->nr_charged_bytes requires two atomic
+        * operations, and memcg->nr_charged_bytes can't be big,
+        * so it's better to ignore it and try grab some new pages.
+        * memcg->nr_charged_bytes will be flushed in
+        * refill_obj_stock(), called from this function or
+        * independently later.
+        */
+       rcu_read_lock();
+       memcg = obj_cgroup_memcg(objcg);
+       css_get(&memcg->css);
+       rcu_read_unlock();
 
-/**
- * __memcg_kmem_uncharge_page: uncharge a kmem page
- * @page: page to uncharge
- * @order: allocation order
- */
-void __memcg_kmem_uncharge_page(struct page *page, int order)
-{
-       struct mem_cgroup *memcg = page->mem_cgroup;
-       unsigned int nr_pages = 1 << order;
+       nr_pages = size >> PAGE_SHIFT;
+       nr_bytes = size & (PAGE_SIZE - 1);
 
-       if (!memcg)
-               return;
+       if (nr_bytes)
+               nr_pages += 1;
 
-       VM_BUG_ON_PAGE(mem_cgroup_is_root(memcg), page);
-       __memcg_kmem_uncharge(memcg, nr_pages);
-       page->mem_cgroup = NULL;
+       ret = __memcg_kmem_charge(memcg, gfp, nr_pages);
+       if (!ret && nr_bytes)
+               refill_obj_stock(objcg, PAGE_SIZE - nr_bytes);
 
-       /* slab pages do not have PageKmemcg flag set */
-       if (PageKmemcg(page))
-               __ClearPageKmemcg(page);
+       css_put(&memcg->css);
+       return ret;
+}
 
-       css_put_many(&memcg->css, nr_pages);
+void obj_cgroup_uncharge(struct obj_cgroup *objcg, size_t size)
+{
+       refill_obj_stock(objcg, size);
 }
+
 #endif /* CONFIG_MEMCG_KMEM */
 
 #ifdef CONFIG_TRANSPARENT_HUGEPAGE
@@ -3008,15 +3221,16 @@ void __memcg_kmem_uncharge_page(struct page *page, int order)
  */
 void mem_cgroup_split_huge_fixup(struct page *head)
 {
+       struct mem_cgroup *memcg = head->mem_cgroup;
        int i;
 
        if (mem_cgroup_disabled())
                return;
 
-       for (i = 1; i < HPAGE_PMD_NR; i++)
-               head[i].mem_cgroup = head->mem_cgroup;
-
-       __mod_memcg_state(head->mem_cgroup, MEMCG_RSS_HUGE, -HPAGE_PMD_NR);
+       for (i = 1; i < HPAGE_PMD_NR; i++) {
+               css_get(&memcg->css);
+               head[i].mem_cgroup = memcg;
+       }
 }
 #endif /* CONFIG_TRANSPARENT_HUGEPAGE */
 
@@ -3201,7 +3415,7 @@ unsigned long mem_cgroup_soft_limit_reclaim(pg_data_t *pgdat, int order,
  * Test whether @memcg has children, dead or alive.  Note that this
  * function doesn't care whether @memcg has use_hierarchy enabled and
  * returns %true if there are child csses according to the cgroup
- * hierarchy.  Testing use_hierarchy is the caller's responsiblity.
+ * hierarchy.  Testing use_hierarchy is the caller's responsibility.
  */
 static inline bool memcg_has_children(struct mem_cgroup *memcg)
 {
@@ -3220,7 +3434,7 @@ static inline bool memcg_has_children(struct mem_cgroup *memcg)
  */
 static int mem_cgroup_force_empty(struct mem_cgroup *memcg)
 {
-       int nr_retries = MEM_CGROUP_RECLAIM_RETRIES;
+       int nr_retries = MAX_RECLAIM_RETRIES;
 
        /* we call try-to-free pages for make this cgroup empty */
        lru_add_drain_all();
@@ -3299,8 +3513,8 @@ static unsigned long mem_cgroup_usage(struct mem_cgroup *memcg, bool swap)
        unsigned long val;
 
        if (mem_cgroup_is_root(memcg)) {
-               val = memcg_page_state(memcg, MEMCG_CACHE) +
-                       memcg_page_state(memcg, MEMCG_RSS);
+               val = memcg_page_state(memcg, NR_FILE_PAGES) +
+                       memcg_page_state(memcg, NR_ANON_MAPPED);
                if (swap)
                        val += memcg_page_state(memcg, MEMCG_SWAP);
        } else {
@@ -3417,6 +3631,7 @@ static void memcg_flush_percpu_vmevents(struct mem_cgroup *memcg)
 #ifdef CONFIG_MEMCG_KMEM
 static int memcg_online_kmem(struct mem_cgroup *memcg)
 {
+       struct obj_cgroup *objcg;
        int memcg_id;
 
        if (cgroup_memory_nokmem)
@@ -3429,7 +3644,16 @@ static int memcg_online_kmem(struct mem_cgroup *memcg)
        if (memcg_id < 0)
                return memcg_id;
 
-       static_branch_inc(&memcg_kmem_enabled_key);
+       objcg = obj_cgroup_alloc();
+       if (!objcg) {
+               memcg_free_cache_id(memcg_id);
+               return -ENOMEM;
+       }
+       objcg->memcg = memcg;
+       rcu_assign_pointer(memcg->objcg, objcg);
+
+       static_branch_enable(&memcg_kmem_enabled_key);
+
        /*
         * A memory cgroup is considered kmem-online as soon as it gets
         * kmemcg_id. Setting the id after enabling static branching will
@@ -3438,7 +3662,6 @@ static int memcg_online_kmem(struct mem_cgroup *memcg)
         */
        memcg->kmemcg_id = memcg_id;
        memcg->kmem_state = KMEM_ONLINE;
-       INIT_LIST_HEAD(&memcg->kmem_caches);
 
        return 0;
 }
@@ -3451,22 +3674,14 @@ static void memcg_offline_kmem(struct mem_cgroup *memcg)
 
        if (memcg->kmem_state != KMEM_ONLINE)
                return;
-       /*
-        * Clear the online state before clearing memcg_caches array
-        * entries. The slab_mutex in memcg_deactivate_kmem_caches()
-        * guarantees that no cache will be created for this cgroup
-        * after we are done (see memcg_create_kmem_cache()).
-        */
+
        memcg->kmem_state = KMEM_ALLOCATED;
 
        parent = parent_mem_cgroup(memcg);
        if (!parent)
                parent = root_mem_cgroup;
 
-       /*
-        * Deactivate and reparent kmem_caches.
-        */
-       memcg_deactivate_kmem_caches(memcg, parent);
+       memcg_reparent_objcgs(memcg, parent);
 
        kmemcg_id = memcg->kmemcg_id;
        BUG_ON(kmemcg_id < 0);
@@ -3499,11 +3714,6 @@ static void memcg_free_kmem(struct mem_cgroup *memcg)
        /* css_alloc() failed, offlining didn't happen */
        if (unlikely(memcg->kmem_state == KMEM_ONLINE))
                memcg_offline_kmem(memcg);
-
-       if (memcg->kmem_state == KMEM_ALLOCATED) {
-               WARN_ON(!list_empty(&memcg->kmem_caches));
-               static_branch_dec(&memcg_kmem_enabled_key);
-       }
 }
 #else
 static int memcg_online_kmem(struct mem_cgroup *memcg)
@@ -3688,7 +3898,7 @@ static int mem_cgroup_move_charge_write(struct cgroup_subsys_state *css,
 #define LRU_ALL             ((1 << NR_LRU_LISTS) - 1)
 
 static unsigned long mem_cgroup_node_nr_lru_pages(struct mem_cgroup *memcg,
-                                          int nid, unsigned int lru_mask)
+                               int nid, unsigned int lru_mask, bool tree)
 {
        struct lruvec *lruvec = mem_cgroup_lruvec(memcg, NODE_DATA(nid));
        unsigned long nr = 0;
@@ -3699,13 +3909,17 @@ static unsigned long mem_cgroup_node_nr_lru_pages(struct mem_cgroup *memcg,
        for_each_lru(lru) {
                if (!(BIT(lru) & lru_mask))
                        continue;
-               nr += lruvec_page_state_local(lruvec, NR_LRU_BASE + lru);
+               if (tree)
+                       nr += lruvec_page_state(lruvec, NR_LRU_BASE + lru);
+               else
+                       nr += lruvec_page_state_local(lruvec, NR_LRU_BASE + lru);
        }
        return nr;
 }
 
 static unsigned long mem_cgroup_nr_lru_pages(struct mem_cgroup *memcg,
-                                            unsigned int lru_mask)
+                                            unsigned int lru_mask,
+                                            bool tree)
 {
        unsigned long nr = 0;
        enum lru_list lru;
@@ -3713,7 +3927,10 @@ static unsigned long mem_cgroup_nr_lru_pages(struct mem_cgroup *memcg,
        for_each_lru(lru) {
                if (!(BIT(lru) & lru_mask))
                        continue;
-               nr += memcg_page_state_local(memcg, NR_LRU_BASE + lru);
+               if (tree)
+                       nr += memcg_page_state(memcg, NR_LRU_BASE + lru);
+               else
+                       nr += memcg_page_state_local(memcg, NR_LRU_BASE + lru);
        }
        return nr;
 }
@@ -3733,34 +3950,28 @@ static int memcg_numa_stat_show(struct seq_file *m, void *v)
        };
        const struct numa_stat *stat;
        int nid;
-       unsigned long nr;
        struct mem_cgroup *memcg = mem_cgroup_from_seq(m);
 
        for (stat = stats; stat < stats + ARRAY_SIZE(stats); stat++) {
-               nr = mem_cgroup_nr_lru_pages(memcg, stat->lru_mask);
-               seq_printf(m, "%s=%lu", stat->name, nr);
-               for_each_node_state(nid, N_MEMORY) {
-                       nr = mem_cgroup_node_nr_lru_pages(memcg, nid,
-                                                         stat->lru_mask);
-                       seq_printf(m, " N%d=%lu", nid, nr);
-               }
+               seq_printf(m, "%s=%lu", stat->name,
+                          mem_cgroup_nr_lru_pages(memcg, stat->lru_mask,
+                                                  false));
+               for_each_node_state(nid, N_MEMORY)
+                       seq_printf(m, " N%d=%lu", nid,
+                                  mem_cgroup_node_nr_lru_pages(memcg, nid,
+                                                       stat->lru_mask, false));
                seq_putc(m, '\n');
        }
 
        for (stat = stats; stat < stats + ARRAY_SIZE(stats); stat++) {
-               struct mem_cgroup *iter;
-
-               nr = 0;
-               for_each_mem_cgroup_tree(iter, memcg)
-                       nr += mem_cgroup_nr_lru_pages(iter, stat->lru_mask);
-               seq_printf(m, "hierarchical_%s=%lu", stat->name, nr);
-               for_each_node_state(nid, N_MEMORY) {
-                       nr = 0;
-                       for_each_mem_cgroup_tree(iter, memcg)
-                               nr += mem_cgroup_node_nr_lru_pages(
-                                       iter, nid, stat->lru_mask);
-                       seq_printf(m, " N%d=%lu", nid, nr);
-               }
+
+               seq_printf(m, "hierarchical_%s=%lu", stat->name,
+                          mem_cgroup_nr_lru_pages(memcg, stat->lru_mask,
+                                                  true));
+               for_each_node_state(nid, N_MEMORY)
+                       seq_printf(m, " N%d=%lu", nid,
+                                  mem_cgroup_node_nr_lru_pages(memcg, nid,
+                                                       stat->lru_mask, true));
                seq_putc(m, '\n');
        }
 
@@ -3769,9 +3980,11 @@ static int memcg_numa_stat_show(struct seq_file *m, void *v)
 #endif /* CONFIG_NUMA */
 
 static const unsigned int memcg1_stats[] = {
-       MEMCG_CACHE,
-       MEMCG_RSS,
-       MEMCG_RSS_HUGE,
+       NR_FILE_PAGES,
+       NR_ANON_MAPPED,
+#ifdef CONFIG_TRANSPARENT_HUGEPAGE
+       NR_ANON_THPS,
+#endif
        NR_SHMEM,
        NR_FILE_MAPPED,
        NR_FILE_DIRTY,
@@ -3782,7 +3995,9 @@ static const unsigned int memcg1_stats[] = {
 static const char *const memcg1_stat_names[] = {
        "cache",
        "rss",
+#ifdef CONFIG_TRANSPARENT_HUGEPAGE
        "rss_huge",
+#endif
        "shmem",
        "mapped_file",
        "dirty",
@@ -3808,11 +4023,16 @@ static int memcg_stat_show(struct seq_file *m, void *v)
        BUILD_BUG_ON(ARRAY_SIZE(memcg1_stat_names) != ARRAY_SIZE(memcg1_stats));
 
        for (i = 0; i < ARRAY_SIZE(memcg1_stats); i++) {
+               unsigned long nr;
+
                if (memcg1_stats[i] == MEMCG_SWAP && !do_memsw_account())
                        continue;
-               seq_printf(m, "%s %lu\n", memcg1_stat_names[i],
-                          memcg_page_state_local(memcg, memcg1_stats[i]) *
-                          PAGE_SIZE);
+               nr = memcg_page_state_local(memcg, memcg1_stats[i]);
+#ifdef CONFIG_TRANSPARENT_HUGEPAGE
+               if (memcg1_stats[i] == NR_ANON_THPS)
+                       nr *= HPAGE_PMD_NR;
+#endif
+               seq_printf(m, "%s %lu\n", memcg1_stat_names[i], nr * PAGE_SIZE);
        }
 
        for (i = 0; i < ARRAY_SIZE(memcg1_events); i++)
@@ -3858,23 +4078,17 @@ static int memcg_stat_show(struct seq_file *m, void *v)
        {
                pg_data_t *pgdat;
                struct mem_cgroup_per_node *mz;
-               struct zone_reclaim_stat *rstat;
-               unsigned long recent_rotated[2] = {0, 0};
-               unsigned long recent_scanned[2] = {0, 0};
+               unsigned long anon_cost = 0;
+               unsigned long file_cost = 0;
 
                for_each_online_pgdat(pgdat) {
                        mz = mem_cgroup_nodeinfo(memcg, pgdat->node_id);
-                       rstat = &mz->lruvec.reclaim_stat;
 
-                       recent_rotated[0] += rstat->recent_rotated[0];
-                       recent_rotated[1] += rstat->recent_rotated[1];
-                       recent_scanned[0] += rstat->recent_scanned[0];
-                       recent_scanned[1] += rstat->recent_scanned[1];
+                       anon_cost += mz->lruvec.anon_cost;
+                       file_cost += mz->lruvec.file_cost;
                }
-               seq_printf(m, "recent_rotated_anon %lu\n", recent_rotated[0]);
-               seq_printf(m, "recent_rotated_file %lu\n", recent_rotated[1]);
-               seq_printf(m, "recent_scanned_anon %lu\n", recent_scanned[0]);
-               seq_printf(m, "recent_scanned_file %lu\n", recent_scanned[1]);
+               seq_printf(m, "anon_cost %lu\n", anon_cost);
+               seq_printf(m, "file_cost %lu\n", file_cost);
        }
 #endif
 
@@ -4330,7 +4544,6 @@ void mem_cgroup_wb_stats(struct bdi_writeback *wb, unsigned long *pfilepages,
 
        *pdirty = memcg_exact_page_state(memcg, NR_FILE_DIRTY);
 
-       /* this should eventually include NR_UNSTABLE_NFS */
        *pwriteback = memcg_exact_page_state(memcg, NR_WRITEBACK);
        *pfilepages = memcg_exact_page_state(memcg, NR_INACTIVE_FILE) +
                        memcg_exact_page_state(memcg, NR_ACTIVE_FILE);
@@ -4338,7 +4551,7 @@ void mem_cgroup_wb_stats(struct bdi_writeback *wb, unsigned long *pfilepages,
 
        while ((parent = parent_mem_cgroup(memcg))) {
                unsigned long ceiling = min(READ_ONCE(memcg->memory.max),
-                                           READ_ONCE(memcg->high));
+                                           READ_ONCE(memcg->memory.high));
                unsigned long used = page_counter_read(&memcg->memory);
 
                *pheadroom = min(*pheadroom, ceiling - min(ceiling, used));
@@ -4810,9 +5023,6 @@ static struct cftype mem_cgroup_legacy_files[] = {
        (defined(CONFIG_SLAB) || defined(CONFIG_SLUB_DEBUG))
        {
                .name = "kmem.slabinfo",
-               .seq_start = memcg_slab_start,
-               .seq_next = memcg_slab_next,
-               .seq_stop = memcg_slab_stop,
                .seq_show = memcg_slab_show,
        },
 #endif
@@ -4850,7 +5060,7 @@ static struct cftype mem_cgroup_legacy_files[] = {
  * limited to 16 bit (MEM_CGROUP_ID_MAX), limiting the total number of
  * memory-controlled cgroups to 64k.
  *
- * However, there usually are many references to the oflline CSS after
+ * However, there usually are many references to the offline CSS after
  * the cgroup has been destroyed, such as page cache or reclaimable
  * slab objects, that don't need to hang on to the ID. We want to keep
  * those dead CSS from occupying IDs, or we might quickly exhaust the
@@ -4927,13 +5137,15 @@ static int alloc_mem_cgroup_per_node_info(struct mem_cgroup *memcg, int node)
        if (!pn)
                return 1;
 
-       pn->lruvec_stat_local = alloc_percpu(struct lruvec_stat);
+       pn->lruvec_stat_local = alloc_percpu_gfp(struct lruvec_stat,
+                                                GFP_KERNEL_ACCOUNT);
        if (!pn->lruvec_stat_local) {
                kfree(pn);
                return 1;
        }
 
-       pn->lruvec_stat_cpu = alloc_percpu(struct lruvec_stat);
+       pn->lruvec_stat_cpu = alloc_percpu_gfp(struct lruvec_stat,
+                                              GFP_KERNEL_ACCOUNT);
        if (!pn->lruvec_stat_cpu) {
                free_percpu(pn->lruvec_stat_local);
                kfree(pn);
@@ -5007,11 +5219,13 @@ static struct mem_cgroup *mem_cgroup_alloc(void)
                goto fail;
        }
 
-       memcg->vmstats_local = alloc_percpu(struct memcg_vmstats_percpu);
+       memcg->vmstats_local = alloc_percpu_gfp(struct memcg_vmstats_percpu,
+                                               GFP_KERNEL_ACCOUNT);
        if (!memcg->vmstats_local)
                goto fail;
 
-       memcg->vmstats_percpu = alloc_percpu(struct memcg_vmstats_percpu);
+       memcg->vmstats_percpu = alloc_percpu_gfp(struct memcg_vmstats_percpu,
+                                                GFP_KERNEL_ACCOUNT);
        if (!memcg->vmstats_percpu)
                goto fail;
 
@@ -5032,6 +5246,7 @@ static struct mem_cgroup *mem_cgroup_alloc(void)
        memcg->socket_pressure = jiffies;
 #ifdef CONFIG_MEMCG_KMEM
        memcg->kmemcg_id = -1;
+       INIT_LIST_HEAD(&memcg->objcg_list);
 #endif
 #ifdef CONFIG_CGROUP_WRITEBACK
        INIT_LIST_HEAD(&memcg->cgwb_list);
@@ -5059,12 +5274,15 @@ mem_cgroup_css_alloc(struct cgroup_subsys_state *parent_css)
        struct mem_cgroup *memcg;
        long error = -ENOMEM;
 
+       memalloc_use_memcg(parent);
        memcg = mem_cgroup_alloc();
+       memalloc_unuse_memcg();
        if (IS_ERR(memcg))
                return ERR_CAST(memcg);
 
-       WRITE_ONCE(memcg->high, PAGE_COUNTER_MAX);
+       page_counter_set_high(&memcg->memory, PAGE_COUNTER_MAX);
        memcg->soft_limit = PAGE_COUNTER_MAX;
+       page_counter_set_high(&memcg->swap, PAGE_COUNTER_MAX);
        if (parent) {
                memcg->swappiness = mem_cgroup_swappiness(parent);
                memcg->oom_kill_disable = parent->oom_kill_disable;
@@ -5093,9 +5311,6 @@ mem_cgroup_css_alloc(struct cgroup_subsys_state *parent_css)
 
        /* The following stuff does not apply to the root */
        if (!parent) {
-#ifdef CONFIG_MEMCG_KMEM
-               INIT_LIST_HEAD(&memcg->kmem_caches);
-#endif
                root_mem_cgroup = memcg;
                return &memcg->css;
        }
@@ -5216,8 +5431,9 @@ static void mem_cgroup_css_reset(struct cgroup_subsys_state *css)
        page_counter_set_max(&memcg->tcpmem, PAGE_COUNTER_MAX);
        page_counter_set_min(&memcg->memory, 0);
        page_counter_set_low(&memcg->memory, 0);
-       WRITE_ONCE(memcg->high, PAGE_COUNTER_MAX);
+       page_counter_set_high(&memcg->memory, PAGE_COUNTER_MAX);
        memcg->soft_limit = PAGE_COUNTER_MAX;
+       page_counter_set_high(&memcg->swap, PAGE_COUNTER_MAX);
        memcg_wb_domain_size_changed(memcg);
 }
 
@@ -5308,8 +5524,7 @@ static struct page *mc_handle_swap_pte(struct vm_area_struct *vma,
         * we call find_get_page() with swapper_space directly.
         */
        page = find_get_page(swap_address_space(ent), swp_offset(ent));
-       if (do_memsw_account())
-               entry->val = ent.val;
+       entry->val = ent.val;
 
        return page;
 }
@@ -5343,8 +5558,7 @@ static struct page *mc_handle_file_pte(struct vm_area_struct *vma,
                page = find_get_entry(mapping, pgoff);
                if (xa_is_value(page)) {
                        swp_entry_t swp = radix_to_swp_entry(page);
-                       if (do_memsw_account())
-                               *entry = swp;
+                       *entry = swp;
                        page = find_get_page(swap_address_space(swp),
                                             swp_offset(swp));
                }
@@ -5375,10 +5589,8 @@ static int mem_cgroup_move_account(struct page *page,
 {
        struct lruvec *from_vec, *to_vec;
        struct pglist_data *pgdat;
-       unsigned long flags;
-       unsigned int nr_pages = compound ? hpage_nr_pages(page) : 1;
+       unsigned int nr_pages = compound ? thp_nr_pages(page) : 1;
        int ret;
-       bool anon;
 
        VM_BUG_ON(from == to);
        VM_BUG_ON_PAGE(PageLRU(page), page);
@@ -5396,30 +5608,47 @@ static int mem_cgroup_move_account(struct page *page,
        if (page->mem_cgroup != from)
                goto out_unlock;
 
-       anon = PageAnon(page);
-
        pgdat = page_pgdat(page);
        from_vec = mem_cgroup_lruvec(from, pgdat);
        to_vec = mem_cgroup_lruvec(to, pgdat);
 
-       spin_lock_irqsave(&from->move_lock, flags);
+       lock_page_memcg(page);
 
-       if (!anon && page_mapped(page)) {
-               __mod_lruvec_state(from_vec, NR_FILE_MAPPED, -nr_pages);
-               __mod_lruvec_state(to_vec, NR_FILE_MAPPED, nr_pages);
-       }
+       if (PageAnon(page)) {
+               if (page_mapped(page)) {
+                       __mod_lruvec_state(from_vec, NR_ANON_MAPPED, -nr_pages);
+                       __mod_lruvec_state(to_vec, NR_ANON_MAPPED, nr_pages);
+                       if (PageTransHuge(page)) {
+                               __mod_lruvec_state(from_vec, NR_ANON_THPS,
+                                                  -nr_pages);
+                               __mod_lruvec_state(to_vec, NR_ANON_THPS,
+                                                  nr_pages);
+                       }
 
-       /*
-        * move_lock grabbed above and caller set from->moving_account, so
-        * mod_memcg_page_state will serialize updates to PageDirty.
-        * So mapping should be stable for dirty pages.
-        */
-       if (!anon && PageDirty(page)) {
-               struct address_space *mapping = page_mapping(page);
+               }
+       } else {
+               __mod_lruvec_state(from_vec, NR_FILE_PAGES, -nr_pages);
+               __mod_lruvec_state(to_vec, NR_FILE_PAGES, nr_pages);
 
-               if (mapping_cap_account_dirty(mapping)) {
-                       __mod_lruvec_state(from_vec, NR_FILE_DIRTY, -nr_pages);
-                       __mod_lruvec_state(to_vec, NR_FILE_DIRTY, nr_pages);
+               if (PageSwapBacked(page)) {
+                       __mod_lruvec_state(from_vec, NR_SHMEM, -nr_pages);
+                       __mod_lruvec_state(to_vec, NR_SHMEM, nr_pages);
+               }
+
+               if (page_mapped(page)) {
+                       __mod_lruvec_state(from_vec, NR_FILE_MAPPED, -nr_pages);
+                       __mod_lruvec_state(to_vec, NR_FILE_MAPPED, nr_pages);
+               }
+
+               if (PageDirty(page)) {
+                       struct address_space *mapping = page_mapping(page);
+
+                       if (mapping_cap_account_dirty(mapping)) {
+                               __mod_lruvec_state(from_vec, NR_FILE_DIRTY,
+                                                  -nr_pages);
+                               __mod_lruvec_state(to_vec, NR_FILE_DIRTY,
+                                                  nr_pages);
+                       }
                }
        }
 
@@ -5429,22 +5658,33 @@ static int mem_cgroup_move_account(struct page *page,
        }
 
        /*
+        * All state has been migrated, let's switch to the new memcg.
+        *
         * It is safe to change page->mem_cgroup here because the page
-        * is referenced, charged, and isolated - we can't race with
-        * uncharging, charging, migration, or LRU putback.
+        * is referenced, charged, isolated, and locked: we can't race
+        * with (un)charging, migration, LRU putback, or anything else
+        * that would rely on a stable page->mem_cgroup.
+        *
+        * Note that lock_page_memcg is a memcg lock, not a page lock,
+        * to save space. As soon as we switch page->mem_cgroup to a
+        * new memcg that isn't locked, the above state can change
+        * concurrently again. Make sure we're truly done with it.
         */
+       smp_mb();
+
+       css_get(&to->css);
+       css_put(&from->css);
 
-       /* caller should have done css_get */
        page->mem_cgroup = to;
 
-       spin_unlock_irqrestore(&from->move_lock, flags);
+       __unlock_page_memcg(from);
 
        ret = 0;
 
        local_irq_disable();
-       mem_cgroup_charge_statistics(to, page, compound, nr_pages);
+       mem_cgroup_charge_statistics(to, page, nr_pages);
        memcg_check_events(to, page);
-       mem_cgroup_charge_statistics(from, page, compound, -nr_pages);
+       mem_cgroup_charge_statistics(from, page, -nr_pages);
        memcg_check_events(from, page);
        local_irq_enable();
 out_unlock:
@@ -5603,9 +5843,9 @@ static unsigned long mem_cgroup_count_precharge(struct mm_struct *mm)
 {
        unsigned long precharge;
 
-       down_read(&mm->mmap_sem);
+       mmap_read_lock(mm);
        walk_page_range(mm, 0, mm->highest_vm_end, &precharge_walk_ops, NULL);
-       up_read(&mm->mmap_sem);
+       mmap_read_unlock(mm);
 
        precharge = mc.precharge;
        mc.precharge = 0;
@@ -5656,9 +5896,6 @@ static void __mem_cgroup_clear_mc(void)
                if (!mem_cgroup_is_root(mc.to))
                        page_counter_uncharge(&mc.to->memory, mc.moved_swap);
 
-               mem_cgroup_id_get_many(mc.to, mc.moved_swap);
-               css_put_many(&mc.to->css, mc.moved_swap);
-
                mc.moved_swap = 0;
        }
        memcg_oom_recover(from);
@@ -5847,7 +6084,8 @@ put:                      /* get_mctgt_type() gets the page */
                        ent = target.ent;
                        if (!mem_cgroup_move_swap_account(ent, mc.from, mc.to)) {
                                mc.precharge--;
-                               /* we fixup refcnts and charges later. */
+                               mem_cgroup_id_get_many(mc.to, 1);
+                               /* we fixup other refcnts and charges later. */
                                mc.moved_swap++;
                        }
                        break;
@@ -5888,9 +6126,9 @@ static void mem_cgroup_move_charge(void)
        atomic_inc(&mc.from->moving_account);
        synchronize_rcu();
 retry:
-       if (unlikely(!down_read_trylock(&mc.mm->mmap_sem))) {
+       if (unlikely(!mmap_read_trylock(mc.mm))) {
                /*
-                * Someone who are holding the mmap_sem might be waiting in
+                * Someone who are holding the mmap_lock might be waiting in
                 * waitq. So we cancel all extra charges, wake up all waiters,
                 * and retry. Because we cancel precharges, we might not be able
                 * to move enough charges, but moving charge is a best-effort
@@ -5907,7 +6145,7 @@ retry:
        walk_page_range(mc.mm, 0, mc.mm->highest_vm_end, &charge_walk_ops,
                        NULL);
 
-       up_read(&mc.mm->mmap_sem);
+       mmap_read_unlock(mc.mm);
        atomic_dec(&mc.from->moving_account);
 }
 
@@ -6015,14 +6253,15 @@ static ssize_t memory_low_write(struct kernfs_open_file *of,
 
 static int memory_high_show(struct seq_file *m, void *v)
 {
-       return seq_puts_memcg_tunable(m, READ_ONCE(mem_cgroup_from_seq(m)->high));
+       return seq_puts_memcg_tunable(m,
+               READ_ONCE(mem_cgroup_from_seq(m)->memory.high));
 }
 
 static ssize_t memory_high_write(struct kernfs_open_file *of,
                                 char *buf, size_t nbytes, loff_t off)
 {
        struct mem_cgroup *memcg = mem_cgroup_from_css(of_css(of));
-       unsigned int nr_retries = MEM_CGROUP_RECLAIM_RETRIES;
+       unsigned int nr_retries = MAX_RECLAIM_RETRIES;
        bool drained = false;
        unsigned long high;
        int err;
@@ -6032,8 +6271,6 @@ static ssize_t memory_high_write(struct kernfs_open_file *of,
        if (err)
                return err;
 
-       WRITE_ONCE(memcg->high, high);
-
        for (;;) {
                unsigned long nr_pages = page_counter_read(&memcg->memory);
                unsigned long reclaimed;
@@ -6057,6 +6294,10 @@ static ssize_t memory_high_write(struct kernfs_open_file *of,
                        break;
        }
 
+       page_counter_set_high(&memcg->memory, high);
+
+       memcg_wb_domain_size_changed(memcg);
+
        return nbytes;
 }
 
@@ -6070,7 +6311,7 @@ static ssize_t memory_max_write(struct kernfs_open_file *of,
                                char *buf, size_t nbytes, loff_t off)
 {
        struct mem_cgroup *memcg = mem_cgroup_from_css(of_css(of));
-       unsigned int nr_reclaims = MEM_CGROUP_RECLAIM_RETRIES;
+       unsigned int nr_reclaims = MAX_RECLAIM_RETRIES;
        bool drained = false;
        unsigned long max;
        int err;
@@ -6227,7 +6468,6 @@ static struct cftype memory_files[] = {
        },
        {
                .name = "stat",
-               .flags = CFTYPE_NOT_ON_ROOT,
                .seq_show = memory_stat_show,
        },
        {
@@ -6349,11 +6589,16 @@ static unsigned long effective_protection(unsigned long usage,
         * We're using unprotected memory for the weight so that if
         * some cgroups DO claim explicit protection, we don't protect
         * the same bytes twice.
+        *
+        * Check both usage and parent_usage against the respective
+        * protected values. One should imply the other, but they
+        * aren't read atomically - make sure the division is sane.
         */
        if (!(cgrp_dfl_root.flags & CGRP_ROOT_MEMORY_RECURSIVE_PROT))
                return ep;
-
-       if (parent_effective > siblings_protected && usage > protected) {
+       if (parent_effective > siblings_protected &&
+           parent_usage > siblings_protected &&
+           usage > protected) {
                unsigned long unclaimed;
 
                unclaimed = parent_effective - siblings_protected;
@@ -6373,40 +6618,42 @@ static unsigned long effective_protection(unsigned long usage,
  *
  * WARNING: This function is not stateless! It can only be used as part
  *          of a top-down tree iteration, not for isolated queries.
- *
- * Returns one of the following:
- *   MEMCG_PROT_NONE: cgroup memory is not protected
- *   MEMCG_PROT_LOW: cgroup memory is protected as long there is
- *     an unprotected supply of reclaimable memory from other cgroups.
- *   MEMCG_PROT_MIN: cgroup memory is protected
  */
-enum mem_cgroup_protection mem_cgroup_protected(struct mem_cgroup *root,
-                                               struct mem_cgroup *memcg)
+void mem_cgroup_calculate_protection(struct mem_cgroup *root,
+                                    struct mem_cgroup *memcg)
 {
        unsigned long usage, parent_usage;
        struct mem_cgroup *parent;
 
        if (mem_cgroup_disabled())
-               return MEMCG_PROT_NONE;
+               return;
 
        if (!root)
                root = root_mem_cgroup;
+
+       /*
+        * Effective values of the reclaim targets are ignored so they
+        * can be stale. Have a look at mem_cgroup_protection for more
+        * details.
+        * TODO: calculation should be more robust so that we do not need
+        * that special casing.
+        */
        if (memcg == root)
-               return MEMCG_PROT_NONE;
+               return;
 
        usage = page_counter_read(&memcg->memory);
        if (!usage)
-               return MEMCG_PROT_NONE;
+               return;
 
        parent = parent_mem_cgroup(memcg);
        /* No parent means a non-hierarchical mode on v1 memcg */
        if (!parent)
-               return MEMCG_PROT_NONE;
+               return;
 
        if (parent == root) {
                memcg->memory.emin = READ_ONCE(memcg->memory.min);
-               memcg->memory.elow = memcg->memory.low;
-               goto out;
+               memcg->memory.elow = READ_ONCE(memcg->memory.low);
+               return;
        }
 
        parent_usage = page_counter_read(&parent->memory);
@@ -6417,138 +6664,70 @@ enum mem_cgroup_protection mem_cgroup_protected(struct mem_cgroup *root,
                        atomic_long_read(&parent->memory.children_min_usage)));
 
        WRITE_ONCE(memcg->memory.elow, effective_protection(usage, parent_usage,
-                       memcg->memory.low, READ_ONCE(parent->memory.elow),
+                       READ_ONCE(memcg->memory.low),
+                       READ_ONCE(parent->memory.elow),
                        atomic_long_read(&parent->memory.children_low_usage)));
-
-out:
-       if (usage <= memcg->memory.emin)
-               return MEMCG_PROT_MIN;
-       else if (usage <= memcg->memory.elow)
-               return MEMCG_PROT_LOW;
-       else
-               return MEMCG_PROT_NONE;
 }
 
 /**
- * mem_cgroup_try_charge - try charging a page
+ * mem_cgroup_charge - charge a newly allocated page to a cgroup
  * @page: page to charge
  * @mm: mm context of the victim
  * @gfp_mask: reclaim mode
- * @memcgp: charged memcg return
- * @compound: charge the page as compound or small page
  *
  * Try to charge @page to the memcg that @mm belongs to, reclaiming
  * pages according to @gfp_mask if necessary.
  *
- * Returns 0 on success, with *@memcgp pointing to the charged memcg.
- * Otherwise, an error code is returned.
- *
- * After page->mapping has been set up, the caller must finalize the
- * charge with mem_cgroup_commit_charge().  Or abort the transaction
- * with mem_cgroup_cancel_charge() in case page instantiation fails.
+ * Returns 0 on success. Otherwise, an error code is returned.
  */
-int mem_cgroup_try_charge(struct page *page, struct mm_struct *mm,
-                         gfp_t gfp_mask, struct mem_cgroup **memcgp,
-                         bool compound)
+int mem_cgroup_charge(struct page *page, struct mm_struct *mm, gfp_t gfp_mask)
 {
+       unsigned int nr_pages = thp_nr_pages(page);
        struct mem_cgroup *memcg = NULL;
-       unsigned int nr_pages = compound ? hpage_nr_pages(page) : 1;
        int ret = 0;
 
        if (mem_cgroup_disabled())
                goto out;
 
        if (PageSwapCache(page)) {
+               swp_entry_t ent = { .val = page_private(page), };
+               unsigned short id;
+
                /*
                 * Every swap fault against a single page tries to charge the
                 * page, bail as early as possible.  shmem_unuse() encounters
-                * already charged pages, too.  The USED bit is protected by
-                * the page lock, which serializes swap cache removal, which
+                * already charged pages, too.  page->mem_cgroup is protected
+                * by the page lock, which serializes swap cache removal, which
                 * in turn serializes uncharging.
                 */
                VM_BUG_ON_PAGE(!PageLocked(page), page);
                if (compound_head(page)->mem_cgroup)
                        goto out;
 
-               if (do_swap_account) {
-                       swp_entry_t ent = { .val = page_private(page), };
-                       unsigned short id = lookup_swap_cgroup_id(ent);
-
-                       rcu_read_lock();
-                       memcg = mem_cgroup_from_id(id);
-                       if (memcg && !css_tryget_online(&memcg->css))
-                               memcg = NULL;
-                       rcu_read_unlock();
-               }
+               id = lookup_swap_cgroup_id(ent);
+               rcu_read_lock();
+               memcg = mem_cgroup_from_id(id);
+               if (memcg && !css_tryget_online(&memcg->css))
+                       memcg = NULL;
+               rcu_read_unlock();
        }
 
        if (!memcg)
                memcg = get_mem_cgroup_from_mm(mm);
 
        ret = try_charge(memcg, gfp_mask, nr_pages);
+       if (ret)
+               goto out_put;
 
-       css_put(&memcg->css);
-out:
-       *memcgp = memcg;
-       return ret;
-}
-
-int mem_cgroup_try_charge_delay(struct page *page, struct mm_struct *mm,
-                         gfp_t gfp_mask, struct mem_cgroup **memcgp,
-                         bool compound)
-{
-       struct mem_cgroup *memcg;
-       int ret;
-
-       ret = mem_cgroup_try_charge(page, mm, gfp_mask, memcgp, compound);
-       memcg = *memcgp;
-       mem_cgroup_throttle_swaprate(memcg, page_to_nid(page), gfp_mask);
-       return ret;
-}
-
-/**
- * mem_cgroup_commit_charge - commit a page charge
- * @page: page to charge
- * @memcg: memcg to charge the page to
- * @lrucare: page might be on LRU already
- * @compound: charge the page as compound or small page
- *
- * Finalize a charge transaction started by mem_cgroup_try_charge(),
- * after page->mapping has been set up.  This must happen atomically
- * as part of the page instantiation, i.e. under the page table lock
- * for anonymous pages, under the page lock for page and swap cache.
- *
- * In addition, the page must not be on the LRU during the commit, to
- * prevent racing with task migration.  If it might be, use @lrucare.
- *
- * Use mem_cgroup_cancel_charge() to cancel the transaction instead.
- */
-void mem_cgroup_commit_charge(struct page *page, struct mem_cgroup *memcg,
-                             bool lrucare, bool compound)
-{
-       unsigned int nr_pages = compound ? hpage_nr_pages(page) : 1;
-
-       VM_BUG_ON_PAGE(!page->mapping, page);
-       VM_BUG_ON_PAGE(PageLRU(page) && !lrucare, page);
-
-       if (mem_cgroup_disabled())
-               return;
-       /*
-        * Swap faults will attempt to charge the same page multiple
-        * times.  But reuse_swap_page() might have removed the page
-        * from swapcache already, so we can't check PageSwapCache().
-        */
-       if (!memcg)
-               return;
-
-       commit_charge(page, memcg, lrucare);
+       css_get(&memcg->css);
+       commit_charge(page, memcg);
 
        local_irq_disable();
-       mem_cgroup_charge_statistics(memcg, page, compound, nr_pages);
+       mem_cgroup_charge_statistics(memcg, page, nr_pages);
        memcg_check_events(memcg, page);
        local_irq_enable();
 
-       if (do_memsw_account() && PageSwapCache(page)) {
+       if (PageSwapCache(page)) {
                swp_entry_t entry = { .val = page_private(page) };
                /*
                 * The swap entry might not get freed for a long time,
@@ -6557,42 +6736,18 @@ void mem_cgroup_commit_charge(struct page *page, struct mem_cgroup *memcg,
                 */
                mem_cgroup_uncharge_swap(entry, nr_pages);
        }
-}
-
-/**
- * mem_cgroup_cancel_charge - cancel a page charge
- * @page: page to charge
- * @memcg: memcg to charge the page to
- * @compound: charge the page as compound or small page
- *
- * Cancel a charge transaction started by mem_cgroup_try_charge().
- */
-void mem_cgroup_cancel_charge(struct page *page, struct mem_cgroup *memcg,
-               bool compound)
-{
-       unsigned int nr_pages = compound ? hpage_nr_pages(page) : 1;
-
-       if (mem_cgroup_disabled())
-               return;
-       /*
-        * Swap faults will attempt to charge the same page multiple
-        * times.  But reuse_swap_page() might have removed the page
-        * from swapcache already, so we can't check PageSwapCache().
-        */
-       if (!memcg)
-               return;
 
-       cancel_charge(memcg, nr_pages);
+out_put:
+       css_put(&memcg->css);
+out:
+       return ret;
 }
 
 struct uncharge_gather {
        struct mem_cgroup *memcg;
+       unsigned long nr_pages;
        unsigned long pgpgout;
-       unsigned long nr_anon;
-       unsigned long nr_file;
        unsigned long nr_kmem;
-       unsigned long nr_huge;
-       unsigned long nr_shmem;
        struct page *dummy_page;
 };
 
@@ -6603,37 +6758,29 @@ static inline void uncharge_gather_clear(struct uncharge_gather *ug)
 
 static void uncharge_batch(const struct uncharge_gather *ug)
 {
-       unsigned long nr_pages = ug->nr_anon + ug->nr_file + ug->nr_kmem;
        unsigned long flags;
 
        if (!mem_cgroup_is_root(ug->memcg)) {
-               page_counter_uncharge(&ug->memcg->memory, nr_pages);
+               page_counter_uncharge(&ug->memcg->memory, ug->nr_pages);
                if (do_memsw_account())
-                       page_counter_uncharge(&ug->memcg->memsw, nr_pages);
+                       page_counter_uncharge(&ug->memcg->memsw, ug->nr_pages);
                if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) && ug->nr_kmem)
                        page_counter_uncharge(&ug->memcg->kmem, ug->nr_kmem);
                memcg_oom_recover(ug->memcg);
        }
 
        local_irq_save(flags);
-       __mod_memcg_state(ug->memcg, MEMCG_RSS, -ug->nr_anon);
-       __mod_memcg_state(ug->memcg, MEMCG_CACHE, -ug->nr_file);
-       __mod_memcg_state(ug->memcg, MEMCG_RSS_HUGE, -ug->nr_huge);
-       __mod_memcg_state(ug->memcg, NR_SHMEM, -ug->nr_shmem);
        __count_memcg_events(ug->memcg, PGPGOUT, ug->pgpgout);
-       __this_cpu_add(ug->memcg->vmstats_percpu->nr_page_events, nr_pages);
+       __this_cpu_add(ug->memcg->vmstats_percpu->nr_page_events, ug->nr_pages);
        memcg_check_events(ug->memcg, ug->dummy_page);
        local_irq_restore(flags);
-
-       if (!mem_cgroup_is_root(ug->memcg))
-               css_put_many(&ug->memcg->css, nr_pages);
 }
 
 static void uncharge_page(struct page *page, struct uncharge_gather *ug)
 {
+       unsigned long nr_pages;
+
        VM_BUG_ON_PAGE(PageLRU(page), page);
-       VM_BUG_ON_PAGE(page_count(page) && !is_zone_device_page(page) &&
-                       !PageHWPoison(page) , page);
 
        if (!page->mem_cgroup)
                return;
@@ -6652,28 +6799,19 @@ static void uncharge_page(struct page *page, struct uncharge_gather *ug)
                ug->memcg = page->mem_cgroup;
        }
 
-       if (!PageKmemcg(page)) {
-               unsigned int nr_pages = 1;
+       nr_pages = compound_nr(page);
+       ug->nr_pages += nr_pages;
 
-               if (PageTransHuge(page)) {
-                       nr_pages = compound_nr(page);
-                       ug->nr_huge += nr_pages;
-               }
-               if (PageAnon(page))
-                       ug->nr_anon += nr_pages;
-               else {
-                       ug->nr_file += nr_pages;
-                       if (PageSwapBacked(page))
-                               ug->nr_shmem += nr_pages;
-               }
+       if (!PageKmemcg(page)) {
                ug->pgpgout++;
        } else {
-               ug->nr_kmem += compound_nr(page);
+               ug->nr_kmem += nr_pages;
                __ClearPageKmemcg(page);
        }
 
        ug->dummy_page = page;
        page->mem_cgroup = NULL;
+       css_put(&ug->memcg->css);
 }
 
 static void uncharge_list(struct list_head *page_list)
@@ -6705,8 +6843,7 @@ static void uncharge_list(struct list_head *page_list)
  * mem_cgroup_uncharge - uncharge a page
  * @page: page to uncharge
  *
- * Uncharge a page previously charged with mem_cgroup_try_charge() and
- * mem_cgroup_commit_charge().
+ * Uncharge a page previously charged with mem_cgroup_charge().
  */
 void mem_cgroup_uncharge(struct page *page)
 {
@@ -6729,7 +6866,7 @@ void mem_cgroup_uncharge(struct page *page)
  * @page_list: list of pages to uncharge
  *
  * Uncharge a list of pages previously charged with
- * mem_cgroup_try_charge() and mem_cgroup_commit_charge().
+ * mem_cgroup_charge().
  */
 void mem_cgroup_uncharge_list(struct list_head *page_list)
 {
@@ -6775,18 +6912,17 @@ void mem_cgroup_migrate(struct page *oldpage, struct page *newpage)
                return;
 
        /* Force-charge the new page. The old one will be freed soon */
-       nr_pages = hpage_nr_pages(newpage);
+       nr_pages = thp_nr_pages(newpage);
 
        page_counter_charge(&memcg->memory, nr_pages);
        if (do_memsw_account())
                page_counter_charge(&memcg->memsw, nr_pages);
-       css_get_many(&memcg->css, nr_pages);
 
-       commit_charge(newpage, memcg, false);
+       css_get(&memcg->css);
+       commit_charge(newpage, memcg);
 
        local_irq_save(flags);
-       mem_cgroup_charge_statistics(memcg, newpage, PageTransHuge(newpage),
-                       nr_pages);
+       mem_cgroup_charge_statistics(memcg, newpage, nr_pages);
        memcg_check_events(memcg, newpage);
        local_irq_restore(flags);
 }
@@ -6905,17 +7041,6 @@ static int __init mem_cgroup_init(void)
 {
        int cpu, node;
 
-#ifdef CONFIG_MEMCG_KMEM
-       /*
-        * Kmem cache creation is mostly done with the slab_mutex held,
-        * so use a workqueue with limited concurrency to avoid stalling
-        * all worker threads in case lots of cgroups are created and
-        * destroyed simultaneously.
-        */
-       memcg_kmem_cache_wq = alloc_workqueue("memcg_kmem_cache", 0, 1);
-       BUG_ON(!memcg_kmem_cache_wq);
-#endif
-
        cpuhp_setup_state_nocalls(CPUHP_MM_MEMCQ_DEAD, "mm/memctrl:dead", NULL,
                                  memcg_hotplug_cpu_dead);
 
@@ -6974,7 +7099,7 @@ void mem_cgroup_swapout(struct page *page, swp_entry_t entry)
        VM_BUG_ON_PAGE(PageLRU(page), page);
        VM_BUG_ON_PAGE(page_count(page), page);
 
-       if (!do_memsw_account())
+       if (cgroup_subsys_on_dfl(memory_cgrp_subsys))
                return;
 
        memcg = page->mem_cgroup;
@@ -6989,7 +7114,7 @@ void mem_cgroup_swapout(struct page *page, swp_entry_t entry)
         * ancestor for the swap instead and transfer the memory+swap charge.
         */
        swap_memcg = mem_cgroup_id_get_online(memcg);
-       nr_entries = hpage_nr_pages(page);
+       nr_entries = thp_nr_pages(page);
        /* Get references for the tail pages, too */
        if (nr_entries > 1)
                mem_cgroup_id_get_many(swap_memcg, nr_entries - 1);
@@ -7003,7 +7128,7 @@ void mem_cgroup_swapout(struct page *page, swp_entry_t entry)
        if (!mem_cgroup_is_root(memcg))
                page_counter_uncharge(&memcg->memory, nr_entries);
 
-       if (memcg != swap_memcg) {
+       if (!cgroup_memory_noswap && memcg != swap_memcg) {
                if (!mem_cgroup_is_root(swap_memcg))
                        page_counter_charge(&swap_memcg->memsw, nr_entries);
                page_counter_uncharge(&memcg->memsw, nr_entries);
@@ -7016,12 +7141,10 @@ void mem_cgroup_swapout(struct page *page, swp_entry_t entry)
         * only synchronisation we have for updating the per-CPU variables.
         */
        VM_BUG_ON(!irqs_disabled());
-       mem_cgroup_charge_statistics(memcg, page, PageTransHuge(page),
-                                    -nr_entries);
+       mem_cgroup_charge_statistics(memcg, page, -nr_entries);
        memcg_check_events(memcg, page);
 
-       if (!mem_cgroup_is_root(memcg))
-               css_put_many(&memcg->css, nr_entries);
+       css_put(&memcg->css);
 }
 
 /**
@@ -7035,12 +7158,12 @@ void mem_cgroup_swapout(struct page *page, swp_entry_t entry)
  */
 int mem_cgroup_try_charge_swap(struct page *page, swp_entry_t entry)
 {
-       unsigned int nr_pages = hpage_nr_pages(page);
+       unsigned int nr_pages = thp_nr_pages(page);
        struct page_counter *counter;
        struct mem_cgroup *memcg;
        unsigned short oldid;
 
-       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) || !do_swap_account)
+       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
                return 0;
 
        memcg = page->mem_cgroup;
@@ -7056,7 +7179,7 @@ int mem_cgroup_try_charge_swap(struct page *page, swp_entry_t entry)
 
        memcg = mem_cgroup_id_get_online(memcg);
 
-       if (!mem_cgroup_is_root(memcg) &&
+       if (!cgroup_memory_noswap && !mem_cgroup_is_root(memcg) &&
            !page_counter_try_charge(&memcg->swap, nr_pages, &counter)) {
                memcg_memory_event(memcg, MEMCG_SWAP_MAX);
                memcg_memory_event(memcg, MEMCG_SWAP_FAIL);
@@ -7084,14 +7207,11 @@ void mem_cgroup_uncharge_swap(swp_entry_t entry, unsigned int nr_pages)
        struct mem_cgroup *memcg;
        unsigned short id;
 
-       if (!do_swap_account)
-               return;
-
        id = swap_cgroup_record(entry, 0, nr_pages);
        rcu_read_lock();
        memcg = mem_cgroup_from_id(id);
        if (memcg) {
-               if (!mem_cgroup_is_root(memcg)) {
+               if (!cgroup_memory_noswap && !mem_cgroup_is_root(memcg)) {
                        if (cgroup_subsys_on_dfl(memory_cgrp_subsys))
                                page_counter_uncharge(&memcg->swap, nr_pages);
                        else
@@ -7107,7 +7227,7 @@ long mem_cgroup_get_nr_swap_pages(struct mem_cgroup *memcg)
 {
        long nr_swap_pages = get_nr_swap_pages();
 
-       if (!do_swap_account || !cgroup_subsys_on_dfl(memory_cgrp_subsys))
+       if (cgroup_memory_noswap || !cgroup_subsys_on_dfl(memory_cgrp_subsys))
                return nr_swap_pages;
        for (; memcg != root_mem_cgroup; memcg = parent_mem_cgroup(memcg))
                nr_swap_pages = min_t(long, nr_swap_pages,
@@ -7124,37 +7244,33 @@ bool mem_cgroup_swap_full(struct page *page)
 
        if (vm_swap_full())
                return true;
-       if (!do_swap_account || !cgroup_subsys_on_dfl(memory_cgrp_subsys))
+       if (cgroup_memory_noswap || !cgroup_subsys_on_dfl(memory_cgrp_subsys))
                return false;
 
        memcg = page->mem_cgroup;
        if (!memcg)
                return false;
 
-       for (; memcg != root_mem_cgroup; memcg = parent_mem_cgroup(memcg))
-               if (page_counter_read(&memcg->swap) * 2 >=
-                   READ_ONCE(memcg->swap.max))
+       for (; memcg != root_mem_cgroup; memcg = parent_mem_cgroup(memcg)) {
+               unsigned long usage = page_counter_read(&memcg->swap);
+
+               if (usage * 2 >= READ_ONCE(memcg->swap.high) ||
+                   usage * 2 >= READ_ONCE(memcg->swap.max))
                        return true;
+       }
 
        return false;
 }
 
-/* for remember boot option*/
-#ifdef CONFIG_MEMCG_SWAP_ENABLED
-static int really_do_swap_account __initdata = 1;
-#else
-static int really_do_swap_account __initdata;
-#endif
-
-static int __init enable_swap_account(char *s)
+static int __init setup_swap_account(char *s)
 {
        if (!strcmp(s, "1"))
-               really_do_swap_account = 1;
+               cgroup_memory_noswap = 0;
        else if (!strcmp(s, "0"))
-               really_do_swap_account = 0;
+               cgroup_memory_noswap = 1;
        return 1;
 }
-__setup("swapaccount=", enable_swap_account);
+__setup("swapaccount=", setup_swap_account);
 
 static u64 swap_current_read(struct cgroup_subsys_state *css,
                             struct cftype *cft)
@@ -7164,6 +7280,29 @@ static u64 swap_current_read(struct cgroup_subsys_state *css,
        return (u64)page_counter_read(&memcg->swap) * PAGE_SIZE;
 }
 
+static int swap_high_show(struct seq_file *m, void *v)
+{
+       return seq_puts_memcg_tunable(m,
+               READ_ONCE(mem_cgroup_from_seq(m)->swap.high));
+}
+
+static ssize_t swap_high_write(struct kernfs_open_file *of,
+                              char *buf, size_t nbytes, loff_t off)
+{
+       struct mem_cgroup *memcg = mem_cgroup_from_css(of_css(of));
+       unsigned long high;
+       int err;
+
+       buf = strstrip(buf);
+       err = page_counter_memparse(buf, "max", &high);
+       if (err)
+               return err;
+
+       page_counter_set_high(&memcg->swap, high);
+
+       return nbytes;
+}
+
 static int swap_max_show(struct seq_file *m, void *v)
 {
        return seq_puts_memcg_tunable(m,
@@ -7191,6 +7330,8 @@ static int swap_events_show(struct seq_file *m, void *v)
 {
        struct mem_cgroup *memcg = mem_cgroup_from_seq(m);
 
+       seq_printf(m, "high %lu\n",
+                  atomic_long_read(&memcg->memory_events[MEMCG_SWAP_HIGH]));
        seq_printf(m, "max %lu\n",
                   atomic_long_read(&memcg->memory_events[MEMCG_SWAP_MAX]));
        seq_printf(m, "fail %lu\n",
@@ -7205,6 +7346,12 @@ static struct cftype swap_files[] = {
                .flags = CFTYPE_NOT_ON_ROOT,
                .read_u64 = swap_current_read,
        },
+       {
+               .name = "swap.high",
+               .flags = CFTYPE_NOT_ON_ROOT,
+               .seq_show = swap_high_show,
+               .write = swap_high_write,
+       },
        {
                .name = "swap.max",
                .flags = CFTYPE_NOT_ON_ROOT,
@@ -7220,7 +7367,7 @@ static struct cftype swap_files[] = {
        { }     /* terminate */
 };
 
-static struct cftype memsw_cgroup_files[] = {
+static struct cftype memsw_files[] = {
        {
                .name = "memsw.usage_in_bytes",
                .private = MEMFILE_PRIVATE(_MEMSWAP, RES_USAGE),
@@ -7247,17 +7394,27 @@ static struct cftype memsw_cgroup_files[] = {
        { },    /* terminate */
 };
 
+/*
+ * If mem_cgroup_swap_init() is implemented as a subsys_initcall()
+ * instead of a core_initcall(), this could mean cgroup_memory_noswap still
+ * remains set to false even when memcg is disabled via "cgroup_disable=memory"
+ * boot parameter. This may result in premature OOPS inside
+ * mem_cgroup_get_nr_swap_pages() function in corner cases.
+ */
 static int __init mem_cgroup_swap_init(void)
 {
-       if (!mem_cgroup_disabled() && really_do_swap_account) {
-               do_swap_account = 1;
-               WARN_ON(cgroup_add_dfl_cftypes(&memory_cgrp_subsys,
-                                              swap_files));
-               WARN_ON(cgroup_add_legacy_cftypes(&memory_cgrp_subsys,
-                                                 memsw_cgroup_files));
-       }
+       /* No memory control -> no swap control */
+       if (mem_cgroup_disabled())
+               cgroup_memory_noswap = true;
+
+       if (cgroup_memory_noswap)
+               return 0;
+
+       WARN_ON(cgroup_add_dfl_cftypes(&memory_cgrp_subsys, swap_files));
+       WARN_ON(cgroup_add_legacy_cftypes(&memory_cgrp_subsys, memsw_files));
+
        return 0;
 }
-subsys_initcall(mem_cgroup_swap_init);
+core_initcall(mem_cgroup_swap_init);
 
 #endif /* CONFIG_MEMCG_SWAP */