Merge tag 'fs.idmapped.mount_setattr.v5.13-rc3' of gitolite.kernel.org:pub/scm/linux...
[linux-2.6-microblaze.git] / mm / memcontrol.c
index e064ac0..64ada9e 100644 (file)
@@ -215,7 +215,7 @@ enum res_type {
 #define MEMFILE_PRIVATE(x, val)        ((x) << 16 | (val))
 #define MEMFILE_TYPE(val)      ((val) >> 16 & 0xffff)
 #define MEMFILE_ATTR(val)      ((val) & 0xffff)
-/* Used for OOM nofiier */
+/* Used for OOM notifier */
 #define OOM_CONTROL            (0)
 
 /*
@@ -255,10 +255,8 @@ struct cgroup_subsys_state *vmpressure_to_css(struct vmpressure *vmpr)
 #ifdef CONFIG_MEMCG_KMEM
 extern spinlock_t css_set_lock;
 
-static int __memcg_kmem_charge(struct mem_cgroup *memcg, gfp_t gfp,
-                              unsigned int nr_pages);
-static void __memcg_kmem_uncharge(struct mem_cgroup *memcg,
-                                 unsigned int nr_pages);
+static void obj_cgroup_uncharge_pages(struct obj_cgroup *objcg,
+                                     unsigned int nr_pages);
 
 static void obj_cgroup_release(struct percpu_ref *ref)
 {
@@ -295,7 +293,7 @@ static void obj_cgroup_release(struct percpu_ref *ref)
        spin_lock_irqsave(&css_set_lock, flags);
        memcg = obj_cgroup_memcg(objcg);
        if (nr_pages)
-               __memcg_kmem_uncharge(memcg, nr_pages);
+               obj_cgroup_uncharge_pages(objcg, nr_pages);
        list_del(&objcg->list);
        mem_cgroup_put(memcg);
        spin_unlock_irqrestore(&css_set_lock, flags);
@@ -402,129 +400,6 @@ DEFINE_STATIC_KEY_FALSE(memcg_kmem_enabled_key);
 EXPORT_SYMBOL(memcg_kmem_enabled_key);
 #endif
 
-static int memcg_shrinker_map_size;
-static DEFINE_MUTEX(memcg_shrinker_map_mutex);
-
-static void memcg_free_shrinker_map_rcu(struct rcu_head *head)
-{
-       kvfree(container_of(head, struct memcg_shrinker_map, rcu));
-}
-
-static int memcg_expand_one_shrinker_map(struct mem_cgroup *memcg,
-                                        int size, int old_size)
-{
-       struct memcg_shrinker_map *new, *old;
-       int nid;
-
-       lockdep_assert_held(&memcg_shrinker_map_mutex);
-
-       for_each_node(nid) {
-               old = rcu_dereference_protected(
-                       mem_cgroup_nodeinfo(memcg, nid)->shrinker_map, true);
-               /* Not yet online memcg */
-               if (!old)
-                       return 0;
-
-               new = kvmalloc_node(sizeof(*new) + size, GFP_KERNEL, nid);
-               if (!new)
-                       return -ENOMEM;
-
-               /* Set all old bits, clear all new bits */
-               memset(new->map, (int)0xff, old_size);
-               memset((void *)new->map + old_size, 0, size - old_size);
-
-               rcu_assign_pointer(memcg->nodeinfo[nid]->shrinker_map, new);
-               call_rcu(&old->rcu, memcg_free_shrinker_map_rcu);
-       }
-
-       return 0;
-}
-
-static void memcg_free_shrinker_maps(struct mem_cgroup *memcg)
-{
-       struct mem_cgroup_per_node *pn;
-       struct memcg_shrinker_map *map;
-       int nid;
-
-       if (mem_cgroup_is_root(memcg))
-               return;
-
-       for_each_node(nid) {
-               pn = mem_cgroup_nodeinfo(memcg, nid);
-               map = rcu_dereference_protected(pn->shrinker_map, true);
-               kvfree(map);
-               rcu_assign_pointer(pn->shrinker_map, NULL);
-       }
-}
-
-static int memcg_alloc_shrinker_maps(struct mem_cgroup *memcg)
-{
-       struct memcg_shrinker_map *map;
-       int nid, size, ret = 0;
-
-       if (mem_cgroup_is_root(memcg))
-               return 0;
-
-       mutex_lock(&memcg_shrinker_map_mutex);
-       size = memcg_shrinker_map_size;
-       for_each_node(nid) {
-               map = kvzalloc_node(sizeof(*map) + size, GFP_KERNEL, nid);
-               if (!map) {
-                       memcg_free_shrinker_maps(memcg);
-                       ret = -ENOMEM;
-                       break;
-               }
-               rcu_assign_pointer(memcg->nodeinfo[nid]->shrinker_map, map);
-       }
-       mutex_unlock(&memcg_shrinker_map_mutex);
-
-       return ret;
-}
-
-int memcg_expand_shrinker_maps(int new_id)
-{
-       int size, old_size, ret = 0;
-       struct mem_cgroup *memcg;
-
-       size = DIV_ROUND_UP(new_id + 1, BITS_PER_LONG) * sizeof(unsigned long);
-       old_size = memcg_shrinker_map_size;
-       if (size <= old_size)
-               return 0;
-
-       mutex_lock(&memcg_shrinker_map_mutex);
-       if (!root_mem_cgroup)
-               goto unlock;
-
-       for_each_mem_cgroup(memcg) {
-               if (mem_cgroup_is_root(memcg))
-                       continue;
-               ret = memcg_expand_one_shrinker_map(memcg, size, old_size);
-               if (ret) {
-                       mem_cgroup_iter_break(NULL, memcg);
-                       goto unlock;
-               }
-       }
-unlock:
-       if (!ret)
-               memcg_shrinker_map_size = size;
-       mutex_unlock(&memcg_shrinker_map_mutex);
-       return ret;
-}
-
-void memcg_set_shrinker_bit(struct mem_cgroup *memcg, int nid, int shrinker_id)
-{
-       if (shrinker_id >= 0 && memcg && !mem_cgroup_is_root(memcg)) {
-               struct memcg_shrinker_map *map;
-
-               rcu_read_lock();
-               map = rcu_dereference(memcg->nodeinfo[nid]->shrinker_map);
-               /* Pairs with smp mb in shrink_slab() */
-               smp_mb__before_atomic();
-               set_bit(shrinker_id, map->map);
-               rcu_read_unlock();
-       }
-}
-
 /**
  * mem_cgroup_css_from_page - css of the memcg associated with a page
  * @page: page of interest
@@ -713,7 +588,7 @@ static void mem_cgroup_remove_from_trees(struct mem_cgroup *memcg)
        int nid;
 
        for_each_node(nid) {
-               mz = mem_cgroup_nodeinfo(memcg, nid);
+               mz = memcg->nodeinfo[nid];
                mctz = soft_limit_tree_node(nid);
                if (mctz)
                        mem_cgroup_remove_exceeded(mz, mctz);
@@ -764,28 +639,37 @@ mem_cgroup_largest_soft_limit_node(struct mem_cgroup_tree_per_node *mctz)
  */
 void __mod_memcg_state(struct mem_cgroup *memcg, int idx, int val)
 {
-       long x, threshold = MEMCG_CHARGE_BATCH;
-
        if (mem_cgroup_disabled())
                return;
 
-       if (memcg_stat_item_in_bytes(idx))
-               threshold <<= PAGE_SHIFT;
+       __this_cpu_add(memcg->vmstats_percpu->state[idx], val);
+       cgroup_rstat_updated(memcg->css.cgroup, smp_processor_id());
+}
 
