io_uring: mutex locked poll hashing
authorPavel Begunkov <asml.silence@gmail.com>
Thu, 16 Jun 2022 09:22:12 +0000 (10:22 +0100)
committerJens Axboe <axboe@kernel.dk>
Mon, 25 Jul 2022 00:39:14 +0000 (18:39 -0600)
Currently we do two extra spin lock/unlock pairs to add a poll/apoll
request to the cancellation hash table and remove it from there.

On the submission side we often already hold ->uring_lock and tw
completion is likely to hold it as well. Add a second cancellation hash
table protected by ->uring_lock. In concerns for latency because of a
need to have the mutex locked on the completion side, use the new table
only in following cases:

1) IORING_SETUP_SINGLE_ISSUER: only one task grabs uring_lock, so there
   is little to no contention and so the main tw hander will almost
   always end up grabbing it before calling callbacks.

2) IORING_SETUP_SQPOLL: same as with single issuer, only one task is
   a major user of ->uring_lock.

3) apoll: we normally grab the lock on the completion side anyway to
   execute the request, so it's free.

Signed-off-by: Pavel Begunkov <asml.silence@gmail.com>
Link: https://lore.kernel.org/r/1bbad9c78c454b7b92f100bbf46730a37df7194f.1655371007.git.asml.silence@gmail.com
Reviewed-by: Hao Xu <howeyxu@tencent.com>
Signed-off-by: Jens Axboe <axboe@kernel.dk>
io_uring/io_uring.c
io_uring/io_uring_types.h
io_uring/poll.c

index 0b3851a..eeda167 100644 (file)
@@ -275,6 +275,8 @@ static __cold struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p)
        hash_bits = clamp(hash_bits, 1, 8);
        if (io_alloc_hash_table(&ctx->cancel_table, hash_bits))
                goto err;
+       if (io_alloc_hash_table(&ctx->cancel_table_locked, hash_bits))
+               goto err;
 
        ctx->dummy_ubuf = kzalloc(sizeof(*ctx->dummy_ubuf), GFP_KERNEL);
        if (!ctx->dummy_ubuf)
@@ -317,6 +319,7 @@ static __cold struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p)
 err:
        kfree(ctx->dummy_ubuf);
        kfree(ctx->cancel_table.hbs);
+       kfree(ctx->cancel_table_locked.hbs);
        kfree(ctx->io_bl);
        xa_destroy(&ctx->io_bl_xa);
        kfree(ctx);
@@ -2493,6 +2496,7 @@ static __cold void io_ring_ctx_free(struct io_ring_ctx *ctx)
        if (ctx->hash_map)
                io_wq_put_hash(ctx->hash_map);
        kfree(ctx->cancel_table.hbs);
+       kfree(ctx->cancel_table_locked.hbs);
        kfree(ctx->dummy_ubuf);
        kfree(ctx->io_bl);
        xa_destroy(&ctx->io_bl_xa);
@@ -2654,12 +2658,13 @@ static __cold void io_ring_ctx_wait_and_kill(struct io_ring_ctx *ctx)
                __io_cqring_overflow_flush(ctx, true);
        xa_for_each(&ctx->personalities, index, creds)
                io_unregister_personality(ctx, index);
+       if (ctx->rings)
+               io_poll_remove_all(ctx, NULL, true);
        mutex_unlock(&ctx->uring_lock);
 
        /* failed during ring init, it couldn't have issued any requests */
        if (ctx->rings) {
                io_kill_timeouts(ctx, NULL, true);
-               io_poll_remove_all(ctx, NULL, true);
                /* if we failed setting up the ctx, we might not have any rings */
                io_iopoll_try_reap_events(ctx);
        }
@@ -2784,7 +2789,9 @@ static __cold void io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
                }
 
                ret |= io_cancel_defer_files(ctx, task, cancel_all);
+               mutex_lock(&ctx->uring_lock);
                ret |= io_poll_remove_all(ctx, task, cancel_all);
+               mutex_unlock(&ctx->uring_lock);
                ret |= io_kill_timeouts(ctx, task, cancel_all);
                if (task)
                        ret |= io_run_task_work();
