Merge tag 'nios2-v5.7-rc1' of git://git.kernel.org/pub/scm/linux/kernel/git/lftan...
[linux-2.6-microblaze.git] / mm / memcontrol.c
index 6f6dc87..5beea03 100644 (file)
@@ -334,7 +334,7 @@ static int memcg_expand_one_shrinker_map(struct mem_cgroup *memcg,
                if (!old)
                        return 0;
 
-               new = kvmalloc(sizeof(*new) + size, GFP_KERNEL);
+               new = kvmalloc_node(sizeof(*new) + size, GFP_KERNEL, nid);
                if (!new)
                        return -ENOMEM;
 
@@ -378,7 +378,7 @@ static int memcg_alloc_shrinker_maps(struct mem_cgroup *memcg)
        mutex_lock(&memcg_shrinker_map_mutex);
        size = memcg_shrinker_map_size;
        for_each_node(nid) {
-               map = kvzalloc(sizeof(*map) + size, GFP_KERNEL);
+               map = kvzalloc_node(sizeof(*map) + size, GFP_KERNEL, nid);
                if (!map) {
                        memcg_free_shrinker_maps(memcg);
                        ret = -ENOMEM;
@@ -409,8 +409,10 @@ int memcg_expand_shrinker_maps(int new_id)
                if (mem_cgroup_is_root(memcg))
                        continue;
                ret = memcg_expand_one_shrinker_map(memcg, size, old_size);
-               if (ret)
+               if (ret) {
+                       mem_cgroup_iter_break(NULL, memcg);
                        goto unlock;
+               }
        }
 unlock:
        if (!ret)
@@ -654,7 +656,7 @@ retry:
         */
        __mem_cgroup_remove_exceeded(mz, mctz);
        if (!soft_limit_excess(mz->memcg) ||
-           !css_tryget_online(&mz->memcg->css))
+           !css_tryget(&mz->memcg->css))
                goto retry;
 done:
        return mz;
@@ -757,13 +759,12 @@ void __mod_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
 
 void __mod_lruvec_slab_state(void *p, enum node_stat_item idx, int val)
 {
-       struct page *page = virt_to_head_page(p);
-       pg_data_t *pgdat = page_pgdat(page);
+       pg_data_t *pgdat = page_pgdat(virt_to_page(p));
        struct mem_cgroup *memcg;
        struct lruvec *lruvec;
 
        rcu_read_lock();
-       memcg = memcg_from_slab_page(page);
+       memcg = mem_cgroup_from_obj(p);
 
        /* Untracked pages have no memcg, no lruvec. Update only the node */
        if (!memcg || memcg == root_mem_cgroup) {
@@ -775,6 +776,17 @@ void __mod_lruvec_slab_state(void *p, enum node_stat_item idx, int val)
        rcu_read_unlock();
 }
 
+void mod_memcg_obj_state(void *p, int idx, int val)
+{
+       struct mem_cgroup *memcg;
+
+       rcu_read_lock();
+       memcg = mem_cgroup_from_obj(p);
+       if (memcg)
+               mod_memcg_state(memcg, idx, val);
+       rcu_read_unlock();
+}
+
 /**
  * __count_memcg_events - account VM events in a cgroup
  * @memcg: the memory cgroup
@@ -960,7 +972,8 @@ struct mem_cgroup *get_mem_cgroup_from_page(struct page *page)
                return NULL;
 
        rcu_read_lock();
-       if (!memcg || !css_tryget_online(&memcg->css))
+       /* Page should not get uncharged and freed memcg under us. */
+       if (!memcg || WARN_ON_ONCE(!css_tryget(&memcg->css)))
                memcg = root_mem_cgroup;
        rcu_read_unlock();
        return memcg;
@@ -973,10 +986,13 @@ EXPORT_SYMBOL(get_mem_cgroup_from_page);
 static __always_inline struct mem_cgroup *get_mem_cgroup_from_current(void)
 {
        if (unlikely(current->active_memcg)) {
-               struct mem_cgroup *memcg = root_mem_cgroup;
+               struct mem_cgroup *memcg;
 
                rcu_read_lock();
-               if (css_tryget_online(&current->active_memcg->css))
+               /* current->active_memcg must hold a ref. */
+               if (WARN_ON_ONCE(!css_tryget(&current->active_memcg->css)))
+                       memcg = root_mem_cgroup;
+               else
                        memcg = current->active_memcg;
                rcu_read_unlock();
                return memcg;
@@ -1505,11 +1521,11 @@ void mem_cgroup_print_oom_meminfo(struct mem_cgroup *memcg)
 
        pr_info("memory: usage %llukB, limit %llukB, failcnt %lu\n",
                K((u64)page_counter_read(&memcg->memory)),
-               K((u64)memcg->memory.max), memcg->memory.failcnt);
+               K((u64)READ_ONCE(memcg->memory.max)), memcg->memory.failcnt);
        if (cgroup_subsys_on_dfl(memory_cgrp_subsys))
                pr_info("swap: usage %llukB, limit %llukB, failcnt %lu\n",
                        K((u64)page_counter_read(&memcg->swap)),
-                       K((u64)memcg->swap.max), memcg->swap.failcnt);
+                       K((u64)READ_ONCE(memcg->swap.max)), memcg->swap.failcnt);
        else {
                pr_info("memory+swap: usage %llukB, limit %llukB, failcnt %lu\n",
                        K((u64)page_counter_read(&memcg->memsw)),
@@ -1536,13 +1552,13 @@ unsigned long mem_cgroup_get_max(struct mem_cgroup *memcg)
 {
        unsigned long max;
 
-       max = memcg->memory.max;
+       max = READ_ONCE(memcg->memory.max);
        if (mem_cgroup_swappiness(memcg)) {
                unsigned long memsw_max;
                unsigned long swap_max;
 
                memsw_max = memcg->memsw.max;
-               swap_max = memcg->swap.max;
+               swap_max = READ_ONCE(memcg->swap.max);
                swap_max = min(swap_max, (unsigned long)total_swap_pages);
                max = min(max + swap_max, memsw_max);
        }
@@ -1914,6 +1930,14 @@ struct mem_cgroup *mem_cgroup_get_oom_group(struct task_struct *victim,
        if (memcg == root_mem_cgroup)
                goto out;
 
+       /*
+        * If the victim task has been asynchronously moved to a different
+        * memory cgroup, we might end up killing tasks outside oom_domain.
+        * In this case it's better to ignore memory.group.oom.
+        */
+       if (unlikely(!mem_cgroup_is_descendant(memcg, oom_domain)))
+               goto out;
+
        /*
         * Traverse the memory cgroup hierarchy from the victim task's
         * cgroup up to the OOMing cgroup (or root) to find the
@@ -2226,11 +2250,12 @@ static void reclaim_high(struct mem_cgroup *memcg,
                         gfp_t gfp_mask)
 {
        do {
-               if (page_counter_read(&memcg->memory) <= memcg->high)
+               if (page_counter_read(&memcg->memory) <= READ_ONCE(memcg->high))
                        continue;
                memcg_memory_event(memcg, MEMCG_HIGH);
                try_to_free_mem_cgroup_pages(memcg, nr_pages, gfp_mask, true);
-       } while ((memcg = parent_mem_cgroup(memcg)));
+       } while ((memcg = parent_mem_cgroup(memcg)) &&
+                !mem_cgroup_is_root(memcg));
 }
 
 static void high_work_func(struct work_struct *work)
@@ -2295,28 +2320,44 @@ static void high_work_func(struct work_struct *work)
  #define MEMCG_DELAY_SCALING_SHIFT 14
 
 /*
- * Scheduled by try_charge() to be executed from the userland return path
- * and reclaims memory over the high limit.
+ * Get the number of jiffies that we should penalise a mischievous cgroup which
+ * is exceeding its memory.high by checking both it and its ancestors.
  */
-void mem_cgroup_handle_over_high(void)
+static unsigned long calculate_high_delay(struct mem_cgroup *memcg,
+                                         unsigned int nr_pages)
 {
-       unsigned long usage, high, clamped_high;
-       unsigned long pflags;
-       unsigned long penalty_jiffies, overage;
-       unsigned int nr_pages = current->memcg_nr_pages_over_high;
-       struct mem_cgroup *memcg;
+       unsigned long penalty_jiffies;
+       u64 max_overage = 0;
 
-       if (likely(!nr_pages))
-               return;
+       do {
+               unsigned long usage, high;
+               u64 overage;
 
-       memcg = get_mem_cgroup_from_mm(current->mm);
-       reclaim_high(memcg, nr_pages, GFP_KERNEL);
-       current->memcg_nr_pages_over_high = 0;
+               usage = page_counter_read(&memcg->memory);
+               high = READ_ONCE(memcg->high);
+
+               if (usage <= high)
+                       continue;
+
+               /*
+                * Prevent division by 0 in overage calculation by acting as if
+                * it was a threshold of 1 page
+                */
+               high = max(high, 1UL);
+
+               overage = usage - high;
+               overage <<= MEMCG_DELAY_PRECISION_SHIFT;
+               overage = div64_u64(overage, high);
+
+               if (overage > max_overage)
+                       max_overage = overage;
+       } while ((memcg = parent_mem_cgroup(memcg)) &&
+                !mem_cgroup_is_root(memcg));
+
+       if (!max_overage)
+               return 0;
 
        /*
-        * memory.high is breached and reclaim is unable to keep up. Throttle
-        * allocators proactively to slow down excessive growth.
-        *
         * We use overage compared to memory.high to calculate the number of
         * jiffies to sleep (penalty_jiffies). Ideally this value should be
         * fairly lenient on small overages, and increasingly harsh when the
@@ -2324,24 +2365,9 @@ void mem_cgroup_handle_over_high(void)
         * its crazy behaviour, so we exponentially increase the delay based on
         * overage amount.
         */
-
-       usage = page_counter_read(&memcg->memory);
-       high = READ_ONCE(memcg->high);
-
-       if (usage <= high)
-               goto out;
-
-       /*
-        * Prevent division by 0 in overage calculation by acting as if it was a
-        * threshold of 1 page
-        */
-       clamped_high = max(high, 1UL);
-
-       overage = div_u64((u64)(usage - high) << MEMCG_DELAY_PRECISION_SHIFT,
-                         clamped_high);
-
-       penalty_jiffies = ((u64)overage * overage * HZ)
-               >> (MEMCG_DELAY_PRECISION_SHIFT + MEMCG_DELAY_SCALING_SHIFT);
+       penalty_jiffies = max_overage * max_overage * HZ;
+       penalty_jiffies >>= MEMCG_DELAY_PRECISION_SHIFT;
+       penalty_jiffies >>= MEMCG_DELAY_SCALING_SHIFT;
 
        /*
         * Factor in the task's own contribution to the overage, such that four
@@ -2358,7 +2384,32 @@ void mem_cgroup_handle_over_high(void)
         * application moving forwards and also permit diagnostics, albeit
         * extremely slowly.
         */
-       penalty_jiffies = min(penalty_jiffies, MEMCG_MAX_HIGH_DELAY_JIFFIES);
+       return min(penalty_jiffies, MEMCG_MAX_HIGH_DELAY_JIFFIES);
+}
+
+/*
+ * Scheduled by try_charge() to be executed from the userland return path
+ * and reclaims memory over the high limit.
+ */
+void mem_cgroup_handle_over_high(void)
+{
+       unsigned long penalty_jiffies;
+       unsigned long pflags;
+       unsigned int nr_pages = current->memcg_nr_pages_over_high;
+       struct mem_cgroup *memcg;
+
+       if (likely(!nr_pages))
+               return;
+
+       memcg = get_mem_cgroup_from_mm(current->mm);
+       reclaim_high(memcg, nr_pages, GFP_KERNEL);
+       current->memcg_nr_pages_over_high = 0;
+
+       /*
+        * memory.high is breached and reclaim is unable to keep up. Throttle
+        * allocators proactively to slow down excessive growth.
+        */
+       penalty_jiffies = calculate_high_delay(memcg, nr_pages);
 
        /*
         * Don't sleep if the amount of jiffies this memcg owes us is so low
@@ -2543,7 +2594,7 @@ done_restock:
         * reclaim, the cost of mismatch is negligible.
         */
        do {
-               if (page_counter_read(&memcg->memory) > memcg->high) {
+               if (page_counter_read(&memcg->memory) > READ_ONCE(memcg->high)) {
                        /* Don't bother a random interrupted task */
                        if (in_interrupt()) {
                                schedule_work(&memcg->high_work);
@@ -2636,6 +2687,33 @@ static void commit_charge(struct page *page, struct mem_cgroup *memcg,
 }
 
 #ifdef CONFIG_MEMCG_KMEM
+/*
+ * Returns a pointer to the memory cgroup to which the kernel object is charged.
+ *
+ * The caller must ensure the memcg lifetime, e.g. by taking rcu_read_lock(),
+ * cgroup_mutex, etc.
+ */
+struct mem_cgroup *mem_cgroup_from_obj(void *p)
+{
+       struct page *page;
+
+       if (mem_cgroup_disabled())
+               return NULL;
+
+       page = virt_to_head_page(p);
+
+       /*
+        * Slab pages don't have page->mem_cgroup set because corresponding
+        * kmem caches can be reparented during the lifetime. That's why
+        * memcg_from_slab_page() should be used instead.
+        */
+       if (PageSlab(page))
+               return memcg_from_slab_page(page);
+
+       /* All other pages use page->mem_cgroup */
+       return page->mem_cgroup;
+}
+
 static int memcg_alloc_cache_id(void)
 {
        int id, size;
@@ -2819,18 +2897,16 @@ void memcg_kmem_put_cache(struct kmem_cache *cachep)
 }
 
 /**
- * __memcg_kmem_charge_memcg: charge a kmem page
- * @page: page to charge
- * @gfp: reclaim mode
- * @order: allocation order
+ * __memcg_kmem_charge: charge a number of kernel pages to a memcg
  * @memcg: memory cgroup to charge
+ * @gfp: reclaim mode
+ * @nr_pages: number of pages to charge
  *
  * Returns 0 on success, an error code on failure.
  */
-int __memcg_kmem_charge_memcg(struct page *page, gfp_t gfp, int order,
-                           struct mem_cgroup *memcg)
+int __memcg_kmem_charge(struct mem_cgroup *memcg, gfp_t gfp,
+                       unsigned int nr_pages)
 {
-       unsigned int nr_pages = 1 << order;
        struct page_counter *counter;
        int ret;
 
@@ -2857,14 +2933,29 @@ int __memcg_kmem_charge_memcg(struct page *page, gfp_t gfp, int order,
 }
 
 /**
- * __memcg_kmem_charge: charge a kmem page to the current memory cgroup
+ * __memcg_kmem_uncharge: uncharge a number of kernel pages from a memcg
+ * @memcg: memcg to uncharge
+ * @nr_pages: number of pages to uncharge
+ */
+void __memcg_kmem_uncharge(struct mem_cgroup *memcg, unsigned int nr_pages)
+{
+       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
+               page_counter_uncharge(&memcg->kmem, nr_pages);
+
+       page_counter_uncharge(&memcg->memory, nr_pages);
+       if (do_memsw_account())
+               page_counter_uncharge(&memcg->memsw, nr_pages);
+}
+
+/**
+ * __memcg_kmem_charge_page: charge a kmem page to the current memory cgroup
  * @page: page to charge
  * @gfp: reclaim mode
  * @order: allocation order
  *
  * Returns 0 on success, an error code on failure.
  */
-int __memcg_kmem_charge(struct page *page, gfp_t gfp, int order)
+int __memcg_kmem_charge_page(struct page *page, gfp_t gfp, int order)
 {
        struct mem_cgroup *memcg;
        int ret = 0;
@@ -2874,7 +2965,7 @@ int __memcg_kmem_charge(struct page *page, gfp_t gfp, int order)
 
        memcg = get_mem_cgroup_from_current();
        if (!mem_cgroup_is_root(memcg)) {
-               ret = __memcg_kmem_charge_memcg(page, gfp, order, memcg);
+               ret = __memcg_kmem_charge(memcg, gfp, 1 << order);
                if (!ret) {
                        page->mem_cgroup = memcg;
                        __SetPageKmemcg(page);
@@ -2885,26 +2976,11 @@ int __memcg_kmem_charge(struct page *page, gfp_t gfp, int order)
 }
 
 /**
- * __memcg_kmem_uncharge_memcg: uncharge a kmem page
- * @memcg: memcg to uncharge
- * @nr_pages: number of pages to uncharge
- */
-void __memcg_kmem_uncharge_memcg(struct mem_cgroup *memcg,
-                                unsigned int nr_pages)
-{
-       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
-               page_counter_uncharge(&memcg->kmem, nr_pages);
-
-       page_counter_uncharge(&memcg->memory, nr_pages);
-       if (do_memsw_account())
-               page_counter_uncharge(&memcg->memsw, nr_pages);
-}
-/**
- * __memcg_kmem_uncharge: uncharge a kmem page
+ * __memcg_kmem_uncharge_page: uncharge a kmem page
  * @page: page to uncharge
  * @order: allocation order
  */
-void __memcg_kmem_uncharge(struct page *page, int order)
+void __memcg_kmem_uncharge_page(struct page *page, int order)
 {
        struct mem_cgroup *memcg = page->mem_cgroup;
        unsigned int nr_pages = 1 << order;
@@ -2913,7 +2989,7 @@ void __memcg_kmem_uncharge(struct page *page, int order)
                return;
 
        VM_BUG_ON_PAGE(mem_cgroup_is_root(memcg), page);
-       __memcg_kmem_uncharge_memcg(memcg, nr_pages);
+       __memcg_kmem_uncharge(memcg, nr_pages);
        page->mem_cgroup = NULL;
 
        /* slab pages do not have PageKmemcg flag set */
@@ -3004,7 +3080,7 @@ static int mem_cgroup_resize_max(struct mem_cgroup *memcg,
                 * Make sure that the new limit (memsw or memory limit) doesn't
                 * break our basic invariant rule memory.max <= memsw.max.
                 */
-               limits_invariant = memsw ? max >= memcg->memory.max :
+               limits_invariant = memsw ? max >= READ_ONCE(memcg->memory.max) :
                                           max <= memcg->memsw.max;
                if (!limits_invariant) {
                        mutex_unlock(&memcg_max_mutex);
@@ -3751,8 +3827,8 @@ static int memcg_stat_show(struct seq_file *m, void *v)
        /* Hierarchical information */
        memory = memsw = PAGE_COUNTER_MAX;
        for (mi = memcg; mi; mi = parent_mem_cgroup(mi)) {
-               memory = min(memory, mi->memory.max);
-               memsw = min(memsw, mi->memsw.max);
+               memory = min(memory, READ_ONCE(mi->memory.max));
+               memsw = min(memsw, READ_ONCE(mi->memsw.max));
        }
        seq_printf(m, "hierarchical_memory_limit %llu\n",
                   (u64)memory * PAGE_SIZE);
@@ -4025,7 +4101,7 @@ static void __mem_cgroup_usage_unregister_event(struct mem_cgroup *memcg,
        struct mem_cgroup_thresholds *thresholds;
        struct mem_cgroup_threshold_ary *new;
        unsigned long usage;
-       int i, j, size;
+       int i, j, size, entries;
 
        mutex_lock(&memcg->thresholds_lock);
 
@@ -4045,14 +4121,20 @@ static void __mem_cgroup_usage_unregister_event(struct mem_cgroup *memcg,
        __mem_cgroup_threshold(memcg, type == _MEMSWAP);
 
        /* Calculate new number of threshold */
-       size = 0;
+       size = entries = 0;
        for (i = 0; i < thresholds->primary->size; i++) {
                if (thresholds->primary->entries[i].eventfd != eventfd)
                        size++;
+               else
+                       entries++;
        }
 
        new = thresholds->spare;
 
+       /* If no items related to eventfd have been cleared, nothing to do */
+       if (!entries)
+               goto unlock;
+
        /* Set thresholds array to NULL if we don't have thresholds */
        if (!size) {
                kfree(new);
@@ -4255,7 +4337,8 @@ void mem_cgroup_wb_stats(struct bdi_writeback *wb, unsigned long *pfilepages,
        *pheadroom = PAGE_COUNTER_MAX;
 
        while ((parent = parent_mem_cgroup(memcg))) {
-               unsigned long ceiling = min(memcg->memory.max, memcg->high);
+               unsigned long ceiling = min(READ_ONCE(memcg->memory.max),
+                                           READ_ONCE(memcg->high));
                unsigned long used = page_counter_read(&memcg->memory);
 
                *pheadroom = min(*pheadroom, ceiling - min(ceiling, used));
@@ -4723,7 +4806,8 @@ static struct cftype mem_cgroup_legacy_files[] = {
                .write = mem_cgroup_reset,
                .read_u64 = mem_cgroup_read_u64,
        },
-#if defined(CONFIG_SLAB) || defined(CONFIG_SLUB_DEBUG)
+#if defined(CONFIG_MEMCG_KMEM) && \
+       (defined(CONFIG_SLAB) || defined(CONFIG_SLUB_DEBUG))
        {
                .name = "kmem.slabinfo",
                .seq_start = memcg_slab_start,
@@ -4792,7 +4876,8 @@ static void mem_cgroup_id_remove(struct mem_cgroup *memcg)
        }
 }
 
-static void mem_cgroup_id_get_many(struct mem_cgroup *memcg, unsigned int n)
+static void __maybe_unused mem_cgroup_id_get_many(struct mem_cgroup *memcg,
+                                                 unsigned int n)
 {
        refcount_add(n, &memcg->id.ref);
 }
@@ -4975,7 +5060,7 @@ mem_cgroup_css_alloc(struct cgroup_subsys_state *parent_css)
        if (!memcg)
                return ERR_PTR(error);
 
-       memcg->high = PAGE_COUNTER_MAX;
+       WRITE_ONCE(memcg->high, PAGE_COUNTER_MAX);
        memcg->soft_limit = PAGE_COUNTER_MAX;
        if (parent) {
                memcg->swappiness = mem_cgroup_swappiness(parent);
@@ -5128,7 +5213,7 @@ static void mem_cgroup_css_reset(struct cgroup_subsys_state *css)
        page_counter_set_max(&memcg->tcpmem, PAGE_COUNTER_MAX);
        page_counter_set_min(&memcg->memory, 0);
        page_counter_set_low(&memcg->memory, 0);
-       memcg->high = PAGE_COUNTER_MAX;
+       WRITE_ONCE(memcg->high, PAGE_COUNTER_MAX);
        memcg->soft_limit = PAGE_COUNTER_MAX;
        memcg_wb_domain_size_changed(memcg);
 }
@@ -5731,7 +5816,7 @@ retry:
                switch (get_mctgt_type(vma, addr, ptent, &target)) {
                case MC_TARGET_DEVICE:
                        device = true;
-                       /* fall through */
+                       fallthrough;
                case MC_TARGET_PAGE:
                        page = target.page;
                        /*
@@ -5944,7 +6029,7 @@ static ssize_t memory_high_write(struct kernfs_open_file *of,
        if (err)
                return err;
 
-       memcg->high = high;
+       WRITE_ONCE(memcg->high, high);
 
        for (;;) {
                unsigned long nr_pages = page_counter_read(&memcg->memory);
@@ -6167,6 +6252,117 @@ struct cgroup_subsys memory_cgrp_subsys = {
        .early_init = 0,
 };
 
+/*
+ * This function calculates an individual cgroup's effective
+ * protection which is derived from its own memory.min/low, its
+ * parent's and siblings' settings, as well as the actual memory
+ * distribution in the tree.
+ *
+ * The following rules apply to the effective protection values:
+ *
+ * 1. At the first level of reclaim, effective protection is equal to
+ *    the declared protection in memory.min and memory.low.
+ *
+ * 2. To enable safe delegation of the protection configuration, at
+ *    subsequent levels the effective protection is capped to the
+ *    parent's effective protection.
+ *
+ * 3. To make complex and dynamic subtrees easier to configure, the
+ *    user is allowed to overcommit the declared protection at a given
+ *    level. If that is the case, the parent's effective protection is
+ *    distributed to the children in proportion to how much protection
+ *    they have declared and how much of it they are utilizing.
+ *
+ *    This makes distribution proportional, but also work-conserving:
+ *    if one cgroup claims much more protection than it uses memory,
+ *    the unused remainder is available to its siblings.
+ *
+ * 4. Conversely, when the declared protection is undercommitted at a
+ *    given level, the distribution of the larger parental protection
+ *    budget is NOT proportional. A cgroup's protection from a sibling
+ *    is capped to its own memory.min/low setting.
+ *
+ * 5. However, to allow protecting recursive subtrees from each other
+ *    without having to declare each individual cgroup's fixed share
+ *    of the ancestor's claim to protection, any unutilized -
+ *    "floating" - protection from up the tree is distributed in
+ *    proportion to each cgroup's *usage*. This makes the protection
+ *    neutral wrt sibling cgroups and lets them compete freely over
+ *    the shared parental protection budget, but it protects the
+ *    subtree as a whole from neighboring subtrees.
+ *
+ * Note that 4. and 5. are not in conflict: 4. is about protecting
+ * against immediate siblings whereas 5. is about protecting against
+ * neighboring subtrees.
+ */
+static unsigned long effective_protection(unsigned long usage,
+                                         unsigned long parent_usage,
+                                         unsigned long setting,
+                                         unsigned long parent_effective,
+                                         unsigned long siblings_protected)
+{
+       unsigned long protected;
+       unsigned long ep;
+
+       protected = min(usage, setting);
+       /*
+        * If all cgroups at this level combined claim and use more
+        * protection then what the parent affords them, distribute
+        * shares in proportion to utilization.
+        *
+        * We are using actual utilization rather than the statically
+        * claimed protection in order to be work-conserving: claimed
+        * but unused protection is available to siblings that would
+        * otherwise get a smaller chunk than what they claimed.
+        */
+       if (siblings_protected > parent_effective)
+               return protected * parent_effective / siblings_protected;
+
+       /*
+        * Ok, utilized protection of all children is within what the
+        * parent affords them, so we know whatever this child claims
+        * and utilizes is effectively protected.
+        *
+        * If there is unprotected usage beyond this value, reclaim
+        * will apply pressure in proportion to that amount.
+        *
+        * If there is unutilized protection, the cgroup will be fully
+        * shielded from reclaim, but we do return a smaller value for
+        * protection than what the group could enjoy in theory. This
+        * is okay. With the overcommit distribution above, effective
+        * protection is always dependent on how memory is actually
+        * consumed among the siblings anyway.
+        */
+       ep = protected;
+
+       /*
+        * If the children aren't claiming (all of) the protection
+        * afforded to them by the parent, distribute the remainder in
+        * proportion to the (unprotected) memory of each cgroup. That
+        * way, cgroups that aren't explicitly prioritized wrt each
+        * other compete freely over the allowance, but they are
+        * collectively protected from neighboring trees.
+        *
+        * We're using unprotected memory for the weight so that if
+        * some cgroups DO claim explicit protection, we don't protect
+        * the same bytes twice.
+        */
+       if (!(cgrp_dfl_root.flags & CGRP_ROOT_MEMORY_RECURSIVE_PROT))
+               return ep;
+
+       if (parent_effective > siblings_protected && usage > protected) {
+               unsigned long unclaimed;
+
+               unclaimed = parent_effective - siblings_protected;
+               unclaimed *= usage - protected;
+               unclaimed /= parent_usage - siblings_protected;
+
+               ep += unclaimed;
+       }
+
+       return ep;
+}
+
 /**
  * mem_cgroup_protected - check if memory consumption is in the normal range
  * @root: the top ancestor of the sub-tree being checked
@@ -6180,70 +6376,12 @@ struct cgroup_subsys memory_cgrp_subsys = {
  *   MEMCG_PROT_LOW: cgroup memory is protected as long there is
  *     an unprotected supply of reclaimable memory from other cgroups.
  *   MEMCG_PROT_MIN: cgroup memory is protected
- *
- * @root is exclusive; it is never protected when looked at directly
- *
- * To provide a proper hierarchical behavior, effective memory.min/low values
- * are used. Below is the description of how effective memory.low is calculated.
- * Effective memory.min values is calculated in the same way.
- *
- * Effective memory.low is always equal or less than the original memory.low.
- * If there is no memory.low overcommittment (which is always true for
- * top-level memory cgroups), these two values are equal.
- * Otherwise, it's a part of parent's effective memory.low,
- * calculated as a cgroup's memory.low usage divided by sum of sibling's
- * memory.low usages, where memory.low usage is the size of actually
- * protected memory.
- *
- *                                             low_usage
- * elow = min( memory.low, parent->elow * ------------------ ),
- *                                        siblings_low_usage
- *
- *             | memory.current, if memory.current < memory.low
- * low_usage = |
- *            | 0, otherwise.
- *
- *
- * Such definition of the effective memory.low provides the expected
- * hierarchical behavior: parent's memory.low value is limiting
- * children, unprotected memory is reclaimed first and cgroups,
- * which are not using their guarantee do not affect actual memory
- * distribution.
- *
- * For example, if there are memcgs A, A/B, A/C, A/D and A/E:
- *
- *     A      A/memory.low = 2G, A/memory.current = 6G
- *    //\\
- *   BC  DE   B/memory.low = 3G  B/memory.current = 2G
- *            C/memory.low = 1G  C/memory.current = 2G
- *            D/memory.low = 0   D/memory.current = 2G
- *            E/memory.low = 10G E/memory.current = 0
- *
- * and the memory pressure is applied, the following memory distribution
- * is expected (approximately):
- *
- *     A/memory.current = 2G
- *
- *     B/memory.current = 1.3G
- *     C/memory.current = 0.6G
- *     D/memory.current = 0
- *     E/memory.current = 0
- *
- * These calculations require constant tracking of the actual low usages
- * (see propagate_protected_usage()), as well as recursive calculation of
- * effective memory.low values. But as we do call mem_cgroup_protected()
- * path for each memory cgroup top-down from the reclaim,
- * it's possible to optimize this part, and save calculated elow
- * for next usage. This part is intentionally racy, but it's ok,
- * as memory.low is a best-effort mechanism.
  */
 enum mem_cgroup_protection mem_cgroup_protected(struct mem_cgroup *root,
                                                struct mem_cgroup *memcg)
 {
+       unsigned long usage, parent_usage;
        struct mem_cgroup *parent;
-       unsigned long emin, parent_emin;
-       unsigned long elow, parent_elow;
-       unsigned long usage;
 
        if (mem_cgroup_disabled())
                return MEMCG_PROT_NONE;
@@ -6257,52 +6395,32 @@ enum mem_cgroup_protection mem_cgroup_protected(struct mem_cgroup *root,
        if (!usage)
                return MEMCG_PROT_NONE;
 
-       emin = memcg->memory.min;
-       elow = memcg->memory.low;
-
        parent = parent_mem_cgroup(memcg);
        /* No parent means a non-hierarchical mode on v1 memcg */
        if (!parent)
                return MEMCG_PROT_NONE;
 
-       if (parent == root)
-               goto exit;
-
-       parent_emin = READ_ONCE(parent->memory.emin);
-       emin = min(emin, parent_emin);
-       if (emin && parent_emin) {
-               unsigned long min_usage, siblings_min_usage;
-
-               min_usage = min(usage, memcg->memory.min);
-               siblings_min_usage = atomic_long_read(
-                       &parent->memory.children_min_usage);
-
-               if (min_usage && siblings_min_usage)
-                       emin = min(emin, parent_emin * min_usage /
-                                  siblings_min_usage);
+       if (parent == root) {
+               memcg->memory.emin = READ_ONCE(memcg->memory.min);
+               memcg->memory.elow = memcg->memory.low;
+               goto out;
        }
 
-       parent_elow = READ_ONCE(parent->memory.elow);
-       elow = min(elow, parent_elow);
-       if (elow && parent_elow) {
-               unsigned long low_usage, siblings_low_usage;
-
-               low_usage = min(usage, memcg->memory.low);
-               siblings_low_usage = atomic_long_read(
-                       &parent->memory.children_low_usage);
+       parent_usage = page_counter_read(&parent->memory);
 
-               if (low_usage && siblings_low_usage)
-                       elow = min(elow, parent_elow * low_usage /
-                                  siblings_low_usage);
-       }
+       WRITE_ONCE(memcg->memory.emin, effective_protection(usage, parent_usage,
+                       READ_ONCE(memcg->memory.min),
+                       READ_ONCE(parent->memory.emin),
+                       atomic_long_read(&parent->memory.children_min_usage)));
 
-exit:
-       memcg->memory.emin = emin;
-       memcg->memory.elow = elow;
+       WRITE_ONCE(memcg->memory.elow, effective_protection(usage, parent_usage,
+                       memcg->memory.low, READ_ONCE(parent->memory.elow),
+                       atomic_long_read(&parent->memory.children_low_usage)));
 
-       if (usage <= emin)
+out:
+       if (usage <= memcg->memory.emin)
                return MEMCG_PROT_MIN;
-       else if (usage <= elow)
+       else if (usage <= memcg->memory.elow)
                return MEMCG_PROT_LOW;
        else
                return MEMCG_PROT_NONE;
@@ -6680,19 +6798,9 @@ void mem_cgroup_sk_alloc(struct sock *sk)
        if (!mem_cgroup_sockets_enabled)
                return;
 
-       /*
-        * Socket cloning can throw us here with sk_memcg already
-        * filled. It won't however, necessarily happen from
-        * process context. So the test for root memcg given
-        * the current task's memcg won't help us in this case.
-        *
-        * Respecting the original socket's memcg is a better
-        * decision in this case.
-        */
-       if (sk->sk_memcg) {
-               css_get(&sk->sk_memcg->css);
+       /* Do not associate the sock with unrelated interrupted task's memcg. */
+       if (in_interrupt())
                return;
-       }
 
        rcu_read_lock();
        memcg = mem_cgroup_from_task(current);
@@ -6700,7 +6808,7 @@ void mem_cgroup_sk_alloc(struct sock *sk)
                goto out;
        if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) && !memcg->tcpmem_active)
                goto out;
-       if (css_tryget_online(&memcg->css))
+       if (css_tryget(&memcg->css))
                sk->sk_memcg = memcg;
 out:
        rcu_read_unlock();
@@ -7021,7 +7129,8 @@ bool mem_cgroup_swap_full(struct page *page)
                return false;
 
        for (; memcg != root_mem_cgroup; memcg = parent_mem_cgroup(memcg))
-               if (page_counter_read(&memcg->swap) * 2 >= memcg->swap.max)
+               if (page_counter_read(&memcg->swap) * 2 >=
+                   READ_ONCE(memcg->swap.max))
                        return true;
 
        return false;