mm: memcg/slab: allocate obj_cgroups for non-root slab pages
[linux-2.6-microblaze.git] / mm / memcontrol.c
index 13f559a..e6cd4c0 100644 (file)
@@ -257,6 +257,98 @@ struct cgroup_subsys_state *vmpressure_to_css(struct vmpressure *vmpr)
 }
 
 #ifdef CONFIG_MEMCG_KMEM
+extern spinlock_t css_set_lock;
+
+static void obj_cgroup_release(struct percpu_ref *ref)
+{
+       struct obj_cgroup *objcg = container_of(ref, struct obj_cgroup, refcnt);
+       struct mem_cgroup *memcg;
+       unsigned int nr_bytes;
+       unsigned int nr_pages;
+       unsigned long flags;
+
+       /*
+        * At this point all allocated objects are freed, and
+        * objcg->nr_charged_bytes can't have an arbitrary byte value.
+        * However, it can be PAGE_SIZE or (x * PAGE_SIZE).
+        *
+        * The following sequence can lead to it:
+        * 1) CPU0: objcg == stock->cached_objcg
+        * 2) CPU1: we do a small allocation (e.g. 92 bytes),
+        *          PAGE_SIZE bytes are charged
+        * 3) CPU1: a process from another memcg is allocating something,
+        *          the stock if flushed,
+        *          objcg->nr_charged_bytes = PAGE_SIZE - 92
+        * 5) CPU0: we do release this object,
+        *          92 bytes are added to stock->nr_bytes
+        * 6) CPU0: stock is flushed,
+        *          92 bytes are added to objcg->nr_charged_bytes
+        *
+        * In the result, nr_charged_bytes == PAGE_SIZE.
+        * This page will be uncharged in obj_cgroup_release().
+        */
+       nr_bytes = atomic_read(&objcg->nr_charged_bytes);
+       WARN_ON_ONCE(nr_bytes & (PAGE_SIZE - 1));
+       nr_pages = nr_bytes >> PAGE_SHIFT;
+
+       spin_lock_irqsave(&css_set_lock, flags);
+       memcg = obj_cgroup_memcg(objcg);
+       if (nr_pages)
+               __memcg_kmem_uncharge(memcg, nr_pages);
+       list_del(&objcg->list);
+       mem_cgroup_put(memcg);
+       spin_unlock_irqrestore(&css_set_lock, flags);
+
+       percpu_ref_exit(ref);
+       kfree_rcu(objcg, rcu);
+}
+
+static struct obj_cgroup *obj_cgroup_alloc(void)
+{
+       struct obj_cgroup *objcg;
+       int ret;
+
+       objcg = kzalloc(sizeof(struct obj_cgroup), GFP_KERNEL);
+       if (!objcg)
+               return NULL;
+
+       ret = percpu_ref_init(&objcg->refcnt, obj_cgroup_release, 0,
+                             GFP_KERNEL);
+       if (ret) {
+               kfree(objcg);
+               return NULL;
+       }
+       INIT_LIST_HEAD(&objcg->list);
+       return objcg;
+}
+
+static void memcg_reparent_objcgs(struct mem_cgroup *memcg,
+                                 struct mem_cgroup *parent)
+{
+       struct obj_cgroup *objcg, *iter;
+
+       objcg = rcu_replace_pointer(memcg->objcg, NULL, true);
+
+       spin_lock_irq(&css_set_lock);
+
+       /* Move active objcg to the parent's list */
+       xchg(&objcg->memcg, parent);
+       css_get(&parent->css);
+       list_add(&objcg->list, &parent->objcg_list);
+
+       /* Move already reparented objcgs to the parent's list */
+       list_for_each_entry(iter, &memcg->objcg_list, list) {
+               css_get(&parent->css);
+               xchg(&iter->memcg, parent);
+               css_put(&memcg->css);
+       }
+       list_splice(&memcg->objcg_list, &parent->objcg_list);
+
+       spin_unlock_irq(&css_set_lock);
+
+       percpu_ref_kill(&objcg->refcnt);
+}
+
 /*
  * This will be the memcg's index in each cache's ->memcg_params.memcg_caches.
  * The main reason for not using cgroup id for this:
@@ -477,10 +569,21 @@ ino_t page_cgroup_ino(struct page *page)
        unsigned long ino = 0;
 
        rcu_read_lock();
-       if (PageSlab(page) && !PageTail(page))
+       if (PageSlab(page) && !PageTail(page)) {
                memcg = memcg_from_slab_page(page);
-       else
-               memcg = READ_ONCE(page->mem_cgroup);
+       } else {
+               memcg = page->mem_cgroup;
+
+               /*
+                * The lowest bit set means that memcg isn't a valid
+                * memcg pointer, but a obj_cgroups pointer.
+                * In this case the page is shared and doesn't belong
+                * to any specific memory cgroup.
+                */
+               if ((unsigned long) memcg & 0x1UL)
+                       memcg = NULL;
+       }
+
        while (memcg && !(memcg->css.flags & CSS_ONLINE))
                memcg = parent_mem_cgroup(memcg);
        if (memcg)