index d3b9bde..65ac7cd 100644 (file)
@@ -191,6 +191,7 @@ struct io_ring_ctx {
                struct xarray           io_bl_xa;
                struct list_head        io_buffers_cache;
 
+               struct io_hash_table    cancel_table_locked;
                struct list_head        cq_overflow_list;
                struct list_head        apoll_cache;
                struct xarray           personalities;
@@ -323,6 +324,7 @@ enum {
        REQ_F_CQE32_INIT_BIT,
        REQ_F_APOLL_MULTISHOT_BIT,
        REQ_F_CLEAR_POLLIN_BIT,
+       REQ_F_HASH_LOCKED_BIT,
        /* keep async read/write and isreg together and in order */
        REQ_F_SUPPORT_NOWAIT_BIT,
        REQ_F_ISREG_BIT,
@@ -393,6 +395,8 @@ enum {
        REQ_F_CQE32_INIT        = BIT(REQ_F_CQE32_INIT_BIT),
        /* recvmsg special flag, clear EPOLLIN */
        REQ_F_CLEAR_POLLIN      = BIT(REQ_F_CLEAR_POLLIN_BIT),
+       /* hashed into ->cancel_hash_locked, protected by ->uring_lock */
+       REQ_F_HASH_LOCKED       = BIT(REQ_F_HASH_LOCKED_BIT),
 };
 
 typedef void (*io_req_tw_func_t)(struct io_kiocb *req, bool *locked);
index c4edf87..9ae2982 100644 (file)
@@ -93,6 +93,32 @@ static void io_poll_req_delete(struct io_kiocb *req, struct io_ring_ctx *ctx)
        spin_unlock(lock);
 }
 
+static void io_poll_req_insert_locked(struct io_kiocb *req)
+{
+       struct io_hash_table *table = &req->ctx->cancel_table_locked;
+       u32 index = hash_long(req->cqe.user_data, table->hash_bits);
+
+       hlist_add_head(&req->hash_node, &table->hbs[index].list);
+}
+
+static void io_poll_tw_hash_eject(struct io_kiocb *req, bool *locked)
+{
+       struct io_ring_ctx *ctx = req->ctx;
+
+       if (req->flags & REQ_F_HASH_LOCKED) {
+               /*
+                * ->cancel_table_locked is protected by ->uring_lock in
+                * contrast to per bucket spinlocks. Likely, tctx_task_work()
+                * already grabbed the mutex for us, but there is a chance it
+                * failed.
+                */
+               io_tw_lock(ctx, locked);
+               hash_del(&req->hash_node);
+       } else {
+               io_poll_req_delete(req, ctx);
+       }
+}
+
 static void io_init_poll_iocb(struct io_poll *poll, __poll_t events,
                              wait_queue_func_t wake_func)
 {
@@ -217,7 +243,6 @@ static int io_poll_check_events(struct io_kiocb *req, bool *locked)
 
 static void io_poll_task_func(struct io_kiocb *req, bool *locked)
 {
-       struct io_ring_ctx *ctx = req->ctx;
        int ret;
 
        ret = io_poll_check_events(req, locked);
@@ -234,7 +259,8 @@ static void io_poll_task_func(struct io_kiocb *req, bool *locked)
        }
 
        io_poll_remove_entries(req);
-       io_poll_req_delete(req, ctx);
+       io_poll_tw_hash_eject(req, locked);
+
        io_req_set_res(req, req->cqe.res, 0);
        io_req_task_complete(req, locked);
 }
@@ -248,7 +274,7 @@ static void io_apoll_task_func(struct io_kiocb *req, bool *locked)
                return;
 
        io_poll_remove_entries(req);
-       io_poll_req_delete(req, req->ctx);
+       io_poll_tw_hash_eject(req, locked);
 
        if (!ret)
                io_req_task_submit(req, locked);
@@ -444,7 +470,10 @@ static int __io_arm_poll_handler(struct io_kiocb *req,
                return 0;
        }
 
-       io_poll_req_insert(req);
+       if (req->flags & REQ_F_HASH_LOCKED)
+               io_poll_req_insert_locked(req);
+       else
+               io_poll_req_insert(req);
 
        if (mask && (poll->events & EPOLLET)) {
                /* can't multishot if failed, just queue the event we've got */
@@ -485,6 +514,15 @@ int io_arm_poll_handler(struct io_kiocb *req, unsigned issue_flags)
        __poll_t mask = POLLPRI | POLLERR | EPOLLET;
        int ret;
 
+       /*
+        * apoll requests already grab the mutex to complete in the tw handler,
+        * so removal from the mutex-backed hash is free, use it by default.
+        */
+       if (issue_flags & IO_URING_F_UNLOCKED)
+               req->flags &= ~REQ_F_HASH_LOCKED;
+       else
+               req->flags |= REQ_F_HASH_LOCKED;
+
        if (!def->pollin && !def->pollout)
                return IO_APOLL_ABORTED;
        if (!file_can_poll(req->file))
@@ -534,13 +572,10 @@ int io_arm_poll_handler(struct io_kiocb *req, unsigned issue_flags)
        return IO_APOLL_OK;
 }
 
-/*
- * Returns true if we found and killed one or more poll requests
- */
-__cold bool io_poll_remove_all(struct io_ring_ctx *ctx, struct task_struct *tsk,
-                              bool cancel_all)
+static __cold bool io_poll_remove_all_table(struct task_struct *tsk,
+                                           struct io_hash_table *table,
+                                           bool cancel_all)
 {
-       struct io_hash_table *table = &ctx->cancel_table;
        unsigned nr_buckets = 1U << table->hash_bits;
        struct hlist_node *tmp;
        struct io_kiocb *req;
@@ -563,6 +598,17 @@ __cold bool io_poll_remove_all(struct io_ring_ctx *ctx, struct task_struct *tsk,
        return found;
 }
 
+/*
+ * Returns true if we found and killed one or more poll requests
+ */
+__cold bool io_poll_remove_all(struct io_ring_ctx *ctx, struct task_struct *tsk,
+                              bool cancel_all)
+       __must_hold(&ctx->uring_lock)
+{
+       return io_poll_remove_all_table(tsk, &ctx->cancel_table, cancel_all) |
+              io_poll_remove_all_table(tsk, &ctx->cancel_table_locked, cancel_all);
+}
+
 static struct io_kiocb *io_poll_find(struct io_ring_ctx *ctx, bool poll_only,
                                     struct io_cancel_data *cd,
                                     struct io_hash_table *table,
@@ -622,13 +668,15 @@ static struct io_kiocb *io_poll_file_find(struct io_ring_ctx *ctx,
        return NULL;
 }
 
-static bool io_poll_disarm(struct io_kiocb *req)
+static int io_poll_disarm(struct io_kiocb *req)
 {
+       if (!req)
+               return -ENOENT;
        if (!io_poll_get_ownership(req))
-               return false;
+               return -EALREADY;
        io_poll_remove_entries(req);
        hash_del(&req->hash_node);
-       return true;
+       return 0;
 }
 
 static int __io_poll_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd,
@@ -652,7 +700,16 @@ static int __io_poll_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd,
 int io_poll_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd,
                   unsigned issue_flags)
 {
-       return __io_poll_cancel(ctx, cd, &ctx->cancel_table);
+       int ret;
+
+       ret = __io_poll_cancel(ctx, cd, &ctx->cancel_table);
+       if (ret != -ENOENT)
+               return ret;
+
+       io_ring_submit_lock(ctx, issue_flags);
+       ret = __io_poll_cancel(ctx, cd, &ctx->cancel_table_locked);
+       io_ring_submit_unlock(ctx, issue_flags);
+       return ret;
 }
 
 static __poll_t io_poll_parse_events(const struct io_uring_sqe *sqe,
@@ -727,6 +784,16 @@ int io_poll_add(struct io_kiocb *req, unsigned int issue_flags)
 
        ipt.pt._qproc = io_poll_queue_proc;
 
+       /*
+        * If sqpoll or single issuer, there is no contention for ->uring_lock
+        * and we'll end up holding it in tw handlers anyway.
+        */
+       if (!(issue_flags & IO_URING_F_UNLOCKED) &&
+           (req->ctx->flags & (IORING_SETUP_SQPOLL | IORING_SETUP_SINGLE_ISSUER)))
+               req->flags |= REQ_F_HASH_LOCKED;
+       else
+               req->flags &= ~REQ_F_HASH_LOCKED;
+
        ret = __io_arm_poll_handler(req, poll, &ipt, poll->events);
        if (ret) {
                io_req_set_res(req, ret, 0);
@@ -751,20 +818,28 @@ int io_poll_remove(struct io_kiocb *req, unsigned int issue_flags)
        bool locked;
 
        preq = io_poll_find(ctx, true, &cd, &ctx->cancel_table, &bucket);
-       if (preq)
-               ret2 = io_poll_disarm(preq);
+       ret2 = io_poll_disarm(preq);
        if (bucket)
                spin_unlock(&bucket->lock);
-
-       if (!preq) {
-               ret = -ENOENT;
+       if (!ret2)
+               goto found;
+       if (ret2 != -ENOENT) {
+               ret = ret2;
                goto out;
        }
-       if (!ret2) {
-               ret = -EALREADY;
+
+       io_ring_submit_lock(ctx, issue_flags);
+       preq = io_poll_find(ctx, true, &cd, &ctx->cancel_table_locked, &bucket);
+       ret2 = io_poll_disarm(preq);
+       if (bucket)
+               spin_unlock(&bucket->lock);
+       io_ring_submit_unlock(ctx, issue_flags);
+       if (ret2) {
+               ret = ret2;
                goto out;
        }
 
+found:
        if (poll_update->update_events || poll_update->update_user_data) {
                /* only mask one event flags, keep behavior flags */
                if (poll_update->update_events) {