Merge tag 'mfd-next-5.11' of git://git.kernel.org/pub/scm/linux/kernel/git/lee/mfd
[linux-2.6-microblaze.git] / fs / io-wq.c
index 414beb5..f72d538 100644 (file)
 #include <linux/rculist_nulls.h>
 #include <linux/fs_struct.h>
 #include <linux/task_work.h>
+#include <linux/blk-cgroup.h>
+#include <linux/audit.h>
+#include <linux/cpu.h>
 
+#include "../kernel/sched/sched.h"
 #include "io-wq.h"
 
 #define WORKER_IDLE_TIMEOUT    (5 * HZ)
@@ -26,9 +30,8 @@ enum {
        IO_WORKER_F_UP          = 1,    /* up and active */
        IO_WORKER_F_RUNNING     = 2,    /* account as running */
        IO_WORKER_F_FREE        = 4,    /* worker on free list */
-       IO_WORKER_F_EXITING     = 8,    /* worker exiting */
-       IO_WORKER_F_FIXED       = 16,   /* static idle worker */
-       IO_WORKER_F_BOUND       = 32,   /* is doing bounded work */
+       IO_WORKER_F_FIXED       = 8,    /* static idle worker */
+       IO_WORKER_F_BOUND       = 16,   /* is doing bounded work */
 };
 
 enum {
@@ -57,9 +60,13 @@ struct io_worker {
 
        struct rcu_head rcu;
        struct mm_struct *mm;
+#ifdef CONFIG_BLK_CGROUP
+       struct cgroup_subsys_state *blkcg_css;
+#endif
        const struct cred *cur_creds;
        const struct cred *saved_creds;
        struct files_struct *restore_files;
+       struct nsproxy *restore_nsproxy;
        struct fs_struct *restore_fs;
 };
 
@@ -87,7 +94,7 @@ enum {
  */
 struct io_wqe {
        struct {
-               spinlock_t lock;
+               raw_spinlock_t lock;
                struct io_wq_work_list work_list;
                unsigned long hash_map;
                unsigned flags;
@@ -118,9 +125,13 @@ struct io_wq {
        refcount_t refs;
        struct completion done;
 
+       struct hlist_node cpuhp_node;
+
        refcount_t use_refs;
 };
 
+static enum cpuhp_state io_wq_online;
+
 static bool io_worker_get(struct io_worker *worker)
 {
        return refcount_inc_not_zero(&worker->ref);
@@ -148,11 +159,12 @@ static bool __io_worker_unuse(struct io_wqe *wqe, struct io_worker *worker)
 
        if (current->files != worker->restore_files) {
                __acquire(&wqe->lock);
-               spin_unlock_irq(&wqe->lock);
+               raw_spin_unlock_irq(&wqe->lock);
                dropped_lock = true;
 
                task_lock(current);
                current->files = worker->restore_files;
+               current->nsproxy = worker->restore_nsproxy;
                task_unlock(current);
        }
 
@@ -166,7 +178,7 @@ static bool __io_worker_unuse(struct io_wqe *wqe, struct io_worker *worker)
        if (worker->mm) {
                if (!dropped_lock) {
                        __acquire(&wqe->lock);
-                       spin_unlock_irq(&wqe->lock);
+                       raw_spin_unlock_irq(&wqe->lock);
                        dropped_lock = true;
                }
                __set_current_state(TASK_RUNNING);
@@ -175,6 +187,14 @@ static bool __io_worker_unuse(struct io_wqe *wqe, struct io_worker *worker)
                worker->mm = NULL;
        }
 
+#ifdef CONFIG_BLK_CGROUP
+       if (worker->blkcg_css) {
+               kthread_associate_blkcg(NULL);
+               worker->blkcg_css = NULL;
+       }
+#endif
+       if (current->signal->rlim[RLIMIT_FSIZE].rlim_cur != RLIM_INFINITY)
+               current->signal->rlim[RLIMIT_FSIZE].rlim_cur = RLIM_INFINITY;
        return dropped_lock;
 }
 
@@ -200,7 +220,6 @@ static void io_worker_exit(struct io_worker *worker)
 {
        struct io_wqe *wqe = worker->wqe;
        struct io_wqe_acct *acct = io_wqe_get_acct(wqe, worker);
-       unsigned nr_workers;
 
        /*
         * If we're not at zero, someone else is holding a brief reference
@@ -220,23 +239,19 @@ static void io_worker_exit(struct io_worker *worker)
        worker->flags = 0;
        preempt_enable();
 
-       spin_lock_irq(&wqe->lock);
+       raw_spin_lock_irq(&wqe->lock);
        hlist_nulls_del_rcu(&worker->nulls_node);
        list_del_rcu(&worker->all_list);
        if (__io_worker_unuse(wqe, worker)) {
                __release(&wqe->lock);
-               spin_lock_irq(&wqe->lock);
+               raw_spin_lock_irq(&wqe->lock);
        }
        acct->nr_workers--;
-       nr_workers = wqe->acct[IO_WQ_ACCT_BOUND].nr_workers +
-                       wqe->acct[IO_WQ_ACCT_UNBOUND].nr_workers;
-       spin_unlock_irq(&wqe->lock);
-
-       /* all workers gone, wq exit can proceed */
-       if (!nr_workers && refcount_dec_and_test(&wqe->wq->refs))
-               complete(&wqe->wq->done);
+       raw_spin_unlock_irq(&wqe->lock);
 
        kfree_rcu(worker, rcu);
+       if (refcount_dec_and_test(&wqe->wq->refs))
+               complete(&wqe->wq->done);
 }
 
 static inline bool io_wqe_run_queue(struct io_wqe *wqe)
@@ -318,6 +333,7 @@ static void io_worker_start(struct io_wqe *wqe, struct io_worker *worker)
 
        worker->flags |= (IO_WORKER_F_UP | IO_WORKER_F_RUNNING);
        worker->restore_files = current->files;
+       worker->restore_nsproxy = current->nsproxy;
        worker->restore_fs = current->fs;
        io_wqe_inc_running(wqe, worker);
 }
@@ -421,14 +437,10 @@ static void io_wq_switch_mm(struct io_worker *worker, struct io_wq_work *work)
                mmput(worker->mm);
                worker->mm = NULL;
        }
-       if (!work->mm)
-               return;
 
-       if (mmget_not_zero(work->mm)) {
-               kthread_use_mm(work->mm);
-               worker->mm = work->mm;
-               /* hang on to this mm */
-               work->mm = NULL;
+       if (mmget_not_zero(work->identity->mm)) {
+               kthread_use_mm(work->identity->mm);
+               worker->mm = work->identity->mm;
                return;
        }
 
@@ -436,12 +448,25 @@ static void io_wq_switch_mm(struct io_worker *worker, struct io_wq_work *work)
        work->flags |= IO_WQ_WORK_CANCEL;
 }
 
+static inline void io_wq_switch_blkcg(struct io_worker *worker,
+                                     struct io_wq_work *work)
+{
+#ifdef CONFIG_BLK_CGROUP
+       if (!(work->flags & IO_WQ_WORK_BLKCG))
+               return;
+       if (work->identity->blkcg_css != worker->blkcg_css) {
+               kthread_associate_blkcg(work->identity->blkcg_css);
+               worker->blkcg_css = work->identity->blkcg_css;
+       }
+#endif
+}
+
 static void io_wq_switch_creds(struct io_worker *worker,
                               struct io_wq_work *work)
 {
-       const struct cred *old_creds = override_creds(work->creds);
+       const struct cred *old_creds = override_creds(work->identity->creds);
 
-       worker->cur_creds = work->creds;
+       worker->cur_creds = work->identity->creds;
        if (worker->saved_creds)
                put_cred(old_creds); /* creds set by previous switch */
        else
@@ -451,18 +476,33 @@ static void io_wq_switch_creds(struct io_worker *worker,
 static void io_impersonate_work(struct io_worker *worker,
                                struct io_wq_work *work)
 {
-       if (work->files && current->files != work->files) {
+       if ((work->flags & IO_WQ_WORK_FILES) &&
+           current->files != work->identity->files) {
                task_lock(current);
-               current->files = work->files;
+               current->files = work->identity->files;
+               current->nsproxy = work->identity->nsproxy;
                task_unlock(current);
+               if (!work->identity->files) {
+                       /* failed grabbing files, ensure work gets cancelled */
+                       work->flags |= IO_WQ_WORK_CANCEL;
+               }
        }
-       if (work->fs && current->fs != work->fs)
-               current->fs = work->fs;
-       if (work->mm != worker->mm)
+       if ((work->flags & IO_WQ_WORK_FS) && current->fs != work->identity->fs)
+               current->fs = work->identity->fs;
+       if ((work->flags & IO_WQ_WORK_MM) && work->identity->mm != worker->mm)
                io_wq_switch_mm(worker, work);
-       if (worker->cur_creds != work->creds)
+       if ((work->flags & IO_WQ_WORK_CREDS) &&
+           worker->cur_creds != work->identity->creds)
                io_wq_switch_creds(worker, work);
-       current->signal->rlim[RLIMIT_FSIZE].rlim_cur = work->fsize;
+       if (work->flags & IO_WQ_WORK_FSIZE)
+               current->signal->rlim[RLIMIT_FSIZE].rlim_cur = work->identity->fsize;
+       else if (current->signal->rlim[RLIMIT_FSIZE].rlim_cur != RLIM_INFINITY)
+               current->signal->rlim[RLIMIT_FSIZE].rlim_cur = RLIM_INFINITY;
+       io_wq_switch_blkcg(worker, work);
+#ifdef CONFIG_AUDIT
+       current->loginuid = work->identity->loginuid;
+       current->sessionid = work->identity->sessionid;
+#endif
 }
 
 static void io_assign_current_work(struct io_worker *worker,
@@ -475,6 +515,11 @@ static void io_assign_current_work(struct io_worker *worker,
                cond_resched();
        }
 
+#ifdef CONFIG_AUDIT
+       current->loginuid = KUIDT_INIT(AUDIT_UID_UNSET);
+       current->sessionid = AUDIT_SID_UNSET;
+#endif
+
        spin_lock_irq(&worker->lock);
        worker->cur_work = work;
        spin_unlock_irq(&worker->lock);
@@ -504,7 +549,7 @@ get_next:
                else if (!wq_list_empty(&wqe->work_list))
                        wqe->flags |= IO_WQE_FLAG_STALLED;
 
-               spin_unlock_irq(&wqe->lock);
+               raw_spin_unlock_irq(&wqe->lock);
                if (!work)
                        break;
                io_assign_current_work(worker, work);
@@ -538,17 +583,17 @@ get_next:
                                io_wqe_enqueue(wqe, linked);
 
                        if (hash != -1U && !next_hashed) {
-                               spin_lock_irq(&wqe->lock);
+                               raw_spin_lock_irq(&wqe->lock);
                                wqe->hash_map &= ~BIT_ULL(hash);
                                wqe->flags &= ~IO_WQE_FLAG_STALLED;
                                /* skip unnecessary unlock-lock wqe->lock */
                                if (!work)
                                        goto get_next;
-                               spin_unlock_irq(&wqe->lock);
+                               raw_spin_unlock_irq(&wqe->lock);
                        }
                } while (work);
 
-               spin_lock_irq(&wqe->lock);
+               raw_spin_lock_irq(&wqe->lock);
        } while (1);
 }
 
@@ -563,7 +608,7 @@ static int io_wqe_worker(void *data)
        while (!test_bit(IO_WQ_BIT_EXIT, &wq->state)) {
                set_current_state(TASK_INTERRUPTIBLE);
 loop:
-               spin_lock_irq(&wqe->lock);
+               raw_spin_lock_irq(&wqe->lock);
                if (io_wqe_run_queue(wqe)) {
                        __set_current_state(TASK_RUNNING);
                        io_worker_handle_work(worker);
@@ -574,7 +619,7 @@ loop:
                        __release(&wqe->lock);
                        goto loop;
                }
-               spin_unlock_irq(&wqe->lock);
+               raw_spin_unlock_irq(&wqe->lock);
                if (signal_pending(current))
                        flush_signals(current);
                if (schedule_timeout(WORKER_IDLE_TIMEOUT))
@@ -586,11 +631,11 @@ loop:
        }
 
        if (test_bit(IO_WQ_BIT_EXIT, &wq->state)) {
-               spin_lock_irq(&wqe->lock);
+               raw_spin_lock_irq(&wqe->lock);
                if (!wq_list_empty(&wqe->work_list))
                        io_worker_handle_work(worker);
                else
-                       spin_unlock_irq(&wqe->lock);
+                       raw_spin_unlock_irq(&wqe->lock);
        }
 
        io_worker_exit(worker);
@@ -630,14 +675,14 @@ void io_wq_worker_sleeping(struct task_struct *tsk)
 
        worker->flags &= ~IO_WORKER_F_RUNNING;
 
-       spin_lock_irq(&wqe->lock);
+       raw_spin_lock_irq(&wqe->lock);
        io_wqe_dec_running(wqe, worker);
-       spin_unlock_irq(&wqe->lock);
+       raw_spin_unlock_irq(&wqe->lock);
 }
 
 static bool create_io_worker(struct io_wq *wq, struct io_wqe *wqe, int index)
 {
-       struct io_wqe_acct *acct =&wqe->acct[index];
+       struct io_wqe_acct *acct = &wqe->acct[index];
        struct io_worker *worker;
 
        worker = kzalloc_node(sizeof(*worker), GFP_KERNEL, wqe->node);
@@ -655,8 +700,9 @@ static bool create_io_worker(struct io_wq *wq, struct io_wqe *wqe, int index)
                kfree(worker);
                return false;
        }
+       kthread_bind_mask(worker->task, cpumask_of_node(wqe->node));
 
-       spin_lock_irq(&wqe->lock);
+       raw_spin_lock_irq(&wqe->lock);
        hlist_nulls_add_head_rcu(&worker->nulls_node, &wqe->free_list);
        list_add_tail_rcu(&worker->all_list, &wqe->all_list);
        worker->flags |= IO_WORKER_F_FREE;
@@ -665,11 +711,12 @@ static bool create_io_worker(struct io_wq *wq, struct io_wqe *wqe, int index)
        if (!acct->nr_workers && (worker->flags & IO_WORKER_F_BOUND))
                worker->flags |= IO_WORKER_F_FIXED;
        acct->nr_workers++;
-       spin_unlock_irq(&wqe->lock);
+       raw_spin_unlock_irq(&wqe->lock);
 
        if (index == IO_WQ_ACCT_UNBOUND)
                atomic_inc(&wq->user->processes);
 
+       refcount_inc(&wq->refs);
        wake_up_process(worker->task);
        return true;
 }
@@ -685,28 +732,63 @@ static inline bool io_wqe_need_worker(struct io_wqe *wqe, int index)
        return acct->nr_workers < acct->max_workers;
 }
 
