io_uring: fix !CONFIG_BLOCK compilation failure
[linux-2.6-microblaze.git] / fs / io_uring.c
index 0441701..65a17d5 100644 (file)
@@ -78,7 +78,6 @@
 #include <linux/task_work.h>
 #include <linux/pagemap.h>
 #include <linux/io_uring.h>
-#include <linux/freezer.h>
 
 #define CREATE_TRACE_POINTS
 #include <trace/events/io_uring.h>
@@ -258,12 +257,11 @@ enum {
 
 struct io_sq_data {
        refcount_t              refs;
+       atomic_t                park_pending;
        struct mutex            lock;
 
        /* ctx's that are using this sqd */
        struct list_head        ctx_list;
-       struct list_head        ctx_new_list;
-       struct mutex            ctx_lock;
 
        struct task_struct      *thread;
        struct wait_queue_head  wait;
@@ -271,11 +269,11 @@ struct io_sq_data {
        unsigned                sq_thread_idle;
        int                     sq_cpu;
        pid_t                   task_pid;
+       pid_t                   task_tgid;
 
        unsigned long           state;
-       struct completion       startup;
-       struct completion       completion;
        struct completion       exited;
+       struct callback_head    *park_task_work;
 };
 
 #define IO_IOPOLL_BATCH                        8
@@ -336,7 +334,6 @@ struct io_ring_ctx {
                unsigned int            drain_next: 1;
                unsigned int            eventfd_async: 1;
                unsigned int            restricted: 1;
-               unsigned int            sqo_exec: 1;
 
                /*
                 * Ring buffer of indices into array of io_uring_sqe, which is
@@ -380,6 +377,7 @@ struct io_ring_ctx {
        /* Only used for accounting purposes */
        struct mm_struct        *mm_account;
 
+       const struct cred       *sq_creds;      /* cred used for __io_sq_thread() */
        struct io_sq_data       *sq_data;       /* if using sq thread polling */
 
        struct wait_queue_head  sqo_sq_wait;
@@ -400,15 +398,15 @@ struct io_ring_ctx {
        struct user_struct      *user;
 
        struct completion       ref_comp;
-       struct completion       sq_thread_comp;
 
 #if defined(CONFIG_UNIX)
        struct socket           *ring_sock;
 #endif
 
-       struct idr              io_buffer_idr;
+       struct xarray           io_buffers;
 
-       struct idr              personality_idr;
+       struct xarray           personalities;
+       u32                     pers_next;
 
        struct {
                unsigned                cached_cq_tail;
@@ -454,6 +452,23 @@ struct io_ring_ctx {
 
        /* Keep this last, we don't need it for the fast path */
        struct work_struct              exit_work;
+       struct list_head                tctx_list;
+};
+
+struct io_uring_task {
+       /* submission side */
+       struct xarray           xa;
+       struct wait_queue_head  wait;
+       const struct io_ring_ctx *last;
+       struct io_wq            *io_wq;
+       struct percpu_counter   inflight;
+       atomic_t                in_idle;
+       bool                    sqpoll;
+
+       spinlock_t              task_lock;
+       struct io_wq_work_list  task_list;
+       unsigned long           task_state;
+       struct callback_head    task_work;
 };
 
 /*
@@ -682,6 +697,7 @@ enum {
        REQ_F_NO_FILE_TABLE_BIT,
        REQ_F_LTIMEOUT_ACTIVE_BIT,
        REQ_F_COMPLETE_INLINE_BIT,
+       REQ_F_REISSUE_BIT,
 
        /* not a real bit, just to check we're not overflowing the space */
        __REQ_F_LAST_BIT,
@@ -725,6 +741,8 @@ enum {
        REQ_F_LTIMEOUT_ACTIVE   = BIT(REQ_F_LTIMEOUT_ACTIVE_BIT),
        /* completion is deferred through io_comp_state */
        REQ_F_COMPLETE_INLINE   = BIT(REQ_F_COMPLETE_INLINE_BIT),
+       /* caller should reissue async */
+       REQ_F_REISSUE           = BIT(REQ_F_REISSUE_BIT),
 };
 
 struct async_poll {
@@ -805,6 +823,12 @@ struct io_kiocb {
        struct io_wq_work               work;
 };
 
+struct io_tctx_node {
+       struct list_head        ctx_node;
+       struct task_struct      *task;
+       struct io_ring_ctx      *ctx;
+};
+
 struct io_defer_entry {
        struct list_head        list;
        struct io_kiocb         *req;
@@ -979,6 +1003,8 @@ static const struct io_op_def io_op_defs[] = {
        [IORING_OP_UNLINKAT] = {},
 };
 
+static bool io_disarm_next(struct io_kiocb *req);
+static void io_uring_del_task_file(unsigned long index);
 static void io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
                                         struct task_struct *task,
                                         struct files_struct *files);
@@ -1071,8 +1097,6 @@ static bool io_match_task(struct io_kiocb *head,
        io_for_each_link(req, head) {
                if (req->flags & REQ_F_INFLIGHT)
                        return true;
-               if (req->task->files == files)
-                       return true;
        }
        return false;
 }
@@ -1129,9 +1153,8 @@ static struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p)
        init_waitqueue_head(&ctx->cq_wait);
        INIT_LIST_HEAD(&ctx->cq_overflow_list);
        init_completion(&ctx->ref_comp);
-       init_completion(&ctx->sq_thread_comp);
-       idr_init(&ctx->io_buffer_idr);
-       idr_init(&ctx->personality_idr);
+       xa_init_flags(&ctx->io_buffers, XA_FLAGS_ALLOC1);
+       xa_init_flags(&ctx->personalities, XA_FLAGS_ALLOC1);
        mutex_init(&ctx->uring_lock);
        init_waitqueue_head(&ctx->wait);
        spin_lock_init(&ctx->completion_lock);
@@ -1144,6 +1167,7 @@ static struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p)
        INIT_LIST_HEAD(&ctx->rsrc_ref_list);
        INIT_DELAYED_WORK(&ctx->rsrc_put_work, io_rsrc_put_work);
        init_llist_head(&ctx->rsrc_put_llist);
+       INIT_LIST_HEAD(&ctx->tctx_list);
        INIT_LIST_HEAD(&ctx->submit_state.comp.free_list);
        INIT_LIST_HEAD(&ctx->submit_state.comp.locked_free_list);
        return ctx;
@@ -1183,13 +1207,16 @@ static void io_prep_async_work(struct io_kiocb *req)
        const struct io_op_def *def = &io_op_defs[req->opcode];
        struct io_ring_ctx *ctx = req->ctx;
 
+       if (!req->work.creds)
+               req->work.creds = get_current_cred();
+
        if (req->flags & REQ_F_FORCE_ASYNC)
                req->work.flags |= IO_WQ_WORK_CONCURRENT;
 
        if (req->flags & REQ_F_ISREG) {
                if (def->hash_reg_file || (ctx->flags & IORING_SETUP_IOPOLL))
                        io_wq_hash_work(&req->work, file_inode(req->file));
-       } else {
+       } else if (!req->file || !S_ISBLK(file_inode(req->file)->i_mode)) {
                if (def->unbound_nonreg_file)
                        req->work.flags |= IO_WQ_WORK_UNBOUND;
        }
@@ -1212,16 +1239,16 @@ static void io_queue_async_work(struct io_kiocb *req)
        BUG_ON(!tctx);
        BUG_ON(!tctx->io_wq);
 
-       trace_io_uring_queue_async_work(ctx, io_wq_is_hashed(&req->work), req,
-                                       &req->work, req->flags);
        /* init ->work of the whole link before punting */
        io_prep_async_link(req);
+       trace_io_uring_queue_async_work(ctx, io_wq_is_hashed(&req->work), req,
+                                       &req->work, req->flags);
        io_wq_enqueue(tctx->io_wq, &req->work);
        if (link)
                io_queue_linked_timeout(link);
 }
 
-static void io_kill_timeout(struct io_kiocb *req)
+static void io_kill_timeout(struct io_kiocb *req, int status)
 {
        struct io_timeout_data *io = req->async_data;
        int ret;
@@ -1231,31 +1258,11 @@ static void io_kill_timeout(struct io_kiocb *req)
                atomic_set(&req->ctx->cq_timeouts,
                        atomic_read(&req->ctx->cq_timeouts) + 1);
                list_del_init(&req->timeout.list);
-               io_cqring_fill_event(req, 0);
+               io_cqring_fill_event(req, status);
                io_put_req_deferred(req, 1);
        }
 }
 
-/*
- * Returns true if we found and killed one or more timeouts
- */
-static bool io_kill_timeouts(struct io_ring_ctx *ctx, struct task_struct *tsk,
-                            struct files_struct *files)
-{
-       struct io_kiocb *req, *tmp;
-       int canceled = 0;
-
-       spin_lock_irq(&ctx->completion_lock);
-       list_for_each_entry_safe(req, tmp, &ctx->timeout_list, timeout.list) {
-               if (io_match_task(req, tsk, files)) {
-                       io_kill_timeout(req);
-                       canceled++;
-               }
-       }
-       spin_unlock_irq(&ctx->completion_lock);
-       return canceled != 0;
-}
-
 static void __io_queue_deferred(struct io_ring_ctx *ctx)
 {
        do {
@@ -1300,7 +1307,7 @@ static void io_flush_timeouts(struct io_ring_ctx *ctx)
                        break;
 
                list_del_init(&req->timeout.list);
-               io_kill_timeout(req);
+               io_kill_timeout(req, 0);
        } while (!list_empty(&ctx->timeout_list));
 
        ctx->cq_last_tm_flush = seq;
@@ -1514,15 +1521,14 @@ static void io_cqring_fill_event(struct io_kiocb *req, long res)
        __io_cqring_fill_event(req, res, 0);
 }
 
-static inline void io_req_complete_post(struct io_kiocb *req, long res,
-                                       unsigned int cflags)
+static void io_req_complete_post(struct io_kiocb *req, long res,
+                                unsigned int cflags)
 {
        struct io_ring_ctx *ctx = req->ctx;
        unsigned long flags;
 
        spin_lock_irqsave(&ctx->completion_lock, flags);
        __io_cqring_fill_event(req, res, cflags);
-       io_commit_cqring(ctx);
        /*
         * If we're the last reference to this request, add to our locked
         * free_list cache.
@@ -1530,17 +1536,27 @@ static inline void io_req_complete_post(struct io_kiocb *req, long res,
        if (refcount_dec_and_test(&req->refs)) {
                struct io_comp_state *cs = &ctx->submit_state.comp;
 
+               if (req->flags & (REQ_F_LINK | REQ_F_HARDLINK)) {
+                       if (req->flags & (REQ_F_LINK_TIMEOUT | REQ_F_FAIL_LINK))
+                               io_disarm_next(req);
+                       if (req->link) {
+                               io_req_task_queue(req->link);
+                               req->link = NULL;
+                       }
+               }
                io_dismantle_req(req);
                io_put_task(req->task, 1);
                list_add(&req->compl.list, &cs->locked_free_list);
                cs->locked_free_nr++;
-       } else
-               req = NULL;
+       } else {
+               if (!percpu_ref_tryget(&ctx->refs))
+                       req = NULL;
+       }
+       io_commit_cqring(ctx);
        spin_unlock_irqrestore(&ctx->completion_lock, flags);
 
-       io_cqring_ev_posted(ctx);
        if (req) {
-               io_queue_next(req);
+               io_cqring_ev_posted(ctx);
                percpu_ref_put(&ctx->refs);
        }
 }
@@ -1648,6 +1664,10 @@ static void io_dismantle_req(struct io_kiocb *req)
                io_put_file(req, req->file, (req->flags & REQ_F_FIXED_FILE));
        if (req->fixed_rsrc_refs)
                percpu_ref_put(req->fixed_rsrc_refs);
+       if (req->work.creds) {
+               put_cred(req->work.creds);
+               req->work.creds = NULL;
+       }
 
        if (req->flags & REQ_F_INFLIGHT) {
                struct io_ring_ctx *ctx = req->ctx;
@@ -1690,15 +1710,11 @@ static inline void io_remove_next_linked(struct io_kiocb *req)
        nxt->link = NULL;
 }
 
-static void io_kill_linked_timeout(struct io_kiocb *req)
+static bool io_kill_linked_timeout(struct io_kiocb *req)
+       __must_hold(&req->ctx->completion_lock)
 {
-       struct io_ring_ctx *ctx = req->ctx;
-       struct io_kiocb *link;
+       struct io_kiocb *link = req->link;
        bool cancelled = false;
-       unsigned long flags;
-
-       spin_lock_irqsave(&ctx->completion_lock, flags);
-       link = req->link;
 
        /*
         * Can happen if a linked timeout fired and link had been like
@@ -1713,50 +1729,48 @@ static void io_kill_linked_timeout(struct io_kiocb *req)
                ret = hrtimer_try_to_cancel(&io->timer);
                if (ret != -1) {
                        io_cqring_fill_event(link, -ECANCELED);
-                       io_commit_cqring(ctx);
+                       io_put_req_deferred(link, 1);
                        cancelled = true;
                }
        }
        req->flags &= ~REQ_F_LINK_TIMEOUT;
-       spin_unlock_irqrestore(&ctx->completion_lock, flags);
-
-       if (cancelled) {
-               io_cqring_ev_posted(ctx);
-               io_put_req(link);
-       }
+       return cancelled;
 }
 
-
 static void io_fail_links(struct io_kiocb *req)
+       __must_hold(&req->ctx->completion_lock)
 {
-       struct io_kiocb *link, *nxt;
-       struct io_ring_ctx *ctx = req->ctx;
-       unsigned long flags;
+       struct io_kiocb *nxt, *link = req->link;
 
-       spin_lock_irqsave(&ctx->completion_lock, flags);
-       link = req->link;
        req->link = NULL;
-
        while (link) {
                nxt = link->link;
                link->link = NULL;
 
                trace_io_uring_fail_link(req, link);
                io_cqring_fill_event(link, -ECANCELED);
-
                io_put_req_deferred(link, 2);
                link = nxt;
        }
-       io_commit_cqring(ctx);
-       spin_unlock_irqrestore(&ctx->completion_lock, flags);
+}
 
-       io_cqring_ev_posted(ctx);
+static bool io_disarm_next(struct io_kiocb *req)
+       __must_hold(&req->ctx->completion_lock)
+{
+       bool posted = false;
+
+       if (likely(req->flags & REQ_F_LINK_TIMEOUT))
+               posted = io_kill_linked_timeout(req);
+       if (unlikely(req->flags & REQ_F_FAIL_LINK)) {
+               posted |= (req->link != NULL);
+               io_fail_links(req);
+       }
+       return posted;
 }
 
 static struct io_kiocb *__io_req_find_next(struct io_kiocb *req)
 {
-       if (req->flags & REQ_F_LINK_TIMEOUT)
-               io_kill_linked_timeout(req);
+       struct io_kiocb *nxt;
 
        /*
         * If LINK is set, we have dependent requests in this chain. If we
@@ -1764,14 +1778,22 @@ static struct io_kiocb *__io_req_find_next(struct io_kiocb *req)
         * dependencies to the next request. In case of failure, fail the rest
         * of the chain.
         */
-       if (likely(!(req->flags & REQ_F_FAIL_LINK))) {
-               struct io_kiocb *nxt = req->link;
+       if (req->flags & (REQ_F_LINK_TIMEOUT | REQ_F_FAIL_LINK)) {
+               struct io_ring_ctx *ctx = req->ctx;
+               unsigned long flags;
+               bool posted;
 
-               req->link = NULL;
-               return nxt;
+               spin_lock_irqsave(&ctx->completion_lock, flags);
+               posted = io_disarm_next(req);
+               if (posted)
+                       io_commit_cqring(req->ctx);
+               spin_unlock_irqrestore(&ctx->completion_lock, flags);
+               if (posted)
+                       io_cqring_ev_posted(ctx);
        }
-       io_fail_links(req);
-       return NULL;
+       nxt = req->link;
+       req->link = NULL;
+       return nxt;
 }
 
 static inline struct io_kiocb *io_req_find_next(struct io_kiocb *req)
@@ -1904,17 +1926,44 @@ static int io_req_task_work_add(struct io_kiocb *req)
        return ret;
 }
 
-static void io_req_task_work_add_fallback(struct io_kiocb *req,
-                                         task_work_func_t cb)
+static bool io_run_task_work_head(struct callback_head **work_head)
+{
+       struct callback_head *work, *next;
+       bool executed = false;
+
+       do {
+               work = xchg(work_head, NULL);
+               if (!work)
+                       break;
+
+               do {
+                       next = work->next;
+                       work->func(work);
+                       work = next;
+                       cond_resched();
+               } while (work);
+               executed = true;
+       } while (1);
+
+       return executed;
+}
+
+static void io_task_work_add_head(struct callback_head **work_head,
+                                 struct callback_head *task_work)
 {
-       struct io_ring_ctx *ctx = req->ctx;
        struct callback_head *head;
 
-       init_task_work(&req->task_work, cb);
        do {
-               head = READ_ONCE(ctx->exit_task_work);
-               req->task_work.next = head;
-       } while (cmpxchg(&ctx->exit_task_work, head, &req->task_work) != head);
+               head = READ_ONCE(*work_head);
+               task_work->next = head;
+       } while (cmpxchg(work_head, head, task_work) != head);
+}
+
+static void io_req_task_work_add_fallback(struct io_kiocb *req,
+                                         task_work_func_t cb)
+{
+       init_task_work(&req->task_work, cb);
+       io_task_work_add_head(&req->ctx->exit_task_work, &req->task_work);
 }
 
 static void __io_req_task_cancel(struct io_kiocb *req, int error)
@@ -2430,6 +2479,11 @@ static bool io_rw_should_reissue(struct io_kiocb *req)
                return false;
        return true;
 }