@@ -681,13 +784,16 @@ mem_cgroup_largest_soft_limit_node(struct mem_cgroup_tree_per_node *mctz)
  */
 void __mod_memcg_state(struct mem_cgroup *memcg, int idx, int val)
 {
-       long x;
+       long x, threshold = MEMCG_CHARGE_BATCH;
 
        if (mem_cgroup_disabled())
                return;
 
+       if (vmstat_item_in_bytes(idx))
+               threshold <<= PAGE_SHIFT;
+
        x = val + __this_cpu_read(memcg->vmstats_percpu->stat[idx]);
-       if (unlikely(abs(x) > MEMCG_CHARGE_BATCH)) {
+       if (unlikely(abs(x) > threshold)) {
                struct mem_cgroup *mi;
 
                /*
@@ -713,29 +819,12 @@ parent_nodeinfo(struct mem_cgroup_per_node *pn, int nid)
        return mem_cgroup_nodeinfo(parent, nid);
 }
 
-/**
- * __mod_lruvec_state - update lruvec memory statistics
- * @lruvec: the lruvec
- * @idx: the stat item
- * @val: delta to add to the counter, can be negative
- *
- * The lruvec is the intersection of the NUMA node and a cgroup. This
- * function updates the all three counters that are affected by a
- * change of state at this level: per-node, per-cgroup, per-lruvec.
- */
-void __mod_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
-                       int val)
+void __mod_memcg_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
+                             int val)
 {
-       pg_data_t *pgdat = lruvec_pgdat(lruvec);
        struct mem_cgroup_per_node *pn;
        struct mem_cgroup *memcg;
-       long x;
-
-       /* Update node */
-       __mod_node_page_state(pgdat, idx, val);
-
-       if (mem_cgroup_disabled())
-               return;
+       long x, threshold = MEMCG_CHARGE_BATCH;
 
        pn = container_of(lruvec, struct mem_cgroup_per_node, lruvec);
        memcg = pn->memcg;
@@ -746,8 +835,12 @@ void __mod_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
        /* Update lruvec */
        __this_cpu_add(pn->lruvec_stat_local->count[idx], val);
 
+       if (vmstat_item_in_bytes(idx))
+               threshold <<= PAGE_SHIFT;
+
        x = val + __this_cpu_read(pn->lruvec_stat_cpu->count[idx]);
-       if (unlikely(abs(x) > MEMCG_CHARGE_BATCH)) {
+       if (unlikely(abs(x) > threshold)) {
+               pg_data_t *pgdat = lruvec_pgdat(lruvec);
                struct mem_cgroup_per_node *pi;
 
                for (pi = pn; pi; pi = parent_nodeinfo(pi, pgdat->node_id))
@@ -757,6 +850,27 @@ void __mod_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
        __this_cpu_write(pn->lruvec_stat_cpu->count[idx], x);
 }
 
+/**
+ * __mod_lruvec_state - update lruvec memory statistics
+ * @lruvec: the lruvec
+ * @idx: the stat item
+ * @val: delta to add to the counter, can be negative
+ *
+ * The lruvec is the intersection of the NUMA node and a cgroup. This
+ * function updates the all three counters that are affected by a
+ * change of state at this level: per-node, per-cgroup, per-lruvec.
+ */
+void __mod_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
+                       int val)
+{
+       /* Update node */
+       __mod_node_page_state(lruvec_pgdat(lruvec), idx, val);
+
+       /* Update memcg and lruvec */
+       if (!mem_cgroup_disabled())
+               __mod_memcg_lruvec_state(lruvec, idx, val);
+}
+
 void __mod_lruvec_slab_state(void *p, enum node_stat_item idx, int val)
 {
        pg_data_t *pgdat = page_pgdat(virt_to_page(p));
@@ -1004,7 +1118,7 @@ struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
                                   struct mem_cgroup *prev,
                                   struct mem_cgroup_reclaim_cookie *reclaim)
 {
-       struct mem_cgroup_reclaim_iter *uninitialized_var(iter);
+       struct mem_cgroup_reclaim_iter *iter;
        struct cgroup_subsys_state *css = NULL;
        struct mem_cgroup *memcg = NULL;
        struct mem_cgroup *pos = NULL;
@@ -1380,9 +1494,8 @@ static char *memory_stat_format(struct mem_cgroup *memcg)
                       (u64)memcg_page_state(memcg, MEMCG_KERNEL_STACK_KB) *
                       1024);
        seq_buf_printf(&s, "slab %llu\n",
-                      (u64)(memcg_page_state(memcg, NR_SLAB_RECLAIMABLE) +
-                            memcg_page_state(memcg, NR_SLAB_UNRECLAIMABLE)) *
-                      PAGE_SIZE);
+                      (u64)(memcg_page_state(memcg, NR_SLAB_RECLAIMABLE_B) +
+                            memcg_page_state(memcg, NR_SLAB_UNRECLAIMABLE_B)));
        seq_buf_printf(&s, "sock %llu\n",
                       (u64)memcg_page_state(memcg, MEMCG_SOCK) *
                       PAGE_SIZE);
@@ -1412,11 +1525,9 @@ static char *memory_stat_format(struct mem_cgroup *memcg)
                               PAGE_SIZE);
 
        seq_buf_printf(&s, "slab_reclaimable %llu\n",
-                      (u64)memcg_page_state(memcg, NR_SLAB_RECLAIMABLE) *
-                      PAGE_SIZE);
+                      (u64)memcg_page_state(memcg, NR_SLAB_RECLAIMABLE_B));
        seq_buf_printf(&s, "slab_unreclaimable %llu\n",
-                      (u64)memcg_page_state(memcg, NR_SLAB_UNRECLAIMABLE) *
-                      PAGE_SIZE);
+                      (u64)memcg_page_state(memcg, NR_SLAB_UNRECLAIMABLE_B));
 
        /* Accumulated memory events */
 