+static bool io_wqe_worker_send_sig(struct io_worker *worker, void *data)
+{
+       send_sig(SIGINT, worker->task, 1);
+       return false;
+}
+
+/*
+ * Iterate the passed in list and call the specific function for each
+ * worker that isn't exiting
+ */
+static bool io_wq_for_each_worker(struct io_wqe *wqe,
+                                 bool (*func)(struct io_worker *, void *),
+                                 void *data)
+{
+       struct io_worker *worker;
+       bool ret = false;
+
+       list_for_each_entry_rcu(worker, &wqe->all_list, all_list) {
+               if (io_worker_get(worker)) {
+                       /* no task if node is/was offline */
+                       if (worker->task)
+                               ret = func(worker, data);
+                       io_worker_release(worker);
+                       if (ret)
+                               break;
+               }
+       }
+
+       return ret;
+}
+
+static bool io_wq_worker_wake(struct io_worker *worker, void *data)
+{
+       wake_up_process(worker->task);
+       return false;
+}
+
 /*
  * Manager thread. Tasked with creating new workers, if we need them.
  */
 static int io_wq_manager(void *data)
 {
        struct io_wq *wq = data;
-       int workers_to_create = num_possible_nodes();
        int node;
 
        /* create fixed workers */
-       refcount_set(&wq->refs, workers_to_create);
+       refcount_set(&wq->refs, 1);
        for_each_node(node) {
                if (!node_online(node))
                        continue;
-               if (!create_io_worker(wq, wq->wqes[node], IO_WQ_ACCT_BOUND))
-                       goto err;
-               workers_to_create--;
+               if (create_io_worker(wq, wq->wqes[node], IO_WQ_ACCT_BOUND))
+                       continue;
+               set_bit(IO_WQ_BIT_ERROR, &wq->state);
+               set_bit(IO_WQ_BIT_EXIT, &wq->state);
+               goto out;
        }
 
-       while (workers_to_create--)
-               refcount_dec(&wq->refs);
-
        complete(&wq->done);
 
        while (!kthread_should_stop()) {
@@ -720,12 +802,12 @@ static int io_wq_manager(void *data)
                        if (!node_online(node))
                                continue;
 
-                       spin_lock_irq(&wqe->lock);
+                       raw_spin_lock_irq(&wqe->lock);
                        if (io_wqe_need_worker(wqe, IO_WQ_ACCT_BOUND))
                                fork_worker[IO_WQ_ACCT_BOUND] = true;
                        if (io_wqe_need_worker(wqe, IO_WQ_ACCT_UNBOUND))
                                fork_worker[IO_WQ_ACCT_UNBOUND] = true;
-                       spin_unlock_irq(&wqe->lock);
+                       raw_spin_unlock_irq(&wqe->lock);
                        if (fork_worker[IO_WQ_ACCT_BOUND])
                                create_io_worker(wq, wqe, IO_WQ_ACCT_BOUND);
                        if (fork_worker[IO_WQ_ACCT_UNBOUND])
@@ -738,12 +820,18 @@ static int io_wq_manager(void *data)
        if (current->task_works)
                task_work_run();
 
-       return 0;
-err:
-       set_bit(IO_WQ_BIT_ERROR, &wq->state);
-       set_bit(IO_WQ_BIT_EXIT, &wq->state);
-       if (refcount_sub_and_test(workers_to_create, &wq->refs))
+out:
+       if (refcount_dec_and_test(&wq->refs)) {
                complete(&wq->done);
+               return 0;
+       }
+       /* if ERROR is set and we get here, we have workers to wake */
+       if (test_bit(IO_WQ_BIT_ERROR, &wq->state)) {
+               rcu_read_lock();
+               for_each_node(node)
+                       io_wq_for_each_worker(wq->wqes[node], io_wq_worker_wake, NULL);
+               rcu_read_unlock();
+       }
        return 0;
 }
 
@@ -821,10 +909,10 @@ static void io_wqe_enqueue(struct io_wqe *wqe, struct io_wq_work *work)
        }
 
        work_flags = work->flags;