+#else
+static bool io_rw_should_reissue(struct io_kiocb *req)
+{
+       return false;
+}
 #endif
 
 static bool io_rw_reissue(struct io_kiocb *req)
@@ -2455,13 +2509,14 @@ static void __io_complete_rw(struct io_kiocb *req, long res, long res2,
 {
        int cflags = 0;
 
-       if ((res == -EAGAIN || res == -EOPNOTSUPP) && io_rw_reissue(req))
+       if (req->rw.kiocb.ki_flags & IOCB_WRITE)
+               kiocb_end_write(req);
+       if ((res == -EAGAIN || res == -EOPNOTSUPP) && io_rw_should_reissue(req)) {
+               req->flags |= REQ_F_REISSUE;
                return;
+       }
        if (res != req->result)
                req_set_fail_links(req);
-
-       if (req->rw.kiocb.ki_flags & IOCB_WRITE)
-               kiocb_end_write(req);
        if (req->flags & REQ_F_BUFFER_SELECTED)
                cflags = io_put_rw_kbuf(req);
        __io_req_complete(req, issue_flags, res, cflags);
@@ -2822,7 +2877,7 @@ static struct io_buffer *io_buffer_select(struct io_kiocb *req, size_t *len,
 
        lockdep_assert_held(&req->ctx->uring_lock);
 
-       head = idr_find(&req->ctx->io_buffer_idr, bgid);
+       head = xa_load(&req->ctx->io_buffers, bgid);
        if (head) {
                if (!list_empty(&head->list)) {
                        kbuf = list_last_entry(&head->list, struct io_buffer,
@@ -2830,7 +2885,7 @@ static struct io_buffer *io_buffer_select(struct io_kiocb *req, size_t *len,
                        list_del(&kbuf->list);
                } else {
                        kbuf = head;
-                       idr_remove(&req->ctx->io_buffer_idr, bgid);
+                       xa_erase(&req->ctx->io_buffers, bgid);
                }
                if (*len > kbuf->len)
                        *len = kbuf->len;
@@ -3238,11 +3293,7 @@ static int io_read(struct io_kiocb *req, unsigned int issue_flags)
 
        ret = io_iter_do_read(req, iter);
 
-       if (ret == -EIOCBQUEUED) {
-               if (req->async_data)
-                       iov_iter_revert(iter, io_size - iov_iter_count(iter));
-               goto out_free;
-       } else if (ret == -EAGAIN) {
+       if (ret == -EAGAIN || (req->flags & REQ_F_REISSUE)) {
                /* IOPOLL retry should happen for io-wq threads */
                if (!force_nonblock && !(req->ctx->flags & IORING_SETUP_IOPOLL))
                        goto done;
@@ -3252,6 +3303,8 @@ static int io_read(struct io_kiocb *req, unsigned int issue_flags)
                /* some cases will consume bytes even on error returns */
                iov_iter_revert(iter, io_size - iov_iter_count(iter));
                ret = 0;
+       } else if (ret == -EIOCBQUEUED) {
+               goto out_free;
        } else if (ret <= 0 || ret == io_size || !force_nonblock ||
                   (req->flags & REQ_F_NOWAIT) || !(req->flags & REQ_F_ISREG)) {
                /* read all, failed, already did sync or don't want to retry */
@@ -3286,6 +3339,7 @@ static int io_read(struct io_kiocb *req, unsigned int issue_flags)
                if (ret == -EIOCBQUEUED)
                        return 0;
                /* we got some bytes, but not all. retry. */
+               kiocb->ki_flags &= ~IOCB_WAITQ;
        } while (ret > 0 && ret < io_size);
 done:
        kiocb_done(kiocb, ret, issue_flags);
@@ -3363,6 +3417,9 @@ static int io_write(struct io_kiocb *req, unsigned int issue_flags)
        else
                ret2 = -EINVAL;
 
+       if (req->flags & REQ_F_REISSUE)
+               ret2 = -EAGAIN;
+
        /*
         * Raw bdev writes will return -EOPNOTSUPP for IOCB_NOWAIT. Just
         * retry them without IOCB_NOWAIT.
@@ -3372,8 +3429,6 @@ static int io_write(struct io_kiocb *req, unsigned int issue_flags)
        /* no retry on NONBLOCK nor RWF_NOWAIT */
        if (ret2 == -EAGAIN && (req->flags & REQ_F_NOWAIT))
                goto done;
-       if (ret2 == -EIOCBQUEUED && req->async_data)
-               iov_iter_revert(iter, io_size - iov_iter_count(iter));
        if (!force_nonblock || ret2 != -EAGAIN) {
                /* IOPOLL retry should happen for io-wq threads */
                if ((req->ctx->flags & IORING_SETUP_IOPOLL) && ret2 == -EAGAIN)
@@ -3827,7 +3882,7 @@ err:
 
 static int io_openat(struct io_kiocb *req, unsigned int issue_flags)
 {
-       return io_openat2(req, issue_flags & IO_URING_F_NONBLOCK);
+       return io_openat2(req, issue_flags);
 }
 
 static int io_remove_buffers_prep(struct io_kiocb *req,
@@ -3870,7 +3925,7 @@ static int __io_remove_buffers(struct io_ring_ctx *ctx, struct io_buffer *buf,
        }
        i++;
        kfree(buf);
-       idr_remove(&ctx->io_buffer_idr, bgid);
+       xa_erase(&ctx->io_buffers, bgid);
 
        return i;
 }
@@ -3888,7 +3943,7 @@ static int io_remove_buffers(struct io_kiocb *req, unsigned int issue_flags)
        lockdep_assert_held(&ctx->uring_lock);
 
        ret = -ENOENT;
-       head = idr_find(&ctx->io_buffer_idr, p->bgid);
+       head = xa_load(&ctx->io_buffers, p->bgid);
        if (head)
                ret = __io_remove_buffers(ctx, head, p->bgid, p->nbufs);
        if (ret < 0)
@@ -3908,6 +3963,7 @@ static int io_remove_buffers(struct io_kiocb *req, unsigned int issue_flags)
 static int io_provide_buffers_prep(struct io_kiocb *req,
                                   const struct io_uring_sqe *sqe)
 {
+       unsigned long size;
        struct io_provide_buf *p = &req->pbuf;
        u64 tmp;
 
@@ -3921,7 +3977,8 @@ static int io_provide_buffers_prep(struct io_kiocb *req,
        p->addr = READ_ONCE(sqe->addr);
        p->len = READ_ONCE(sqe->len);
 
-       if (!access_ok(u64_to_user_ptr(p->addr), (p->len * p->nbufs)))
+       size = (unsigned long)p->len * p->nbufs;
+       if (!access_ok(u64_to_user_ptr(p->addr), size))
                return -EFAULT;
 
        p->bgid = READ_ONCE(sqe->buf_group);
@@ -3971,21 +4028,14 @@ static int io_provide_buffers(struct io_kiocb *req, unsigned int issue_flags)
 
        lockdep_assert_held(&ctx->uring_lock);
 
-       list = head = idr_find(&ctx->io_buffer_idr, p->bgid);
+       list = head = xa_load(&ctx->io_buffers, p->bgid);
 
        ret = io_add_buffers(p, &head);
-       if (ret < 0)
-               goto out;
-
-       if (!list) {
-               ret = idr_alloc(&ctx->io_buffer_idr, head, p->bgid, p->bgid + 1,
-                                       GFP_KERNEL);
-               if (ret < 0) {
+       if (ret >= 0 && !list) {
+               ret = xa_insert(&ctx->io_buffers, p->bgid, head, GFP_KERNEL);
+               if (ret < 0)
                        __io_remove_buffers(ctx, head, p->bgid, -1U);
-                       goto out;
-               }
        }
-out:
        if (ret < 0)
                req_set_fail_links(req);
 
@@ -4323,6 +4373,7 @@ static int io_sendmsg(struct io_kiocb *req, unsigned int issue_flags)
        struct io_async_msghdr iomsg, *kmsg;
        struct socket *sock;
        unsigned flags;
+       int min_ret = 0;
        int ret;
 
        sock = sock_from_file(req->file);
@@ -4337,12 +4388,15 @@ static int io_sendmsg(struct io_kiocb *req, unsigned int issue_flags)
                kmsg = &iomsg;
        }
 
-       flags = req->sr_msg.msg_flags;
+       flags = req->sr_msg.msg_flags | MSG_NOSIGNAL;
        if (flags & MSG_DONTWAIT)
                req->flags |= REQ_F_NOWAIT;
        else if (issue_flags & IO_URING_F_NONBLOCK)
                flags |= MSG_DONTWAIT;
 
+       if (flags & MSG_WAITALL)
+               min_ret = iov_iter_count(&kmsg->msg.msg_iter);
+
        ret = __sys_sendmsg_sock(sock, &kmsg->msg, flags);
        if ((issue_flags & IO_URING_F_NONBLOCK) && ret == -EAGAIN)
                return io_setup_async_msg(req, kmsg);
@@ -4353,7 +4407,7 @@ static int io_sendmsg(struct io_kiocb *req, unsigned int issue_flags)
        if (kmsg->free_iov)
                kfree(kmsg->free_iov);
        req->flags &= ~REQ_F_NEED_CLEANUP;
-       if (ret < 0)
+       if (ret < min_ret)
                req_set_fail_links(req);
        __io_req_complete(req, issue_flags, ret, 0);
        return 0;
@@ -4366,6 +4420,7 @@ static int io_send(struct io_kiocb *req, unsigned int issue_flags)
        struct iovec iov;
        struct socket *sock;
        unsigned flags;
+       int min_ret = 0;
        int ret;
 
        sock = sock_from_file(req->file);
@@ -4381,12 +4436,15 @@ static int io_send(struct io_kiocb *req, unsigned int issue_flags)
        msg.msg_controllen = 0;
        msg.msg_namelen = 0;
 
-       flags = req->sr_msg.msg_flags;
+       flags = req->sr_msg.msg_flags | MSG_NOSIGNAL;
        if (flags & MSG_DONTWAIT)
                req->flags |= REQ_F_NOWAIT;
        else if (issue_flags & IO_URING_F_NONBLOCK)
                flags |= MSG_DONTWAIT;
 
+       if (flags & MSG_WAITALL)
+               min_ret = iov_iter_count(&msg.msg_iter);
+
        msg.msg_flags = flags;
        ret = sock_sendmsg(sock, &msg);
        if ((issue_flags & IO_URING_F_NONBLOCK) && ret == -EAGAIN)
@@ -4394,7 +4452,7 @@ static int io_send(struct io_kiocb *req, unsigned int issue_flags)
        if (ret == -ERESTARTSYS)
                ret = -EINTR;
 
-       if (ret < 0)
+       if (ret < min_ret)
                req_set_fail_links(req);
        __io_req_complete(req, issue_flags, ret, 0);
        return 0;
@@ -4546,6 +4604,7 @@ static int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
        struct socket *sock;
        struct io_buffer *kbuf;
        unsigned flags;
+       int min_ret = 0;
        int ret, cflags = 0;
        bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
 
@@ -4571,12 +4630,15 @@ static int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
                                1, req->sr_msg.len);
        }
 
-       flags = req->sr_msg.msg_flags;
+       flags = req->sr_msg.msg_flags | MSG_NOSIGNAL;
        if (flags & MSG_DONTWAIT)
                req->flags |= REQ_F_NOWAIT;
        else if (force_nonblock)
                flags |= MSG_DONTWAIT;
 
+       if (flags & MSG_WAITALL)
+               min_ret = iov_iter_count(&kmsg->msg.msg_iter);
+
        ret = __sys_recvmsg_sock(sock, &kmsg->msg, req->sr_msg.umsg,
                                        kmsg->uaddr, flags);
        if (force_nonblock && ret == -EAGAIN)
@@ -4590,7 +4652,7 @@ static int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
        if (kmsg->free_iov)
                kfree(kmsg->free_iov);
        req->flags &= ~REQ_F_NEED_CLEANUP;
-       if (ret < 0)
+       if (ret < min_ret || ((flags & MSG_WAITALL) && (kmsg->msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))))
                req_set_fail_links(req);
        __io_req_complete(req, issue_flags, ret, cflags);
        return 0;
@@ -4605,6 +4667,7 @@ static int io_recv(struct io_kiocb *req, unsigned int issue_flags)
        struct socket *sock;
        struct iovec iov;
        unsigned flags;
+       int min_ret = 0;
        int ret, cflags = 0;
        bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
 
@@ -4630,12 +4693,15 @@ static int io_recv(struct io_kiocb *req, unsigned int issue_flags)
        msg.msg_iocb = NULL;
        msg.msg_flags = 0;
 
-       flags = req->sr_msg.msg_flags;
+       flags = req->sr_msg.msg_flags | MSG_NOSIGNAL;
        if (flags & MSG_DONTWAIT)
                req->flags |= REQ_F_NOWAIT;
        else if (force_nonblock)
                flags |= MSG_DONTWAIT;
 
+       if (flags & MSG_WAITALL)
+               min_ret = iov_iter_count(&msg.msg_iter);
+
        ret = sock_recvmsg(sock, &msg, flags);
        if (force_nonblock && ret == -EAGAIN)
                return -EAGAIN;
@@ -4644,7 +4710,7 @@ static int io_recv(struct io_kiocb *req, unsigned int issue_flags)
 out_free:
        if (req->flags & REQ_F_BUFFER_SELECTED)
                cflags = io_put_recv_kbuf(req);
-       if (ret < 0)
+       if (ret < min_ret || ((flags & MSG_WAITALL) && (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))))
                req_set_fail_links(req);
        __io_req_complete(req, issue_flags, ret, cflags);
        return 0;
@@ -4741,7 +4807,6 @@ static int io_connect(struct io_kiocb *req, unsigned int issue_flags)
                        ret = -ENOMEM;
                        goto out;
                }
-               io = req->async_data;
                memcpy(req->async_data, &__io, sizeof(__io));
                return -EAGAIN;
        }
@@ -5504,7 +5569,8 @@ static int io_timeout_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe,
 
        data->mode = io_translate_timeout_mode(flags);
        hrtimer_init(&data->timer, CLOCK_MONOTONIC, data->mode);
-       io_req_track_inflight(req);
+       if (is_timeout_link)
+               io_req_track_inflight(req);
        return 0;
 }
 
@@ -5558,22 +5624,30 @@ add:
        return 0;
 }
 
+struct io_cancel_data {
+       struct io_ring_ctx *ctx;
+       u64 user_data;
+};
+
 static bool io_cancel_cb(struct io_wq_work *work, void *data)
 {
        struct io_kiocb *req = container_of(work, struct io_kiocb, work);
+       struct io_cancel_data *cd = data;
 
-       return req->user_data == (unsigned long) data;
+       return req->ctx == cd->ctx && req->user_data == cd->user_data;
 }
 
-static int io_async_cancel_one(struct io_uring_task *tctx, void *sqe_addr)
+static int io_async_cancel_one(struct io_uring_task *tctx, u64 user_data,
+                              struct io_ring_ctx *ctx)
 {
+       struct io_cancel_data data = { .ctx = ctx, .user_data = user_data, };
        enum io_wq_cancel cancel_ret;
        int ret = 0;
 
-       if (!tctx->io_wq)
+       if (!tctx || !tctx->io_wq)
                return -ENOENT;
 
-       cancel_ret = io_wq_cancel_cb(tctx->io_wq, io_cancel_cb, sqe_addr, false);
+       cancel_ret = io_wq_cancel_cb(tctx->io_wq, io_cancel_cb, &data, false);
        switch (cancel_ret) {
        case IO_WQ_CANCEL_OK:
                ret = 0;
@@ -5596,8 +5670,7 @@ static void io_async_find_and_cancel(struct io_ring_ctx *ctx,
        unsigned long flags;
        int ret;
 
-       ret = io_async_cancel_one(req->task->io_uring,
-                                       (void *) (unsigned long) sqe_addr);
+       ret = io_async_cancel_one(req->task->io_uring, sqe_addr, ctx);
        if (ret != -ENOENT) {
                spin_lock_irqsave(&ctx->completion_lock, flags);
                goto done;
@@ -5638,8 +5711,47 @@ static int io_async_cancel_prep(struct io_kiocb *req,
 static int io_async_cancel(struct io_kiocb *req, unsigned int issue_flags)
 {
        struct io_ring_ctx *ctx = req->ctx;
+       u64 sqe_addr = req->cancel.addr;
+       struct io_tctx_node *node;
+       int ret;
+
+       /* tasks should wait for their io-wq threads, so safe w/o sync */
+       ret = io_async_cancel_one(req->task->io_uring, sqe_addr, ctx);
+       spin_lock_irq(&ctx->completion_lock);
+       if (ret != -ENOENT)
+               goto done;
+       ret = io_timeout_cancel(ctx, sqe_addr);
+       if (ret != -ENOENT)
+               goto done;
+       ret = io_poll_cancel(ctx, sqe_addr);
+       if (ret != -ENOENT)
+               goto done;
+       spin_unlock_irq(&ctx->completion_lock);
+
+       /* slow path, try all io-wq's */
+       io_ring_submit_lock(ctx, !(issue_flags & IO_URING_F_NONBLOCK));
+       ret = -ENOENT;
+       list_for_each_entry(node, &ctx->tctx_list, ctx_node) {
+               struct io_uring_task *tctx = node->task->io_uring;
+
+               if (!tctx || !tctx->io_wq)
+                       continue;
+               ret = io_async_cancel_one(tctx, req->cancel.addr, ctx);
+               if (ret != -ENOENT)
+                       break;
+       }
+       io_ring_submit_unlock(ctx, !(issue_flags & IO_URING_F_NONBLOCK));
 
-       io_async_find_and_cancel(ctx, req, req->cancel.addr, 0);
+       spin_lock_irq(&ctx->completion_lock);
+done:
+       io_cqring_fill_event(req, ret);
+       io_commit_cqring(ctx);
+       spin_unlock_irq(&ctx->completion_lock);
+       io_cqring_ev_posted(ctx);
+
+       if (ret < 0)
+               req_set_fail_links(req);
+       io_put_req(req);
        return 0;
 }
 
@@ -5915,18 +6027,8 @@ static int io_issue_sqe(struct io_kiocb *req, unsigned int issue_flags)
        const struct cred *creds = NULL;
        int ret;
 
-       if (req->work.personality) {
-               const struct cred *new_creds;
-
-               if (!(issue_flags & IO_URING_F_NONBLOCK))
-                       mutex_lock(&ctx->uring_lock);
-               new_creds = idr_find(&ctx->personality_idr, req->work.personality);
-               if (!(issue_flags & IO_URING_F_NONBLOCK))
-                       mutex_unlock(&ctx->uring_lock);
-               if (!new_creds)
-                       return -EINVAL;
-               creds = override_creds(new_creds);
-       }
+       if (req->work.creds && req->work.creds != current_cred())
+               creds = override_creds(req->work.creds);
 
        switch (req->opcode) {
        case IORING_OP_NOP:
@@ -6071,6 +6173,7 @@ static void io_wq_submit_work(struct io_wq_work *work)
                ret = -ECANCELED;
 
        if (!ret) {
+               req->flags &= ~REQ_F_REISSUE;
                do {
                        ret = io_issue_sqe(req, 0);
                        /*
@@ -6146,7 +6249,6 @@ static enum hrtimer_restart io_link_timeout_fn(struct hrtimer *timer)
        spin_unlock_irqrestore(&ctx->completion_lock, flags);
 
        if (prev) {
-               req_set_fail_links(prev);
                io_async_find_and_cancel(ctx, req, prev->user_data, -ETIME);
                io_put_req_deferred(prev, 1);
        } else {
@@ -6290,7 +6392,7 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
 {
        struct io_submit_state *state;
        unsigned int sqe_flags;
-       int ret = 0;
+       int personality, ret = 0;
 
        req->opcode = READ_ONCE(sqe->opcode);
        /* same numerical values with corresponding REQ_F_*, safe to copy */
@@ -6305,6 +6407,9 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
        refcount_set(&req->refs, 2);
        req->task = current;
        req->result = 0;
+       req->work.list.next = NULL;
+       req->work.creds = NULL;
+       req->work.flags = 0;
 
        /* enforce forwards compatibility on users */
        if (unlikely(sqe_flags & ~SQE_VALID_FLAGS)) {
@@ -6322,9 +6427,13 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
            !io_op_defs[req->opcode].buffer_select)
                return -EOPNOTSUPP;
 
-       req->work.list.next = NULL;
-       req->work.flags = 0;
-       req->work.personality = READ_ONCE(sqe->personality);
+       personality = READ_ONCE(sqe->personality);
+       if (personality) {
+               req->work.creds = xa_load(&ctx->personalities, personality);
+               if (!req->work.creds)
+                       return -EINVAL;
+               get_cred(req->work.creds);
+       }
        state = &ctx->submit_state;
 
        /*
@@ -6358,8 +6467,6 @@ static int io_submit_sqe(struct io_ring_ctx *ctx, struct io_kiocb *req,
        ret = io_init_req(ctx, req, sqe);
        if (unlikely(ret)) {
 fail_req:
-               io_put_req(req);
-               io_req_complete(req, ret);
                if (link->head) {
                        /* fail even hard links since we don't submit */
                        link->head->flags |= REQ_F_FAIL_LINK;
@@ -6367,6 +6474,8 @@ fail_req:
                        io_req_complete(link->head, -ECANCELED);
                        link->head = NULL;
                }
+               io_put_req(req);
+               io_req_complete(req, ret);
                return ret;
        }
        ret = io_req_prep(req, sqe);
@@ -6586,7 +6695,8 @@ static int __io_sq_thread(struct io_ring_ctx *ctx, bool cap_entries)
                if (!list_empty(&ctx->iopoll_list))
                        io_do_iopoll(ctx, &nr_events, 0);
 
-               if (to_submit && likely(!percpu_ref_is_dying(&ctx->refs)))
+               if (to_submit && likely(!percpu_ref_is_dying(&ctx->refs)) &&
+                   !(ctx->flags & IORING_SETUP_R_DISABLED))
                        ret = io_submit_sqes(ctx, to_submit);
                mutex_unlock(&ctx->uring_lock);
        }
@@ -6610,58 +6720,6 @@ static void io_sqd_update_thread_idle(struct io_sq_data *sqd)
        sqd->sq_thread_idle = sq_thread_idle;
 }
 
-static void io_sqd_init_new(struct io_sq_data *sqd)
-{
-       struct io_ring_ctx *ctx;
-
-       while (!list_empty(&sqd->ctx_new_list)) {
-               ctx = list_first_entry(&sqd->ctx_new_list, struct io_ring_ctx, sqd_list);
-               list_move_tail(&ctx->sqd_list, &sqd->ctx_list);
-               complete(&ctx->sq_thread_comp);
-       }
-
-       io_sqd_update_thread_idle(sqd);
-}
-
-static bool io_sq_thread_should_stop(struct io_sq_data *sqd)
-{
-       return test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
-}
-
-static bool io_sq_thread_should_park(struct io_sq_data *sqd)
-{
-       return test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
-}
-
-static void io_sq_thread_parkme(struct io_sq_data *sqd)
-{
-       for (;;) {
-               /*
-                * TASK_PARKED is a special state; we must serialize against
-                * possible pending wakeups to avoid store-store collisions on
-                * task->state.
-                *
-                * Such a collision might possibly result in the task state
-                * changin from TASK_PARKED and us failing the
-                * wait_task_inactive() in kthread_park().
-                */
-               set_special_state(TASK_PARKED);
-               if (!test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state))
-                       break;
-
-               /*
-                * Thread is going to call schedule(), do not preempt it,
-                * or the caller of kthread_park() may spend more time in
-                * wait_task_inactive().
-                */
-               preempt_disable();
-               complete(&sqd->completion);
-               schedule_preempt_disabled();
-               preempt_enable();
-       }
-       __set_current_state(TASK_RUNNING);
-}
-
 static int io_sq_thread(void *data)
 {
        struct io_sq_data *sqd = data;
@@ -6670,7 +6728,7 @@ static int io_sq_thread(void *data)
        char buf[TASK_COMM_LEN];
        DEFINE_WAIT(wait);
 
-       sprintf(buf, "iou-sqp-%d", sqd->task_pid);
+       snprintf(buf, sizeof(buf), "iou-sqp-%d", sqd->task_pid);
        set_task_comm(current, buf);
        current->pf_io_worker = NULL;
 
@@ -6680,31 +6738,40 @@ static int io_sq_thread(void *data)
                set_cpus_allowed_ptr(current, cpu_online_mask);
        current->flags |= PF_NO_SETAFFINITY;
 
-       wait_for_completion(&sqd->startup);
-
-       while (!io_sq_thread_should_stop(sqd)) {
+       mutex_lock(&sqd->lock);
+       while (!test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state)) {
                int ret;
                bool cap_entries, sqt_spin, needs_sched;
 
-               /*
-                * Any changes to the sqd lists are synchronized through the
-                * thread parking. This synchronizes the thread vs users,
-                * the users are synchronized on the sqd->ctx_lock.
-                */
-               if (io_sq_thread_should_park(sqd)) {
-                       io_sq_thread_parkme(sqd);
-                       continue;
-               }
-               if (unlikely(!list_empty(&sqd->ctx_new_list))) {
-                       io_sqd_init_new(sqd);
+               if (test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state) ||
+                   signal_pending(current)) {
+                       bool did_sig = false;
+
+                       mutex_unlock(&sqd->lock);
+                       if (signal_pending(current)) {
+                               struct ksignal ksig;
+
+                               did_sig = get_signal(&ksig);
+                       }
+                       cond_resched();
+                       mutex_lock(&sqd->lock);
+                       if (did_sig)
+                               break;
+                       io_run_task_work();
+                       io_run_task_work_head(&sqd->park_task_work);
                        timeout = jiffies + sqd->sq_thread_idle;
+                       continue;
                }
-               if (fatal_signal_pending(current))
-                       break;
                sqt_spin = false;
                cap_entries = !list_is_singular(&sqd->ctx_list);
                list_for_each_entry(ctx, &sqd->ctx_list, sqd_list) {
+                       const struct cred *creds = NULL;
+
+                       if (ctx->sq_creds != current_cred())
+                               creds = override_creds(ctx->sq_creds);
                        ret = __io_sq_thread(ctx, cap_entries);
+                       if (creds)
+                               revert_creds(creds);
                        if (!sqt_spin && (ret > 0 || !list_empty(&ctx->iopoll_list)))
                                sqt_spin = true;
                }
@@ -6731,41 +6798,32 @@ static int io_sq_thread(void *data)
                        }
                }
 
-               if (needs_sched && !io_sq_thread_should_park(sqd)) {
+               if (needs_sched && !test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state)) {
                        list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
                                io_ring_set_wakeup_flag(ctx);
 
+                       mutex_unlock(&sqd->lock);
                        schedule();
-                       try_to_freeze();
+                       mutex_lock(&sqd->lock);
                        list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
                                io_ring_clear_wakeup_flag(ctx);
                }
 
                finish_wait(&sqd->wait, &wait);
+               io_run_task_work_head(&sqd->park_task_work);
                timeout = jiffies + sqd->sq_thread_idle;
        }
 
        list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
                io_uring_cancel_sqpoll(ctx);
-
-       io_run_task_work();
-
-       if (io_sq_thread_should_park(sqd))
-               io_sq_thread_parkme(sqd);
-
-       /*
-        * Clear thread under lock so that concurrent parks work correctly
-        */
-       complete(&sqd->completion);
-       mutex_lock(&sqd->lock);
        sqd->thread = NULL;
-       list_for_each_entry(ctx, &sqd->ctx_list, sqd_list) {
-               ctx->sqo_exec = 1;
+       list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
                io_ring_set_wakeup_flag(ctx);
-       }
+       mutex_unlock(&sqd->lock);
 
+       io_run_task_work();
+       io_run_task_work_head(&sqd->park_task_work);
        complete(&sqd->exited);
-       mutex_unlock(&sqd->lock);
        do_exit(0);
 }
 
@@ -6810,7 +6868,7 @@ static int io_run_task_work_sig(void)
                return 1;
        if (!signal_pending(current))
                return 0;
-       if (test_tsk_thread_flag(current, TIF_NOTIFY_SIGNAL))
+       if (test_thread_flag(TIF_NOTIFY_SIGNAL))
                return -ERESTARTSYS;
        return -EINTR;
 }
@@ -7066,51 +7124,47 @@ static int io_sqe_files_unregister(struct io_ring_ctx *ctx)
 static void io_sq_thread_unpark(struct io_sq_data *sqd)
        __releases(&sqd->lock)
 {
-       if (!sqd->thread)
-               return;
-       if (sqd->thread == current)
-               return;
+       WARN_ON_ONCE(sqd->thread == current);
+
+       /*
+        * Do the dance but not conditional clear_bit() because it'd race with
+        * other threads incrementing park_pending and setting the bit.
+        */
        clear_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
-       wake_up_state(sqd->thread, TASK_PARKED);
+       if (atomic_dec_return(&sqd->park_pending))
+               set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
        mutex_unlock(&sqd->lock);
 }
 
-static bool io_sq_thread_park(struct io_sq_data *sqd)
+static void io_sq_thread_park(struct io_sq_data *sqd)
        __acquires(&sqd->lock)
 {
-       if (sqd->thread == current)
-               return true;
-       mutex_lock(&sqd->lock);
-       if (!sqd->thread) {
-               mutex_unlock(&sqd->lock);
-               return false;
-       }
+       WARN_ON_ONCE(sqd->thread == current);
+
+       atomic_inc(&sqd->park_pending);
        set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
-       wake_up_process(sqd->thread);
-       wait_for_completion(&sqd->completion);
-       return true;
+       mutex_lock(&sqd->lock);
+       if (sqd->thread)
+               wake_up_process(sqd->thread);
 }
 
 static void io_sq_thread_stop(struct io_sq_data *sqd)
 {
-       if (test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state))
-               return;
+       WARN_ON_ONCE(sqd->thread == current);
+
        mutex_lock(&sqd->lock);
-       if (sqd->thread) {
-               set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
-               WARN_ON_ONCE(test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state));
+       set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
+       if (sqd->thread)
                wake_up_process(sqd->thread);
-               mutex_unlock(&sqd->lock);
-               wait_for_completion(&sqd->exited);
-               WARN_ON_ONCE(sqd->thread);
-       } else {
-               mutex_unlock(&sqd->lock);
-       }
+       mutex_unlock(&sqd->lock);
+       wait_for_completion(&sqd->exited);
 }
 
 static void io_put_sq_data(struct io_sq_data *sqd)
 {
        if (refcount_dec_and_test(&sqd->refs)) {
+               WARN_ON_ONCE(atomic_read(&sqd->park_pending));
+
                io_sq_thread_stop(sqd);
                kfree(sqd);
        }
@@ -7121,22 +7175,15 @@ static void io_sq_thread_finish(struct io_ring_ctx *ctx)
        struct io_sq_data *sqd = ctx->sq_data;
 
        if (sqd) {
-               complete(&sqd->startup);
-               if (sqd->thread) {
-                       wait_for_completion(&ctx->sq_thread_comp);
-                       io_sq_thread_park(sqd);
-               }
-
-               mutex_lock(&sqd->ctx_lock);
-               list_del(&ctx->sqd_list);
+               io_sq_thread_park(sqd);
+               list_del_init(&ctx->sqd_list);
                io_sqd_update_thread_idle(sqd);
-               mutex_unlock(&sqd->ctx_lock);
-
-               if (sqd->thread)
-                       io_sq_thread_unpark(sqd);
+               io_sq_thread_unpark(sqd);
 
                io_put_sq_data(sqd);
                ctx->sq_data = NULL;
+               if (ctx->sq_creds)
+                       put_cred(ctx->sq_creds);
        }
 }
 