@@ -2039,6 +2150,12 @@ EXPORT_SYMBOL(unlock_page_memcg);
 struct memcg_stock_pcp {
        struct mem_cgroup *cached; /* this never be root cgroup */
        unsigned int nr_pages;
+
+#ifdef CONFIG_MEMCG_KMEM
+       struct obj_cgroup *cached_objcg;
+       unsigned int nr_bytes;
+#endif
+
        struct work_struct work;
        unsigned long flags;
 #define FLUSHING_CACHED_CHARGE 0
@@ -2046,6 +2163,22 @@ struct memcg_stock_pcp {
 static DEFINE_PER_CPU(struct memcg_stock_pcp, memcg_stock);
 static DEFINE_MUTEX(percpu_charge_mutex);
 
+#ifdef CONFIG_MEMCG_KMEM
+static void drain_obj_stock(struct memcg_stock_pcp *stock);
+static bool obj_stock_flush_required(struct memcg_stock_pcp *stock,
+                                    struct mem_cgroup *root_memcg);
+
+#else
+static inline void drain_obj_stock(struct memcg_stock_pcp *stock)
+{
+}
+static bool obj_stock_flush_required(struct memcg_stock_pcp *stock,
+                                    struct mem_cgroup *root_memcg)
+{
+       return false;
+}
+#endif
+
 /**
  * consume_stock: Try to consume stocked charge on this cpu.
  * @memcg: memcg to consume from.
@@ -2086,13 +2219,17 @@ static void drain_stock(struct memcg_stock_pcp *stock)
 {
        struct mem_cgroup *old = stock->cached;
 
+       if (!old)
+               return;
+
        if (stock->nr_pages) {
                page_counter_uncharge(&old->memory, stock->nr_pages);
                if (do_memsw_account())
                        page_counter_uncharge(&old->memsw, stock->nr_pages);
-               css_put_many(&old->css, stock->nr_pages);
                stock->nr_pages = 0;
        }
+
+       css_put(&old->css);
        stock->cached = NULL;
 }
 
@@ -2108,6 +2245,7 @@ static void drain_local_stock(struct work_struct *dummy)
        local_irq_save(flags);
 
        stock = this_cpu_ptr(&memcg_stock);
+       drain_obj_stock(stock);
        drain_stock(stock);
        clear_bit(FLUSHING_CACHED_CHARGE, &stock->flags);
 
@@ -2128,6 +2266,7 @@ static void refill_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
        stock = this_cpu_ptr(&memcg_stock);
        if (stock->cached != memcg) { /* reset if necessary */
                drain_stock(stock);
+               css_get(&memcg->css);
                stock->cached = memcg;
        }
        stock->nr_pages += nr_pages;
@@ -2166,6 +2305,8 @@ static void drain_all_stock(struct mem_cgroup *root_memcg)
                if (memcg && stock->nr_pages &&
                    mem_cgroup_is_descendant(memcg, root_memcg))
                        flush = true;
+               if (obj_stock_flush_required(stock, root_memcg))
+                       flush = true;
                rcu_read_unlock();
 
                if (flush &&
@@ -2586,12 +2727,10 @@ force:
        page_counter_charge(&memcg->memory, nr_pages);
        if (do_memsw_account())
                page_counter_charge(&memcg->memsw, nr_pages);
-       css_get_many(&memcg->css, nr_pages);
 
        return 0;
 
 done_restock:
-       css_get_many(&memcg->css, batch);
        if (batch > nr_pages)
                refill_stock(memcg, batch - nr_pages);
 
@@ -2649,8 +2788,6 @@ static void cancel_charge(struct mem_cgroup *memcg, unsigned int nr_pages)
        page_counter_uncharge(&memcg->memory, nr_pages);
        if (do_memsw_account())
                page_counter_uncharge(&memcg->memsw, nr_pages);
-
-       css_put_many(&memcg->css, nr_pages);
 }
 #endif
 
@@ -2696,6 +2833,30 @@ struct mem_cgroup *mem_cgroup_from_obj(void *p)
        return page->mem_cgroup;
 }
 