-       spin_lock_irqsave(&wqe->lock, flags);
+       raw_spin_lock_irqsave(&wqe->lock, flags);
        io_wqe_insert_work(wqe, work);
        wqe->flags &= ~IO_WQE_FLAG_STALLED;
-       spin_unlock_irqrestore(&wqe->lock, flags);
+       raw_spin_unlock_irqrestore(&wqe->lock, flags);
 
        if ((work_flags & IO_WQ_WORK_CONCURRENT) ||
            !atomic_read(&acct->nr_running))
@@ -850,37 +938,6 @@ void io_wq_hash_work(struct io_wq_work *work, void *val)
        work->flags |= (IO_WQ_WORK_HASHED | (bit << IO_WQ_HASH_SHIFT));
 }
 
-static bool io_wqe_worker_send_sig(struct io_worker *worker, void *data)
-{
-       send_sig(SIGINT, worker->task, 1);
-       return false;
-}
-
-/*
- * Iterate the passed in list and call the specific function for each
- * worker that isn't exiting
- */
-static bool io_wq_for_each_worker(struct io_wqe *wqe,
-                                 bool (*func)(struct io_worker *, void *),
-                                 void *data)
-{
-       struct io_worker *worker;
-       bool ret = false;
-
-       list_for_each_entry_rcu(worker, &wqe->all_list, all_list) {
-               if (io_worker_get(worker)) {
-                       /* no task if node is/was offline */
-                       if (worker->task)
-                               ret = func(worker, data);
-                       io_worker_release(worker);
-                       if (ret)
-                               break;
-               }
-       }
-
-       return ret;
-}
-
 void io_wq_cancel_all(struct io_wq *wq)
 {
        int node;
@@ -951,13 +1008,13 @@ static void io_wqe_cancel_pending_work(struct io_wqe *wqe,
        unsigned long flags;
 
 retry:
-       spin_lock_irqsave(&wqe->lock, flags);
+       raw_spin_lock_irqsave(&wqe->lock, flags);
        wq_list_for_each(node, prev, &wqe->work_list) {
                work = container_of(node, struct io_wq_work, list);
                if (!match->fn(work, match->data))
                        continue;
                io_wqe_remove_pending(wqe, work, prev);
-               spin_unlock_irqrestore(&wqe->lock, flags);
+               raw_spin_unlock_irqrestore(&wqe->lock, flags);
                io_run_cancel(work, wqe);
                match->nr_pending++;
                if (!match->cancel_all)
@@ -966,7 +1023,7 @@ retry:
                /* not safe to continue after unlock */
                goto retry;
        }
-       spin_unlock_irqrestore(&wqe->lock, flags);
+       raw_spin_unlock_irqrestore(&wqe->lock, flags);
 }
 
 static void io_wqe_cancel_running_work(struct io_wqe *wqe,
@@ -1021,16 +1078,6 @@ enum io_wq_cancel io_wq_cancel_cb(struct io_wq *wq, work_cancel_fn *cancel,
        return IO_WQ_CANCEL_NOTFOUND;
 }
 
-static bool io_wq_io_cb_cancel_data(struct io_wq_work *work, void *data)
-{
-       return work == data;
-}
-
-enum io_wq_cancel io_wq_cancel_work(struct io_wq *wq, struct io_wq_work *cwork)
-{
-       return io_wq_cancel_cb(wq, io_wq_io_cb_cancel_data, (void *)cwork, false);
-}
-
 struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
 {
        int ret = -ENOMEM, node;
@@ -1044,10 +1091,12 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
                return ERR_PTR(-ENOMEM);
 
        wq->wqes = kcalloc(nr_node_ids, sizeof(struct io_wqe *), GFP_KERNEL);
-       if (!wq->wqes) {
-               kfree(wq);
-               return ERR_PTR(-ENOMEM);
-       }
+       if (!wq->wqes)
+               goto err_wq;
+
+       ret = cpuhp_state_add_instance_nocalls(io_wq_online, &wq->cpuhp_node);
+       if (ret)
+               goto err_wqes;
 
        wq->free_work = data->free_work;
        wq->do_work = data->do_work;
@@ -1055,6 +1104,7 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
        /* caller must already hold a reference to this */
        wq->user = data->user;
 
+       ret = -ENOMEM;
        for_each_node(node) {
                struct io_wqe *wqe;
                int alloc_node = node;
@@ -1074,7 +1124,7 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
                }
                atomic_set(&wqe->acct[IO_WQ_ACCT_UNBOUND].nr_running, 0);
                wqe->wq = wq;
-               spin_lock_init(&wqe->lock);
+               raw_spin_lock_init(&wqe->lock);
                INIT_WQ_LIST(&wqe->work_list);
                INIT_HLIST_NULLS_HEAD(&wqe->free_list, 0);
                INIT_LIST_HEAD(&wqe->all_list);
@@ -1098,9 +1148,12 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
        ret = PTR_ERR(wq->manager);
        complete(&wq->done);
 err:
+       cpuhp_state_remove_instance_nocalls(io_wq_online, &wq->cpuhp_node);
        for_each_node(node)
                kfree(wq->wqes[node]);
+err_wqes:
        kfree(wq->wqes);
+err_wq:
        kfree(wq);
        return ERR_PTR(ret);
 }
@@ -1113,16 +1166,12 @@ bool io_wq_get(struct io_wq *wq, struct io_wq_data *data)
        return refcount_inc_not_zero(&wq->use_refs);
 }
 
-static bool io_wq_worker_wake(struct io_worker *worker, void *data)
-{
-       wake_up_process(worker->task);
-       return false;
-}
-
 static void __io_wq_destroy(struct io_wq *wq)
 {
        int node;
 
+       cpuhp_state_remove_instance_nocalls(io_wq_online, &wq->cpuhp_node);
+
        set_bit(IO_WQ_BIT_EXIT, &wq->state);
        if (wq->manager)
                kthread_stop(wq->manager);
@@ -1150,3 +1199,41 @@ struct task_struct *io_wq_get_task(struct io_wq *wq)
 {
        return wq->manager;
 }
+
+static bool io_wq_worker_affinity(struct io_worker *worker, void *data)
+{
+       struct task_struct *task = worker->task;
+       struct rq_flags rf;
+       struct rq *rq;
+
+       rq = task_rq_lock(task, &rf);
+       do_set_cpus_allowed(task, cpumask_of_node(worker->wqe->node));
+       task->flags |= PF_NO_SETAFFINITY;
+       task_rq_unlock(rq, task, &rf);
+       return false;
+}
+
+static int io_wq_cpu_online(unsigned int cpu, struct hlist_node *node)
+{
+       struct io_wq *wq = hlist_entry_safe(node, struct io_wq, cpuhp_node);
+       int i;
+
+       rcu_read_lock();
+       for_each_node(i)
+               io_wq_for_each_worker(wq->wqes[i], io_wq_worker_affinity, NULL);
+       rcu_read_unlock();
+       return 0;
+}
+
+static __init int io_wq_init(void)
+{
+       int ret;
+
+       ret = cpuhp_setup_state_multi(CPUHP_AP_ONLINE_DYN, "io-wq/online",
+                                       io_wq_cpu_online, NULL);
+       if (ret < 0)
+               return ret;
+       io_wq_online = ret;
+       return 0;
+}
+subsys_initcall(io_wq_init);