memcg: completely decouple memcg and obj stocks
authorShakeel Butt <shakeel.butt@linux.dev>
Tue, 6 May 2025 22:55:32 +0000 (15:55 -0700)
committerAndrew Morton <akpm@linux-foundation.org>
Tue, 13 May 2025 23:28:08 +0000 (16:28 -0700)
Let's completely decouple the memcg and obj per-cpu stocks.  This will
enable us to make memcg per-cpu stocks to used without disabling irqs.
Also it will enable us to make obj stocks nmi safe independently which is
required to make kmalloc/slab safe for allocations from nmi context.

Link: https://lkml.kernel.org/r/20250506225533.2580386-4-shakeel.butt@linux.dev
Signed-off-by: Shakeel Butt <shakeel.butt@linux.dev>
Acked-by: Vlastimil Babka <vbabka@suse.cz>
Cc: Alexei Starovoitov <ast@kernel.org>
Cc: Eric Dumaze <edumazet@google.com>
Cc: Jakub Kacinski <kuba@kernel.org>
Cc: Johannes Weiner <hannes@cmpxchg.org>
Cc: Michal Hocko <mhocko@kernel.org>
Cc: Muchun Song <muchun.song@linux.dev>
Cc: Roman Gushchin <roman.gushchin@linux.dev>
Cc: Sebastian Andrzej Siewior <bigeasy@linutronix.de>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
mm/memcontrol.c

index 352aaae..1f8611e 100644 (file)
@@ -1778,12 +1778,22 @@ void mem_cgroup_print_oom_group(struct mem_cgroup *memcg)
  * nr_pages in a single cacheline. This may change in future.
  */
 #define NR_MEMCG_STOCK 7