+__always_inline struct obj_cgroup *get_obj_cgroup_from_current(void)
+{
+       struct obj_cgroup *objcg = NULL;
+       struct mem_cgroup *memcg;
+
+       if (unlikely(!current->mm && !current->active_memcg))
+               return NULL;
+
+       rcu_read_lock();
+       if (unlikely(current->active_memcg))
+               memcg = rcu_dereference(current->active_memcg);
+       else
+               memcg = mem_cgroup_from_task(current);
+
+       for (; memcg != root_mem_cgroup; memcg = parent_mem_cgroup(memcg)) {
+               objcg = rcu_dereference(memcg->objcg);
+               if (objcg && obj_cgroup_tryget(objcg))
+                       break;
+       }
+       rcu_read_unlock();
+
+       return objcg;
+}
+
 static int memcg_alloc_cache_id(void)
 {
        int id, size;
@@ -2958,6 +3119,7 @@ int __memcg_kmem_charge_page(struct page *page, gfp_t gfp, int order)
                if (!ret) {
                        page->mem_cgroup = memcg;
                        __SetPageKmemcg(page);
+                       return 0;
                }
        }
        css_put(&memcg->css);
@@ -2980,13 +3142,146 @@ void __memcg_kmem_uncharge_page(struct page *page, int order)
        VM_BUG_ON_PAGE(mem_cgroup_is_root(memcg), page);
        __memcg_kmem_uncharge(memcg, nr_pages);
        page->mem_cgroup = NULL;
