Merge tag 'mlx5-fixes-2023-05-24' of git://git.kernel.org/pub/scm/linux/kernel/git...
[linux-2.6-microblaze.git] / mm / vmscan.c
index 9c1c5e8..6d0cd28 100644 (file)
@@ -35,7 +35,7 @@
 #include <linux/cpuset.h>
 #include <linux/compaction.h>
 #include <linux/notifier.h>
-#include <linux/rwsem.h>
+#include <linux/mutex.h>
 #include <linux/delay.h>
 #include <linux/kthread.h>
 #include <linux/freezer.h>
@@ -57,6 +57,7 @@
 #include <linux/khugepaged.h>
 #include <linux/rculist_nulls.h>
 #include <linux/random.h>
+#include <linux/srcu.h>
 
 #include <asm/tlbflush.h>
 #include <asm/div64.h>
@@ -188,20 +189,10 @@ struct scan_control {
  */
 int vm_swappiness = 60;
 
-static void set_task_reclaim_state(struct task_struct *task,
-                                  struct reclaim_state *rs)
-{
-       /* Check for an overwrite */
-       WARN_ON_ONCE(rs && task->reclaim_state);
-
-       /* Check for the nulling of an already-nulled member */
-       WARN_ON_ONCE(!rs && !task->reclaim_state);
-
-       task->reclaim_state = rs;
-}
-
 LIST_HEAD(shrinker_list);
-DECLARE_RWSEM(shrinker_rwsem);
+DEFINE_MUTEX(shrinker_mutex);
+DEFINE_SRCU(shrinker_srcu);
+static atomic_t shrinker_srcu_generation = ATOMIC_INIT(0);
 
 #ifdef CONFIG_MEMCG
 static int shrinker_nr_max;
@@ -220,13 +211,27 @@ static inline int shrinker_defer_size(int nr_items)
 static struct shrinker_info *shrinker_info_protected(struct mem_cgroup *memcg,
                                                     int nid)
 {
-       return rcu_dereference_protected(memcg->nodeinfo[nid]->shrinker_info,
-                                        lockdep_is_held(&shrinker_rwsem));
+       return srcu_dereference_check(memcg->nodeinfo[nid]->shrinker_info,
+                                     &shrinker_srcu,
+                                     lockdep_is_held(&shrinker_mutex));
+}
+
+static struct shrinker_info *shrinker_info_srcu(struct mem_cgroup *memcg,
+                                                    int nid)
+{
+       return srcu_dereference(memcg->nodeinfo[nid]->shrinker_info,
+                               &shrinker_srcu);
+}
+
+static void free_shrinker_info_rcu(struct rcu_head *head)
+{
+       kvfree(container_of(head, struct shrinker_info, rcu));
 }
 
 static int expand_one_shrinker_info(struct mem_cgroup *memcg,
                                    int map_size, int defer_size,
-                                   int old_map_size, int old_defer_size)
+                                   int old_map_size, int old_defer_size,
+                                   int new_nr_max)
 {
        struct shrinker_info *new, *old;
        struct mem_cgroup_per_node *pn;
@@ -240,12 +245,17 @@ static int expand_one_shrinker_info(struct mem_cgroup *memcg,
                if (!old)
                        return 0;
 
+               /* Already expanded this shrinker_info */
+               if (new_nr_max <= old->map_nr_max)
+                       continue;
+
                new = kvmalloc_node(sizeof(*new) + size, GFP_KERNEL, nid);
                if (!new)
                        return -ENOMEM;
 
                new->nr_deferred = (atomic_long_t *)(new + 1);
                new->map = (void *)new->nr_deferred + defer_size;
+               new->map_nr_max = new_nr_max;
 
                /* map: set all old bits, clear all new bits */
                memset(new->map, (int)0xff, old_map_size);
@@ -256,7 +266,7 @@ static int expand_one_shrinker_info(struct mem_cgroup *memcg,
                       defer_size - old_defer_size);
 
                rcu_assign_pointer(pn->shrinker_info, new);
-               kvfree_rcu(old, rcu);
+               call_srcu(&shrinker_srcu, &old->rcu, free_shrinker_info_rcu);
        }
 
        return 0;
@@ -282,7 +292,7 @@ int alloc_shrinker_info(struct mem_cgroup *memcg)
        int nid, size, ret = 0;
        int map_size, defer_size = 0;
 
-       down_write(&shrinker_rwsem);
+       mutex_lock(&shrinker_mutex);
        map_size = shrinker_map_size(shrinker_nr_max);
        defer_size = shrinker_defer_size(shrinker_nr_max);
        size = map_size + defer_size;
@@ -295,34 +305,26 @@ int alloc_shrinker_info(struct mem_cgroup *memcg)
                }
                info->nr_deferred = (atomic_long_t *)(info + 1);
                info->map = (void *)info->nr_deferred + defer_size;