+#define FLUSHING_CACHED_CHARGE 0
 struct memcg_stock_pcp {
-       local_trylock_t memcg_lock;
+       local_trylock_t lock;
        uint8_t nr_pages[NR_MEMCG_STOCK];
        struct mem_cgroup *cached[NR_MEMCG_STOCK];
 
-       local_trylock_t obj_lock;
+       struct work_struct work;
+       unsigned long flags;
+};
+
+static DEFINE_PER_CPU_ALIGNED(struct memcg_stock_pcp, memcg_stock) = {
+       .lock = INIT_LOCAL_TRYLOCK(lock),
+};
+
+struct obj_stock_pcp {
+       local_trylock_t lock;
        unsigned int nr_bytes;
        struct obj_cgroup *cached_objcg;
        struct pglist_data *cached_pgdat;
@@ -1792,16 +1802,16 @@ struct memcg_stock_pcp {
 
        struct work_struct work;
        unsigned long flags;
-#define FLUSHING_CACHED_CHARGE 0
 };
-static DEFINE_PER_CPU_ALIGNED(struct memcg_stock_pcp, memcg_stock) = {
-       .memcg_lock = INIT_LOCAL_TRYLOCK(memcg_lock),
-       .obj_lock = INIT_LOCAL_TRYLOCK(obj_lock),
+
+static DEFINE_PER_CPU_ALIGNED(struct obj_stock_pcp, obj_stock) = {
+       .lock = INIT_LOCAL_TRYLOCK(lock),
 };
+
 static DEFINE_MUTEX(percpu_charge_mutex);
 
-static void drain_obj_stock(struct memcg_stock_pcp *stock);
-static bool obj_stock_flush_required(struct memcg_stock_pcp *stock,
+static void drain_obj_stock(struct obj_stock_pcp *stock);
+static bool obj_stock_flush_required(struct obj_stock_pcp *stock,
                                     struct mem_cgroup *root_memcg);
 
 /**
@@ -1824,7 +1834,7 @@ static bool consume_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
        int i;
 
        if (nr_pages > MEMCG_CHARGE_BATCH ||
-           !local_trylock_irqsave(&memcg_stock.memcg_lock, flags))
+           !local_trylock_irqsave(&memcg_stock.lock, flags))
                return ret;
 
        stock = this_cpu_ptr(&memcg_stock);
@@ -1841,7 +1851,7 @@ static bool consume_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
                break;
        }
 
-       local_unlock_irqrestore(&memcg_stock.memcg_lock, flags);
+       local_unlock_irqrestore(&memcg_stock.lock, flags);
 
        return ret;
 }
@@ -1882,7 +1892,7 @@ static void drain_stock_fully(struct memcg_stock_pcp *stock)
                drain_stock(stock, i);
 }
 
-static void drain_local_stock(struct work_struct *dummy)
+static void drain_local_memcg_stock(struct work_struct *dummy)
 {
        struct memcg_stock_pcp *stock;
        unsigned long flags;
@@ -1890,16 +1900,30 @@ static void drain_local_stock(struct work_struct *dummy)
        if (WARN_ONCE(!in_task(), "drain in non-task context"))
                return;
 
-       local_lock_irqsave(&memcg_stock.obj_lock, flags);
-       stock = this_cpu_ptr(&memcg_stock);
-       drain_obj_stock(stock);
-       local_unlock_irqrestore(&memcg_stock.obj_lock, flags);
+       local_lock_irqsave(&memcg_stock.lock, flags);
 
-       local_lock_irqsave(&memcg_stock.memcg_lock, flags);
        stock = this_cpu_ptr(&memcg_stock);
        drain_stock_fully(stock);
        clear_bit(FLUSHING_CACHED_CHARGE, &stock->flags);
-       local_unlock_irqrestore(&memcg_stock.memcg_lock, flags);
+
+       local_unlock_irqrestore(&memcg_stock.lock, flags);
+}
+
+static void drain_local_obj_stock(struct work_struct *dummy)
+{
+       struct obj_stock_pcp *stock;
+       unsigned long flags;
+
+       if (WARN_ONCE(!in_task(), "drain in non-task context"))
+               return;
+
+       local_lock_irqsave(&obj_stock.lock, flags);
+
+       stock = this_cpu_ptr(&obj_stock);
+       drain_obj_stock(stock);
+       clear_bit(FLUSHING_CACHED_CHARGE, &stock->flags);
+
+       local_unlock_irqrestore(&obj_stock.lock, flags);
 }
 
 static void refill_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
@@ -1922,10 +1946,10 @@ static void refill_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
        VM_WARN_ON_ONCE(mem_cgroup_is_root(memcg));
 
        if (nr_pages > MEMCG_CHARGE_BATCH ||
-           !local_trylock_irqsave(&memcg_stock.memcg_lock, flags)) {
+           !local_trylock_irqsave(&memcg_stock.lock, flags)) {
                /*
                 * In case of larger than batch refill or unlikely failure to
-                * lock the percpu memcg_lock, uncharge memcg directly.
+                * lock the percpu memcg_stock.lock, uncharge memcg directly.
                 */
                memcg_uncharge(memcg, nr_pages);
                return;
@@ -1957,23 +1981,17 @@ static void refill_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
                WRITE_ONCE(stock->nr_pages[i], nr_pages);
        }
 
-       local_unlock_irqrestore(&memcg_stock.memcg_lock, flags);
+       local_unlock_irqrestore(&memcg_stock.lock, flags);
 }
 
-static bool is_drain_needed(struct memcg_stock_pcp *stock,
-                           struct mem_cgroup *root_memcg)
+static bool is_memcg_drain_needed(struct memcg_stock_pcp *stock,
+                                 struct mem_cgroup *root_memcg)
 {
        struct mem_cgroup *memcg;
        bool flush = false;
        int i;
 
        rcu_read_lock();
-
-       if (obj_stock_flush_required(stock, root_memcg)) {
-               flush = true;
-               goto out;
-       }
-
        for (i = 0; i < NR_MEMCG_STOCK; ++i) {
                memcg = READ_ONCE(stock->cached[i]);
                if (!memcg)
@@ -1985,7 +2003,6 @@ static bool is_drain_needed(struct memcg_stock_pcp *stock,
                        break;
                }
        }
-out:
        rcu_read_unlock();
        return flush;
 }
@@ -2010,15 +2027,27 @@ void drain_all_stock(struct mem_cgroup *root_memcg)
        migrate_disable();
        curcpu = smp_processor_id();
        for_each_online_cpu(cpu) {
-               struct memcg_stock_pcp *stock = &per_cpu(memcg_stock, cpu);
-               bool flush = is_drain_needed(stock, root_memcg);
+               struct memcg_stock_pcp *memcg_st = &per_cpu(memcg_stock, cpu);
+               struct obj_stock_pcp *obj_st = &per_cpu(obj_stock, cpu);
 
-               if (flush &&
-                   !test_and_set_bit(FLUSHING_CACHED_CHARGE, &stock->flags)) {
+               if (!test_bit(FLUSHING_CACHED_CHARGE, &memcg_st->flags) &&
+                   is_memcg_drain_needed(memcg_st, root_memcg) &&
+                   !test_and_set_bit(FLUSHING_CACHED_CHARGE,
+                                     &memcg_st->flags)) {
                        if (cpu == curcpu)
-                               drain_local_stock(&stock->work);
+                               drain_local_memcg_stock(&memcg_st->work);
                        else if (!cpu_is_isolated(cpu))
-                               schedule_work_on(cpu, &stock->work);
+                               schedule_work_on(cpu, &memcg_st->work);
+               }
+
+               if (!test_bit(FLUSHING_CACHED_CHARGE, &obj_st->flags) &&
+                   obj_stock_flush_required(obj_st, root_memcg) &&
+                   !test_and_set_bit(FLUSHING_CACHED_CHARGE,
+                                     &obj_st->flags)) {
+                       if (cpu == curcpu)
+                               drain_local_obj_stock(&obj_st->work);
+                       else if (!cpu_is_isolated(cpu))
+                               schedule_work_on(cpu, &obj_st->work);
                }
        }
        migrate_enable();
@@ -2027,18 +2056,18 @@ void drain_all_stock(struct mem_cgroup *root_memcg)
 
 static int memcg_hotplug_cpu_dead(unsigned int cpu)
 {
-       struct memcg_stock_pcp *stock;
+       struct obj_stock_pcp *obj_st;
        unsigned long flags;
 
-       stock = &per_cpu(memcg_stock, cpu);
+       obj_st = &per_cpu(obj_stock, cpu);
 
-       /* drain_obj_stock requires obj_lock */
-       local_lock_irqsave(&memcg_stock.obj_lock, flags);
-       drain_obj_stock(stock);
-       local_unlock_irqrestore(&memcg_stock.obj_lock, flags);
+       /* drain_obj_stock requires objstock.lock */
+       local_lock_irqsave(&obj_stock.lock, flags);
+       drain_obj_stock(obj_st);
+       local_unlock_irqrestore(&obj_stock.lock, flags);
 
        /* no need for the local lock */
-       drain_stock_fully(stock);
+       drain_stock_fully(&per_cpu(memcg_stock, cpu));
 
        return 0;
 }
@@ -2835,7 +2864,7 @@ void __memcg_kmem_uncharge_page(struct page *page, int order)
 }
 
 static void __account_obj_stock(struct obj_cgroup *objcg,
-                               struct memcg_stock_pcp *stock, int nr,
+                               struct obj_stock_pcp *stock, int nr,
                                struct pglist_data *pgdat, enum node_stat_item idx)
 {
        int *bytes;
@@ -2886,13 +2915,13 @@ static void __account_obj_stock(struct obj_cgroup *objcg,
 static bool consume_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes,
                              struct pglist_data *pgdat, enum node_stat_item idx)
 {
-       struct memcg_stock_pcp *stock;
+       struct obj_stock_pcp *stock;
        unsigned long flags;
        bool ret = false;
 
-       local_lock_irqsave(&memcg_stock.obj_lock, flags);
+       local_lock_irqsave(&obj_stock.lock, flags);
 
-       stock = this_cpu_ptr(&memcg_stock);
+       stock = this_cpu_ptr(&obj_stock);
        if (objcg == READ_ONCE(stock->cached_objcg) && stock->nr_bytes >= nr_bytes) {
                stock->nr_bytes -= nr_bytes;
                ret = true;
@@ -2901,12 +2930,12 @@ static bool consume_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes,
                        __account_obj_stock(objcg, stock, nr_bytes, pgdat, idx);
        }
 
-       local_unlock_irqrestore(&memcg_stock.obj_lock, flags);
+       local_unlock_irqrestore(&obj_stock.lock, flags);
 
        return ret;
 }
 
-static void drain_obj_stock(struct memcg_stock_pcp *stock)
+static void drain_obj_stock(struct obj_stock_pcp *stock)
 {
        struct obj_cgroup *old = READ_ONCE(stock->cached_objcg);
 
@@ -2967,32 +2996,35 @@ static void drain_obj_stock(struct memcg_stock_pcp *stock)
        obj_cgroup_put(old);
 }
 
-static bool obj_stock_flush_required(struct memcg_stock_pcp *stock,
+static bool obj_stock_flush_required(struct obj_stock_pcp *stock,
                                     struct mem_cgroup *root_memcg)
 {
        struct obj_cgroup *objcg = READ_ONCE(stock->cached_objcg);
        struct mem_cgroup *memcg;
+       bool flush = false;
 
+       rcu_read_lock();
        if (objcg) {
                memcg = obj_cgroup_memcg(objcg);
                if (memcg && mem_cgroup_is_descendant(memcg, root_memcg))
-                       return true;
+                       flush = true;
        }
+       rcu_read_unlock();
 
-       return false;
+       return flush;
 }
 
 static void refill_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes,
                bool allow_uncharge, int nr_acct, struct pglist_data *pgdat,
                enum node_stat_item idx)
 {
-       struct memcg_stock_pcp *stock;
+       struct obj_stock_pcp *stock;
        unsigned long flags;
        unsigned int nr_pages = 0;
 
-       local_lock_irqsave(&memcg_stock.obj_lock, flags);
+       local_lock_irqsave(&obj_stock.lock, flags);
 
-       stock = this_cpu_ptr(&memcg_stock);
+       stock = this_cpu_ptr(&obj_stock);
        if (READ_ONCE(stock->cached_objcg) != objcg) { /* reset if necessary */
                drain_obj_stock(stock);
                obj_cgroup_get(objcg);
@@ -3012,7 +3044,7 @@ static void refill_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes,
                stock->nr_bytes &= (PAGE_SIZE - 1);
        }
 
-       local_unlock_irqrestore(&memcg_stock.obj_lock, flags);
+       local_unlock_irqrestore(&obj_stock.lock, flags);
 
        if (nr_pages)
                obj_cgroup_uncharge_pages(objcg, nr_pages);
@@ -5077,9 +5109,12 @@ int __init mem_cgroup_init(void)
        cpuhp_setup_state_nocalls(CPUHP_MM_MEMCQ_DEAD, "mm/memctrl:dead", NULL,
                                  memcg_hotplug_cpu_dead);
 
-       for_each_possible_cpu(cpu)
+       for_each_possible_cpu(cpu) {
                INIT_WORK(&per_cpu_ptr(&memcg_stock, cpu)->work,
-                         drain_local_stock);
+                         drain_local_memcg_stock);
+               INIT_WORK(&per_cpu_ptr(&obj_stock, cpu)->work,
+                         drain_local_obj_stock);
+       }
 
        memcg_size = struct_size_t(struct mem_cgroup, nodeinfo, nr_node_ids);
        memcg_cachep = kmem_cache_create("mem_cgroup", memcg_size, 0,