-       x = val + __this_cpu_read(memcg->vmstats_percpu->stat[idx]);
-       if (unlikely(abs(x) > threshold)) {
-               struct mem_cgroup *mi;
+/* idx can be of type enum memcg_stat_item or node_stat_item. */
+static unsigned long memcg_page_state(struct mem_cgroup *memcg, int idx)
+{
+       long x = READ_ONCE(memcg->vmstats.state[idx]);
+#ifdef CONFIG_SMP
+       if (x < 0)
+               x = 0;
+#endif
+       return x;
+}
 
-               /*
-                * Batch local counters to keep them in sync with
-                * the hierarchical ones.
-                */
-               __this_cpu_add(memcg->vmstats_local->stat[idx], x);
-               for (mi = memcg; mi; mi = parent_mem_cgroup(mi))
-                       atomic_long_add(x, &mi->vmstats[idx]);
+/* idx can be of type enum memcg_stat_item or node_stat_item. */
+static unsigned long memcg_page_state_local(struct mem_cgroup *memcg, int idx)
+{
+       long x = 0;
+       int cpu;
+
+       for_each_possible_cpu(cpu)
+               x += per_cpu(memcg->vmstats_percpu->state[idx], cpu);
+#ifdef CONFIG_SMP
+       if (x < 0)
                x = 0;
-       }
-       __this_cpu_write(memcg->vmstats_percpu->stat[idx], x);
+#endif
+       return x;
 }
 
 static struct mem_cgroup_per_node *
@@ -796,7 +680,7 @@ parent_nodeinfo(struct mem_cgroup_per_node *pn, int nid)
        parent = parent_mem_cgroup(pn->memcg);
        if (!parent)
                return NULL;
-       return mem_cgroup_nodeinfo(parent, nid);
+       return parent->nodeinfo[nid];
 }
 
 void __mod_memcg_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
@@ -855,18 +739,22 @@ void __mod_lruvec_page_state(struct page *page, enum node_stat_item idx,
                             int val)
 {
        struct page *head = compound_head(page); /* rmap on tail pages */
-       struct mem_cgroup *memcg = page_memcg(head);
+       struct mem_cgroup *memcg;
        pg_data_t *pgdat = page_pgdat(page);
        struct lruvec *lruvec;
 
+       rcu_read_lock();
+       memcg = page_memcg(head);
        /* Untracked pages have no memcg, no lruvec. Update only the node */
        if (!memcg) {
+               rcu_read_unlock();
                __mod_node_page_state(pgdat, idx, val);
                return;
        }
 
        lruvec = mem_cgroup_lruvec(memcg, pgdat);
        __mod_lruvec_state(lruvec, idx, val);
+       rcu_read_unlock();
 }
 EXPORT_SYMBOL(__mod_lruvec_page_state);
 
@@ -898,35 +786,21 @@ void __mod_lruvec_kmem_state(void *p, enum node_stat_item idx, int val)
  * __count_memcg_events - account VM events in a cgroup
  * @memcg: the memory cgroup
  * @idx: the event item
- * @count: the number of events that occured
+ * @count: the number of events that occurred
  */
 void __count_memcg_events(struct mem_cgroup *memcg, enum vm_event_item idx,
                          unsigned long count)
 {
-       unsigned long x;
-
        if (mem_cgroup_disabled())
                return;
 
-       x = count + __this_cpu_read(memcg->vmstats_percpu->events[idx]);
-       if (unlikely(x > MEMCG_CHARGE_BATCH)) {
-               struct mem_cgroup *mi;
-
-               /*
-                * Batch local counters to keep them in sync with
-                * the hierarchical ones.
-                */
-               __this_cpu_add(memcg->vmstats_local->events[idx], x);
-               for (mi = memcg; mi; mi = parent_mem_cgroup(mi))
-                       atomic_long_add(x, &mi->vmevents[idx]);
-               x = 0;
-       }
-       __this_cpu_write(memcg->vmstats_percpu->events[idx], x);
+       __this_cpu_add(memcg->vmstats_percpu->events[idx], count);
+       cgroup_rstat_updated(memcg->css.cgroup, smp_processor_id());
 }
 
 static unsigned long memcg_events(struct mem_cgroup *memcg, int event)
 {
-       return atomic_long_read(&memcg->vmevents[event]);
+       return READ_ONCE(memcg->vmstats.events[event]);
 }
 
 static unsigned long memcg_events_local(struct mem_cgroup *memcg, int event)
@@ -935,7 +809,7 @@ static unsigned long memcg_events_local(struct mem_cgroup *memcg, int event)
        int cpu;
 
        for_each_possible_cpu(cpu)
-               x += per_cpu(memcg->vmstats_local->events[event], cpu);
+               x += per_cpu(memcg->vmstats_percpu->events[event], cpu);
        return x;
 }
 
@@ -1030,7 +904,7 @@ struct mem_cgroup *get_mem_cgroup_from_mm(struct mm_struct *mm)
        rcu_read_lock();
        do {
                /*
-                * Page cache insertions can happen withou an
+                * Page cache insertions can happen without an
                 * actual mm context, e.g. during disk probing
                 * on boot, loopback IO, acct() writes etc.
                 */
@@ -1055,20 +929,6 @@ static __always_inline struct mem_cgroup *active_memcg(void)
                return current->active_memcg;
 }
 
-static __always_inline struct mem_cgroup *get_active_memcg(void)
-{
-       struct mem_cgroup *memcg;
-
-       rcu_read_lock();
-       memcg = active_memcg();
-       /* remote memcg must hold a ref. */
-       if (memcg && WARN_ON_ONCE(!css_tryget(&memcg->css)))
-               memcg = root_mem_cgroup;
-       rcu_read_unlock();
-
-       return memcg;
-}
-
 static __always_inline bool memcg_kmem_bypass(void)
 {
        /* Allow remote memcg charging from any context. */
@@ -1082,20 +942,6 @@ static __always_inline bool memcg_kmem_bypass(void)
        return false;
 }
 