@@ -7160,31 +7207,42 @@ static struct io_sq_data *io_attach_sq_data(struct io_uring_params *p)
                fdput(f);
                return ERR_PTR(-EINVAL);
        }
+       if (sqd->task_tgid != current->tgid) {
+               fdput(f);
+               return ERR_PTR(-EPERM);
+       }
 
        refcount_inc(&sqd->refs);
        fdput(f);
        return sqd;
 }
 
-static struct io_sq_data *io_get_sq_data(struct io_uring_params *p)
+static struct io_sq_data *io_get_sq_data(struct io_uring_params *p,
+                                        bool *attached)
 {
        struct io_sq_data *sqd;
 
-       if (p->flags & IORING_SETUP_ATTACH_WQ)
-               return io_attach_sq_data(p);
+       *attached = false;
+       if (p->flags & IORING_SETUP_ATTACH_WQ) {
+               sqd = io_attach_sq_data(p);
+               if (!IS_ERR(sqd)) {
+                       *attached = true;
+                       return sqd;
+               }
+               /* fall through for EPERM case, setup new sqd/task */
+               if (PTR_ERR(sqd) != -EPERM)
+                       return sqd;
+       }
 
        sqd = kzalloc(sizeof(*sqd), GFP_KERNEL);
        if (!sqd)
                return ERR_PTR(-ENOMEM);
 
+       atomic_set(&sqd->park_pending, 0);
        refcount_set(&sqd->refs, 1);
        INIT_LIST_HEAD(&sqd->ctx_list);
-       INIT_LIST_HEAD(&sqd->ctx_new_list);
-       mutex_init(&sqd->ctx_lock);
        mutex_init(&sqd->lock);
        init_waitqueue_head(&sqd->wait);
-       init_completion(&sqd->startup);
-       init_completion(&sqd->completion);
        init_completion(&sqd->exited);
        return sqd;
 }
