io_uring: move SQPOLL related handling into its own file
authorJens Axboe <axboe@kernel.dk>
Wed, 25 May 2022 15:13:39 +0000 (09:13 -0600)
committerJens Axboe <axboe@kernel.dk>
Mon, 25 Jul 2022 00:39:12 +0000 (18:39 -0600)
Signed-off-by: Jens Axboe <axboe@kernel.dk>
io_uring/Makefile
io_uring/io_uring.c
io_uring/io_uring.h
io_uring/sqpoll.c [new file with mode: 0644]
io_uring/sqpoll.h [new file with mode: 0644]

index 6ae4e45..c59a9ca 100644 (file)
@@ -5,5 +5,6 @@
 obj-$(CONFIG_IO_URING)         += io_uring.o xattr.o nop.o fs.o splice.o \
                                        sync.o advise.o filetable.o \
                                        openclose.o uring_cmd.o epoll.o \
-                                       statx.o net.o msg_ring.o timeout.o
+                                       statx.o net.o msg_ring.o timeout.o \
+                                       sqpoll.o
 obj-$(CONFIG_IO_WQ)            += io-wq.o
index 3fc59a2..17c555a 100644 (file)
@@ -92,6 +92,7 @@
 #include "io_uring_types.h"
 #include "io_uring.h"
 #include "refs.h"
+#include "sqpoll.h"
 
 #include "xattr.h"
 #include "nop.h"
 
 #define IORING_MAX_ENTRIES     32768
 #define IORING_MAX_CQ_ENTRIES  (2 * IORING_MAX_ENTRIES)
-#define IORING_SQPOLL_CAP_ENTRIES_VALUE 8
 
 /* only define max */
 #define IORING_MAX_FIXED_FILES (1U << 20)
@@ -214,31 +214,6 @@ struct io_buffer {
        __u16 bgid;
 };
 
-enum {
-       IO_SQ_THREAD_SHOULD_STOP = 0,
-       IO_SQ_THREAD_SHOULD_PARK,
-};
-
-struct io_sq_data {
-       refcount_t              refs;
-       atomic_t                park_pending;
-       struct mutex            lock;
-
-       /* ctx's that are using this sqd */
-       struct list_head        ctx_list;
-
-       struct task_struct      *thread;
-       struct wait_queue_head  wait;
-
-       unsigned                sq_thread_idle;
-       int                     sq_cpu;
-       pid_t                   task_pid;
-       pid_t                   task_tgid;
-
-       unsigned long           state;
-       struct completion       exited;
-};
-
 #define IO_COMPL_BATCH                 32
 #define IO_REQ_CACHE_SIZE              32
 #define IO_REQ_ALLOC_BATCH             8
@@ -402,7 +377,6 @@ static void io_uring_del_tctx_node(unsigned long index);
 static void io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
                                         struct task_struct *task,
                                         bool cancel_all);
-static void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd);
 
 static void io_dismantle_req(struct io_kiocb *req);
 static int __io_register_rsrc_update(struct io_ring_ctx *ctx, unsigned type,
@@ -1079,13 +1053,6 @@ static void __io_commit_cqring_flush(struct io_ring_ctx *ctx)
                io_eventfd_signal(ctx);
 }
 
-static inline bool io_sqring_full(struct io_ring_ctx *ctx)
-{
-       struct io_rings *r = ctx->rings;
-
-       return READ_ONCE(r->sq.tail) - ctx->cached_sq_head == ctx->sq_entries;
-}
-
 static inline unsigned int __io_cqring_events(struct io_ring_ctx *ctx)
 {
        return ctx->cached_cq_tail - READ_ONCE(ctx->rings->cq.head);
@@ -1974,28 +1941,7 @@ static unsigned io_cqring_events(struct io_ring_ctx *ctx)
        return __io_cqring_events(ctx);
 }
 
-static inline unsigned int io_sqring_entries(struct io_ring_ctx *ctx)
-{
-       struct io_rings *rings = ctx->rings;
-
-       /* make sure SQ entry isn't read before tail */
-       return smp_load_acquire(&rings->sq.tail) - ctx->cached_sq_head;
-}
-
-static inline bool io_run_task_work(void)
-{
-       if (test_thread_flag(TIF_NOTIFY_SIGNAL) || task_work_pending(current)) {
-               __set_current_state(TASK_RUNNING);
-               clear_notify_signal();
-               if (task_work_pending(current))
-                       task_work_run();
-               return true;
-       }
-
-       return false;
-}
-
-static int io_do_iopoll(struct io_ring_ctx *ctx, bool force_nonspin)
+int io_do_iopoll(struct io_ring_ctx *ctx, bool force_nonspin)
 {
        struct io_wq_work_node *pos, *start, *prev;
        unsigned int poll_flags = BLK_POLL_NOSLEEP;
@@ -5297,7 +5243,7 @@ static const struct io_uring_sqe *io_get_sqe(struct io_ring_ctx *ctx)
        return NULL;
 }
 
