kasan, page_alloc: refactor init checks in post_alloc_hook
[linux-2.6-microblaze.git] / mm / list_lru.c
index 0cd5e89..c669d87 100644 (file)
@@ -13,6 +13,7 @@
 #include <linux/mutex.h>
 #include <linux/memcontrol.h>
 #include "slab.h"
+#include "internal.h"
 
 #ifdef CONFIG_MEMCG_KMEM
 static LIST_HEAD(memcg_list_lrus);
@@ -49,35 +50,32 @@ static int lru_shrinker_id(struct list_lru *lru)
 }
 
 static inline struct list_lru_one *
-list_lru_from_memcg_idx(struct list_lru_node *nlru, int idx)
+list_lru_from_memcg_idx(struct list_lru *lru, int nid, int idx)
 {
-       struct list_lru_memcg *memcg_lrus;
-       /*
-        * Either lock or RCU protects the array of per cgroup lists
-        * from relocation (see memcg_update_list_lru_node).
-        */
-       memcg_lrus = rcu_dereference_check(nlru->memcg_lrus,
-                                          lockdep_is_held(&nlru->lock));
-       if (memcg_lrus && idx >= 0)
-               return memcg_lrus->lru[idx];
-       return &nlru->lru;
+       if (list_lru_memcg_aware(lru) && idx >= 0) {
+               struct list_lru_memcg *mlru = xa_load(&lru->xa, idx);
+
+               return mlru ? &mlru->node[nid] : NULL;
+       }
+       return &lru->node[nid].lru;
 }
 
 static inline struct list_lru_one *