@@ -7801,7 +7859,6 @@ static int io_uring_alloc_task_context(struct task_struct *task,
        init_waitqueue_head(&tctx->wait);
        tctx->last = NULL;
        atomic_set(&tctx->in_idle, 0);
-       tctx->sqpoll = false;
        task->io_uring = tctx;
        spin_lock_init(&tctx->task_lock);
        INIT_WQ_LIST(&tctx->task_list);
@@ -7822,26 +7879,6 @@ void __io_uring_free(struct task_struct *tsk)
        tsk->io_uring = NULL;
 }
 
-static int io_sq_thread_fork(struct io_sq_data *sqd, struct io_ring_ctx *ctx)
-{
-       struct task_struct *tsk;
-       int ret;
-
-       clear_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
-       reinit_completion(&sqd->completion);
-       ctx->sqo_exec = 0;
-       sqd->task_pid = current->pid;
-       tsk = create_io_thread(io_sq_thread, sqd, NUMA_NO_NODE);
-       if (IS_ERR(tsk))
-               return PTR_ERR(tsk);
-       ret = io_uring_alloc_task_context(tsk, ctx);
-       if (ret)
-               set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
-       sqd->thread = tsk;
-       wake_up_new_task(tsk);
-       return ret;
-}
-
 static int io_sq_offload_create(struct io_ring_ctx *ctx,
                                struct io_uring_params *p)
 {
@@ -7864,29 +7901,36 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
        if (ctx->flags & IORING_SETUP_SQPOLL) {
                struct task_struct *tsk;
                struct io_sq_data *sqd;
+               bool attached;
 
                ret = -EPERM;
                if (!capable(CAP_SYS_ADMIN) && !capable(CAP_SYS_NICE))
                        goto err;
 
-               sqd = io_get_sq_data(p);
+               sqd = io_get_sq_data(p, &attached);
                if (IS_ERR(sqd)) {
                        ret = PTR_ERR(sqd);
                        goto err;
                }
 
+               ctx->sq_creds = get_current_cred();
                ctx->sq_data = sqd;
-               io_sq_thread_park(sqd);
-               mutex_lock(&sqd->ctx_lock);
-               list_add(&ctx->sqd_list, &sqd->ctx_new_list);
-               mutex_unlock(&sqd->ctx_lock);
-               io_sq_thread_unpark(sqd);
-
                ctx->sq_thread_idle = msecs_to_jiffies(p->sq_thread_idle);
                if (!ctx->sq_thread_idle)
                        ctx->sq_thread_idle = HZ;
 
-               if (sqd->thread)
+               ret = 0;
+               io_sq_thread_park(sqd);
+               list_add(&ctx->sqd_list, &sqd->ctx_list);
+               io_sqd_update_thread_idle(sqd);
+               /* don't attach to a dying SQPOLL thread, would be racy */
+               if (attached && !sqd->thread)
+                       ret = -ENXIO;
+               io_sq_thread_unpark(sqd);
+
+               if (ret < 0)
+                       goto err;
+               if (attached)
                        return 0;
 
                if (p->flags & IORING_SETUP_SQ_AFF) {
@@ -7894,9 +7938,9 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
 
                        ret = -EINVAL;
                        if (cpu >= nr_cpu_ids)
-                               goto err;
+                               goto err_sqpoll;
                        if (!cpu_online(cpu))
-                               goto err;
+                               goto err_sqpoll;
 
                        sqd->sq_cpu = cpu;
                } else {
@@ -7904,15 +7948,15 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
                }
 
                sqd->task_pid = current->pid;
+               sqd->task_tgid = current->tgid;
                tsk = create_io_thread(io_sq_thread, sqd, NUMA_NO_NODE);
                if (IS_ERR(tsk)) {
                        ret = PTR_ERR(tsk);
-                       goto err;
+                       goto err_sqpoll;
                }
-               ret = io_uring_alloc_task_context(tsk, ctx);
-               if (ret)
-                       set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
+
                sqd->thread = tsk;
+               ret = io_uring_alloc_task_context(tsk, ctx);
                wake_up_new_task(tsk);
                if (ret)
                        goto err;
@@ -7926,15 +7970,9 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
 err:
        io_sq_thread_finish(ctx);
        return ret;
-}
-
-static void io_sq_offload_start(struct io_ring_ctx *ctx)
-{
-       struct io_sq_data *sqd = ctx->sq_data;
-
-       ctx->flags &= ~IORING_SETUP_R_DISABLED;
-       if (ctx->flags & IORING_SETUP_SQPOLL)
-               complete(&sqd->startup);
+err_sqpoll:
+       complete(&ctx->sq_data->exited);
+       goto err;
 }
 
 static inline void __io_unaccount_mem(struct user_struct *user,
@@ -8344,19 +8382,13 @@ static int io_eventfd_unregister(struct io_ring_ctx *ctx)
        return -ENXIO;
 }
 
-static int __io_destroy_buffers(int id, void *p, void *data)
-{
-       struct io_ring_ctx *ctx = data;
-       struct io_buffer *buf = p;
-
-       __io_remove_buffers(ctx, buf, id, -1U);
-       return 0;
-}
-
 static void io_destroy_buffers(struct io_ring_ctx *ctx)
 {
-       idr_for_each(&ctx->io_buffer_idr, __io_destroy_buffers, ctx);
-       idr_destroy(&ctx->io_buffer_idr);
+       struct io_buffer *buf;
+       unsigned long index;
+
+       xa_for_each(&ctx->io_buffers, index, buf)
+               __io_remove_buffers(ctx, buf, index, -1U);
 }
 
 static void io_req_cache_free(struct list_head *list, struct task_struct *tsk)
@@ -8398,11 +8430,13 @@ static void io_ring_ctx_free(struct io_ring_ctx *ctx)
 {
        /*
         * Some may use context even when all refs and requests have been put,
-        * and they are free to do so while still holding uring_lock, see
-        * __io_req_task_submit(). Wait for them to finish.
+        * and they are free to do so while still holding uring_lock or
+        * completion_lock, see __io_req_task_submit(). Wait for them to finish.
         */
        mutex_lock(&ctx->uring_lock);
        mutex_unlock(&ctx->uring_lock);
+       spin_lock_irq(&ctx->completion_lock);
+       spin_unlock_irq(&ctx->completion_lock);
 
        io_sq_thread_finish(ctx);
        io_sqe_buffers_unregister(ctx);
@@ -8417,7 +8451,6 @@ static void io_ring_ctx_free(struct io_ring_ctx *ctx)
        mutex_unlock(&ctx->uring_lock);
        io_eventfd_unregister(ctx);
        io_destroy_buffers(ctx);
-       idr_destroy(&ctx->personality_idr);
 
 #if defined(CONFIG_UNIX)
        if (ctx->ring_sock) {
@@ -8482,7 +8515,7 @@ static int io_unregister_personality(struct io_ring_ctx *ctx, unsigned id)
 {
        const struct cred *creds;
 
-       creds = idr_remove(&ctx->personality_idr, id);
+       creds = xa_erase(&ctx->personalities, id);
        if (creds) {
                put_cred(creds);
                return 0;
@@ -8491,40 +8524,47 @@ static int io_unregister_personality(struct io_ring_ctx *ctx, unsigned id)
        return -EINVAL;
 }
 
-static int io_remove_personalities(int id, void *p, void *data)
+static inline bool io_run_ctx_fallback(struct io_ring_ctx *ctx)
 {
-       struct io_ring_ctx *ctx = data;
-
-       io_unregister_personality(ctx, id);
-       return 0;
+       return io_run_task_work_head(&ctx->exit_task_work);
 }
 
-static bool io_run_ctx_fallback(struct io_ring_ctx *ctx)
-{
-       struct callback_head *work, *next;
-       bool executed = false;
-
-       do {
-               work = xchg(&ctx->exit_task_work, NULL);
-               if (!work)
-                       break;
+struct io_tctx_exit {
+       struct callback_head            task_work;
+       struct completion               completion;
+       struct io_ring_ctx              *ctx;
+};
 
-               do {
-                       next = work->next;
-                       work->func(work);
-                       work = next;
-                       cond_resched();
-               } while (work);
-               executed = true;
-       } while (1);
+static void io_tctx_exit_cb(struct callback_head *cb)
+{
+       struct io_uring_task *tctx = current->io_uring;
+       struct io_tctx_exit *work;
 
-       return executed;
+       work = container_of(cb, struct io_tctx_exit, task_work);
+       /*
+        * When @in_idle, we're in cancellation and it's racy to remove the
+        * node. It'll be removed by the end of cancellation, just ignore it.
+        */
+       if (!atomic_read(&tctx->in_idle))
+               io_uring_del_task_file((unsigned long)work->ctx);
+       complete(&work->completion);
 }
 
 static void io_ring_exit_work(struct work_struct *work)
 {
-       struct io_ring_ctx *ctx = container_of(work, struct io_ring_ctx,
-                                              exit_work);
+       struct io_ring_ctx *ctx = container_of(work, struct io_ring_ctx, exit_work);
+       unsigned long timeout = jiffies + HZ * 60 * 5;
+       struct io_tctx_exit exit;
+       struct io_tctx_node *node;
+       int ret;
+
+       /* prevent SQPOLL from submitting new requests */
+       if (ctx->sq_data) {
+               io_sq_thread_park(ctx->sq_data);
+               list_del_init(&ctx->sqd_list);
+               io_sqd_update_thread_idle(ctx->sq_data);
+               io_sq_thread_unpark(ctx->sq_data);
+       }
 
        /*
         * If we're doing polled IO and end up having requests being
@@ -8534,19 +8574,69 @@ static void io_ring_exit_work(struct work_struct *work)
         */
        do {
                io_uring_try_cancel_requests(ctx, NULL, NULL);
+
+               WARN_ON_ONCE(time_after(jiffies, timeout));
        } while (!wait_for_completion_timeout(&ctx->ref_comp, HZ/20));
+
+       mutex_lock(&ctx->uring_lock);
+       while (!list_empty(&ctx->tctx_list)) {
+               WARN_ON_ONCE(time_after(jiffies, timeout));
+
+               node = list_first_entry(&ctx->tctx_list, struct io_tctx_node,
+                                       ctx_node);
+               exit.ctx = ctx;
+               init_completion(&exit.completion);
+               init_task_work(&exit.task_work, io_tctx_exit_cb);
+               ret = task_work_add(node->task, &exit.task_work, TWA_SIGNAL);
+               if (WARN_ON_ONCE(ret))
+                       continue;
+               wake_up_process(node->task);
+
+               mutex_unlock(&ctx->uring_lock);
+               wait_for_completion(&exit.completion);
+               cond_resched();
+               mutex_lock(&ctx->uring_lock);
+       }
+       mutex_unlock(&ctx->uring_lock);
+
        io_ring_ctx_free(ctx);
 }
 
+/* Returns true if we found and killed one or more timeouts */
+static bool io_kill_timeouts(struct io_ring_ctx *ctx, struct task_struct *tsk,
+                            struct files_struct *files)
+{
+       struct io_kiocb *req, *tmp;
+       int canceled = 0;
+
+       spin_lock_irq(&ctx->completion_lock);
+       list_for_each_entry_safe(req, tmp, &ctx->timeout_list, timeout.list) {
+               if (io_match_task(req, tsk, files)) {
+                       io_kill_timeout(req, -ECANCELED);
+                       canceled++;
+               }
+       }
+       if (canceled != 0)
+               io_commit_cqring(ctx);
+       spin_unlock_irq(&ctx->completion_lock);
+       if (canceled != 0)
+               io_cqring_ev_posted(ctx);
+       return canceled != 0;
+}
+
 static void io_ring_ctx_wait_and_kill(struct io_ring_ctx *ctx)
 {
+       unsigned long index;
+       struct creds *creds;
+
        mutex_lock(&ctx->uring_lock);
        percpu_ref_kill(&ctx->refs);
        /* if force is set, the ring is going away. always drop after that */
        ctx->cq_overflow_flushed = 1;
        if (ctx->rings)
                __io_cqring_overflow_flush(ctx, true, NULL, NULL);
-       idr_for_each(&ctx->personality_idr, io_remove_personalities, ctx);
+       xa_for_each(&ctx->personalities, index, creds)
+               io_unregister_personality(ctx, index);
        mutex_unlock(&ctx->uring_lock);
 
        io_kill_timeouts(ctx, NULL, NULL);
@@ -8599,11 +8689,11 @@ static bool io_cancel_task_cb(struct io_wq_work *work, void *data)
        return ret;
 }
 
-static void io_cancel_defer_files(struct io_ring_ctx *ctx,
+static bool io_cancel_defer_files(struct io_ring_ctx *ctx,
                                  struct task_struct *task,
                                  struct files_struct *files)
 {
-       struct io_defer_entry *de = NULL;
+       struct io_defer_entry *de;
        LIST_HEAD(list);
 
        spin_lock_irq(&ctx->completion_lock);
@@ -8614,6 +8704,8 @@ static void io_cancel_defer_files(struct io_ring_ctx *ctx,
                }
        }
        spin_unlock_irq(&ctx->completion_lock);
+       if (list_empty(&list))
+               return false;
 
        while (!list_empty(&list)) {
                de = list_first_entry(&list, struct io_defer_entry, list);
@@ -8623,6 +8715,38 @@ static void io_cancel_defer_files(struct io_ring_ctx *ctx,
                io_req_complete(de->req, -ECANCELED);
                kfree(de);
        }
+       return true;
+}
+
+static bool io_cancel_ctx_cb(struct io_wq_work *work, void *data)
+{
+       struct io_kiocb *req = container_of(work, struct io_kiocb, work);
+
+       return req->ctx == data;
+}
+
+static bool io_uring_try_cancel_iowq(struct io_ring_ctx *ctx)
+{
+       struct io_tctx_node *node;
+       enum io_wq_cancel cret;
+       bool ret = false;
+
+       mutex_lock(&ctx->uring_lock);
+       list_for_each_entry(node, &ctx->tctx_list, ctx_node) {
+               struct io_uring_task *tctx = node->task->io_uring;
+
+               /*
+                * io_wq will stay alive while we hold uring_lock, because it's
+                * killed after ctx nodes, which requires to take the lock.
+                */
+               if (!tctx || !tctx->io_wq)
+                       continue;
+               cret = io_wq_cancel_cb(tctx->io_wq, io_cancel_ctx_cb, ctx, true);
+               ret |= (cret != IO_WQ_CANCEL_NOTFOUND);
+       }
+       mutex_unlock(&ctx->uring_lock);
+
+       return ret;
 }
 
 static void io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
@@ -8630,27 +8754,34 @@ static void io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
                                         struct files_struct *files)
 {
        struct io_task_cancel cancel = { .task = task, .files = files, };
-       struct task_struct *tctx_task = task ?: current;
-       struct io_uring_task *tctx = tctx_task->io_uring;
+       struct io_uring_task *tctx = task ? task->io_uring : NULL;
 
        while (1) {
                enum io_wq_cancel cret;
                bool ret = false;
 
-               if (tctx && tctx->io_wq) {
+               if (!task) {
+                       ret |= io_uring_try_cancel_iowq(ctx);
+               } else if (tctx && tctx->io_wq) {
+                       /*
+                        * Cancels requests of all rings, not only @ctx, but
+                        * it's fine as the task is in exit/exec.
+                        */
                        cret = io_wq_cancel_cb(tctx->io_wq, io_cancel_task_cb,
                                               &cancel, true);
                        ret |= (cret != IO_WQ_CANCEL_NOTFOUND);
                }
 
                /* SQPOLL thread does its own polling */
-               if (!(ctx->flags & IORING_SETUP_SQPOLL) && !files) {
+               if ((!(ctx->flags & IORING_SETUP_SQPOLL) && !files) ||
+                   (ctx->sq_data && ctx->sq_data->thread == current)) {
                        while (!list_empty_careful(&ctx->iopoll_list)) {
                                io_iopoll_try_reap_events(ctx);
                                ret = true;
                        }
                }
 
+               ret |= io_cancel_defer_files(ctx, task, files);
                ret |= io_poll_remove_all(ctx, task, files);
                ret |= io_kill_timeouts(ctx, task, files);
                ret |= io_run_task_work();
@@ -8690,60 +8821,21 @@ static void io_uring_cancel_files(struct io_ring_ctx *ctx,
 
                io_uring_try_cancel_requests(ctx, task, files);
 
-               if (ctx->sq_data)
-                       io_sq_thread_unpark(ctx->sq_data);
                prepare_to_wait(&task->io_uring->wait, &wait,
                                TASK_UNINTERRUPTIBLE);
                if (inflight == io_uring_count_inflight(ctx, task, files))
                        schedule();
                finish_wait(&task->io_uring->wait, &wait);
-               if (ctx->sq_data)
-                       io_sq_thread_park(ctx->sq_data);
-       }
-}
-
-/*
- * We need to iteratively cancel requests, in case a request has dependent
- * hard links. These persist even for failure of cancelations, hence keep
- * looping until none are found.
- */
-static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
-                                         struct files_struct *files)
-{
-       struct task_struct *task = current;
-       bool did_park = false;
-
-       if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sq_data) {
-               /* never started, nothing to cancel */
-               if (ctx->flags & IORING_SETUP_R_DISABLED) {
-                       io_sq_offload_start(ctx);
-                       return;
-               }
-               did_park = io_sq_thread_park(ctx->sq_data);
-               if (did_park) {
-                       task = ctx->sq_data->thread;
-                       atomic_inc(&task->io_uring->in_idle);
-               }
-       }
-
-       io_cancel_defer_files(ctx, task, files);
-
-       io_uring_cancel_files(ctx, task, files);
-       if (!files)
-               io_uring_try_cancel_requests(ctx, task, NULL);
-
-       if (did_park) {
-               atomic_dec(&task->io_uring->in_idle);
-               io_sq_thread_unpark(ctx->sq_data);
        }
 }
 
 /*
  * Note that this task has used io_uring. We use it for cancelation purposes.
  */
-static int io_uring_add_task_file(struct io_ring_ctx *ctx, struct file *file)
+static int io_uring_add_task_file(struct io_ring_ctx *ctx)
 {
        struct io_uring_task *tctx = current->io_uring;
+       struct io_tctx_node *node;
        int ret;
 
        if (unlikely(!tctx)) {
@@ -8752,97 +8844,143 @@ static int io_uring_add_task_file(struct io_ring_ctx *ctx, struct file *file)
                        return ret;
                tctx = current->io_uring;
        }
-       if (tctx->last != file) {
-               void *old = xa_load(&tctx->xa, (unsigned long)file);
+       if (tctx->last != ctx) {
+               void *old = xa_load(&tctx->xa, (unsigned long)ctx);
 
                if (!old) {
-                       get_file(file);
-                       ret = xa_err(xa_store(&tctx->xa, (unsigned long)file,
-                                               file, GFP_KERNEL));
+                       node = kmalloc(sizeof(*node), GFP_KERNEL);
+                       if (!node)
+                               return -ENOMEM;
+                       node->ctx = ctx;
+                       node->task = current;
+
+                       ret = xa_err(xa_store(&tctx->xa, (unsigned long)ctx,
+                                               node, GFP_KERNEL));
                        if (ret) {
-                               fput(file);
+                               kfree(node);
                                return ret;
                        }
+
+                       mutex_lock(&ctx->uring_lock);
+                       list_add(&node->ctx_node, &ctx->tctx_list);
+                       mutex_unlock(&ctx->uring_lock);
                }
-               tctx->last = file;
+               tctx->last = ctx;
        }
-
-       /*
-        * This is race safe in that the task itself is doing this, hence it
-        * cannot be going through the exit/cancel paths at the same time.
-        * This cannot be modified while exit/cancel is running.
-        */
-       if (!tctx->sqpoll && (ctx->flags & IORING_SETUP_SQPOLL))
-               tctx->sqpoll = true;
-
        return 0;
 }
 
 /*
  * Remove this io_uring_file -> task mapping.
  */
-static void io_uring_del_task_file(struct file *file)
+static void io_uring_del_task_file(unsigned long index)
 {
        struct io_uring_task *tctx = current->io_uring;
+       struct io_tctx_node *node;
+
+       if (!tctx)
+               return;
+       node = xa_erase(&tctx->xa, index);
+       if (!node)
+               return;
 
-       if (tctx->last == file)
+       WARN_ON_ONCE(current != node->task);
+       WARN_ON_ONCE(list_empty(&node->ctx_node));
+
+       mutex_lock(&node->ctx->uring_lock);
+       list_del(&node->ctx_node);
+       mutex_unlock(&node->ctx->uring_lock);
+
+       if (tctx->last == node->ctx)
                tctx->last = NULL;
-       file = xa_erase(&tctx->xa, (unsigned long)file);
-       if (file)
-               fput(file);
+       kfree(node);
 }
 
 static void io_uring_clean_tctx(struct io_uring_task *tctx)
 {
-       struct file *file;
+       struct io_tctx_node *node;
        unsigned long index;
 
-       xa_for_each(&tctx->xa, index, file)
-               io_uring_del_task_file(file);
+       xa_for_each(&tctx->xa, index, node)
+               io_uring_del_task_file(index);
        if (tctx->io_wq) {
                io_wq_put_and_exit(tctx->io_wq);
                tctx->io_wq = NULL;
        }
 }
 
+static s64 tctx_inflight(struct io_uring_task *tctx)
+{
+       return percpu_counter_sum(&tctx->inflight);
+}
+
+static void io_sqpoll_cancel_cb(struct callback_head *cb)
+{
+       struct io_tctx_exit *work = container_of(cb, struct io_tctx_exit, task_work);
+       struct io_ring_ctx *ctx = work->ctx;
+       struct io_sq_data *sqd = ctx->sq_data;
+
+       if (sqd->thread)
+               io_uring_cancel_sqpoll(ctx);
+       complete(&work->completion);
+}
+
+static void io_sqpoll_cancel_sync(struct io_ring_ctx *ctx)
+{
+       struct io_sq_data *sqd = ctx->sq_data;
+       struct io_tctx_exit work = { .ctx = ctx, };
+       struct task_struct *task;
+
+       io_sq_thread_park(sqd);
+       list_del_init(&ctx->sqd_list);
+       io_sqd_update_thread_idle(sqd);
+       task = sqd->thread;
+       if (task) {
+               init_completion(&work.completion);
+               init_task_work(&work.task_work, io_sqpoll_cancel_cb);
+               io_task_work_add_head(&sqd->park_task_work, &work.task_work);
+               wake_up_process(task);
+       }
+       io_sq_thread_unpark(sqd);
+
+       if (task)
+               wait_for_completion(&work.completion);
+}
+
 void __io_uring_files_cancel(struct files_struct *files)
 {
        struct io_uring_task *tctx = current->io_uring;
-       struct file *file;
+       struct io_tctx_node *node;
        unsigned long index;
 
        /* make sure overflow events are dropped */
        atomic_inc(&tctx->in_idle);
-       xa_for_each(&tctx->xa, index, file)
-               io_uring_cancel_task_requests(file->private_data, files);
+       xa_for_each(&tctx->xa, index, node) {
+               struct io_ring_ctx *ctx = node->ctx;
+
+               if (ctx->sq_data) {
+                       io_sqpoll_cancel_sync(ctx);
+                       continue;
+               }
+               io_uring_cancel_files(ctx, current, files);
+               if (!files)
+                       io_uring_try_cancel_requests(ctx, current, NULL);
+       }
        atomic_dec(&tctx->in_idle);
 
        if (files)
                io_uring_clean_tctx(tctx);
 }
 
-static s64 tctx_inflight(struct io_uring_task *tctx)
-{
-       return percpu_counter_sum(&tctx->inflight);
-}
-
+/* should only be called by SQPOLL task */
 static void io_uring_cancel_sqpoll(struct io_ring_ctx *ctx)
 {
        struct io_sq_data *sqd = ctx->sq_data;
-       struct io_uring_task *tctx;
+       struct io_uring_task *tctx = current->io_uring;
        s64 inflight;
        DEFINE_WAIT(wait);
 
-       if (!sqd)
-               return;
-       if (!io_sq_thread_park(sqd))
-               return;
-       tctx = ctx->sq_data->thread->io_uring;
-       /* can happen on fork/alloc failure, just ignore that state */
-       if (!tctx) {
-               io_sq_thread_unpark(sqd);
-               return;
-       }
+       WARN_ON_ONCE(!sqd || ctx->sq_data->thread != current);
 
        atomic_inc(&tctx->in_idle);
        do {
@@ -8850,7 +8988,7 @@ static void io_uring_cancel_sqpoll(struct io_ring_ctx *ctx)
                inflight = tctx_inflight(tctx);
                if (!inflight)
                        break;
-               io_uring_cancel_task_requests(ctx, NULL);
+               io_uring_try_cancel_requests(ctx, current, NULL);
 
                prepare_to_wait(&tctx->wait, &wait, TASK_UNINTERRUPTIBLE);
                /*
@@ -8863,7 +9001,6 @@ static void io_uring_cancel_sqpoll(struct io_ring_ctx *ctx)
                finish_wait(&tctx->wait, &wait);
        } while (1);
        atomic_dec(&tctx->in_idle);
-       io_sq_thread_unpark(sqd);
 }
 
 /*
@@ -8878,14 +9015,7 @@ void __io_uring_task_cancel(void)
 
        /* make sure overflow events are dropped */
        atomic_inc(&tctx->in_idle);
-
-       if (tctx->sqpoll) {
-               struct file *file;
-               unsigned long index;
-
-               xa_for_each(&tctx->xa, index, file)
-                       io_uring_cancel_sqpoll(file->private_data);
-       }
+       __io_uring_files_cancel(NULL);
 
        do {
                /* read completions before cancelations */
@@ -8985,7 +9115,6 @@ static unsigned long io_uring_nommu_get_unmapped_area(struct file *file,
 
 static int io_sqpoll_wait_sq(struct io_ring_ctx *ctx)
 {
-       int ret = 0;
        DEFINE_WAIT(wait);
 
        do {
@@ -8999,7 +9128,7 @@ static int io_sqpoll_wait_sq(struct io_ring_ctx *ctx)
        } while (!signal_pending(current));
 
        finish_wait(&ctx->sqo_sq_wait, &wait);
-       return ret;
+       return 0;
 }
 
 static int io_get_ext_arg(unsigned flags, const void __user *argp, size_t *argsz,
@@ -9073,13 +9202,10 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
        if (ctx->flags & IORING_SETUP_SQPOLL) {
                io_cqring_overflow_flush(ctx, false, NULL, NULL);
 
-               if (unlikely(ctx->sqo_exec)) {
-                       ret = io_sq_thread_fork(ctx->sq_data, ctx);
-                       if (ret)
-                               goto out;
-                       ctx->sqo_exec = 0;
-               }
                ret = -EOWNERDEAD;
+               if (unlikely(ctx->sq_data->thread == NULL)) {
+                       goto out;
+               }
                if (flags & IORING_ENTER_SQ_WAKEUP)
                        wake_up(&ctx->sq_data->wait);
                if (flags & IORING_ENTER_SQ_WAIT) {
@@ -9089,7 +9215,7 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
                }
                submitted = to_submit;
        } else if (to_submit) {
-               ret = io_uring_add_task_file(ctx, f.file);
+               ret = io_uring_add_task_file(ctx);
                if (unlikely(ret))
                        goto out;
                mutex_lock(&ctx->uring_lock);
@@ -9131,10 +9257,9 @@ out_fput:
 }
 
 #ifdef CONFIG_PROC_FS
-static int io_uring_show_cred(int id, void *p, void *data)
+static int io_uring_show_cred(struct seq_file *m, unsigned int id,
+               const struct cred *cred)
 {
-       const struct cred *cred = p;
-       struct seq_file *m = data;
        struct user_namespace *uns = seq_user_ns(m);
        struct group_info *gi;
        kernel_cap_t cap;
@@ -9202,9 +9327,13 @@ static void __io_uring_show_fdinfo(struct io_ring_ctx *ctx, struct seq_file *m)
                seq_printf(m, "%5u: 0x%llx/%u\n", i, buf->ubuf,
                                                (unsigned int) buf->len);
        }
-       if (has_lock && !idr_is_empty(&ctx->personality_idr)) {
+       if (has_lock && !xa_empty(&ctx->personalities)) {
+               unsigned long index;
+               const struct cred *cred;
+
                seq_printf(m, "Personalities:\n");
-               idr_for_each(&ctx->personality_idr, io_uring_show_cred, m);
+               xa_for_each(&ctx->personalities, index, cred)
+                       io_uring_show_cred(m, index, cred);
        }
        seq_printf(m, "PollList:\n");
        spin_lock_irq(&ctx->completion_lock);
@@ -9298,7 +9427,7 @@ static int io_uring_install_fd(struct io_ring_ctx *ctx, struct file *file)
        if (fd < 0)
                return fd;
 
-       ret = io_uring_add_task_file(ctx, file);
+       ret = io_uring_add_task_file(ctx);
        if (ret) {
                put_unused_fd(fd);
                return ret;
@@ -9406,9 +9535,6 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p,
        if (ret)
                goto err;
 
-       if (!(p->flags & IORING_SETUP_R_DISABLED))
-               io_sq_offload_start(ctx);
-
        memset(&p->sq_off, 0, sizeof(p->sq_off));
        p->sq_off.head = offsetof(struct io_rings, sq.head);
        p->sq_off.tail = offsetof(struct io_rings, sq.tail);
@@ -9536,14 +9662,16 @@ out:
 static int io_register_personality(struct io_ring_ctx *ctx)
 {
        const struct cred *creds;
+       u32 id;
        int ret;
 
        creds = get_current_cred();
 
-       ret = idr_alloc_cyclic(&ctx->personality_idr, (void *) creds, 1,
-                               USHRT_MAX, GFP_KERNEL);
-       if (ret < 0)
-               put_cred(creds);
+       ret = xa_alloc_cyclic(&ctx->personalities, &id, (void *)creds,
+                       XA_LIMIT(0, USHRT_MAX), &ctx->pers_next, GFP_KERNEL);
+       if (!ret)
+               return id;
+       put_cred(creds);
        return ret;
 }
 
@@ -9625,7 +9753,9 @@ static int io_register_enable_rings(struct io_ring_ctx *ctx)
        if (ctx->restrictions.registered)
                ctx->restricted = 1;
 
-       io_sq_offload_start(ctx);
+       ctx->flags &= ~IORING_SETUP_R_DISABLED;
+       if (ctx->sq_data && wq_has_sleeper(&ctx->sq_data->wait))
+               wake_up(&ctx->sq_data->wait);
        return 0;
 }