-static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr)
+int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr)
        __must_hold(&ctx->uring_lock)
 {
        unsigned int entries = io_sqring_entries(ctx);
@@ -5349,173 +5295,6 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr)
        return ret;
 }
 
-static inline bool io_sqd_events_pending(struct io_sq_data *sqd)
-{
-       return READ_ONCE(sqd->state);
-}
-
-static int __io_sq_thread(struct io_ring_ctx *ctx, bool cap_entries)
-{
-       unsigned int to_submit;
-       int ret = 0;
-
-       to_submit = io_sqring_entries(ctx);
-       /* if we're handling multiple rings, cap submit size for fairness */
-       if (cap_entries && to_submit > IORING_SQPOLL_CAP_ENTRIES_VALUE)
-               to_submit = IORING_SQPOLL_CAP_ENTRIES_VALUE;
-
-       if (!wq_list_empty(&ctx->iopoll_list) || to_submit) {
-               const struct cred *creds = NULL;
-
-               if (ctx->sq_creds != current_cred())
-                       creds = override_creds(ctx->sq_creds);
-
-               mutex_lock(&ctx->uring_lock);
-               if (!wq_list_empty(&ctx->iopoll_list))
-                       io_do_iopoll(ctx, true);
-
-               /*
-                * Don't submit if refs are dying, good for io_uring_register(),
-                * but also it is relied upon by io_ring_exit_work()
-                */
-               if (to_submit && likely(!percpu_ref_is_dying(&ctx->refs)) &&
-                   !(ctx->flags & IORING_SETUP_R_DISABLED))
-                       ret = io_submit_sqes(ctx, to_submit);
-               mutex_unlock(&ctx->uring_lock);
-
-               if (to_submit && wq_has_sleeper(&ctx->sqo_sq_wait))
-                       wake_up(&ctx->sqo_sq_wait);
-               if (creds)
-                       revert_creds(creds);
-       }
-
-       return ret;
-}
-
-static __cold void io_sqd_update_thread_idle(struct io_sq_data *sqd)
-{
-       struct io_ring_ctx *ctx;
-       unsigned sq_thread_idle = 0;
-
-       list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
-               sq_thread_idle = max(sq_thread_idle, ctx->sq_thread_idle);
-       sqd->sq_thread_idle = sq_thread_idle;
-}
-
-static bool io_sqd_handle_event(struct io_sq_data *sqd)
-{
-       bool did_sig = false;
-       struct ksignal ksig;
-
-       if (test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state) ||
-           signal_pending(current)) {
-               mutex_unlock(&sqd->lock);
-               if (signal_pending(current))
-                       did_sig = get_signal(&ksig);
-               cond_resched();
-               mutex_lock(&sqd->lock);
-       }
-       return did_sig || test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
-}
-
-static int io_sq_thread(void *data)
-{
-       struct io_sq_data *sqd = data;
-       struct io_ring_ctx *ctx;
-       unsigned long timeout = 0;
-       char buf[TASK_COMM_LEN];
-       DEFINE_WAIT(wait);
-
-       snprintf(buf, sizeof(buf), "iou-sqp-%d", sqd->task_pid);
-       set_task_comm(current, buf);
-
-       if (sqd->sq_cpu != -1)
-               set_cpus_allowed_ptr(current, cpumask_of(sqd->sq_cpu));
-       else
-               set_cpus_allowed_ptr(current, cpu_online_mask);
-       current->flags |= PF_NO_SETAFFINITY;
-
-       audit_alloc_kernel(current);
-
-       mutex_lock(&sqd->lock);
-       while (1) {
-               bool cap_entries, sqt_spin = false;
-
-               if (io_sqd_events_pending(sqd) || signal_pending(current)) {
-                       if (io_sqd_handle_event(sqd))
-                               break;
-                       timeout = jiffies + sqd->sq_thread_idle;
-               }
-
-               cap_entries = !list_is_singular(&sqd->ctx_list);
-               list_for_each_entry(ctx, &sqd->ctx_list, sqd_list) {
-                       int ret = __io_sq_thread(ctx, cap_entries);
-
-                       if (!sqt_spin && (ret > 0 || !wq_list_empty(&ctx->iopoll_list)))
-                               sqt_spin = true;
-               }
-               if (io_run_task_work())
-                       sqt_spin = true;
-
-               if (sqt_spin || !time_after(jiffies, timeout)) {
-                       cond_resched();
-                       if (sqt_spin)
-                               timeout = jiffies + sqd->sq_thread_idle;
-                       continue;
-               }
-
-               prepare_to_wait(&sqd->wait, &wait, TASK_INTERRUPTIBLE);
-               if (!io_sqd_events_pending(sqd) && !task_work_pending(current)) {
-                       bool needs_sched = true;
-
-                       list_for_each_entry(ctx, &sqd->ctx_list, sqd_list) {
-                               atomic_or(IORING_SQ_NEED_WAKEUP,
-                                               &ctx->rings->sq_flags);
-                               if ((ctx->flags & IORING_SETUP_IOPOLL) &&
-                                   !wq_list_empty(&ctx->iopoll_list)) {
-                                       needs_sched = false;
-                                       break;
-                               }
-
-                               /*
-                                * Ensure the store of the wakeup flag is not
-                                * reordered with the load of the SQ tail
-                                */
-                               smp_mb__after_atomic();
-
-                               if (io_sqring_entries(ctx)) {
-                                       needs_sched = false;
-                                       break;
-                               }
-                       }
-
-                       if (needs_sched) {
-                               mutex_unlock(&sqd->lock);
-                               schedule();
-                               mutex_lock(&sqd->lock);
-                       }
-                       list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
-                               atomic_andnot(IORING_SQ_NEED_WAKEUP,
-                                               &ctx->rings->sq_flags);
-               }
-
-               finish_wait(&sqd->wait, &wait);
-               timeout = jiffies + sqd->sq_thread_idle;
-       }
-
-       io_uring_cancel_generic(true, sqd);
-       sqd->thread = NULL;
-       list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
-               atomic_or(IORING_SQ_NEED_WAKEUP, &ctx->rings->sq_flags);
-       io_run_task_work();
-       mutex_unlock(&sqd->lock);
-
-       audit_free(current);
-
-       complete(&sqd->exited);
-       do_exit(0);
-}
-
 struct io_wait_queue {
        struct wait_queue_entry wq;
        struct io_ring_ctx *ctx;
@@ -5934,131 +5713,6 @@ static int io_sqe_files_unregister(struct io_ring_ctx *ctx)
        return ret;
 }
 