-list_lru_from_kmem(struct list_lru_node *nlru, void *ptr,
+list_lru_from_kmem(struct list_lru *lru, int nid, void *ptr,
                   struct mem_cgroup **memcg_ptr)
 {
+       struct list_lru_node *nlru = &lru->node[nid];
        struct list_lru_one *l = &nlru->lru;
        struct mem_cgroup *memcg = NULL;
 
-       if (!nlru->memcg_lrus)
+       if (!list_lru_memcg_aware(lru))
                goto out;
 
        memcg = mem_cgroup_from_obj(ptr);
        if (!memcg)
                goto out;
 
-       l = list_lru_from_memcg_idx(nlru, memcg_cache_id(memcg));
+       l = list_lru_from_memcg_idx(lru, nid, memcg_kmem_id(memcg));
 out:
        if (memcg_ptr)
                *memcg_ptr = memcg;
@@ -103,18 +101,18 @@ static inline bool list_lru_memcg_aware(struct list_lru *lru)
 }
 
 static inline struct list_lru_one *
-list_lru_from_memcg_idx(struct list_lru_node *nlru, int idx)
+list_lru_from_memcg_idx(struct list_lru *lru, int nid, int idx)
 {
-       return &nlru->lru;
+       return &lru->node[nid].lru;
 }
 
 static inline struct list_lru_one *
-list_lru_from_kmem(struct list_lru_node *nlru, void *ptr,
+list_lru_from_kmem(struct list_lru *lru, int nid, void *ptr,
                   struct mem_cgroup **memcg_ptr)
 {
        if (memcg_ptr)
                *memcg_ptr = NULL;
-       return &nlru->lru;
+       return &lru->node[nid].lru;
 }
 #endif /* CONFIG_MEMCG_KMEM */
 
@@ -127,7 +125,7 @@ bool list_lru_add(struct list_lru *lru, struct list_head *item)
 
        spin_lock(&nlru->lock);
        if (list_empty(item)) {
-               l = list_lru_from_kmem(nlru, item, &memcg);
+               l = list_lru_from_kmem(lru, nid, item, &memcg);
                list_add_tail(item, &l->list);
                /* Set shrinker bit if the first element was added */
                if (!l->nr_items++)
@@ -150,7 +148,7 @@ bool list_lru_del(struct list_lru *lru, struct list_head *item)
 
        spin_lock(&nlru->lock);
        if (!list_empty(item)) {
-               l = list_lru_from_kmem(nlru, item, NULL);
+               l = list_lru_from_kmem(lru, nid, item, NULL);
                list_del_init(item);
                l->nr_items--;
                nlru->nr_items--;
@@ -180,13 +178,12 @@ EXPORT_SYMBOL_GPL(list_lru_isolate_move);
 unsigned long list_lru_count_one(struct list_lru *lru,
                                 int nid, struct mem_cgroup *memcg)
 {
-       struct list_lru_node *nlru = &lru->node[nid];
        struct list_lru_one *l;
        long count;
 
        rcu_read_lock();
-       l = list_lru_from_memcg_idx(nlru, memcg_cache_id(memcg));
-       count = READ_ONCE(l->nr_items);
+       l = list_lru_from_memcg_idx(lru, nid, memcg_kmem_id(memcg));
+       count = l ? READ_ONCE(l->nr_items) : 0;
        rcu_read_unlock();
 
        if (unlikely(count < 0))
@@ -206,17 +203,20 @@ unsigned long list_lru_count_node(struct list_lru *lru, int nid)
 EXPORT_SYMBOL_GPL(list_lru_count_node);
 
 static unsigned long
-__list_lru_walk_one(struct list_lru_node *nlru, int memcg_idx,
+__list_lru_walk_one(struct list_lru *lru, int nid, int memcg_idx,
                    list_lru_walk_cb isolate, void *cb_arg,
                    unsigned long *nr_to_walk)
 {
-
+       struct list_lru_node *nlru = &lru->node[nid];
        struct list_lru_one *l;
        struct list_head *item, *n;
        unsigned long isolated = 0;
 
-       l = list_lru_from_memcg_idx(nlru, memcg_idx);
 restart:
+       l = list_lru_from_memcg_idx(lru, nid, memcg_idx);
+       if (!l)
+               goto out;
+
        list_for_each_safe(item, n, &l->list) {
                enum lru_status ret;
 
@@ -260,6 +260,7 @@ restart:
                        BUG();
                }
        }
+out:
        return isolated;
 }
 
@@ -272,8 +273,8 @@ list_lru_walk_one(struct list_lru *lru, int nid, struct mem_cgroup *memcg,
        unsigned long ret;
 
        spin_lock(&nlru->lock);
-       ret = __list_lru_walk_one(nlru, memcg_cache_id(memcg), isolate, cb_arg,
-                                 nr_to_walk);
+       ret = __list_lru_walk_one(lru, nid, memcg_kmem_id(memcg), isolate,
+                                 cb_arg, nr_to_walk);
        spin_unlock(&nlru->lock);
        return ret;
 }
@@ -288,8 +289,8 @@ list_lru_walk_one_irq(struct list_lru *lru, int nid, struct mem_cgroup *memcg,
        unsigned long ret;
 
        spin_lock_irq(&nlru->lock);
-       ret = __list_lru_walk_one(nlru, memcg_cache_id(memcg), isolate, cb_arg,
-                                 nr_to_walk);
+       ret = __list_lru_walk_one(lru, nid, memcg_kmem_id(memcg), isolate,
+                                 cb_arg, nr_to_walk);
        spin_unlock_irq(&nlru->lock);
        return ret;
 }
@@ -299,16 +300,20 @@ unsigned long list_lru_walk_node(struct list_lru *lru, int nid,
                                 unsigned long *nr_to_walk)
 {
        long isolated = 0;
-       int memcg_idx;
 
        isolated += list_lru_walk_one(lru, nid, NULL, isolate, cb_arg,
                                      nr_to_walk);
+
+#ifdef CONFIG_MEMCG_KMEM
        if (*nr_to_walk > 0 && list_lru_memcg_aware(lru)) {
-               for_each_memcg_cache_index(memcg_idx) {
+               struct list_lru_memcg *mlru;
+               unsigned long index;
+
+               xa_for_each(&lru->xa, index, mlru) {
                        struct list_lru_node *nlru = &lru->node[nid];
 
                        spin_lock(&nlru->lock);
-                       isolated += __list_lru_walk_one(nlru, memcg_idx,
+                       isolated += __list_lru_walk_one(lru, nid, index,
                                                        isolate, cb_arg,
                                                        nr_to_walk);
                        spin_unlock(&nlru->lock);
@@ -317,6 +322,8 @@ unsigned long list_lru_walk_node(struct list_lru *lru, int nid,
                                break;
                }
        }
+#endif
+
        return isolated;
 }
 EXPORT_SYMBOL_GPL(list_lru_walk_node);
@@ -328,204 +335,81 @@ static void init_one_lru(struct list_lru_one *l)
 }
 
 #ifdef CONFIG_MEMCG_KMEM
-static void __memcg_destroy_list_lru_node(struct list_lru_memcg *memcg_lrus,
-                                         int begin, int end)
+static struct list_lru_memcg *memcg_init_list_lru_one(gfp_t gfp)
 {
-       int i;
-
-       for (i = begin; i < end; i++)
-               kfree(memcg_lrus->lru[i]);
-}
-
-static int __memcg_init_list_lru_node(struct list_lru_memcg *memcg_lrus,
-                                     int begin, int end)
-{
-       int i;
+       int nid;
+       struct list_lru_memcg *mlru;
 
-       for (i = begin; i < end; i++) {
-               struct list_lru_one *l;
+       mlru = kmalloc(struct_size(mlru, node, nr_node_ids), gfp);
+       if (!mlru)
+               return NULL;
 
-               l = kmalloc(sizeof(struct list_lru_one), GFP_KERNEL);
-               if (!l)
-                       goto fail;
+       for_each_node(nid)
+               init_one_lru(&mlru->node[nid]);
 
-               init_one_lru(l);
-               memcg_lrus->lru[i] = l;
-       }
-       return 0;
-fail:
-       __memcg_destroy_list_lru_node(memcg_lrus, begin, i);
-       return -ENOMEM;
+       return mlru;
 }
 
-static int memcg_init_list_lru_node(struct list_lru_node *nlru)
+static void memcg_list_lru_free(struct list_lru *lru, int src_idx)
 {
-       struct list_lru_memcg *memcg_lrus;
-       int size = memcg_nr_cache_ids;
-
-       memcg_lrus = kvmalloc(struct_size(memcg_lrus, lru, size), GFP_KERNEL);
-       if (!memcg_lrus)
-               return -ENOMEM;
-
-       if (__memcg_init_list_lru_node(memcg_lrus, 0, size)) {
-               kvfree(memcg_lrus);
-               return -ENOMEM;
-       }
-       RCU_INIT_POINTER(nlru->memcg_lrus, memcg_lrus);
-
-       return 0;
-}
+       struct list_lru_memcg *mlru = xa_erase_irq(&lru->xa, src_idx);
 
-static void memcg_destroy_list_lru_node(struct list_lru_node *nlru)
-{
-       struct list_lru_memcg *memcg_lrus;
        /*
-        * This is called when shrinker has already been unregistered,
-        * and nobody can use it. So, there is no need to use kvfree_rcu().
+        * The __list_lru_walk_one() can walk the list of this node.
+        * We need kvfree_rcu() here. And the walking of the list
+        * is under lru->node[nid]->lock, which can serve as a RCU
+        * read-side critical section.
         */
-       memcg_lrus = rcu_dereference_protected(nlru->memcg_lrus, true);
-       __memcg_destroy_list_lru_node(memcg_lrus, 0, memcg_nr_cache_ids);
-       kvfree(memcg_lrus);
+       if (mlru)
+               kvfree_rcu(mlru, rcu);
 }
 
-static int memcg_update_list_lru_node(struct list_lru_node *nlru,
-                                     int old_size, int new_size)
+static inline void memcg_init_list_lru(struct list_lru *lru, bool memcg_aware)
 {
-       struct list_lru_memcg *old, *new;
-
-       BUG_ON(old_size > new_size);
-
-       old = rcu_dereference_protected(nlru->memcg_lrus,
-                                       lockdep_is_held(&list_lrus_mutex));
-       new = kvmalloc(struct_size(new, lru, new_size), GFP_KERNEL);
-       if (!new)
-               return -ENOMEM;
-
-       if (__memcg_init_list_lru_node(new, old_size, new_size)) {
-               kvfree(new);
-               return -ENOMEM;
-       }
-
-       memcpy(&new->lru, &old->lru, flex_array_size(new, lru, old_size));
-       rcu_assign_pointer(nlru->memcg_lrus, new);
-       kvfree_rcu(old, rcu);
-       return 0;
-}
-
-static void memcg_cancel_update_list_lru_node(struct list_lru_node *nlru,
-                                             int old_size, int new_size)
-{
-       struct list_lru_memcg *memcg_lrus;
-
-       memcg_lrus = rcu_dereference_protected(nlru->memcg_lrus,
-                                              lockdep_is_held(&list_lrus_mutex));
-       /* do not bother shrinking the array back to the old size, because we
-        * cannot handle allocation failures here */
-       __memcg_destroy_list_lru_node(memcg_lrus, old_size, new_size);
-}
-
-static int memcg_init_list_lru(struct list_lru *lru, bool memcg_aware)
-{
-       int i;
-
+       if (memcg_aware)
+               xa_init_flags(&lru->xa, XA_FLAGS_LOCK_IRQ);
        lru->memcg_aware = memcg_aware;
-
-       if (!memcg_aware)
-               return 0;
-
-       for_each_node(i) {
-               if (memcg_init_list_lru_node(&lru->node[i]))
-                       goto fail;
-       }
-       return 0;
-fail:
-       for (i = i - 1; i >= 0; i--) {
-               if (!lru->node[i].memcg_lrus)
-                       continue;
-               memcg_destroy_list_lru_node(&lru->node[i]);
-       }
-       return -ENOMEM;
 }
 
 static void memcg_destroy_list_lru(struct list_lru *lru)
 {
-       int i;
+       XA_STATE(xas, &lru->xa, 0);
+       struct list_lru_memcg *mlru;
 
        if (!list_lru_memcg_aware(lru))
                return;
 
-       for_each_node(i)
-               memcg_destroy_list_lru_node(&lru->node[i]);
-}
-
-static int memcg_update_list_lru(struct list_lru *lru,
-                                int old_size, int new_size)
-{
-       int i;
-
-       for_each_node(i) {
-               if (memcg_update_list_lru_node(&lru->node[i],
-                                              old_size, new_size))
-                       goto fail;
-       }
-       return 0;
-fail:
-       for (i = i - 1; i >= 0; i--) {
-               if (!lru->node[i].memcg_lrus)
-                       continue;
-
-               memcg_cancel_update_list_lru_node(&lru->node[i],
-                                                 old_size, new_size);
+       xas_lock_irq(&xas);
+       xas_for_each(&xas, mlru, ULONG_MAX) {
+               kfree(mlru);
+               xas_store(&xas, NULL);
        }
-       return -ENOMEM;
-}
-
-static void memcg_cancel_update_list_lru(struct list_lru *lru,
-                                        int old_size, int new_size)
-{
-       int i;
-
-       for_each_node(i)
-               memcg_cancel_update_list_lru_node(&lru->node[i],
-                                                 old_size, new_size);
+       xas_unlock_irq(&xas);
 }
 
-int memcg_update_all_list_lrus(int new_size)
-{
-       int ret = 0;
-       struct list_lru *lru;
-       int old_size = memcg_nr_cache_ids;
-
-       mutex_lock(&list_lrus_mutex);
-       list_for_each_entry(lru, &memcg_list_lrus, list) {
-               ret = memcg_update_list_lru(lru, old_size, new_size);
-               if (ret)
-                       goto fail;
-       }
-out:
-       mutex_unlock(&list_lrus_mutex);
-       return ret;
-fail:
-       list_for_each_entry_continue_reverse(lru, &memcg_list_lrus, list)
-               memcg_cancel_update_list_lru(lru, old_size, new_size);
-       goto out;
-}
-
-static void memcg_drain_list_lru_node(struct list_lru *lru, int nid,
-                                     int src_idx, struct mem_cgroup *dst_memcg)
+static void memcg_reparent_list_lru_node(struct list_lru *lru, int nid,
+                                        int src_idx, struct mem_cgroup *dst_memcg)
 {
        struct list_lru_node *nlru = &lru->node[nid];
        int dst_idx = dst_memcg->kmemcg_id;
        struct list_lru_one *src, *dst;
 
+       /*
+        * If there is no lru entry in this nlru, we can skip it immediately.
+        */
+       if (!READ_ONCE(nlru->nr_items))
+               return;
+
        /*
         * Since list_lru_{add,del} may be called under an IRQ-safe lock,
         * we have to use IRQ-safe primitives here to avoid deadlock.
         */
        spin_lock_irq(&nlru->lock);
 
-       src = list_lru_from_memcg_idx(nlru, src_idx);
-       dst = list_lru_from_memcg_idx(nlru, dst_idx);
+       src = list_lru_from_memcg_idx(lru, nid, src_idx);
+       if (!src)
+               goto out;
+       dst = list_lru_from_memcg_idx(lru, nid, dst_idx);
 
        list_splice_init(&src->list, &dst->list);
 
@@ -534,32 +418,143 @@ static void memcg_drain_list_lru_node(struct list_lru *lru, int nid,
                set_shrinker_bit(dst_memcg, nid, lru_shrinker_id(lru));
                src->nr_items = 0;
        }
-
+out:
        spin_unlock_irq(&nlru->lock);
 }
 
-static void memcg_drain_list_lru(struct list_lru *lru,
-                                int src_idx, struct mem_cgroup *dst_memcg)
+static void memcg_reparent_list_lru(struct list_lru *lru,
+                                   int src_idx, struct mem_cgroup *dst_memcg)
 {
        int i;
 
        for_each_node(i)
-               memcg_drain_list_lru_node(lru, i, src_idx, dst_memcg);
+               memcg_reparent_list_lru_node(lru, i, src_idx, dst_memcg);
+
+       memcg_list_lru_free(lru, src_idx);
 }
 
-void memcg_drain_all_list_lrus(int src_idx, struct mem_cgroup *dst_memcg)
+void memcg_reparent_list_lrus(struct mem_cgroup *memcg, struct mem_cgroup *parent)
 {
+       struct cgroup_subsys_state *css;
        struct list_lru *lru;
+       int src_idx = memcg->kmemcg_id;
+
+       /*
+        * Change kmemcg_id of this cgroup and all its descendants to the
+        * parent's id, and then move all entries from this cgroup's list_lrus
+        * to ones of the parent.
+        *
+        * After we have finished, all list_lrus corresponding to this cgroup
+        * are guaranteed to remain empty. So we can safely free this cgroup's
+        * list lrus in memcg_list_lru_free().
+        *
+        * Changing ->kmemcg_id to the parent can prevent memcg_list_lru_alloc()
+        * from allocating list lrus for this cgroup after memcg_list_lru_free()
+        * call.
+        */
+       rcu_read_lock();
+       css_for_each_descendant_pre(css, &memcg->css) {
+               struct mem_cgroup *child;
+
+               child = mem_cgroup_from_css(css);
+               WRITE_ONCE(child->kmemcg_id, parent->kmemcg_id);
+       }
+       rcu_read_unlock();
 
        mutex_lock(&list_lrus_mutex);
        list_for_each_entry(lru, &memcg_list_lrus, list)
-               memcg_drain_list_lru(lru, src_idx, dst_memcg);
+               memcg_reparent_list_lru(lru, src_idx, parent);
        mutex_unlock(&list_lrus_mutex);
 }
+
+static inline bool memcg_list_lru_allocated(struct mem_cgroup *memcg,
+                                           struct list_lru *lru)
+{
+       int idx = memcg->kmemcg_id;
+
+       return idx < 0 || xa_load(&lru->xa, idx);
+}
+
+int memcg_list_lru_alloc(struct mem_cgroup *memcg, struct list_lru *lru,
+                        gfp_t gfp)
+{
+       int i;
+       unsigned long flags;
+       struct list_lru_memcg_table {
+               struct list_lru_memcg *mlru;
+               struct mem_cgroup *memcg;
+       } *table;
+       XA_STATE(xas, &lru->xa, 0);
+
+       if (!list_lru_memcg_aware(lru) || memcg_list_lru_allocated(memcg, lru))
+               return 0;
+
+       gfp &= GFP_RECLAIM_MASK;
+       table = kmalloc_array(memcg->css.cgroup->level, sizeof(*table), gfp);
+       if (!table)
+               return -ENOMEM;
+
+       /*
+        * Because the list_lru can be reparented to the parent cgroup's
+        * list_lru, we should make sure that this cgroup and all its
+        * ancestors have allocated list_lru_memcg.
+        */
+       for (i = 0; memcg; memcg = parent_mem_cgroup(memcg), i++) {
+               if (memcg_list_lru_allocated(memcg, lru))
+                       break;
+
+               table[i].memcg = memcg;
+               table[i].mlru = memcg_init_list_lru_one(gfp);
+               if (!table[i].mlru) {
+                       while (i--)
+                               kfree(table[i].mlru);
+                       kfree(table);
+                       return -ENOMEM;
+               }
+       }
+
+       xas_lock_irqsave(&xas, flags);
+       while (i--) {
+               int index = READ_ONCE(table[i].memcg->kmemcg_id);
+               struct list_lru_memcg *mlru = table[i].mlru;
+
+               xas_set(&xas, index);
+retry:
+               if (unlikely(index < 0 || xas_error(&xas) || xas_load(&xas))) {
+                       kfree(mlru);
+               } else {
+                       xas_store(&xas, mlru);
+                       if (xas_error(&xas) == -ENOMEM) {
+                               xas_unlock_irqrestore(&xas, flags);
+                               if (xas_nomem(&xas, gfp))
+                                       xas_set_err(&xas, 0);
+                               xas_lock_irqsave(&xas, flags);
+                               /*
+                                * The xas lock has been released, this memcg
+                                * can be reparented before us. So reload
+                                * memcg id. More details see the comments
+                                * in memcg_reparent_list_lrus().
+                                */
+                               index = READ_ONCE(table[i].memcg->kmemcg_id);
+                               if (index < 0)
+                                       xas_set_err(&xas, 0);
+                               else if (!xas_error(&xas) && index != xas.xa_index)
+                                       xas_set(&xas, index);
+                               goto retry;
+                       }
+               }
+       }
+       /* xas_nomem() is used to free memory instead of memory allocation. */
+       if (xas.xa_alloc)
+               xas_nomem(&xas, gfp);
+       xas_unlock_irqrestore(&xas, flags);
+       kfree(table);
+
+       return xas_error(&xas);
+}
 #else
-static int memcg_init_list_lru(struct list_lru *lru, bool memcg_aware)
+static inline void memcg_init_list_lru(struct list_lru *lru, bool memcg_aware)
 {
-       return 0;
 }
 
 static void memcg_destroy_list_lru(struct list_lru *lru)
@@ -571,7 +566,6 @@ int __list_lru_init(struct list_lru *lru, bool memcg_aware,
                    struct lock_class_key *key, struct shrinker *shrinker)
 {
        int i;
-       int err = -ENOMEM;
 
 #ifdef CONFIG_MEMCG_KMEM
        if (shrinker)
@@ -579,11 +573,10 @@ int __list_lru_init(struct list_lru *lru, bool memcg_aware,
        else
                lru->shrinker_id = -1;
 #endif
-       memcg_get_cache_ids();
 
        lru->node = kcalloc(nr_node_ids, sizeof(*lru->node), GFP_KERNEL);
        if (!lru->node)
-               goto out;
+               return -ENOMEM;
 
        for_each_node(i) {
                spin_lock_init(&lru->node[i].lock);
@@ -592,18 +585,10 @@ int __list_lru_init(struct list_lru *lru, bool memcg_aware,
                init_one_lru(&lru->node[i].lru);
        }
 
-       err = memcg_init_list_lru(lru, memcg_aware);
-       if (err) {
-               kfree(lru->node);
-               /* Do this so a list_lru_destroy() doesn't crash: */
-               lru->node = NULL;
-               goto out;
-       }
-
+       memcg_init_list_lru(lru, memcg_aware);
        list_lru_register(lru);
-out:
-       memcg_put_cache_ids();
-       return err;
+
+       return 0;
 }
 EXPORT_SYMBOL_GPL(__list_lru_init);
 
@@ -613,8 +598,6 @@ void list_lru_destroy(struct list_lru *lru)
        if (!lru->node)
                return;
 
-       memcg_get_cache_ids();
-
        list_lru_unregister(lru);
 
        memcg_destroy_list_lru(lru);
@@ -624,6 +607,5 @@ void list_lru_destroy(struct list_lru *lru)
 #ifdef CONFIG_MEMCG_KMEM
        lru->shrinker_id = -1;
 #endif
-       memcg_put_cache_ids();
 }
 EXPORT_SYMBOL_GPL(list_lru_destroy);