-/**
- * If active memcg is set, do not fallback to current->mm->memcg.
- */
-static __always_inline struct mem_cgroup *get_mem_cgroup_from_current(void)
-{
-       if (memcg_kmem_bypass())
-               return NULL;
-
-       if (unlikely(active_memcg()))
-               return get_active_memcg();
-
-       return get_mem_cgroup_from_mm(current->mm);
-}
-
 /**
  * mem_cgroup_iter - iterate over memory cgroup hierarchy
  * @root: hierarchy root
@@ -1136,7 +982,7 @@ struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
        if (reclaim) {
                struct mem_cgroup_per_node *mz;
 
-               mz = mem_cgroup_nodeinfo(root, reclaim->pgdat->node_id);
+               mz = root->nodeinfo[reclaim->pgdat->node_id];
                iter = &mz->iter;
 
                if (prev && reclaim->generation != iter->generation)
@@ -1238,7 +1084,7 @@ static void __invalidate_reclaim_iterators(struct mem_cgroup *from,
        int nid;
 
        for_each_node(nid) {
-               mz = mem_cgroup_nodeinfo(from, nid);
+               mz = from->nodeinfo[nid];
                iter = &mz->iter;
                cmpxchg(&iter->position, dead_memcg, NULL);
        }
@@ -1571,6 +1417,7 @@ static char *memory_stat_format(struct mem_cgroup *memcg)
         *
         * Current memory state:
         */
+       cgroup_rstat_flush(memcg->css.cgroup);
 
        for (i = 0; i < ARRAY_SIZE(memory_stats); i++) {
                u64 size;
@@ -1865,7 +1712,7 @@ static void mem_cgroup_unmark_under_oom(struct mem_cgroup *memcg)
        struct mem_cgroup *iter;
 
        /*
-        * Be careful about under_oom underflows becase a child memcg
+        * Be careful about under_oom underflows because a child memcg
         * could have been added after mem_cgroup_mark_under_oom.
         */
        spin_lock(&memcg_oom_lock);
@@ -2037,7 +1884,7 @@ bool mem_cgroup_oom_synchronize(bool handle)
                /*
                 * There is no guarantee that an OOM-lock contender
                 * sees the wakeups triggered by the OOM kill
-                * uncharges.  Wake any sleepers explicitely.
+                * uncharges.  Wake any sleepers explicitly.
                 */
                memcg_oom_recover(memcg);
        }
@@ -2118,11 +1965,10 @@ void mem_cgroup_print_oom_group(struct mem_cgroup *memcg)
  * This function protects unlocked LRU pages from being moved to
  * another cgroup.
  *
- * It ensures lifetime of the returned memcg. Caller is responsible
- * for the lifetime of the page; __unlock_page_memcg() is available
- * when @page might get freed inside the locked section.
+ * It ensures lifetime of the locked memcg. Caller is responsible
+ * for the lifetime of the page.
  */
-struct mem_cgroup *lock_page_memcg(struct page *page)
+void lock_page_memcg(struct page *page)
 {
        struct page *head = compound_head(page); /* rmap on tail pages */
        struct mem_cgroup *memcg;
@@ -2132,21 +1978,15 @@ struct mem_cgroup *lock_page_memcg(struct page *page)
         * The RCU lock is held throughout the transaction.  The fast
         * path can get away without acquiring the memcg->move_lock
         * because page moving starts with an RCU grace period.
-        *
-        * The RCU lock also protects the memcg from being freed when
-        * the page state that is going to change is the only thing
-        * preventing the page itself from being freed. E.g. writeback
-        * doesn't hold a page reference and relies on PG_writeback to
-        * keep off truncation, migration and so forth.
          */
        rcu_read_lock();
 
        if (mem_cgroup_disabled())
-               return NULL;
+               return;
 again:
        memcg = page_memcg(head);
        if (unlikely(!memcg))
-               return NULL;
+               return;
 
 #ifdef CONFIG_PROVE_LOCKING
        local_irq_save(flags);
@@ -2155,7 +1995,7 @@ again:
 #endif
 
        if (atomic_read(&memcg->moving_account) <= 0)
-               return memcg;
+               return;
 
        spin_lock_irqsave(&memcg->move_lock, flags);
        if (memcg != page_memcg(head)) {
@@ -2164,24 +2004,17 @@ again:
        }
 
        /*
-        * When charge migration first begins, we can have locked and
-        * unlocked page stat updates happening concurrently.  Track
-        * the task who has the lock for unlock_page_memcg().
+        * When charge migration first begins, we can have multiple
+        * critical sections holding the fast-path RCU lock and one
+        * holding the slowpath move_lock. Track the task who has the
+        * move_lock for unlock_page_memcg().
         */
        memcg->move_lock_task = current;
        memcg->move_lock_flags = flags;
-
-       return memcg;
 }
 EXPORT_SYMBOL(lock_page_memcg);
 
-/**
- * __unlock_page_memcg - unlock and unpin a memcg
- * @memcg: the memcg
- *
- * Unlock and unpin a memcg returned by lock_page_memcg().
- */
-void __unlock_page_memcg(struct mem_cgroup *memcg)
+static void __unlock_page_memcg(struct mem_cgroup *memcg)
 {
        if (memcg && memcg->move_lock_task == current) {
                unsigned long flags = memcg->move_lock_flags;
@@ -2381,50 +2214,39 @@ static void drain_all_stock(struct mem_cgroup *root_memcg)
        mutex_unlock(&percpu_charge_mutex);
 }
 
-static int memcg_hotplug_cpu_dead(unsigned int cpu)
+static void memcg_flush_lruvec_page_state(struct mem_cgroup *memcg, int cpu)
 {
-       struct memcg_stock_pcp *stock;
-       struct mem_cgroup *memcg, *mi;
-
-       stock = &per_cpu(memcg_stock, cpu);
-       drain_stock(stock);
+       int nid;
 
-       for_each_mem_cgroup(memcg) {
+       for_each_node(nid) {
+               struct mem_cgroup_per_node *pn = memcg->nodeinfo[nid];
+               unsigned long stat[NR_VM_NODE_STAT_ITEMS];
+               struct batched_lruvec_stat *lstatc;
                int i;
 
-               for (i = 0; i < MEMCG_NR_STAT; i++) {
-                       int nid;
-                       long x;
-
-                       x = this_cpu_xchg(memcg->vmstats_percpu->stat[i], 0);
-                       if (x)
-                               for (mi = memcg; mi; mi = parent_mem_cgroup(mi))
-                                       atomic_long_add(x, &memcg->vmstats[i]);
-
-                       if (i >= NR_VM_NODE_STAT_ITEMS)
-                               continue;
+               lstatc = per_cpu_ptr(pn->lruvec_stat_cpu, cpu);
+               for (i = 0; i < NR_VM_NODE_STAT_ITEMS; i++) {
+                       stat[i] = lstatc->count[i];
+                       lstatc->count[i] = 0;
+               }
 
-                       for_each_node(nid) {
-                               struct mem_cgroup_per_node *pn;
+               do {
+                       for (i = 0; i < NR_VM_NODE_STAT_ITEMS; i++)
+                               atomic_long_add(stat[i], &pn->lruvec_stat[i]);
+               } while ((pn = parent_nodeinfo(pn, nid)));
+       }
+}
 
-                               pn = mem_cgroup_nodeinfo(memcg, nid);
-                               x = this_cpu_xchg(pn->lruvec_stat_cpu->count[i], 0);
-                               if (x)
-                                       do {
-                                               atomic_long_add(x, &pn->lruvec_stat[i]);
-                                       } while ((pn = parent_nodeinfo(pn, nid)));
-                       }
-               }
+static int memcg_hotplug_cpu_dead(unsigned int cpu)
+{
+       struct memcg_stock_pcp *stock;
+       struct mem_cgroup *memcg;
 
-               for (i = 0; i < NR_VM_EVENT_ITEMS; i++) {
-                       long x;
+       stock = &per_cpu(memcg_stock, cpu);
+       drain_stock(stock);
 
-                       x = this_cpu_xchg(memcg->vmstats_percpu->events[i], 0);
-                       if (x)
-                               for (mi = memcg; mi; mi = parent_mem_cgroup(mi))
-                                       atomic_long_add(x, &memcg->vmevents[i]);
-               }
-       }
+       for_each_mem_cgroup(memcg)
+               memcg_flush_lruvec_page_state(memcg, cpu);
 
        return 0;
 }
@@ -2793,9 +2615,6 @@ retry:
        if (gfp_mask & __GFP_RETRY_MAYFAIL)
                goto nomem;
 
-       if (gfp_mask & __GFP_NOFAIL)
-               goto force;
-
        if (fatal_signal_pending(current))
                goto force;
 
@@ -2905,6 +2724,20 @@ static void commit_charge(struct page *page, struct mem_cgroup *memcg)
        page->memcg_data = (unsigned long)memcg;
 }
 