+       css_put(&memcg->css);
 
        /* slab pages do not have PageKmemcg flag set */
        if (PageKmemcg(page))
                __ClearPageKmemcg(page);
+}
+
+static bool consume_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes)
+{
+       struct memcg_stock_pcp *stock;
+       unsigned long flags;
+       bool ret = false;
+
+       local_irq_save(flags);
+
+       stock = this_cpu_ptr(&memcg_stock);
+       if (objcg == stock->cached_objcg && stock->nr_bytes >= nr_bytes) {
+               stock->nr_bytes -= nr_bytes;
+               ret = true;
+       }
+
+       local_irq_restore(flags);
+
+       return ret;
+}
+
+static void drain_obj_stock(struct memcg_stock_pcp *stock)
+{
+       struct obj_cgroup *old = stock->cached_objcg;
+
+       if (!old)
+               return;
+
+       if (stock->nr_bytes) {
+               unsigned int nr_pages = stock->nr_bytes >> PAGE_SHIFT;
+               unsigned int nr_bytes = stock->nr_bytes & (PAGE_SIZE - 1);
 
-       css_put_many(&memcg->css, nr_pages);
+               if (nr_pages) {
+                       rcu_read_lock();
+                       __memcg_kmem_uncharge(obj_cgroup_memcg(old), nr_pages);
+                       rcu_read_unlock();
+               }
+
+               /*
+                * The leftover is flushed to the centralized per-memcg value.
+                * On the next attempt to refill obj stock it will be moved
+                * to a per-cpu stock (probably, on an other CPU), see
+                * refill_obj_stock().
+                *
+                * How often it's flushed is a trade-off between the memory
+                * limit enforcement accuracy and potential CPU contention,
+                * so it might be changed in the future.
+                */
+               atomic_add(nr_bytes, &old->nr_charged_bytes);
+               stock->nr_bytes = 0;
+       }
+
+       obj_cgroup_put(old);
+       stock->cached_objcg = NULL;
 }