+               info->map_nr_max = shrinker_nr_max;
                rcu_assign_pointer(memcg->nodeinfo[nid]->shrinker_info, info);
        }
-       up_write(&shrinker_rwsem);
+       mutex_unlock(&shrinker_mutex);
 
        return ret;
 }
 
-static inline bool need_expand(int nr_max)
-{
-       return round_up(nr_max, BITS_PER_LONG) >
-              round_up(shrinker_nr_max, BITS_PER_LONG);
-}
-
 static int expand_shrinker_info(int new_id)
 {
        int ret = 0;
-       int new_nr_max = new_id + 1;
+       int new_nr_max = round_up(new_id + 1, BITS_PER_LONG);
        int map_size, defer_size = 0;
        int old_map_size, old_defer_size = 0;
        struct mem_cgroup *memcg;
 
-       if (!need_expand(new_nr_max))
-               goto out;
-
        if (!root_mem_cgroup)
                goto out;
 
-       lockdep_assert_held(&shrinker_rwsem);
+       lockdep_assert_held(&shrinker_mutex);
 
        map_size = shrinker_map_size(new_nr_max);
        defer_size = shrinker_defer_size(new_nr_max);
@@ -332,7 +334,8 @@ static int expand_shrinker_info(int new_id)
        memcg = mem_cgroup_iter(NULL, NULL, NULL);
        do {
                ret = expand_one_shrinker_info(memcg, map_size, defer_size,
-                                              old_map_size, old_defer_size);
+                                              old_map_size, old_defer_size,
+                                              new_nr_max);
                if (ret) {
                        mem_cgroup_iter_break(NULL, memcg);
                        goto out;
@@ -349,13 +352,16 @@ void set_shrinker_bit(struct mem_cgroup *memcg, int nid, int shrinker_id)
 {
        if (shrinker_id >= 0 && memcg && !mem_cgroup_is_root(memcg)) {
                struct shrinker_info *info;
-
-               rcu_read_lock();
-               info = rcu_dereference(memcg->nodeinfo[nid]->shrinker_info);
-               /* Pairs with smp mb in shrink_slab() */
-               smp_mb__before_atomic();
-               set_bit(shrinker_id, info->map);
-               rcu_read_unlock();
+               int srcu_idx;
+
+               srcu_idx = srcu_read_lock(&shrinker_srcu);
+               info = shrinker_info_srcu(memcg, nid);
+               if (!WARN_ON_ONCE(shrinker_id >= info->map_nr_max)) {
+                       /* Pairs with smp mb in shrink_slab() */
+                       smp_mb__before_atomic();
+                       set_bit(shrinker_id, info->map);
+               }
+               srcu_read_unlock(&shrinker_srcu, srcu_idx);
        }
 }
 
@@ -368,8 +374,7 @@ static int prealloc_memcg_shrinker(struct shrinker *shrinker)
        if (mem_cgroup_disabled())
                return -ENOSYS;
 
-       down_write(&shrinker_rwsem);
-       /* This may call shrinker, so it must use down_read_trylock() */
+       mutex_lock(&shrinker_mutex);
        id = idr_alloc(&shrinker_idr, shrinker, 0, 0, GFP_KERNEL);
        if (id < 0)
                goto unlock;
@@ -383,7 +388,7 @@ static int prealloc_memcg_shrinker(struct shrinker *shrinker)
        shrinker->id = id;
        ret = 0;
 unlock:
-       up_write(&shrinker_rwsem);
+       mutex_unlock(&shrinker_mutex);
        return ret;
 }
 
@@ -393,7 +398,7 @@ static void unregister_memcg_shrinker(struct shrinker *shrinker)
 
        BUG_ON(id < 0);
 
-       lockdep_assert_held(&shrinker_rwsem);
+       lockdep_assert_held(&shrinker_mutex);
 
        idr_remove(&shrinker_idr, id);
 }