+static struct mem_cgroup *get_mem_cgroup_from_objcg(struct obj_cgroup *objcg)
+{
+       struct mem_cgroup *memcg;
+
+       rcu_read_lock();
+retry:
+       memcg = obj_cgroup_memcg(objcg);
+       if (unlikely(!css_tryget(&memcg->css)))
+               goto retry;
+       rcu_read_unlock();
+
+       return memcg;
+}
+
 #ifdef CONFIG_MEMCG_KMEM
 int memcg_alloc_page_obj_cgroups(struct page *page, struct kmem_cache *s,
                                 gfp_t gfp, bool new_page)
@@ -3056,23 +2889,45 @@ static void memcg_free_cache_id(int id)
        ida_simple_remove(&memcg_cache_ida, id);
 }
 
-/**
- * __memcg_kmem_charge: charge a number of kernel pages to a memcg
- * @memcg: memory cgroup to charge
+/*
+ * obj_cgroup_uncharge_pages: uncharge a number of kernel pages from a objcg
+ * @objcg: object cgroup to uncharge
+ * @nr_pages: number of pages to uncharge
+ */
+static void obj_cgroup_uncharge_pages(struct obj_cgroup *objcg,
+                                     unsigned int nr_pages)
+{
+       struct mem_cgroup *memcg;
+
+       memcg = get_mem_cgroup_from_objcg(objcg);
+
+       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
+               page_counter_uncharge(&memcg->kmem, nr_pages);
+       refill_stock(memcg, nr_pages);
+
+       css_put(&memcg->css);
+}
+
+/*
+ * obj_cgroup_charge_pages: charge a number of kernel pages to a objcg
+ * @objcg: object cgroup to charge
  * @gfp: reclaim mode
  * @nr_pages: number of pages to charge
  *
  * Returns 0 on success, an error code on failure.
  */
-static int __memcg_kmem_charge(struct mem_cgroup *memcg, gfp_t gfp,
-                              unsigned int nr_pages)
+static int obj_cgroup_charge_pages(struct obj_cgroup *objcg, gfp_t gfp,
+                                  unsigned int nr_pages)
 {
        struct page_counter *counter;
+       struct mem_cgroup *memcg;
        int ret;
 
+       memcg = get_mem_cgroup_from_objcg(objcg);
+
        ret = try_charge(memcg, gfp, nr_pages);
        if (ret)
-               return ret;
+               goto out;
 
        if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) &&
            !page_counter_try_charge(&memcg->kmem, nr_pages, &counter)) {
@@ -3084,25 +2939,15 @@ static int __memcg_kmem_charge(struct mem_cgroup *memcg, gfp_t gfp,
                 */
                if (gfp & __GFP_NOFAIL) {
                        page_counter_charge(&memcg->kmem, nr_pages);
-                       return 0;
+                       goto out;
                }
                cancel_charge(memcg, nr_pages);
-               return -ENOMEM;
+               ret = -ENOMEM;
        }
-       return 0;
-}
-
-/**
- * __memcg_kmem_uncharge: uncharge a number of kernel pages from a memcg
- * @memcg: memcg to uncharge
- * @nr_pages: number of pages to uncharge
- */
-static void __memcg_kmem_uncharge(struct mem_cgroup *memcg, unsigned int nr_pages)
-{
-       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
-               page_counter_uncharge(&memcg->kmem, nr_pages);
+out:
+       css_put(&memcg->css);
 
-       refill_stock(memcg, nr_pages);
+       return ret;
 }
 
 /**
@@ -3115,18 +2960,18 @@ static void __memcg_kmem_uncharge(struct mem_cgroup *memcg, unsigned int nr_page
  */
 int __memcg_kmem_charge_page(struct page *page, gfp_t gfp, int order)
 {
-       struct mem_cgroup *memcg;
+       struct obj_cgroup *objcg;
        int ret = 0;
 
-       memcg = get_mem_cgroup_from_current();
-       if (memcg && !mem_cgroup_is_root(memcg)) {
-               ret = __memcg_kmem_charge(memcg, gfp, 1 << order);
+       objcg = get_obj_cgroup_from_current();
+       if (objcg) {
+               ret = obj_cgroup_charge_pages(objcg, gfp, 1 << order);
                if (!ret) {
-                       page->memcg_data = (unsigned long)memcg |
+                       page->memcg_data = (unsigned long)objcg |
                                MEMCG_DATA_KMEM;
                        return 0;
                }
-               css_put(&memcg->css);
+               obj_cgroup_put(objcg);
        }
        return ret;
 }
@@ -3138,16 +2983,16 @@ int __memcg_kmem_charge_page(struct page *page, gfp_t gfp, int order)
  */
 void __memcg_kmem_uncharge_page(struct page *page, int order)
 {
-       struct mem_cgroup *memcg = page_memcg(page);
+       struct obj_cgroup *objcg;
        unsigned int nr_pages = 1 << order;
 
-       if (!memcg)
+       if (!PageMemcgKmem(page))
                return;
 
-       VM_BUG_ON_PAGE(mem_cgroup_is_root(memcg), page);
-       __memcg_kmem_uncharge(memcg, nr_pages);
+       objcg = __page_objcg(page);
+       obj_cgroup_uncharge_pages(objcg, nr_pages);
        page->memcg_data = 0;
-       css_put(&memcg->css);
+       obj_cgroup_put(objcg);
 }
 
 static bool consume_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes)