-static void io_sq_thread_unpark(struct io_sq_data *sqd)
-       __releases(&sqd->lock)
-{
-       WARN_ON_ONCE(sqd->thread == current);
-
-       /*
-        * Do the dance but not conditional clear_bit() because it'd race with
-        * other threads incrementing park_pending and setting the bit.
-        */
-       clear_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
-       if (atomic_dec_return(&sqd->park_pending))
-               set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
-       mutex_unlock(&sqd->lock);
-}
-
-static void io_sq_thread_park(struct io_sq_data *sqd)
-       __acquires(&sqd->lock)
-{
-       WARN_ON_ONCE(sqd->thread == current);
-
-       atomic_inc(&sqd->park_pending);
-       set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
-       mutex_lock(&sqd->lock);
-       if (sqd->thread)
-               wake_up_process(sqd->thread);
-}
-
-static void io_sq_thread_stop(struct io_sq_data *sqd)
-{
-       WARN_ON_ONCE(sqd->thread == current);
-       WARN_ON_ONCE(test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state));
-
-       set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
-       mutex_lock(&sqd->lock);
-       if (sqd->thread)
-               wake_up_process(sqd->thread);
-       mutex_unlock(&sqd->lock);
-       wait_for_completion(&sqd->exited);
-}
-
-static void io_put_sq_data(struct io_sq_data *sqd)
-{
-       if (refcount_dec_and_test(&sqd->refs)) {
-               WARN_ON_ONCE(atomic_read(&sqd->park_pending));
-
-               io_sq_thread_stop(sqd);
-               kfree(sqd);
-       }
-}
-
-static void io_sq_thread_finish(struct io_ring_ctx *ctx)
-{
-       struct io_sq_data *sqd = ctx->sq_data;
-
-       if (sqd) {
-               io_sq_thread_park(sqd);
-               list_del_init(&ctx->sqd_list);
-               io_sqd_update_thread_idle(sqd);
-               io_sq_thread_unpark(sqd);
-
-               io_put_sq_data(sqd);
-               ctx->sq_data = NULL;
-       }
-}
-
-static struct io_sq_data *io_attach_sq_data(struct io_uring_params *p)
-{
-       struct io_ring_ctx *ctx_attach;
-       struct io_sq_data *sqd;
-       struct fd f;
-
-       f = fdget(p->wq_fd);
-       if (!f.file)
-               return ERR_PTR(-ENXIO);
-       if (f.file->f_op != &io_uring_fops) {
-               fdput(f);
-               return ERR_PTR(-EINVAL);
-       }
-
-       ctx_attach = f.file->private_data;
-       sqd = ctx_attach->sq_data;
-       if (!sqd) {
-               fdput(f);
-               return ERR_PTR(-EINVAL);
-       }
-       if (sqd->task_tgid != current->tgid) {
-               fdput(f);
-               return ERR_PTR(-EPERM);
-       }
-
-       refcount_inc(&sqd->refs);
-       fdput(f);
-       return sqd;
-}
-
-static struct io_sq_data *io_get_sq_data(struct io_uring_params *p,
-                                        bool *attached)
-{
-       struct io_sq_data *sqd;
-
-       *attached = false;
-       if (p->flags & IORING_SETUP_ATTACH_WQ) {
-               sqd = io_attach_sq_data(p);
-               if (!IS_ERR(sqd)) {
-                       *attached = true;
-                       return sqd;
-               }
-               /* fall through for EPERM case, setup new sqd/task */
-               if (PTR_ERR(sqd) != -EPERM)
-                       return sqd;
-       }
-
-       sqd = kzalloc(sizeof(*sqd), GFP_KERNEL);
-       if (!sqd)
-               return ERR_PTR(-ENOMEM);
-
-       atomic_set(&sqd->park_pending, 0);
-       refcount_set(&sqd->refs, 1);
-       INIT_LIST_HEAD(&sqd->ctx_list);
-       mutex_init(&sqd->lock);
-       init_waitqueue_head(&sqd->wait);
-       init_completion(&sqd->exited);
-       return sqd;
-}
-
 /*
  * Ensure the UNIX gc is aware of our file set, so we are certain that
  * the io_uring can be safely unregistered on process exit, even if we have
@@ -6495,8 +6149,8 @@ static struct io_wq *io_init_wq_offload(struct io_ring_ctx *ctx,
        return io_wq_create(concurrency, &data);
 }
 
-static __cold int io_uring_alloc_task_context(struct task_struct *task,
-                                             struct io_ring_ctx *ctx)
+__cold int io_uring_alloc_task_context(struct task_struct *task,
+                                      struct io_ring_ctx *ctx)
 {
        struct io_uring_task *tctx;
        int ret;
@@ -6554,96 +6208,6 @@ void __io_uring_free(struct task_struct *tsk)
        tsk->io_uring = NULL;
 }
 
-static __cold int io_sq_offload_create(struct io_ring_ctx *ctx,
-                                      struct io_uring_params *p)
-{
-       int ret;
-
-       /* Retain compatibility with failing for an invalid attach attempt */
-       if ((ctx->flags & (IORING_SETUP_ATTACH_WQ | IORING_SETUP_SQPOLL)) ==
-                               IORING_SETUP_ATTACH_WQ) {
-               struct fd f;
-
-               f = fdget(p->wq_fd);
-               if (!f.file)
-                       return -ENXIO;
-               if (f.file->f_op != &io_uring_fops) {
-                       fdput(f);
-                       return -EINVAL;
-               }
-               fdput(f);
-       }
-       if (ctx->flags & IORING_SETUP_SQPOLL) {
-               struct task_struct *tsk;
-               struct io_sq_data *sqd;
-               bool attached;
-
-               ret = security_uring_sqpoll();
-               if (ret)
-                       return ret;
-
-               sqd = io_get_sq_data(p, &attached);
-               if (IS_ERR(sqd)) {
-                       ret = PTR_ERR(sqd);
-                       goto err;
-               }
-
-               ctx->sq_creds = get_current_cred();
-               ctx->sq_data = sqd;
-               ctx->sq_thread_idle = msecs_to_jiffies(p->sq_thread_idle);
-               if (!ctx->sq_thread_idle)
-                       ctx->sq_thread_idle = HZ;
-
-               io_sq_thread_park(sqd);
-               list_add(&ctx->sqd_list, &sqd->ctx_list);
-               io_sqd_update_thread_idle(sqd);
-               /* don't attach to a dying SQPOLL thread, would be racy */
-               ret = (attached && !sqd->thread) ? -ENXIO : 0;
-               io_sq_thread_unpark(sqd);
-
-               if (ret < 0)
-                       goto err;
-               if (attached)
-                       return 0;
-
-               if (p->flags & IORING_SETUP_SQ_AFF) {
-                       int cpu = p->sq_thread_cpu;
-
-                       ret = -EINVAL;
-                       if (cpu >= nr_cpu_ids || !cpu_online(cpu))
-                               goto err_sqpoll;
-                       sqd->sq_cpu = cpu;
-               } else {
-                       sqd->sq_cpu = -1;
-               }
-
-               sqd->task_pid = current->pid;
-               sqd->task_tgid = current->tgid;
-               tsk = create_io_thread(io_sq_thread, sqd, NUMA_NO_NODE);
-               if (IS_ERR(tsk)) {
-                       ret = PTR_ERR(tsk);
-                       goto err_sqpoll;
-               }
-
-               sqd->thread = tsk;
-               ret = io_uring_alloc_task_context(tsk, ctx);
-               wake_up_new_task(tsk);
-               if (ret)
-                       goto err;
-       } else if (p->flags & IORING_SETUP_SQ_AFF) {
-               /* Can't have SQ_AFF without SQPOLL */
-               ret = -EINVAL;
-               goto err;
-       }
-
-       return 0;
-err_sqpoll:
-       complete(&ctx->sq_data->exited);
-err:
-       io_sq_thread_finish(ctx);
-       return ret;
-}
-
 static inline void __io_unaccount_mem(struct user_struct *user,
                                      unsigned long nr_pages)
 {
@@ -7755,8 +7319,7 @@ static s64 tctx_inflight(struct io_uring_task *tctx, bool tracked)
  * Find any io_uring ctx that this task has registered or done IO on, and cancel
  * requests. @sqd should be not-null IFF it's an SQPOLL thread cancellation.
  */
-static __cold void io_uring_cancel_generic(bool cancel_all,
-                                          struct io_sq_data *sqd)
+__cold void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd)
 {
        struct io_uring_task *tctx = current->io_uring;
        struct io_ring_ctx *ctx;
@@ -8034,24 +7597,6 @@ static unsigned long io_uring_nommu_get_unmapped_area(struct file *file,
 
 #endif /* !CONFIG_MMU */
 
-static int io_sqpoll_wait_sq(struct io_ring_ctx *ctx)
-{
-       DEFINE_WAIT(wait);
-
-       do {
-               if (!io_sqring_full(ctx))
-                       break;
-               prepare_to_wait(&ctx->sqo_sq_wait, &wait, TASK_INTERRUPTIBLE);
-
-               if (!io_sqring_full(ctx))
-                       break;
-               schedule();
-       } while (!signal_pending(current));
-
-       finish_wait(&ctx->sqo_sq_wait, &wait);
-       return 0;
-}
-
 static int io_validate_ext_arg(unsigned flags, const void __user *argp, size_t argsz)
 {
        if (flags & IORING_ENTER_EXT_ARG) {
index e285e12..1da8e66 100644 (file)
@@ -64,6 +64,34 @@ static inline void io_commit_cqring(struct io_ring_ctx *ctx)
        smp_store_release(&ctx->rings->cq.tail, ctx->cached_cq_tail);
 }
 
+static inline bool io_sqring_full(struct io_ring_ctx *ctx)
+{
+       struct io_rings *r = ctx->rings;
+
+       return READ_ONCE(r->sq.tail) - ctx->cached_sq_head == ctx->sq_entries;
+}
+
+static inline unsigned int io_sqring_entries(struct io_ring_ctx *ctx)
+{
+       struct io_rings *rings = ctx->rings;
+
+       /* make sure SQ entry isn't read before tail */
+       return smp_load_acquire(&rings->sq.tail) - ctx->cached_sq_head;
+}
+
+static inline bool io_run_task_work(void)
+{
+       if (test_thread_flag(TIF_NOTIFY_SIGNAL) || task_work_pending(current)) {
+               __set_current_state(TASK_RUNNING);
+               clear_notify_signal();
+               if (task_work_pending(current))
+                       task_work_run();
+               return true;
+       }
+
+       return false;
+}
+
 void __io_req_complete(struct io_kiocb *req, unsigned issue_flags);
 void io_req_complete_post(struct io_kiocb *req);
 void __io_req_complete_post(struct io_kiocb *req);
@@ -101,6 +129,12 @@ void io_req_tw_post_queue(struct io_kiocb *req, s32 res, u32 cflags);
 void io_req_task_complete(struct io_kiocb *req, bool *locked);
 void io_req_task_queue_fail(struct io_kiocb *req, int ret);
 int io_try_cancel(struct io_kiocb *req, struct io_cancel_data *cd);
+__cold void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd);
+int io_uring_alloc_task_context(struct task_struct *task,
+                               struct io_ring_ctx *ctx);
+
+int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr);
+int io_do_iopoll(struct io_ring_ctx *ctx, bool force_nonspin);
 
 void io_free_req(struct io_kiocb *req);
 void io_queue_next(struct io_kiocb *req);
diff --git a/io_uring/sqpoll.c b/io_uring/sqpoll.c
new file mode 100644 (file)
index 0000000..149d5c9
--- /dev/null
@@ -0,0 +1,426 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Contains the core associated with submission side polling of the SQ
+ * ring, offloading submissions from the application to a kernel thread.
+ */
+#include <linux/kernel.h>
+#include <linux/errno.h>
+#include <linux/file.h>
+#include <linux/mm.h>
+#include <linux/slab.h>
+#include <linux/audit.h>
+#include <linux/security.h>
+#include <linux/io_uring.h>
+
+#include <uapi/linux/io_uring.h>
+
+#include "io_uring_types.h"
+#include "io_uring.h"
+#include "sqpoll.h"
+
+#define IORING_SQPOLL_CAP_ENTRIES_VALUE 8
+
+enum {
+       IO_SQ_THREAD_SHOULD_STOP = 0,
+       IO_SQ_THREAD_SHOULD_PARK,
+};
+
+void io_sq_thread_unpark(struct io_sq_data *sqd)
+       __releases(&sqd->lock)
+{
+       WARN_ON_ONCE(sqd->thread == current);
+
+       /*
+        * Do the dance but not conditional clear_bit() because it'd race with
+        * other threads incrementing park_pending and setting the bit.
+        */
+       clear_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
+       if (atomic_dec_return(&sqd->park_pending))
+               set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
+       mutex_unlock(&sqd->lock);
+}
+
+void io_sq_thread_park(struct io_sq_data *sqd)
+       __acquires(&sqd->lock)
+{
+       WARN_ON_ONCE(sqd->thread == current);
+
+       atomic_inc(&sqd->park_pending);
+       set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
+       mutex_lock(&sqd->lock);
+       if (sqd->thread)
+               wake_up_process(sqd->thread);
+}
+
+void io_sq_thread_stop(struct io_sq_data *sqd)
+{
+       WARN_ON_ONCE(sqd->thread == current);
+       WARN_ON_ONCE(test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state));
+
+       set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
+       mutex_lock(&sqd->lock);
+       if (sqd->thread)
+               wake_up_process(sqd->thread);
+       mutex_unlock(&sqd->lock);
+       wait_for_completion(&sqd->exited);
+}
+
+void io_put_sq_data(struct io_sq_data *sqd)
+{
+       if (refcount_dec_and_test(&sqd->refs)) {
+               WARN_ON_ONCE(atomic_read(&sqd->park_pending));
+
+               io_sq_thread_stop(sqd);
+               kfree(sqd);
+       }
+}
+
+static __cold void io_sqd_update_thread_idle(struct io_sq_data *sqd)
+{
+       struct io_ring_ctx *ctx;
+       unsigned sq_thread_idle = 0;
+
+       list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
+               sq_thread_idle = max(sq_thread_idle, ctx->sq_thread_idle);
+       sqd->sq_thread_idle = sq_thread_idle;
+}
+
+void io_sq_thread_finish(struct io_ring_ctx *ctx)
+{
+       struct io_sq_data *sqd = ctx->sq_data;
+
+       if (sqd) {
+               io_sq_thread_park(sqd);
+               list_del_init(&ctx->sqd_list);
+               io_sqd_update_thread_idle(sqd);
+               io_sq_thread_unpark(sqd);
+
+               io_put_sq_data(sqd);
+               ctx->sq_data = NULL;
+       }
+}
+
+static struct io_sq_data *io_attach_sq_data(struct io_uring_params *p)
+{
+       struct io_ring_ctx *ctx_attach;
+       struct io_sq_data *sqd;
+       struct fd f;
+
+       f = fdget(p->wq_fd);
+       if (!f.file)
+               return ERR_PTR(-ENXIO);
+       if (!io_is_uring_fops(f.file)) {
+               fdput(f);
+               return ERR_PTR(-EINVAL);
+       }
+
+       ctx_attach = f.file->private_data;
+       sqd = ctx_attach->sq_data;
+       if (!sqd) {
+               fdput(f);
+               return ERR_PTR(-EINVAL);
+       }
+       if (sqd->task_tgid != current->tgid) {
+               fdput(f);
+               return ERR_PTR(-EPERM);
+       }
+
+       refcount_inc(&sqd->refs);
+       fdput(f);
+       return sqd;
+}
+
+static struct io_sq_data *io_get_sq_data(struct io_uring_params *p,
+                                        bool *attached)
+{
+       struct io_sq_data *sqd;
+
+       *attached = false;
+       if (p->flags & IORING_SETUP_ATTACH_WQ) {
+               sqd = io_attach_sq_data(p);
+               if (!IS_ERR(sqd)) {
+                       *attached = true;
+                       return sqd;
+               }
+               /* fall through for EPERM case, setup new sqd/task */
+               if (PTR_ERR(sqd) != -EPERM)
+                       return sqd;
+       }
+
+       sqd = kzalloc(sizeof(*sqd), GFP_KERNEL);
+       if (!sqd)
+               return ERR_PTR(-ENOMEM);
+
+       atomic_set(&sqd->park_pending, 0);
+       refcount_set(&sqd->refs, 1);
+       INIT_LIST_HEAD(&sqd->ctx_list);
+       mutex_init(&sqd->lock);
+       init_waitqueue_head(&sqd->wait);
+       init_completion(&sqd->exited);
+       return sqd;
+}
+
+static inline bool io_sqd_events_pending(struct io_sq_data *sqd)
+{
+       return READ_ONCE(sqd->state);
+}
+
+static int __io_sq_thread(struct io_ring_ctx *ctx, bool cap_entries)
+{
+       unsigned int to_submit;
+       int ret = 0;
+
+       to_submit = io_sqring_entries(ctx);
+       /* if we're handling multiple rings, cap submit size for fairness */
+       if (cap_entries && to_submit > IORING_SQPOLL_CAP_ENTRIES_VALUE)
+               to_submit = IORING_SQPOLL_CAP_ENTRIES_VALUE;
+
+       if (!wq_list_empty(&ctx->iopoll_list) || to_submit) {
+               const struct cred *creds = NULL;
+
+               if (ctx->sq_creds != current_cred())
+                       creds = override_creds(ctx->sq_creds);
+
+               mutex_lock(&ctx->uring_lock);
+               if (!wq_list_empty(&ctx->iopoll_list))
+                       io_do_iopoll(ctx, true);
+
+               /*
+                * Don't submit if refs are dying, good for io_uring_register(),
+                * but also it is relied upon by io_ring_exit_work()
+                */
+               if (to_submit && likely(!percpu_ref_is_dying(&ctx->refs)) &&
+                   !(ctx->flags & IORING_SETUP_R_DISABLED))
+                       ret = io_submit_sqes(ctx, to_submit);
+               mutex_unlock(&ctx->uring_lock);
+
+               if (to_submit && wq_has_sleeper(&ctx->sqo_sq_wait))
+                       wake_up(&ctx->sqo_sq_wait);
+               if (creds)
+                       revert_creds(creds);
+       }
+
+       return ret;
+}
+
+static bool io_sqd_handle_event(struct io_sq_data *sqd)
+{
+       bool did_sig = false;
+       struct ksignal ksig;
+
+       if (test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state) ||
+           signal_pending(current)) {
+               mutex_unlock(&sqd->lock);
+               if (signal_pending(current))
+                       did_sig = get_signal(&ksig);
+               cond_resched();
+               mutex_lock(&sqd->lock);
+       }
+       return did_sig || test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
+}
+
+static int io_sq_thread(void *data)
+{
+       struct io_sq_data *sqd = data;
+       struct io_ring_ctx *ctx;
+       unsigned long timeout = 0;
+       char buf[TASK_COMM_LEN];
+       DEFINE_WAIT(wait);
+
+       snprintf(buf, sizeof(buf), "iou-sqp-%d", sqd->task_pid);
+       set_task_comm(current, buf);
+
+       if (sqd->sq_cpu != -1)
+               set_cpus_allowed_ptr(current, cpumask_of(sqd->sq_cpu));
+       else
+               set_cpus_allowed_ptr(current, cpu_online_mask);
+       current->flags |= PF_NO_SETAFFINITY;
+
+       audit_alloc_kernel(current);
+
+       mutex_lock(&sqd->lock);
+       while (1) {
+               bool cap_entries, sqt_spin = false;
+
+               if (io_sqd_events_pending(sqd) || signal_pending(current)) {
+                       if (io_sqd_handle_event(sqd))
+                               break;
+                       timeout = jiffies + sqd->sq_thread_idle;
+               }
+
+               cap_entries = !list_is_singular(&sqd->ctx_list);
+               list_for_each_entry(ctx, &sqd->ctx_list, sqd_list) {
+                       int ret = __io_sq_thread(ctx, cap_entries);
+
+                       if (!sqt_spin && (ret > 0 || !wq_list_empty(&ctx->iopoll_list)))
+                               sqt_spin = true;
+               }
+               if (io_run_task_work())
+                       sqt_spin = true;
+
+               if (sqt_spin || !time_after(jiffies, timeout)) {
+                       cond_resched();
+                       if (sqt_spin)
+                               timeout = jiffies + sqd->sq_thread_idle;
+                       continue;
+               }
+
+               prepare_to_wait(&sqd->wait, &wait, TASK_INTERRUPTIBLE);
+               if (!io_sqd_events_pending(sqd) && !task_work_pending(current)) {
+                       bool needs_sched = true;
+
+                       list_for_each_entry(ctx, &sqd->ctx_list, sqd_list) {
+                               atomic_or(IORING_SQ_NEED_WAKEUP,
+                                               &ctx->rings->sq_flags);
+                               if ((ctx->flags & IORING_SETUP_IOPOLL) &&
+                                   !wq_list_empty(&ctx->iopoll_list)) {
+                                       needs_sched = false;
+                                       break;
+                               }
+
+                               /*
+                                * Ensure the store of the wakeup flag is not
+                                * reordered with the load of the SQ tail
+                                */
+                               smp_mb__after_atomic();
+
+                               if (io_sqring_entries(ctx)) {
+                                       needs_sched = false;
+                                       break;
+                               }
+                       }
+
+                       if (needs_sched) {
+                               mutex_unlock(&sqd->lock);
+                               schedule();
+                               mutex_lock(&sqd->lock);
+                       }
+                       list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
+                               atomic_andnot(IORING_SQ_NEED_WAKEUP,
+                                               &ctx->rings->sq_flags);
+               }
+
+               finish_wait(&sqd->wait, &wait);
+               timeout = jiffies + sqd->sq_thread_idle;
+       }
+
+       io_uring_cancel_generic(true, sqd);
+       sqd->thread = NULL;
+       list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
+               atomic_or(IORING_SQ_NEED_WAKEUP, &ctx->rings->sq_flags);
+       io_run_task_work();
+       mutex_unlock(&sqd->lock);
+
+       audit_free(current);
+
+       complete(&sqd->exited);
+       do_exit(0);
+}
+
+int io_sqpoll_wait_sq(struct io_ring_ctx *ctx)
+{
+       DEFINE_WAIT(wait);
+
+       do {
+               if (!io_sqring_full(ctx))
+                       break;
+               prepare_to_wait(&ctx->sqo_sq_wait, &wait, TASK_INTERRUPTIBLE);
+
+               if (!io_sqring_full(ctx))
+                       break;
+               schedule();
+       } while (!signal_pending(current));
+
+       finish_wait(&ctx->sqo_sq_wait, &wait);
+       return 0;
+}
+
+__cold int io_sq_offload_create(struct io_ring_ctx *ctx,
+                               struct io_uring_params *p)
+{
+       int ret;
+
+       /* Retain compatibility with failing for an invalid attach attempt */
+       if ((ctx->flags & (IORING_SETUP_ATTACH_WQ | IORING_SETUP_SQPOLL)) ==
+                               IORING_SETUP_ATTACH_WQ) {
+               struct fd f;
+
+               f = fdget(p->wq_fd);
+               if (!f.file)
+                       return -ENXIO;
+               if (!io_is_uring_fops(f.file)) {
+                       fdput(f);
+                       return -EINVAL;
+               }
+               fdput(f);
+       }
+       if (ctx->flags & IORING_SETUP_SQPOLL) {
+               struct task_struct *tsk;
+               struct io_sq_data *sqd;
+               bool attached;
+
+               ret = security_uring_sqpoll();
+               if (ret)
+                       return ret;
+
+               sqd = io_get_sq_data(p, &attached);
+               if (IS_ERR(sqd)) {
+                       ret = PTR_ERR(sqd);
+                       goto err;
+               }
+
+               ctx->sq_creds = get_current_cred();
+               ctx->sq_data = sqd;
+               ctx->sq_thread_idle = msecs_to_jiffies(p->sq_thread_idle);
+               if (!ctx->sq_thread_idle)
+                       ctx->sq_thread_idle = HZ;
+
+               io_sq_thread_park(sqd);
+               list_add(&ctx->sqd_list, &sqd->ctx_list);
+               io_sqd_update_thread_idle(sqd);
+               /* don't attach to a dying SQPOLL thread, would be racy */
+               ret = (attached && !sqd->thread) ? -ENXIO : 0;
+               io_sq_thread_unpark(sqd);
+
+               if (ret < 0)
+                       goto err;
+               if (attached)
+                       return 0;
+
+               if (p->flags & IORING_SETUP_SQ_AFF) {
+                       int cpu = p->sq_thread_cpu;
+
+                       ret = -EINVAL;
+                       if (cpu >= nr_cpu_ids || !cpu_online(cpu))
+                               goto err_sqpoll;
+                       sqd->sq_cpu = cpu;
+               } else {
+                       sqd->sq_cpu = -1;
+               }
+
+               sqd->task_pid = current->pid;
+               sqd->task_tgid = current->tgid;
+               tsk = create_io_thread(io_sq_thread, sqd, NUMA_NO_NODE);
+               if (IS_ERR(tsk)) {
+                       ret = PTR_ERR(tsk);
+                       goto err_sqpoll;
+               }
+
+               sqd->thread = tsk;
+               ret = io_uring_alloc_task_context(tsk, ctx);
+               wake_up_new_task(tsk);
+               if (ret)
+                       goto err;
+       } else if (p->flags & IORING_SETUP_SQ_AFF) {
+               /* Can't have SQ_AFF without SQPOLL */
+               ret = -EINVAL;
+               goto err;
+       }
+
+       return 0;
+err_sqpoll:
+       complete(&ctx->sq_data->exited);
+err:
+       io_sq_thread_finish(ctx);
+       return ret;
+}
diff --git a/io_uring/sqpoll.h b/io_uring/sqpoll.h
new file mode 100644 (file)
index 0000000..0c3fbcd
--- /dev/null
@@ -0,0 +1,29 @@
+// SPDX-License-Identifier: GPL-2.0
+
+struct io_sq_data {
+       refcount_t              refs;
+       atomic_t                park_pending;
+       struct mutex            lock;
+
+       /* ctx's that are using this sqd */
+       struct list_head        ctx_list;
+
+       struct task_struct      *thread;
+       struct wait_queue_head  wait;
+
+       unsigned                sq_thread_idle;
+       int                     sq_cpu;
+       pid_t                   task_pid;
+       pid_t                   task_tgid;
+
+       unsigned long           state;
+       struct completion       exited;
+};
+
+int io_sq_offload_create(struct io_ring_ctx *ctx, struct io_uring_params *p);
+void io_sq_thread_finish(struct io_ring_ctx *ctx);
+void io_sq_thread_stop(struct io_sq_data *sqd);
+void io_sq_thread_park(struct io_sq_data *sqd);
+void io_sq_thread_unpark(struct io_sq_data *sqd);
+void io_put_sq_data(struct io_sq_data *sqd);
+int io_sqpoll_wait_sq(struct io_ring_ctx *ctx);