+
+static bool obj_stock_flush_required(struct memcg_stock_pcp *stock,
+                                    struct mem_cgroup *root_memcg)
+{
+       struct mem_cgroup *memcg;
+
+       if (stock->cached_objcg) {
+               memcg = obj_cgroup_memcg(stock->cached_objcg);
+               if (memcg && mem_cgroup_is_descendant(memcg, root_memcg))
+                       return true;
+       }
+
+       return false;
+}
+
+static void refill_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes)
+{
+       struct memcg_stock_pcp *stock;
+       unsigned long flags;
+
+       local_irq_save(flags);
+
+       stock = this_cpu_ptr(&memcg_stock);
+       if (stock->cached_objcg != objcg) { /* reset if necessary */
+               drain_obj_stock(stock);
+               obj_cgroup_get(objcg);
+               stock->cached_objcg = objcg;
+               stock->nr_bytes = atomic_xchg(&objcg->nr_charged_bytes, 0);
+       }
+       stock->nr_bytes += nr_bytes;
+
+       if (stock->nr_bytes > PAGE_SIZE)
+               drain_obj_stock(stock);
+
+       local_irq_restore(flags);
+}
+
+int obj_cgroup_charge(struct obj_cgroup *objcg, gfp_t gfp, size_t size)
+{
+       struct mem_cgroup *memcg;
+       unsigned int nr_pages, nr_bytes;
+       int ret;
+
+       if (consume_obj_stock(objcg, size))
+               return 0;
+
+       /*
+        * In theory, memcg->nr_charged_bytes can have enough
+        * pre-charged bytes to satisfy the allocation. However,
+        * flushing memcg->nr_charged_bytes requires two atomic
+        * operations, and memcg->nr_charged_bytes can't be big,
+        * so it's better to ignore it and try grab some new pages.
+        * memcg->nr_charged_bytes will be flushed in
+        * refill_obj_stock(), called from this function or
+        * independently later.
+        */
+       rcu_read_lock();
+       memcg = obj_cgroup_memcg(objcg);
+       css_get(&memcg->css);
+       rcu_read_unlock();
+
+       nr_pages = size >> PAGE_SHIFT;
+       nr_bytes = size & (PAGE_SIZE - 1);
+
+       if (nr_bytes)
+               nr_pages += 1;
+
+       ret = __memcg_kmem_charge(memcg, gfp, nr_pages);
+       if (!ret && nr_bytes)
+               refill_obj_stock(objcg, PAGE_SIZE - nr_bytes);
+
+       css_put(&memcg->css);
+       return ret;
+}
+
+void obj_cgroup_uncharge(struct obj_cgroup *objcg, size_t size)
+{
+       refill_obj_stock(objcg, size);
+}
+
 #endif /* CONFIG_MEMCG_KMEM */
 
 #ifdef CONFIG_TRANSPARENT_HUGEPAGE
@@ -2997,13 +3292,16 @@ void __memcg_kmem_uncharge_page(struct page *page, int order)
  */
 void mem_cgroup_split_huge_fixup(struct page *head)
 {
+       struct mem_cgroup *memcg = head->mem_cgroup;
        int i;
 
        if (mem_cgroup_disabled())
                return;
 
-       for (i = 1; i < HPAGE_PMD_NR; i++)
-               head[i].mem_cgroup = head->mem_cgroup;
+       for (i = 1; i < HPAGE_PMD_NR; i++) {
+               css_get(&memcg->css);
+               head[i].mem_cgroup = memcg;
+       }
 }
 #endif /* CONFIG_TRANSPARENT_HUGEPAGE */
 
@@ -3404,6 +3702,7 @@ static void memcg_flush_percpu_vmevents(struct mem_cgroup *memcg)
 #ifdef CONFIG_MEMCG_KMEM
 static int memcg_online_kmem(struct mem_cgroup *memcg)
 {
+       struct obj_cgroup *objcg;
        int memcg_id;
 
        if (cgroup_memory_nokmem)
@@ -3416,7 +3715,16 @@ static int memcg_online_kmem(struct mem_cgroup *memcg)
        if (memcg_id < 0)
                return memcg_id;
 
-       static_branch_inc(&memcg_kmem_enabled_key);
+       objcg = obj_cgroup_alloc();
+       if (!objcg) {
+               memcg_free_cache_id(memcg_id);
+               return -ENOMEM;
+       }
+       objcg->memcg = memcg;
+       rcu_assign_pointer(memcg->objcg, objcg);
+
+       static_branch_enable(&memcg_kmem_enabled_key);
+
        /*
         * A memory cgroup is considered kmem-online as soon as it gets
         * kmemcg_id. Setting the id after enabling static branching will
@@ -3451,9 +3759,10 @@ static void memcg_offline_kmem(struct mem_cgroup *memcg)
                parent = root_mem_cgroup;
 
        /*
-        * Deactivate and reparent kmem_caches.
+        * Deactivate and reparent kmem_caches and objcgs.
         */
        memcg_deactivate_kmem_caches(memcg, parent);
+       memcg_reparent_objcgs(memcg, parent);
 
        kmemcg_id = memcg->kmemcg_id;
        BUG_ON(kmemcg_id < 0);
@@ -3486,11 +3795,6 @@ static void memcg_free_kmem(struct mem_cgroup *memcg)
        /* css_alloc() failed, offlining didn't happen */
        if (unlikely(memcg->kmem_state == KMEM_ONLINE))
                memcg_offline_kmem(memcg);
-
-       if (memcg->kmem_state == KMEM_ALLOCATED) {
-               WARN_ON(!list_empty(&memcg->kmem_caches));
-               static_branch_dec(&memcg_kmem_enabled_key);
-       }
 }
 #else
 static int memcg_online_kmem(struct mem_cgroup *memcg)