@@ -3180,11 +3025,8 @@ static void drain_obj_stock(struct memcg_stock_pcp *stock)
                unsigned int nr_pages = stock->nr_bytes >> PAGE_SHIFT;
                unsigned int nr_bytes = stock->nr_bytes & (PAGE_SIZE - 1);
 
-               if (nr_pages) {
-                       rcu_read_lock();
-                       __memcg_kmem_uncharge(obj_cgroup_memcg(old), nr_pages);
-                       rcu_read_unlock();
-               }
+               if (nr_pages)
+                       obj_cgroup_uncharge_pages(old, nr_pages);
 
                /*
                 * The leftover is flushed to the centralized per-memcg value.
@@ -3242,7 +3084,6 @@ static void refill_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes)
 
 int obj_cgroup_charge(struct obj_cgroup *objcg, gfp_t gfp, size_t size)
 {
-       struct mem_cgroup *memcg;
        unsigned int nr_pages, nr_bytes;
        int ret;
 
@@ -3259,24 +3100,16 @@ int obj_cgroup_charge(struct obj_cgroup *objcg, gfp_t gfp, size_t size)
         * refill_obj_stock(), called from this function or
         * independently later.
         */
-       rcu_read_lock();
-retry:
-       memcg = obj_cgroup_memcg(objcg);
-       if (unlikely(!css_tryget(&memcg->css)))
-               goto retry;
-       rcu_read_unlock();
-
        nr_pages = size >> PAGE_SHIFT;
        nr_bytes = size & (PAGE_SIZE - 1);
 
        if (nr_bytes)
                nr_pages += 1;
 
-       ret = __memcg_kmem_charge(memcg, gfp, nr_pages);
+       ret = obj_cgroup_charge_pages(objcg, gfp, nr_pages);
        if (!ret && nr_bytes)
                refill_obj_stock(objcg, PAGE_SIZE - nr_bytes);
 
-       css_put(&memcg->css);
        return ret;
 }
 
@@ -3300,7 +3133,11 @@ void split_page_memcg(struct page *head, unsigned int nr)
 
        for (i = 1; i < nr; i++)
                head[i].memcg_data = head->memcg_data;
-       css_get_many(&memcg->css, nr - 1);
+
+       if (PageMemcgKmem(head))
+               obj_cgroup_get_many(__page_objcg(head), nr - 1);
+       else
+               css_get_many(&memcg->css, nr - 1);
 }
 
 #ifdef CONFIG_MEMCG_SWAP
@@ -3549,6 +3386,7 @@ static unsigned long mem_cgroup_usage(struct mem_cgroup *memcg, bool swap)
        unsigned long val;
 
        if (mem_cgroup_is_root(memcg)) {
+               cgroup_rstat_flush(memcg->css.cgroup);
                val = memcg_page_state(memcg, NR_FILE_PAGES) +
                        memcg_page_state(memcg, NR_ANON_MAPPED);
                if (swap)
@@ -3613,57 +3451,6 @@ static u64 mem_cgroup_read_u64(struct cgroup_subsys_state *css,
        }
 }
 