@@ -403,7 +408,7 @@ static long xchg_nr_deferred_memcg(int nid, struct shrinker *shrinker,
 {
        struct shrinker_info *info;
 
-       info = shrinker_info_protected(memcg, nid);
+       info = shrinker_info_srcu(memcg, nid);
        return atomic_long_xchg(&info->nr_deferred[shrinker->id], 0);
 }
 
@@ -412,7 +417,7 @@ static long add_nr_deferred_memcg(long nr, int nid, struct shrinker *shrinker,
 {
        struct shrinker_info *info;
 
-       info = shrinker_info_protected(memcg, nid);
+       info = shrinker_info_srcu(memcg, nid);
        return atomic_long_add_return(nr, &info->nr_deferred[shrinker->id]);
 }
 
@@ -428,16 +433,16 @@ void reparent_shrinker_deferred(struct mem_cgroup *memcg)
                parent = root_mem_cgroup;
 
        /* Prevent from concurrent shrinker_info expand */
-       down_read(&shrinker_rwsem);
+       mutex_lock(&shrinker_mutex);
        for_each_node(nid) {
                child_info = shrinker_info_protected(memcg, nid);
                parent_info = shrinker_info_protected(parent, nid);
-               for (i = 0; i < shrinker_nr_max; i++) {
+               for (i = 0; i < child_info->map_nr_max; i++) {
                        nr = atomic_long_read(&child_info->nr_deferred[i]);
                        atomic_long_add(nr, &parent_info->nr_deferred[i]);
                }
        }
-       up_read(&shrinker_rwsem);
+       mutex_unlock(&shrinker_mutex);
 }
 
 static bool cgroup_reclaim(struct scan_control *sc)
@@ -511,6 +516,58 @@ static bool writeback_throttling_sane(struct scan_control *sc)
 }
 #endif
 
+static void set_task_reclaim_state(struct task_struct *task,
+                                  struct reclaim_state *rs)
+{
+       /* Check for an overwrite */
+       WARN_ON_ONCE(rs && task->reclaim_state);
+
+       /* Check for the nulling of an already-nulled member */
+       WARN_ON_ONCE(!rs && !task->reclaim_state);
+
+       task->reclaim_state = rs;
+}
+
+/*
+ * flush_reclaim_state(): add pages reclaimed outside of LRU-based reclaim to
+ * scan_control->nr_reclaimed.
+ */
+static void flush_reclaim_state(struct scan_control *sc)
+{
+       /*
+        * Currently, reclaim_state->reclaimed includes three types of pages
+        * freed outside of vmscan:
+        * (1) Slab pages.
+        * (2) Clean file pages from pruned inodes (on highmem systems).
+        * (3) XFS freed buffer pages.
+        *
+        * For all of these cases, we cannot universally link the pages to a
+        * single memcg. For example, a memcg-aware shrinker can free one object
+        * charged to the target memcg, causing an entire page to be freed.
+        * If we count the entire page as reclaimed from the memcg, we end up
+        * overestimating the reclaimed amount (potentially under-reclaiming).
+        *
+        * Only count such pages for global reclaim to prevent under-reclaiming
+        * from the target memcg; preventing unnecessary retries during memcg
+        * charging and false positives from proactive reclaim.
+        *
+        * For uncommon cases where the freed pages were actually mostly
+        * charged to the target memcg, we end up underestimating the reclaimed
+        * amount. This should be fine. The freed pages will be uncharged
+        * anyway, even if they are not counted here properly, and we will be
+        * able to make forward progress in charging (which is usually in a
+        * retry loop).
+        *
+        * We can go one step further, and report the uncharged objcg pages in
+        * memcg reclaim, to make reporting more accurate and reduce
+        * underestimation, but it's probably not worth the complexity for now.
+        */
+       if (current->reclaim_state && global_reclaim(sc)) {
+               sc->nr_reclaimed += current->reclaim_state->reclaimed;
+               current->reclaim_state->reclaimed = 0;
+       }
+}
+
 static long xchg_nr_deferred(struct shrinker *shrinker,
                             struct shrink_control *sc)
 {
@@ -686,9 +743,9 @@ void free_prealloced_shrinker(struct shrinker *shrinker)
        shrinker->name = NULL;
 #endif
        if (shrinker->flags & SHRINKER_MEMCG_AWARE) {
-               down_write(&shrinker_rwsem);
+               mutex_lock(&shrinker_mutex);
                unregister_memcg_shrinker(shrinker);
-               up_write(&shrinker_rwsem);
+               mutex_unlock(&shrinker_mutex);
                return;
        }
 
@@ -698,11 +755,11 @@ void free_prealloced_shrinker(struct shrinker *shrinker)
 
 void register_shrinker_prepared(struct shrinker *shrinker)
 {
-       down_write(&shrinker_rwsem);
-       list_add_tail(&shrinker->list, &shrinker_list);
+       mutex_lock(&shrinker_mutex);
+       list_add_tail_rcu(&shrinker->list, &shrinker_list);
        shrinker->flags |= SHRINKER_REGISTERED;
        shrinker_debugfs_add(shrinker);
-       up_write(&shrinker_rwsem);
+       mutex_unlock(&shrinker_mutex);
 }
 
 static int __register_shrinker(struct shrinker *shrinker)
@@ -748,19 +805,23 @@ EXPORT_SYMBOL(register_shrinker);
 void unregister_shrinker(struct shrinker *shrinker)
 {
        struct dentry *debugfs_entry;
+       int debugfs_id;
 
        if (!(shrinker->flags & SHRINKER_REGISTERED))
                return;
 
-       down_write(&shrinker_rwsem);
-       list_del(&shrinker->list);
+       mutex_lock(&shrinker_mutex);
+       list_del_rcu(&shrinker->list);
        shrinker->flags &= ~SHRINKER_REGISTERED;
        if (shrinker->flags & SHRINKER_MEMCG_AWARE)
                unregister_memcg_shrinker(shrinker);
-       debugfs_entry = shrinker_debugfs_remove(shrinker);
-       up_write(&shrinker_rwsem);
+       debugfs_entry = shrinker_debugfs_detach(shrinker, &debugfs_id);
+       mutex_unlock(&shrinker_mutex);
+
+       atomic_inc(&shrinker_srcu_generation);
+       synchronize_srcu(&shrinker_srcu);
 
-       debugfs_remove_recursive(debugfs_entry);
+       shrinker_debugfs_remove(debugfs_entry, debugfs_id);
 
        kfree(shrinker->nr_deferred);
        shrinker->nr_deferred = NULL;
@@ -770,15 +831,13 @@ EXPORT_SYMBOL(unregister_shrinker);
 /**
  * synchronize_shrinkers - Wait for all running shrinkers to complete.
  *
- * This is equivalent to calling unregister_shrink() and register_shrinker(),
- * but atomically and with less overhead. This is useful to guarantee that all
- * shrinker invocations have seen an update, before freeing memory, similar to
- * rcu.
+ * This is useful to guarantee that all shrinker invocations have seen an
+ * update, before freeing memory.
  */
 void synchronize_shrinkers(void)
 {
-       down_write(&shrinker_rwsem);
-       up_write(&shrinker_rwsem);
+       atomic_inc(&shrinker_srcu_generation);
+       synchronize_srcu(&shrinker_srcu);
 }
 EXPORT_SYMBOL(synchronize_shrinkers);
 
@@ -887,19 +946,20 @@ static unsigned long shrink_slab_memcg(gfp_t gfp_mask, int nid,
 {
        struct shrinker_info *info;
        unsigned long ret, freed = 0;
-       int i;
+       int srcu_idx, generation;
+       int i = 0;
 
        if (!mem_cgroup_online(memcg))
                return 0;
 
-       if (!down_read_trylock(&shrinker_rwsem))
-               return 0;
-
-       info = shrinker_info_protected(memcg, nid);
+again:
+       srcu_idx = srcu_read_lock(&shrinker_srcu);
+       info = shrinker_info_srcu(memcg, nid);
        if (unlikely(!info))
                goto unlock;
 
-       for_each_set_bit(i, info->map, shrinker_nr_max) {
+       generation = atomic_read(&shrinker_srcu_generation);
+       for_each_set_bit_from(i, info->map, info->map_nr_max) {
                struct shrink_control sc = {
                        .gfp_mask = gfp_mask,
                        .nid = nid,
@@ -945,14 +1005,14 @@ static unsigned long shrink_slab_memcg(gfp_t gfp_mask, int nid,
                                set_shrinker_bit(memcg, nid, i);
                }
                freed += ret;
-
-               if (rwsem_is_contended(&shrinker_rwsem)) {
-                       freed = freed ? : 1;
-                       break;
+               if (atomic_read(&shrinker_srcu_generation) != generation) {
+                       srcu_read_unlock(&shrinker_srcu, srcu_idx);
+                       i++;
+                       goto again;
                }
        }
 unlock:
-       up_read(&shrinker_rwsem);
+       srcu_read_unlock(&shrinker_srcu, srcu_idx);
        return freed;
 }
 #else /* CONFIG_MEMCG */
@@ -989,6 +1049,7 @@ static unsigned long shrink_slab(gfp_t gfp_mask, int nid,
 {
        unsigned long ret, freed = 0;
        struct shrinker *shrinker;
+       int srcu_idx, generation;
 
        /*
         * The root memcg might be allocated even though memcg is disabled
@@ -1000,10 +1061,11 @@ static unsigned long shrink_slab(gfp_t gfp_mask, int nid,
        if (!mem_cgroup_disabled() && !mem_cgroup_is_root(memcg))
                return shrink_slab_memcg(gfp_mask, nid, memcg, priority);
 
-       if (!down_read_trylock(&shrinker_rwsem))
-               goto out;
+       srcu_idx = srcu_read_lock(&shrinker_srcu);
 
-       list_for_each_entry(shrinker, &shrinker_list, list) {
+       generation = atomic_read(&shrinker_srcu_generation);
+       list_for_each_entry_srcu(shrinker, &shrinker_list, list,
+                                srcu_read_lock_held(&shrinker_srcu)) {
                struct shrink_control sc = {
                        .gfp_mask = gfp_mask,
                        .nid = nid,
@@ -1014,19 +1076,14 @@ static unsigned long shrink_slab(gfp_t gfp_mask, int nid,
                if (ret == SHRINK_EMPTY)
                        ret = 0;
                freed += ret;
-               /*
-                * Bail out if someone want to register a new shrinker to
-                * prevent the registration from being stalled for long periods
-                * by parallel ongoing shrinking.
-                */
-               if (rwsem_is_contended(&shrinker_rwsem)) {
+
+               if (atomic_read(&shrinker_srcu_generation) != generation) {
                        freed = freed ? : 1;
                        break;
                }
        }
 
-       up_read(&shrinker_rwsem);
-out:
+       srcu_read_unlock(&shrinker_srcu, srcu_idx);
        cond_resched();
        return freed;
 }
@@ -1151,12 +1208,12 @@ void reclaim_throttle(pg_data_t *pgdat, enum vmscan_throttle_state reason)
        DEFINE_WAIT(wait);
 
        /*
-        * Do not throttle IO workers, kthreads other than kswapd or
+        * Do not throttle user workers, kthreads other than kswapd or
         * workqueues. They may be required for reclaim to make
         * forward progress (e.g. journalling workqueues or kthreads).
         */
        if (!current_is_kswapd() &&
-           current->flags & (PF_IO_WORKER|PF_KTHREAD)) {
+           current->flags & (PF_USER_WORKER|PF_KTHREAD)) {
                cond_resched();
                return;
        }
@@ -1911,6 +1968,16 @@ retry:
                        }
                }
 
+               /*
+                * Folio is unmapped now so it cannot be newly pinned anymore.
+                * No point in trying to reclaim folio if it is pinned.
+                * Furthermore we don't want to reclaim underlying fs metadata
+                * if the folio is pinned and thus potentially modified by the
+                * pinning process as that may upset the filesystem.
+                */
+               if (folio_maybe_dma_pinned(folio))
+                       goto activate_locked;
+
                mapping = folio_mapping(folio);
                if (folio_test_dirty(folio)) {
                        /*
@@ -3394,18 +3461,13 @@ void lru_gen_del_mm(struct mm_struct *mm)
        for_each_node(nid) {
                struct lruvec *lruvec = get_lruvec(memcg, nid);
 
-               /* where the last iteration ended (exclusive) */
+               /* where the current iteration continues after */
+               if (lruvec->mm_state.head == &mm->lru_gen.list)
+                       lruvec->mm_state.head = lruvec->mm_state.head->prev;
+
+               /* where the last iteration ended before */
                if (lruvec->mm_state.tail == &mm->lru_gen.list)
                        lruvec->mm_state.tail = lruvec->mm_state.tail->next;
-
-               /* where the current iteration continues (inclusive) */
-               if (lruvec->mm_state.head != &mm->lru_gen.list)
-                       continue;
-
-               lruvec->mm_state.head = lruvec->mm_state.head->next;
-               /* the deletion ends the current iteration */
-               if (lruvec->mm_state.head == &mm_list->fifo)
-                       WRITE_ONCE(lruvec->mm_state.seq, lruvec->mm_state.seq + 1);
        }
 
        list_del_init(&mm->lru_gen.list);
@@ -3501,68 +3563,54 @@ static bool iterate_mm_list(struct lruvec *lruvec, struct lru_gen_mm_walk *walk,
                            struct mm_struct **iter)
 {
        bool first = false;
-       bool last = true;
+       bool last = false;
        struct mm_struct *mm = NULL;
        struct mem_cgroup *memcg = lruvec_memcg(lruvec);
        struct lru_gen_mm_list *mm_list = get_mm_list(memcg);
        struct lru_gen_mm_state *mm_state = &lruvec->mm_state;
 
        /*
-        * There are four interesting cases for this page table walker:
-        * 1. It tries to start a new iteration of mm_list with a stale max_seq;
-        *    there is nothing left to do.
-        * 2. It's the first of the current generation, and it needs to reset
-        *    the Bloom filter for the next generation.
-        * 3. It reaches the end of mm_list, and it needs to increment
-        *    mm_state->seq; the iteration is done.
-        * 4. It's the last of the current generation, and it needs to reset the
-        *    mm stats counters for the next generation.
+        * mm_state->seq is incremented after each iteration of mm_list. There
+        * are three interesting cases for this page table walker:
+        * 1. It tries to start a new iteration with a stale max_seq: there is
+        *    nothing left to do.
+        * 2. It started the next iteration: it needs to reset the Bloom filter
+        *    so that a fresh set of PTE tables can be recorded.
+        * 3. It ended the current iteration: it needs to reset the mm stats
+        *    counters and tell its caller to increment max_seq.
         */
        spin_lock(&mm_list->lock);
 
        VM_WARN_ON_ONCE(mm_state->seq + 1 < walk->max_seq);
-       VM_WARN_ON_ONCE(*iter && mm_state->seq > walk->max_seq);
-       VM_WARN_ON_ONCE(*iter && !mm_state->nr_walkers);
 
-       if (walk->max_seq <= mm_state->seq) {
-               if (!*iter)
-                       last = false;
+       if (walk->max_seq <= mm_state->seq)
                goto done;
-       }
 
-       if (!mm_state->nr_walkers) {
-               VM_WARN_ON_ONCE(mm_state->head && mm_state->head != &mm_list->fifo);
+       if (!mm_state->head)
+               mm_state->head = &mm_list->fifo;
 
-               mm_state->head = mm_list->fifo.next;
+       if (mm_state->head == &mm_list->fifo)
                first = true;
-       }
-
-       while (!mm && mm_state->head != &mm_list->fifo) {
-               mm = list_entry(mm_state->head, struct mm_struct, lru_gen.list);
 
+       do {
                mm_state->head = mm_state->head->next;
+               if (mm_state->head == &mm_list->fifo) {
+                       WRITE_ONCE(mm_state->seq, mm_state->seq + 1);
+                       last = true;
+                       break;
+               }
 
                /* force scan for those added after the last iteration */
-               if (!mm_state->tail || mm_state->tail == &mm->lru_gen.list) {
-                       mm_state->tail = mm_state->head;
+               if (!mm_state->tail || mm_state->tail == mm_state->head) {
+                       mm_state->tail = mm_state->head->next;
                        walk->force_scan = true;
                }
 
+               mm = list_entry(mm_state->head, struct mm_struct, lru_gen.list);
                if (should_skip_mm(mm, walk))
                        mm = NULL;
-       }
-
-       if (mm_state->head == &mm_list->fifo)
-               WRITE_ONCE(mm_state->seq, mm_state->seq + 1);
+       } while (!mm);
 done:
-       if (*iter && !mm)
-               mm_state->nr_walkers--;
-       if (!*iter && mm)
-               mm_state->nr_walkers++;
-
-       if (mm_state->nr_walkers)
-               last = false;
-
        if (*iter || last)
                reset_mm_stats(lruvec, walk, last);
 
@@ -3590,9 +3638,9 @@ static bool iterate_mm_list_nowalk(struct lruvec *lruvec, unsigned long max_seq)
 
        VM_WARN_ON_ONCE(mm_state->seq + 1 < max_seq);
 
-       if (max_seq > mm_state->seq && !mm_state->nr_walkers) {
-               VM_WARN_ON_ONCE(mm_state->head && mm_state->head != &mm_list->fifo);
-
+       if (max_seq > mm_state->seq) {
+               mm_state->head = NULL;
+               mm_state->tail = NULL;
                WRITE_ONCE(mm_state->seq, mm_state->seq + 1);
                reset_mm_stats(lruvec, NULL, true);
                success = true;
@@ -3604,7 +3652,7 @@ static bool iterate_mm_list_nowalk(struct lruvec *lruvec, unsigned long max_seq)
 }
 
 /******************************************************************************
- *                          refault feedback loop
+ *                          PID controller
  ******************************************************************************/
 
 /*
@@ -4192,10 +4240,6 @@ restart:
 
                walk_pmd_range(&val, addr, next, args);
 
-               /* a racy check to curtail the waiting time */
-               if (wq_has_sleeper(&walk->lruvec->mm_state.wait))
-                       return 1;
-
                if (need_resched() || walk->batched >= MAX_LRU_BATCH) {
                        end = (addr | ~PUD_MASK) + 1;
                        goto done;
@@ -4228,8 +4272,14 @@ static void walk_mm(struct lruvec *lruvec, struct mm_struct *mm, struct lru_gen_
        walk->next_addr = FIRST_USER_ADDRESS;
 
        do {
+               DEFINE_MAX_SEQ(lruvec);
+
                err = -EBUSY;
 
+               /* another thread might have called inc_max_seq() */
+               if (walk->max_seq != max_seq)
+                       break;
+
                /* folio_update_gen() requires stable folio_memcg() */
                if (!mem_cgroup_trylock_pages(memcg))
                        break;
@@ -4462,25 +4512,12 @@ static bool try_to_inc_max_seq(struct lruvec *lruvec, unsigned long max_seq,
                success = iterate_mm_list(lruvec, walk, &mm);
                if (mm)
                        walk_mm(lruvec, mm, walk);
-
-               cond_resched();
        } while (mm);
 done:
-       if (!success) {
-               if (sc->priority <= DEF_PRIORITY - 2)
-                       wait_event_killable(lruvec->mm_state.wait,
-                                           max_seq < READ_ONCE(lrugen->max_seq));
-               return false;
-       }
-
-       VM_WARN_ON_ONCE(max_seq != READ_ONCE(lrugen->max_seq));
+       if (success)
+               inc_max_seq(lruvec, can_swap, force_scan);
 
-       inc_max_seq(lruvec, can_swap, force_scan);
-       /* either this sees any waiters or they will see updated max_seq */
-       if (wq_has_sleeper(&lruvec->mm_state.wait))
-               wake_up_all(&lruvec->mm_state.wait);
-
-       return true;
+       return success;
 }
 
 /******************************************************************************
@@ -5346,8 +5383,7 @@ static int shrink_one(struct lruvec *lruvec, struct scan_control *sc)
                vmpressure(sc->gfp_mask, memcg, false, sc->nr_scanned - scanned,
                           sc->nr_reclaimed - reclaimed);
 
-       sc->nr_reclaimed += current->reclaim_state->reclaimed_slab;
-       current->reclaim_state->reclaimed_slab = 0;
+       flush_reclaim_state(sc);
 
        return success ? MEMCG_LRU_YOUNG : 0;
 }
@@ -5663,14 +5699,14 @@ unlock:
  *                          sysfs interface
  ******************************************************************************/
 
-static ssize_t show_min_ttl(struct kobject *kobj, struct kobj_attribute *attr, char *buf)
+static ssize_t min_ttl_ms_show(struct kobject *kobj, struct kobj_attribute *attr, char *buf)
 {
-       return sprintf(buf, "%u\n", jiffies_to_msecs(READ_ONCE(lru_gen_min_ttl)));
+       return sysfs_emit(buf, "%u\n", jiffies_to_msecs(READ_ONCE(lru_gen_min_ttl)));
 }
 
 /* see Documentation/admin-guide/mm/multigen_lru.rst for details */
-static ssize_t store_min_ttl(struct kobject *kobj, struct kobj_attribute *attr,
-                            const char *buf, size_t len)
+static ssize_t min_ttl_ms_store(struct kobject *kobj, struct kobj_attribute *attr,
+                               const char *buf, size_t len)
 {
        unsigned int msecs;
 
@@ -5682,11 +5718,9 @@ static ssize_t store_min_ttl(struct kobject *kobj, struct kobj_attribute *attr,
        return len;
 }
 
-static struct kobj_attribute lru_gen_min_ttl_attr = __ATTR(
-       min_ttl_ms, 0644, show_min_ttl, store_min_ttl
-);
+static struct kobj_attribute lru_gen_min_ttl_attr = __ATTR_RW(min_ttl_ms);
 
-static ssize_t show_enabled(struct kobject *kobj, struct kobj_attribute *attr, char *buf)
+static ssize_t enabled_show(struct kobject *kobj, struct kobj_attribute *attr, char *buf)
 {
        unsigned int caps = 0;
 
@@ -5703,7 +5737,7 @@ static ssize_t show_enabled(struct kobject *kobj, struct kobj_attribute *attr, c
 }
 
 /* see Documentation/admin-guide/mm/multigen_lru.rst for details */
-static ssize_t store_enabled(struct kobject *kobj, struct kobj_attribute *attr,
+static ssize_t enabled_store(struct kobject *kobj, struct kobj_attribute *attr,
                             const char *buf, size_t len)
 {
        int i;
@@ -5730,9 +5764,7 @@ static ssize_t store_enabled(struct kobject *kobj, struct kobj_attribute *attr,
        return len;
 }
 
-static struct kobj_attribute lru_gen_enabled_attr = __ATTR(
-       enabled, 0644, show_enabled, store_enabled
-);
+static struct kobj_attribute lru_gen_enabled_attr = __ATTR_RW(enabled);
 
 static struct attribute *lru_gen_attrs[] = {
        &lru_gen_min_ttl_attr.attr,
@@ -5740,7 +5772,7 @@ static struct attribute *lru_gen_attrs[] = {
        NULL
 };
 
-static struct attribute_group lru_gen_attr_group = {
+static const struct attribute_group lru_gen_attr_group = {
        .name = "lru_gen",
        .attrs = lru_gen_attrs,
 };
@@ -6122,7 +6154,6 @@ void lru_gen_init_lruvec(struct lruvec *lruvec)
                INIT_LIST_HEAD(&lrugen->folios[gen][type][zone]);
 
        lruvec->mm_state.seq = MIN_NR_GENS;
-       init_waitqueue_head(&lruvec->mm_state.wait);
 }
 
 #ifdef CONFIG_MEMCG
@@ -6155,7 +6186,6 @@ void lru_gen_exit_memcg(struct mem_cgroup *memcg)
        for_each_node(nid) {
                struct lruvec *lruvec = get_lruvec(memcg, nid);
 
-               VM_WARN_ON_ONCE(lruvec->mm_state.nr_walkers);
                VM_WARN_ON_ONCE(memchr_inv(lruvec->lrugen.nr_pages, 0,
                                           sizeof(lruvec->lrugen.nr_pages)));
 
@@ -6450,8 +6480,7 @@ static void shrink_node_memcgs(pg_data_t *pgdat, struct scan_control *sc)
 
 static void shrink_node(pg_data_t *pgdat, struct scan_control *sc)
 {
-       struct reclaim_state *reclaim_state = current->reclaim_state;
-       unsigned long nr_reclaimed, nr_scanned;
+       unsigned long nr_reclaimed, nr_scanned, nr_node_reclaimed;
        struct lruvec *target_lruvec;
        bool reclaimable = false;
 
@@ -6472,18 +6501,16 @@ again:
 
        shrink_node_memcgs(pgdat, sc);
 
-       if (reclaim_state) {
-               sc->nr_reclaimed += reclaim_state->reclaimed_slab;
-               reclaim_state->reclaimed_slab = 0;
-       }
+       flush_reclaim_state(sc);
+
+       nr_node_reclaimed = sc->nr_reclaimed - nr_reclaimed;
 
        /* Record the subtree's reclaim efficiency */
        if (!sc->proactive)
                vmpressure(sc->gfp_mask, sc->target_mem_cgroup, true,
-                          sc->nr_scanned - nr_scanned,
-                          sc->nr_reclaimed - nr_reclaimed);
+                          sc->nr_scanned - nr_scanned, nr_node_reclaimed);
 
-       if (sc->nr_reclaimed - nr_reclaimed)
+       if (nr_node_reclaimed)
                reclaimable = true;
 
        if (current_is_kswapd()) {
@@ -6545,8 +6572,7 @@ again:
            test_bit(LRUVEC_CONGESTED, &target_lruvec->flags))
                reclaim_throttle(pgdat, VMSCAN_THROTTLE_CONGESTED);
 
-       if (should_continue_reclaim(pgdat, sc->nr_reclaimed - nr_reclaimed,
-                                   sc))
+       if (should_continue_reclaim(pgdat, nr_node_reclaimed, sc))
                goto again;
 
        /*
@@ -6990,7 +7016,7 @@ unsigned long try_to_free_pages(struct zonelist *zonelist, int order,
         * scan_control uses s8 fields for order, priority, and reclaim_idx.
         * Confirm they are large enough for max values.
         */
-       BUILD_BUG_ON(MAX_ORDER > S8_MAX);
+       BUILD_BUG_ON(MAX_ORDER >= S8_MAX);
        BUILD_BUG_ON(DEF_PRIORITY > S8_MAX);
        BUILD_BUG_ON(MAX_NR_ZONES > S8_MAX);