@@ -5022,6 +5326,7 @@ static struct mem_cgroup *mem_cgroup_alloc(void)
        memcg->socket_pressure = jiffies;
 #ifdef CONFIG_MEMCG_KMEM
        memcg->kmemcg_id = -1;
+       INIT_LIST_HEAD(&memcg->objcg_list);
 #endif
 #ifdef CONFIG_CGROUP_WRITEBACK
        INIT_LIST_HEAD(&memcg->cgwb_list);
@@ -5448,7 +5753,10 @@ static int mem_cgroup_move_account(struct page *page,
         */
        smp_mb();
 
-       page->mem_cgroup = to;  /* caller should have done css_get */
+       css_get(&to->css);
+       css_put(&from->css);
+
+       page->mem_cgroup = to;
 
        __unlock_page_memcg(from);
 
@@ -5669,8 +5977,6 @@ static void __mem_cgroup_clear_mc(void)
                if (!mem_cgroup_is_root(mc.to))
                        page_counter_uncharge(&mc.to->memory, mc.moved_swap);
 
-               css_put_many(&mc.to->css, mc.moved_swap);
-
                mc.moved_swap = 0;
        }
        memcg_oom_recover(from);
@@ -6498,6 +6804,7 @@ int mem_cgroup_charge(struct page *page, struct mm_struct *mm, gfp_t gfp_mask)
        if (ret)
                goto out_put;
 
+       css_get(&memcg->css);
        commit_charge(page, memcg);
 
        local_irq_disable();
@@ -6552,9 +6859,6 @@ static void uncharge_batch(const struct uncharge_gather *ug)
        __this_cpu_add(ug->memcg->vmstats_percpu->nr_page_events, ug->nr_pages);
        memcg_check_events(ug->memcg, ug->dummy_page);
        local_irq_restore(flags);
-
-       if (!mem_cgroup_is_root(ug->memcg))
-               css_put_many(&ug->memcg->css, ug->nr_pages);
 }
 
 static void uncharge_page(struct page *page, struct uncharge_gather *ug)
@@ -6592,6 +6896,7 @@ static void uncharge_page(struct page *page, struct uncharge_gather *ug)
 
        ug->dummy_page = page;
        page->mem_cgroup = NULL;
+       css_put(&ug->memcg->css);
 }
 
 static void uncharge_list(struct list_head *page_list)
@@ -6697,8 +7002,8 @@ void mem_cgroup_migrate(struct page *oldpage, struct page *newpage)
        page_counter_charge(&memcg->memory, nr_pages);
        if (do_memsw_account())
                page_counter_charge(&memcg->memsw, nr_pages);
-       css_get_many(&memcg->css, nr_pages);
 
+       css_get(&memcg->css);
        commit_charge(newpage, memcg);
 
        local_irq_save(flags);
@@ -6935,8 +7240,7 @@ void mem_cgroup_swapout(struct page *page, swp_entry_t entry)
        mem_cgroup_charge_statistics(memcg, page, -nr_entries);
        memcg_check_events(memcg, page);
 
-       if (!mem_cgroup_is_root(memcg))
-               css_put_many(&memcg->css, nr_entries);
+       css_put(&memcg->css);
 }
 
 /**