-static void memcg_flush_percpu_vmstats(struct mem_cgroup *memcg)
-{
-       unsigned long stat[MEMCG_NR_STAT] = {0};
-       struct mem_cgroup *mi;
-       int node, cpu, i;
-
-       for_each_online_cpu(cpu)
-               for (i = 0; i < MEMCG_NR_STAT; i++)
-                       stat[i] += per_cpu(memcg->vmstats_percpu->stat[i], cpu);
-
-       for (mi = memcg; mi; mi = parent_mem_cgroup(mi))
-               for (i = 0; i < MEMCG_NR_STAT; i++)
-                       atomic_long_add(stat[i], &mi->vmstats[i]);
-
-       for_each_node(node) {
-               struct mem_cgroup_per_node *pn = memcg->nodeinfo[node];
-               struct mem_cgroup_per_node *pi;
-
-               for (i = 0; i < NR_VM_NODE_STAT_ITEMS; i++)
-                       stat[i] = 0;
-
-               for_each_online_cpu(cpu)
-                       for (i = 0; i < NR_VM_NODE_STAT_ITEMS; i++)
-                               stat[i] += per_cpu(
-                                       pn->lruvec_stat_cpu->count[i], cpu);
-
-               for (pi = pn; pi; pi = parent_nodeinfo(pi, node))
-                       for (i = 0; i < NR_VM_NODE_STAT_ITEMS; i++)
-                               atomic_long_add(stat[i], &pi->lruvec_stat[i]);
-       }
-}
-
-static void memcg_flush_percpu_vmevents(struct mem_cgroup *memcg)
-{
-       unsigned long events[NR_VM_EVENT_ITEMS];
-       struct mem_cgroup *mi;
-       int cpu, i;
-
-       for (i = 0; i < NR_VM_EVENT_ITEMS; i++)
-               events[i] = 0;
-
-       for_each_online_cpu(cpu)
-               for (i = 0; i < NR_VM_EVENT_ITEMS; i++)
-                       events[i] += per_cpu(memcg->vmstats_percpu->events[i],
-                                            cpu);
-
-       for (mi = memcg; mi; mi = parent_mem_cgroup(mi))
-               for (i = 0; i < NR_VM_EVENT_ITEMS; i++)
-                       atomic_long_add(events[i], &mi->vmevents[i]);
-}
-
 #ifdef CONFIG_MEMCG_KMEM
 static int memcg_online_kmem(struct mem_cgroup *memcg)
 {
@@ -3980,6 +3767,8 @@ static int memcg_numa_stat_show(struct seq_file *m, void *v)
        int nid;
        struct mem_cgroup *memcg = mem_cgroup_from_seq(m);
 
+       cgroup_rstat_flush(memcg->css.cgroup);
+
        for (stat = stats; stat < stats + ARRAY_SIZE(stats); stat++) {
                seq_printf(m, "%s=%lu", stat->name,
                           mem_cgroup_nr_lru_pages(memcg, stat->lru_mask,
@@ -4050,6 +3839,8 @@ static int memcg_stat_show(struct seq_file *m, void *v)
 
        BUILD_BUG_ON(ARRAY_SIZE(memcg1_stat_names) != ARRAY_SIZE(memcg1_stats));
 
+       cgroup_rstat_flush(memcg->css.cgroup);
+
        for (i = 0; i < ARRAY_SIZE(memcg1_stats); i++) {
                unsigned long nr;
 
@@ -4108,7 +3899,7 @@ static int memcg_stat_show(struct seq_file *m, void *v)
                unsigned long file_cost = 0;
 
                for_each_online_pgdat(pgdat) {
-                       mz = mem_cgroup_nodeinfo(memcg, pgdat->node_id);
+                       mz = memcg->nodeinfo[pgdat->node_id];
 
                        anon_cost += mz->lruvec.anon_cost;
                        file_cost += mz->lruvec.file_cost;
@@ -4137,7 +3928,7 @@ static int mem_cgroup_swappiness_write(struct cgroup_subsys_state *css,
        if (val > 100)
                return -EINVAL;
 
-       if (css->parent)
+       if (!mem_cgroup_is_root(memcg))
                memcg->swappiness = val;
        else
                vm_swappiness = val;
@@ -4487,7 +4278,7 @@ static int mem_cgroup_oom_control_write(struct cgroup_subsys_state *css,
        struct mem_cgroup *memcg = mem_cgroup_from_css(css);
 
        /* cannot set to root cgroup and only 0 and 1 are allowed */
-       if (!css->parent || !((val == 0) || (val == 1)))
+       if (mem_cgroup_is_root(memcg) || !((val == 0) || (val == 1)))
                return -EINVAL;
 
        memcg->oom_kill_disable = val;
@@ -4526,22 +4317,6 @@ struct wb_domain *mem_cgroup_wb_domain(struct bdi_writeback *wb)
        return &memcg->cgwb_domain;
 }
 
-/*
- * idx can be of type enum memcg_stat_item or node_stat_item.
- * Keep in sync with memcg_exact_page().
- */
-static unsigned long memcg_exact_page_state(struct mem_cgroup *memcg, int idx)
-{
-       long x = atomic_long_read(&memcg->vmstats[idx]);
-       int cpu;
-
-       for_each_online_cpu(cpu)
-               x += per_cpu_ptr(memcg->vmstats_percpu, cpu)->stat[idx];
-       if (x < 0)
-               x = 0;
-       return x;
-}
-
 /**
  * mem_cgroup_wb_stats - retrieve writeback related stats from its memcg
  * @wb: bdi_writeback in question
@@ -4567,13 +4342,14 @@ void mem_cgroup_wb_stats(struct bdi_writeback *wb, unsigned long *pfilepages,
        struct mem_cgroup *memcg = mem_cgroup_from_css(wb->memcg_css);
        struct mem_cgroup *parent;
 
-       *pdirty = memcg_exact_page_state(memcg, NR_FILE_DIRTY);
+       cgroup_rstat_flush_irqsafe(memcg->css.cgroup);
 
-       *pwriteback = memcg_exact_page_state(memcg, NR_WRITEBACK);
-       *pfilepages = memcg_exact_page_state(memcg, NR_INACTIVE_FILE) +
-                       memcg_exact_page_state(memcg, NR_ACTIVE_FILE);
-       *pheadroom = PAGE_COUNTER_MAX;
+       *pdirty = memcg_page_state(memcg, NR_FILE_DIRTY);
+       *pwriteback = memcg_page_state(memcg, NR_WRITEBACK);
+       *pfilepages = memcg_page_state(memcg, NR_INACTIVE_FILE) +
+                       memcg_page_state(memcg, NR_ACTIVE_FILE);
 
+       *pheadroom = PAGE_COUNTER_MAX;
        while ((parent = parent_mem_cgroup(memcg))) {
                unsigned long ceiling = min(READ_ONCE(memcg->memory.max),
                                            READ_ONCE(memcg->memory.high));
@@ -4588,7 +4364,7 @@ void mem_cgroup_wb_stats(struct bdi_writeback *wb, unsigned long *pfilepages,
  * Foreign dirty flushing
  *
  * There's an inherent mismatch between memcg and writeback.  The former
- * trackes ownership per-page while the latter per-inode.  This was a
+ * tracks ownership per-page while the latter per-inode.  This was a
  * deliberate design decision because honoring per-page ownership in the
  * writeback path is complicated, may lead to higher CPU and IO overheads
  * and deemed unnecessary given that write-sharing an inode across
@@ -4603,9 +4379,9 @@ void mem_cgroup_wb_stats(struct bdi_writeback *wb, unsigned long *pfilepages,
  * triggering background writeback.  A will be slowed down without a way to
  * make writeback of the dirty pages happen.
  *
- * Conditions like the above can lead to a cgroup getting repatedly and
+ * Conditions like the above can lead to a cgroup getting repeatedly and
  * severely throttled after making some progress after each
- * dirty_expire_interval while the underyling IO device is almost
+ * dirty_expire_interval while the underlying IO device is almost
  * completely idle.
  *
  * Solving this problem completely requires matching the ownership tracking
@@ -5205,19 +4981,20 @@ static void __mem_cgroup_free(struct mem_cgroup *memcg)
        for_each_node(node)
                free_mem_cgroup_per_node_info(memcg, node);
        free_percpu(memcg->vmstats_percpu);
-       free_percpu(memcg->vmstats_local);
        kfree(memcg);
 }
 
 static void mem_cgroup_free(struct mem_cgroup *memcg)
 {
+       int cpu;
+
        memcg_wb_domain_exit(memcg);
        /*
-        * Flush percpu vmstats and vmevents to guarantee the value correctness
-        * on parent's and all ancestor levels.
+        * Flush percpu lruvec stats to guarantee the value
+        * correctness on parent's and all ancestor levels.
         */
-       memcg_flush_percpu_vmstats(memcg);
-       memcg_flush_percpu_vmevents(memcg);
+       for_each_online_cpu(cpu)
+               memcg_flush_lruvec_page_state(memcg, cpu);
        __mem_cgroup_free(memcg);
 }
 
@@ -5244,11 +5021,6 @@ static struct mem_cgroup *mem_cgroup_alloc(void)
                goto fail;
        }
 
-       memcg->vmstats_local = alloc_percpu_gfp(struct memcg_vmstats_percpu,
-                                               GFP_KERNEL_ACCOUNT);
-       if (!memcg->vmstats_local)
-               goto fail;
-
        memcg->vmstats_percpu = alloc_percpu_gfp(struct memcg_vmstats_percpu,
                                                 GFP_KERNEL_ACCOUNT);
        if (!memcg->vmstats_percpu)
@@ -5346,11 +5118,11 @@ static int mem_cgroup_css_online(struct cgroup_subsys_state *css)
        struct mem_cgroup *memcg = mem_cgroup_from_css(css);
 
        /*
-        * A memcg must be visible for memcg_expand_shrinker_maps()
+        * A memcg must be visible for expand_shrinker_info()
         * by the time the maps are allocated. So, we allocate maps
         * here, when for_each_mem_cgroup() can't skip it.
         */
-       if (memcg_alloc_shrinker_maps(memcg)) {
+       if (alloc_shrinker_info(memcg)) {
                mem_cgroup_id_remove(memcg);
                return -ENOMEM;
        }
@@ -5382,6 +5154,7 @@ static void mem_cgroup_css_offline(struct cgroup_subsys_state *css)
        page_counter_set_low(&memcg->memory, 0);
 
        memcg_offline_kmem(memcg);
+       reparent_shrinker_deferred(memcg);
        wb_memcg_offline(memcg);
 
        drain_all_stock(memcg);
@@ -5414,7 +5187,7 @@ static void mem_cgroup_css_free(struct cgroup_subsys_state *css)
        vmpressure_cleanup(&memcg->vmpressure);
        cancel_work_sync(&memcg->high_work);
        mem_cgroup_remove_from_trees(memcg);
-       memcg_free_shrinker_maps(memcg);
+       free_shrinker_info(memcg);
        memcg_free_kmem(memcg);
        mem_cgroup_free(memcg);
 }
@@ -5448,6 +5221,62 @@ static void mem_cgroup_css_reset(struct cgroup_subsys_state *css)
        memcg_wb_domain_size_changed(memcg);
 }
 
+static void mem_cgroup_css_rstat_flush(struct cgroup_subsys_state *css, int cpu)
+{
+       struct mem_cgroup *memcg = mem_cgroup_from_css(css);
+       struct mem_cgroup *parent = parent_mem_cgroup(memcg);
+       struct memcg_vmstats_percpu *statc;
+       long delta, v;
+       int i;
+
+       statc = per_cpu_ptr(memcg->vmstats_percpu, cpu);
+
+       for (i = 0; i < MEMCG_NR_STAT; i++) {
+               /*
+                * Collect the aggregated propagation counts of groups
+                * below us. We're in a per-cpu loop here and this is
+                * a global counter, so the first cycle will get them.
+                */
+               delta = memcg->vmstats.state_pending[i];
+               if (delta)
+                       memcg->vmstats.state_pending[i] = 0;
+
+               /* Add CPU changes on this level since the last flush */
+               v = READ_ONCE(statc->state[i]);
+               if (v != statc->state_prev[i]) {
+                       delta += v - statc->state_prev[i];
+                       statc->state_prev[i] = v;
+               }
+
+               if (!delta)
+                       continue;
+
+               /* Aggregate counts on this level and propagate upwards */
+               memcg->vmstats.state[i] += delta;
+               if (parent)
+                       parent->vmstats.state_pending[i] += delta;
+       }
+
+       for (i = 0; i < NR_VM_EVENT_ITEMS; i++) {
+               delta = memcg->vmstats.events_pending[i];
+               if (delta)
+                       memcg->vmstats.events_pending[i] = 0;
+
+               v = READ_ONCE(statc->events[i]);
+               if (v != statc->events_prev[i]) {
+                       delta += v - statc->events_prev[i];
+                       statc->events_prev[i] = v;
+               }
+
+               if (!delta)
+                       continue;
+
+               memcg->vmstats.events[i] += delta;
+               if (parent)
+                       parent->vmstats.events_pending[i] += delta;
+       }
+}
+
 #ifdef CONFIG_MMU
 /* Handlers for move charge at task migration. */
 static int mem_cgroup_do_precharge(unsigned long count)
@@ -5945,7 +5774,7 @@ static int mem_cgroup_can_attach(struct cgroup_taskset *tset)
                return 0;
 
        /*
-        * We are now commited to this value whatever it is. Changes in this
+        * We are now committed to this value whatever it is. Changes in this
         * tunable will only affect upcoming migrations, not the current one.
         * So we need to save it, and keep it going.
         */
@@ -6501,6 +6330,7 @@ struct cgroup_subsys memory_cgrp_subsys = {
        .css_released = mem_cgroup_css_released,
        .css_free = mem_cgroup_css_free,
        .css_reset = mem_cgroup_css_reset,
+       .css_rstat_flush = mem_cgroup_css_rstat_flush,
        .can_attach = mem_cgroup_can_attach,
        .cancel_attach = mem_cgroup_cancel_attach,
        .post_attach = mem_cgroup_move_task,
@@ -6683,6 +6513,27 @@ void mem_cgroup_calculate_protection(struct mem_cgroup *root,
                        atomic_long_read(&parent->memory.children_low_usage)));
 }
 
+static int __mem_cgroup_charge(struct page *page, struct mem_cgroup *memcg,
+                              gfp_t gfp)
+{
+       unsigned int nr_pages = thp_nr_pages(page);
+       int ret;
+
+       ret = try_charge(memcg, gfp, nr_pages);
+       if (ret)
+               goto out;
+
+       css_get(&memcg->css);
+       commit_charge(page, memcg);
+
+       local_irq_disable();
+       mem_cgroup_charge_statistics(memcg, page, nr_pages);
+       memcg_check_events(memcg, page);
+       local_irq_enable();
+out:
+       return ret;
+}
+
 /**
  * mem_cgroup_charge - charge a newly allocated page to a cgroup
  * @page: page to charge
@@ -6692,55 +6543,71 @@ void mem_cgroup_calculate_protection(struct mem_cgroup *root,
  * Try to charge @page to the memcg that @mm belongs to, reclaiming
  * pages according to @gfp_mask if necessary.
  *
+ * Do not use this for pages allocated for swapin.
+ *
  * Returns 0 on success. Otherwise, an error code is returned.
  */
 int mem_cgroup_charge(struct page *page, struct mm_struct *mm, gfp_t gfp_mask)
 {
-       unsigned int nr_pages = thp_nr_pages(page);
-       struct mem_cgroup *memcg = NULL;
-       int ret = 0;
+       struct mem_cgroup *memcg;
+       int ret;
 
        if (mem_cgroup_disabled())
-               goto out;
+               return 0;
 
-       if (PageSwapCache(page)) {
-               swp_entry_t ent = { .val = page_private(page), };
-               unsigned short id;
+       memcg = get_mem_cgroup_from_mm(mm);
+       ret = __mem_cgroup_charge(page, memcg, gfp_mask);
+       css_put(&memcg->css);
 
-               /*
-                * Every swap fault against a single page tries to charge the
-                * page, bail as early as possible.  shmem_unuse() encounters
-                * already charged pages, too.  page and memcg binding is
-                * protected by the page lock, which serializes swap cache
-                * removal, which in turn serializes uncharging.
-                */
-               VM_BUG_ON_PAGE(!PageLocked(page), page);
-               if (page_memcg(compound_head(page)))
-                       goto out;
+       return ret;
+}
 
-               id = lookup_swap_cgroup_id(ent);
-               rcu_read_lock();
-               memcg = mem_cgroup_from_id(id);
-               if (memcg && !css_tryget_online(&memcg->css))
-                       memcg = NULL;
-               rcu_read_unlock();
-       }
+/**
+ * mem_cgroup_swapin_charge_page - charge a newly allocated page for swapin
+ * @page: page to charge
+ * @mm: mm context of the victim
+ * @gfp: reclaim mode
+ * @entry: swap entry for which the page is allocated
+ *
+ * This function charges a page allocated for swapin. Please call this before
+ * adding the page to the swapcache.
+ *
+ * Returns 0 on success. Otherwise, an error code is returned.
+ */
+int mem_cgroup_swapin_charge_page(struct page *page, struct mm_struct *mm,
+                                 gfp_t gfp, swp_entry_t entry)
+{
+       struct mem_cgroup *memcg;
+       unsigned short id;
+       int ret;
 
-       if (!memcg)
-               memcg = get_mem_cgroup_from_mm(mm);
+       if (mem_cgroup_disabled())
+               return 0;
 
-       ret = try_charge(memcg, gfp_mask, nr_pages);
-       if (ret)
-               goto out_put;
+       id = lookup_swap_cgroup_id(entry);
+       rcu_read_lock();
+       memcg = mem_cgroup_from_id(id);
+       if (!memcg || !css_tryget_online(&memcg->css))
+               memcg = get_mem_cgroup_from_mm(mm);
+       rcu_read_unlock();
 
-       css_get(&memcg->css);
-       commit_charge(page, memcg);
+       ret = __mem_cgroup_charge(page, memcg, gfp);
 
-       local_irq_disable();
-       mem_cgroup_charge_statistics(memcg, page, nr_pages);
-       memcg_check_events(memcg, page);
-       local_irq_enable();
+       css_put(&memcg->css);
+       return ret;
+}
 
+/*
+ * mem_cgroup_swapin_uncharge_swap - uncharge swap slot
+ * @entry: swap entry for which the page is charged
+ *
+ * Call this function after successfully adding the charged page to swapcache.
+ *
+ * Note: This function assumes the page for which swap slot is being uncharged
+ * is order 0 page.
+ */
+void mem_cgroup_swapin_uncharge_swap(swp_entry_t entry)
+{
        /*
         * Cgroup1's unified memory+swap counter has been charged with the
         * new swapcache page, finish the transfer by uncharging the swap
@@ -6753,25 +6620,19 @@ int mem_cgroup_charge(struct page *page, struct mm_struct *mm, gfp_t gfp_mask)
         * correspond 1:1 to page and swap slot lifetimes: we charge the
         * page to memory here, and uncharge swap when the slot is freed.
         */
-       if (do_memsw_account() && PageSwapCache(page)) {
-               swp_entry_t entry = { .val = page_private(page) };
+       if (!mem_cgroup_disabled() && do_memsw_account()) {
                /*
                 * The swap entry might not get freed for a long time,
                 * let's not wait for it.  The page already received a
                 * memory+swap charge, drop the swap entry duplicate.
                 */
-               mem_cgroup_uncharge_swap(entry, nr_pages);
+               mem_cgroup_uncharge_swap(entry, 1);
        }
-
-out_put:
-       css_put(&memcg->css);
-out:
-       return ret;
 }
 
 struct uncharge_gather {
        struct mem_cgroup *memcg;
-       unsigned long nr_pages;
+       unsigned long nr_memory;
        unsigned long pgpgout;
        unsigned long nr_kmem;
        struct page *dummy_page;
@@ -6786,10 +6647,10 @@ static void uncharge_batch(const struct uncharge_gather *ug)
 {
        unsigned long flags;
 
-       if (!mem_cgroup_is_root(ug->memcg)) {
-               page_counter_uncharge(&ug->memcg->memory, ug->nr_pages);
+       if (ug->nr_memory) {
+               page_counter_uncharge(&ug->memcg->memory, ug->nr_memory);
                if (do_memsw_account())
-                       page_counter_uncharge(&ug->memcg->memsw, ug->nr_pages);
+                       page_counter_uncharge(&ug->memcg->memsw, ug->nr_memory);
                if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) && ug->nr_kmem)
                        page_counter_uncharge(&ug->memcg->kmem, ug->nr_kmem);
                memcg_oom_recover(ug->memcg);
@@ -6797,7 +6658,7 @@ static void uncharge_batch(const struct uncharge_gather *ug)
 
        local_irq_save(flags);
        __count_memcg_events(ug->memcg, PGPGOUT, ug->pgpgout);
-       __this_cpu_add(ug->memcg->vmstats_percpu->nr_page_events, ug->nr_pages);
+       __this_cpu_add(ug->memcg->vmstats_percpu->nr_page_events, ug->nr_memory);
        memcg_check_events(ug->memcg, ug->dummy_page);
        local_irq_restore(flags);
 
@@ -6808,40 +6669,60 @@ static void uncharge_batch(const struct uncharge_gather *ug)
 static void uncharge_page(struct page *page, struct uncharge_gather *ug)
 {
        unsigned long nr_pages;
+       struct mem_cgroup *memcg;
+       struct obj_cgroup *objcg;
 
        VM_BUG_ON_PAGE(PageLRU(page), page);
 
-       if (!page_memcg(page))
-               return;
-
        /*
         * Nobody should be changing or seriously looking at
-        * page_memcg(page) at this point, we have fully
+        * page memcg or objcg at this point, we have fully
         * exclusive access to the page.
         */
+       if (PageMemcgKmem(page)) {
+               objcg = __page_objcg(page);
+               /*
+                * This get matches the put at the end of the function and
+                * kmem pages do not hold memcg references anymore.
+                */
+               memcg = get_mem_cgroup_from_objcg(objcg);
+       } else {
+               memcg = __page_memcg(page);
+       }
 
-       if (ug->memcg != page_memcg(page)) {
+       if (!memcg)
+               return;
+
+       if (ug->memcg != memcg) {
                if (ug->memcg) {
                        uncharge_batch(ug);
                        uncharge_gather_clear(ug);
                }
-               ug->memcg = page_memcg(page);
+               ug->memcg = memcg;
+               ug->dummy_page = page;
 
                /* pairs with css_put in uncharge_batch */
-               css_get(&ug->memcg->css);
+               css_get(&memcg->css);
        }
 
        nr_pages = compound_nr(page);
-       ug->nr_pages += nr_pages;
 
-       if (PageMemcgKmem(page))
+       if (PageMemcgKmem(page)) {
+               ug->nr_memory += nr_pages;
                ug->nr_kmem += nr_pages;
-       else
+
+               page->memcg_data = 0;
+               obj_cgroup_put(objcg);
+       } else {
+               /* LRU pages aren't accounted at the root level */
+               if (!mem_cgroup_is_root(memcg))
+                       ug->nr_memory += nr_pages;
                ug->pgpgout++;
 
-       ug->dummy_page = page;
-       page->memcg_data = 0;
-       css_put(&ug->memcg->css);
+               page->memcg_data = 0;
+       }
+
+       css_put(&memcg->css);
 }
 
 /**