ALSA: mtpav: Don't call card private_free at probe error path
[linux-2.6-microblaze.git] / fs / io_uring.c
index d5ab0e9..e54c412 100644 (file)
@@ -57,7 +57,7 @@
 #include <linux/mman.h>
 #include <linux/percpu.h>
 #include <linux/slab.h>
-#include <linux/blkdev.h>
+#include <linux/blk-mq.h>
 #include <linux/bvec.h>
 #include <linux/net.h>
 #include <net/sock.h>
 #define SQE_COMMON_FLAGS (IOSQE_FIXED_FILE | IOSQE_IO_LINK | \
                          IOSQE_IO_HARDLINK | IOSQE_ASYNC)
 
-#define SQE_VALID_FLAGS        (SQE_COMMON_FLAGS|IOSQE_BUFFER_SELECT|IOSQE_IO_DRAIN)
+#define SQE_VALID_FLAGS        (SQE_COMMON_FLAGS | IOSQE_BUFFER_SELECT | \
+                       IOSQE_IO_DRAIN | IOSQE_CQE_SKIP_SUCCESS)
 
 #define IO_REQ_CLEAN_FLAGS (REQ_F_BUFFER_SELECTED | REQ_F_NEED_CLEANUP | \
                                REQ_F_POLLED | REQ_F_INFLIGHT | REQ_F_CREDS | \
@@ -320,6 +321,7 @@ struct io_submit_state {
 
        bool                    plug_started;
        bool                    need_plug;
+       bool                    flush_cqes;
        unsigned short          submit_nr;
        struct blk_plug         plug;
 };
@@ -337,6 +339,7 @@ struct io_ring_ctx {
                unsigned int            restricted: 1;
                unsigned int            off_timeout_used: 1;
                unsigned int            drain_active: 1;
+               unsigned int            drain_disabled: 1;
        } ____cacheline_aligned_in_smp;
 
        /* submission data */
@@ -471,6 +474,7 @@ struct io_uring_task {
 
        spinlock_t              task_lock;
        struct io_wq_work_list  task_list;
+       struct io_wq_work_list  prior_task_list;
        struct callback_head    task_work;
        bool                    task_running;
 };
@@ -483,8 +487,6 @@ struct io_poll_iocb {
        struct file                     *file;
        struct wait_queue_head          *head;
        __poll_t                        events;
-       bool                            done;
-       bool                            canceled;
        struct wait_queue_entry         wait;
 };
 
@@ -721,6 +723,7 @@ enum {
        REQ_F_HARDLINK_BIT      = IOSQE_IO_HARDLINK_BIT,
        REQ_F_FORCE_ASYNC_BIT   = IOSQE_ASYNC_BIT,
        REQ_F_BUFFER_SELECT_BIT = IOSQE_BUFFER_SELECT_BIT,
+       REQ_F_CQE_SKIP_BIT      = IOSQE_CQE_SKIP_SUCCESS_BIT,
 
        /* first byte is taken by user flags, shift it to not overlap */
        REQ_F_FAIL_BIT          = 8,
@@ -737,6 +740,7 @@ enum {
        REQ_F_REFCOUNT_BIT,
        REQ_F_ARM_LTIMEOUT_BIT,
        REQ_F_ASYNC_DATA_BIT,
+       REQ_F_SKIP_LINK_CQES_BIT,
        /* keep async read/write and isreg together and in order */
        REQ_F_SUPPORT_NOWAIT_BIT,
        REQ_F_ISREG_BIT,
@@ -758,6 +762,8 @@ enum {
        REQ_F_FORCE_ASYNC       = BIT(REQ_F_FORCE_ASYNC_BIT),
        /* IOSQE_BUFFER_SELECT */
        REQ_F_BUFFER_SELECT     = BIT(REQ_F_BUFFER_SELECT_BIT),
+       /* IOSQE_CQE_SKIP_SUCCESS */
+       REQ_F_CQE_SKIP          = BIT(REQ_F_CQE_SKIP_BIT),
 
        /* fail rest of links */
        REQ_F_FAIL              = BIT(REQ_F_FAIL_BIT),
@@ -791,6 +797,8 @@ enum {
        REQ_F_ARM_LTIMEOUT      = BIT(REQ_F_ARM_LTIMEOUT_BIT),
        /* ->async_data allocated */
        REQ_F_ASYNC_DATA        = BIT(REQ_F_ASYNC_DATA_BIT),
+       /* don't post CQEs while failing linked requests */
+       REQ_F_SKIP_LINK_CQES    = BIT(REQ_F_SKIP_LINK_CQES_BIT),
 };
 
 struct async_poll {
@@ -882,6 +890,7 @@ struct io_kiocb {
        const struct cred               *creds;
        /* stores selected buf, valid IFF REQ_F_BUFFER_SELECTED is set */
        struct io_buffer                *kbuf;
+       atomic_t                        poll_refs;
 };
 
 struct io_tctx_node {
@@ -1108,8 +1117,8 @@ static void io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
                                         bool cancel_all);
 static void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd);
 
-static bool io_cqring_fill_event(struct io_ring_ctx *ctx, u64 user_data,
-                                s32 res, u32 cflags);
+static void io_fill_cqe_req(struct io_kiocb *req, s32 res, u32 cflags);
+
 static void io_put_req(struct io_kiocb *req);
 static void io_put_req_deferred(struct io_kiocb *req);
 static void io_dismantle_req(struct io_kiocb *req);
@@ -1183,12 +1192,6 @@ static inline bool req_ref_put_and_test(struct io_kiocb *req)
        return atomic_dec_and_test(&req->refs);
 }
 
-static inline void req_ref_put(struct io_kiocb *req)
-{
-       WARN_ON_ONCE(!(req->flags & REQ_F_REFCOUNT));
-       WARN_ON_ONCE(req_ref_put_and_test(req));
-}
-
 static inline void req_ref_get(struct io_kiocb *req)
 {
        WARN_ON_ONCE(!(req->flags & REQ_F_REFCOUNT));
@@ -1264,6 +1267,26 @@ static inline void io_req_set_rsrc_node(struct io_kiocb *req,
        }
 }
 
+static unsigned int __io_put_kbuf(struct io_kiocb *req)
+{
+       struct io_buffer *kbuf = req->kbuf;
+       unsigned int cflags;
+
+       cflags = kbuf->bid << IORING_CQE_BUFFER_SHIFT;
+       cflags |= IORING_CQE_F_BUFFER;
+       req->flags &= ~REQ_F_BUFFER_SELECTED;
+       kfree(kbuf);
+       req->kbuf = NULL;
+       return cflags;
+}
+
+static inline unsigned int io_put_kbuf(struct io_kiocb *req)
+{
+       if (likely(!(req->flags & REQ_F_BUFFER_SELECTED)))
+               return 0;
+       return __io_put_kbuf(req);
+}
+
 static void io_refs_resurrect(struct percpu_ref *ref, struct completion *compl)
 {
        bool got = percpu_ref_tryget(ref);
@@ -1340,6 +1363,10 @@ static inline bool req_has_async_data(struct io_kiocb *req)
 static inline void req_set_fail(struct io_kiocb *req)
 {
        req->flags |= REQ_F_FAIL;
+       if (req->flags & REQ_F_CQE_SKIP) {
+               req->flags &= ~REQ_F_CQE_SKIP;
+               req->flags |= REQ_F_SKIP_LINK_CQES;
+       }
 }
 
 static inline void req_fail_link_node(struct io_kiocb *req, int res)
@@ -1553,8 +1580,11 @@ static void io_prep_async_link(struct io_kiocb *req)
 
 static inline void io_req_add_compl_list(struct io_kiocb *req)
 {
-       struct io_submit_state *state = &req->ctx->submit_state;
+       struct io_ring_ctx *ctx = req->ctx;
+       struct io_submit_state *state = &ctx->submit_state;
 
+       if (!(req->flags & REQ_F_CQE_SKIP))
+               ctx->submit_state.flush_cqes = true;
        wq_list_add_tail(&req->comp_list, &state->compl_reqs);
 }
 
@@ -1599,7 +1629,7 @@ static void io_kill_timeout(struct io_kiocb *req, int status)
                atomic_set(&req->ctx->cq_timeouts,
                        atomic_read(&req->ctx->cq_timeouts) + 1);
                list_del_init(&req->timeout.list);
-               io_cqring_fill_event(req->ctx, req->user_data, status, 0);
+               io_fill_cqe_req(req, status, 0);
                io_put_req_deferred(req);
        }
 }
@@ -1830,6 +1860,18 @@ static inline void io_get_task_refs(int nr)
                io_task_refs_refill(tctx);
 }
 
+static __cold void io_uring_drop_tctx_refs(struct task_struct *task)
+{
+       struct io_uring_task *tctx = task->io_uring;
+       unsigned int refs = tctx->cached_refs;
+
+       if (refs) {
+               tctx->cached_refs = 0;
+               percpu_counter_sub(&tctx->inflight, refs);
+               put_task_struct_many(task, refs);
+       }
+}
+
 static bool io_cqring_event_overflow(struct io_ring_ctx *ctx, u64 user_data,
                                     s32 res, u32 cflags)
 {
@@ -1858,8 +1900,8 @@ static bool io_cqring_event_overflow(struct io_ring_ctx *ctx, u64 user_data,
        return true;
 }
 
-static inline bool __io_cqring_fill_event(struct io_ring_ctx *ctx, u64 user_data,
-                                         s32 res, u32 cflags)
+static inline bool __io_fill_cqe(struct io_ring_ctx *ctx, u64 user_data,
+                                s32 res, u32 cflags)
 {
        struct io_uring_cqe *cqe;
 
@@ -1880,20 +1922,26 @@ static inline bool __io_cqring_fill_event(struct io_ring_ctx *ctx, u64 user_data
        return io_cqring_event_overflow(ctx, user_data, res, cflags);
 }
 
-/* not as hot to bloat with inlining */
-static noinline bool io_cqring_fill_event(struct io_ring_ctx *ctx, u64 user_data,
-                                         s32 res, u32 cflags)
+static noinline void io_fill_cqe_req(struct io_kiocb *req, s32 res, u32 cflags)
 {
-       return __io_cqring_fill_event(ctx, user_data, res, cflags);
+       if (!(req->flags & REQ_F_CQE_SKIP))
+               __io_fill_cqe(req->ctx, req->user_data, res, cflags);
 }
 
-static void io_req_complete_post(struct io_kiocb *req, s32 res,
-                                u32 cflags)
+static noinline bool io_fill_cqe_aux(struct io_ring_ctx *ctx, u64 user_data,
+                                    s32 res, u32 cflags)
+{
+       ctx->cq_extra++;
+       return __io_fill_cqe(ctx, user_data, res, cflags);
+}
+
+static void __io_req_complete_post(struct io_kiocb *req, s32 res,
+                                  u32 cflags)
 {
        struct io_ring_ctx *ctx = req->ctx;
 
-       spin_lock(&ctx->completion_lock);
-       __io_cqring_fill_event(ctx, req->user_data, res, cflags);
+       if (!(req->flags & REQ_F_CQE_SKIP))
+               __io_fill_cqe(ctx, req->user_data, res, cflags);
        /*
         * If we're the last reference to this request, add to our locked
         * free_list cache.
@@ -1913,6 +1961,15 @@ static void io_req_complete_post(struct io_kiocb *req, s32 res,
                wq_list_add_head(&req->comp_list, &ctx->locked_free_list);
                ctx->locked_free_nr++;
        }
+}
+
+static void io_req_complete_post(struct io_kiocb *req, s32 res,
+                                u32 cflags)
+{
+       struct io_ring_ctx *ctx = req->ctx;
+
+       spin_lock(&ctx->completion_lock);
+       __io_req_complete_post(req, res, cflags);
        io_commit_cqring(ctx);
        spin_unlock(&ctx->completion_lock);
        io_cqring_ev_posted(ctx);
@@ -2101,8 +2158,8 @@ static bool io_kill_linked_timeout(struct io_kiocb *req)
                link->timeout.head = NULL;
                if (hrtimer_try_to_cancel(&io->timer) != -1) {
                        list_del(&link->timeout.list);
-                       io_cqring_fill_event(link->ctx, link->user_data,
-                                            -ECANCELED, 0);
+                       /* leave REQ_F_CQE_SKIP to io_fill_cqe_req */
+                       io_fill_cqe_req(link, -ECANCELED, 0);
                        io_put_req_deferred(link);
                        return true;
                }
@@ -2114,6 +2171,7 @@ static void io_fail_links(struct io_kiocb *req)
        __must_hold(&req->ctx->completion_lock)
 {
        struct io_kiocb *nxt, *link = req->link;
+       bool ignore_cqes = req->flags & REQ_F_SKIP_LINK_CQES;
 
        req->link = NULL;
        while (link) {
@@ -2126,7 +2184,10 @@ static void io_fail_links(struct io_kiocb *req)
                link->link = NULL;
 
                trace_io_uring_fail_link(req, link);
-               io_cqring_fill_event(link->ctx, link->user_data, res, 0);
+               if (!ignore_cqes) {
+                       link->flags &= ~REQ_F_CQE_SKIP;
+                       io_fill_cqe_req(link, res, 0);
+               }
                io_put_req_deferred(link);
                link = nxt;
        }
@@ -2143,8 +2204,8 @@ static bool io_disarm_next(struct io_kiocb *req)
                req->flags &= ~REQ_F_ARM_LTIMEOUT;
                if (link && link->opcode == IORING_OP_LINK_TIMEOUT) {
                        io_remove_next_linked(req);
-                       io_cqring_fill_event(link->ctx, link->user_data,
-                                            -ECANCELED, 0);
+                       /* leave REQ_F_CQE_SKIP to io_fill_cqe_req */
+                       io_fill_cqe_req(link, -ECANCELED, 0);
                        io_put_req_deferred(link);
                        posted = true;
                }
@@ -2171,7 +2232,7 @@ static void __io_req_find_next_prep(struct io_kiocb *req)
        spin_lock(&ctx->completion_lock);
        posted = io_disarm_next(req);
        if (posted)
-               io_commit_cqring(req->ctx);
+               io_commit_cqring(ctx);
        spin_unlock(&ctx->completion_lock);
        if (posted)
                io_cqring_ev_posted(ctx);
@@ -2208,51 +2269,108 @@ static void ctx_flush_and_put(struct io_ring_ctx *ctx, bool *locked)
        percpu_ref_put(&ctx->refs);
 }
 
+static inline void ctx_commit_and_unlock(struct io_ring_ctx *ctx)
+{
+       io_commit_cqring(ctx);
+       spin_unlock(&ctx->completion_lock);
+       io_cqring_ev_posted(ctx);
+}
+
+static void handle_prev_tw_list(struct io_wq_work_node *node,
+                               struct io_ring_ctx **ctx, bool *uring_locked)
+{
+       if (*ctx && !*uring_locked)
+               spin_lock(&(*ctx)->completion_lock);
+
+       do {
+               struct io_wq_work_node *next = node->next;
+               struct io_kiocb *req = container_of(node, struct io_kiocb,
+                                                   io_task_work.node);
+
+               if (req->ctx != *ctx) {
+                       if (unlikely(!*uring_locked && *ctx))
+                               ctx_commit_and_unlock(*ctx);
+
+                       ctx_flush_and_put(*ctx, uring_locked);
+                       *ctx = req->ctx;
+                       /* if not contended, grab and improve batching */
+                       *uring_locked = mutex_trylock(&(*ctx)->uring_lock);
+                       percpu_ref_get(&(*ctx)->refs);
+                       if (unlikely(!*uring_locked))
+                               spin_lock(&(*ctx)->completion_lock);
+               }
+               if (likely(*uring_locked))
+                       req->io_task_work.func(req, uring_locked);
+               else
+                       __io_req_complete_post(req, req->result, io_put_kbuf(req));
+               node = next;
+       } while (node);
+
+       if (unlikely(!*uring_locked))
+               ctx_commit_and_unlock(*ctx);
+}
+
+static void handle_tw_list(struct io_wq_work_node *node,
+                          struct io_ring_ctx **ctx, bool *locked)
+{
+       do {
+               struct io_wq_work_node *next = node->next;
+               struct io_kiocb *req = container_of(node, struct io_kiocb,
+                                                   io_task_work.node);
+
+               if (req->ctx != *ctx) {
+                       ctx_flush_and_put(*ctx, locked);
+                       *ctx = req->ctx;
+                       /* if not contended, grab and improve batching */
+                       *locked = mutex_trylock(&(*ctx)->uring_lock);
+                       percpu_ref_get(&(*ctx)->refs);
+               }
+               req->io_task_work.func(req, locked);
+               node = next;
+       } while (node);
+}
+
 static void tctx_task_work(struct callback_head *cb)
 {
-       bool locked = false;
+       bool uring_locked = false;
        struct io_ring_ctx *ctx = NULL;
        struct io_uring_task *tctx = container_of(cb, struct io_uring_task,
                                                  task_work);
 
        while (1) {
-               struct io_wq_work_node *node;
+               struct io_wq_work_node *node1, *node2;
 
-               if (!tctx->task_list.first && locked)
+               if (!tctx->task_list.first &&
+                   !tctx->prior_task_list.first && uring_locked)
                        io_submit_flush_completions(ctx);
 
                spin_lock_irq(&tctx->task_lock);
-               node = tctx->task_list.first;
+               node1 = tctx->prior_task_list.first;
+               node2 = tctx->task_list.first;
                INIT_WQ_LIST(&tctx->task_list);
-               if (!node)
+               INIT_WQ_LIST(&tctx->prior_task_list);
+               if (!node2 && !node1)
                        tctx->task_running = false;
                spin_unlock_irq(&tctx->task_lock);
-               if (!node)
+               if (!node2 && !node1)
                        break;
 
-               do {
-                       struct io_wq_work_node *next = node->next;
-                       struct io_kiocb *req = container_of(node, struct io_kiocb,
-                                                           io_task_work.node);
-
-                       if (req->ctx != ctx) {
-                               ctx_flush_and_put(ctx, &locked);
-                               ctx = req->ctx;
-                               /* if not contended, grab and improve batching */
-                               locked = mutex_trylock(&ctx->uring_lock);
-                               percpu_ref_get(&ctx->refs);
-                       }
-                       req->io_task_work.func(req, &locked);
-                       node = next;
-               } while (node);
+               if (node1)
+                       handle_prev_tw_list(node1, &ctx, &uring_locked);
 
+               if (node2)
+                       handle_tw_list(node2, &ctx, &uring_locked);
                cond_resched();
        }
 
-       ctx_flush_and_put(ctx, &locked);
+       ctx_flush_and_put(ctx, &uring_locked);
+
+       /* relaxed read is enough as only the task itself sets ->in_idle */
+       if (unlikely(atomic_read(&tctx->in_idle)))
+               io_uring_drop_tctx_refs(current);
 }
 
-static void io_req_task_work_add(struct io_kiocb *req)
+static void io_req_task_work_add(struct io_kiocb *req, bool priority)
 {
        struct task_struct *tsk = req->task;
        struct io_uring_task *tctx = tsk->io_uring;
@@ -2264,7 +2382,10 @@ static void io_req_task_work_add(struct io_kiocb *req)
        WARN_ON_ONCE(!tctx);
 
        spin_lock_irqsave(&tctx->task_lock, flags);
-       wq_list_add_tail(&req->io_task_work.node, &tctx->task_list);
+       if (priority)
+               wq_list_add_tail(&req->io_task_work.node, &tctx->prior_task_list);
+       else
+               wq_list_add_tail(&req->io_task_work.node, &tctx->task_list);
        running = tctx->task_running;
        if (!running)
                tctx->task_running = true;
@@ -2289,8 +2410,7 @@ static void io_req_task_work_add(struct io_kiocb *req)
 
        spin_lock_irqsave(&tctx->task_lock, flags);
        tctx->task_running = false;
-       node = tctx->task_list.first;
-       INIT_WQ_LIST(&tctx->task_list);
+       node = wq_list_merge(&tctx->prior_task_list, &tctx->task_list);
        spin_unlock_irqrestore(&tctx->task_lock, flags);
 
        while (node) {
@@ -2327,19 +2447,19 @@ static void io_req_task_queue_fail(struct io_kiocb *req, int ret)
 {
        req->result = ret;
        req->io_task_work.func = io_req_task_cancel;
-       io_req_task_work_add(req);
+       io_req_task_work_add(req, false);
 }
 
 static void io_req_task_queue(struct io_kiocb *req)
 {
        req->io_task_work.func = io_req_task_submit;
-       io_req_task_work_add(req);
+       io_req_task_work_add(req, false);
 }
 
 static void io_req_task_queue_reissue(struct io_kiocb *req)
 {
        req->io_task_work.func = io_queue_async_work;
-       io_req_task_work_add(req);
+       io_req_task_work_add(req, false);
 }
 
 static inline void io_queue_next(struct io_kiocb *req)
@@ -2403,17 +2523,22 @@ static void __io_submit_flush_completions(struct io_ring_ctx *ctx)
        struct io_wq_work_node *node, *prev;
        struct io_submit_state *state = &ctx->submit_state;
 
-       spin_lock(&ctx->completion_lock);
-       wq_list_for_each(node, prev, &state->compl_reqs) {
-               struct io_kiocb *req = container_of(node, struct io_kiocb,
+       if (state->flush_cqes) {
+               spin_lock(&ctx->completion_lock);
+               wq_list_for_each(node, prev, &state->compl_reqs) {
+                       struct io_kiocb *req = container_of(node, struct io_kiocb,
                                                    comp_list);
 
-               __io_cqring_fill_event(ctx, req->user_data, req->result,
-                                       req->cflags);
+                       if (!(req->flags & REQ_F_CQE_SKIP))
+                               __io_fill_cqe(ctx, req->user_data, req->result,
+                                             req->cflags);
+               }
+
+               io_commit_cqring(ctx);
+               spin_unlock(&ctx->completion_lock);
+               io_cqring_ev_posted(ctx);
+               state->flush_cqes = false;
        }
-       io_commit_cqring(ctx);
-       spin_unlock(&ctx->completion_lock);
-       io_cqring_ev_posted(ctx);
 
        io_free_batch_list(ctx, state->compl_reqs.first);
        INIT_WQ_LIST(&state->compl_reqs);
@@ -2444,7 +2569,7 @@ static inline void io_put_req_deferred(struct io_kiocb *req)
 {
        if (req_ref_put_and_test(req)) {
                req->io_task_work.func = io_free_req_work;
-               io_req_task_work_add(req);
+               io_req_task_work_add(req, false);
        }
 }
 
@@ -2463,24 +2588,6 @@ static inline unsigned int io_sqring_entries(struct io_ring_ctx *ctx)
        return smp_load_acquire(&rings->sq.tail) - ctx->cached_sq_head;
 }
 
-static unsigned int io_put_kbuf(struct io_kiocb *req, struct io_buffer *kbuf)
-{
-       unsigned int cflags;
-
-       cflags = kbuf->bid << IORING_CQE_BUFFER_SHIFT;
-       cflags |= IORING_CQE_F_BUFFER;
-       req->flags &= ~REQ_F_BUFFER_SELECTED;
-       kfree(kbuf);
-       return cflags;
-}
-
-static inline unsigned int io_put_rw_kbuf(struct io_kiocb *req)
-{
-       if (likely(!(req->flags & REQ_F_BUFFER_SELECTED)))
-               return 0;
-       return io_put_kbuf(req, req->kbuf);
-}
-
 static inline bool io_run_task_work(void)
 {
        if (test_thread_flag(TIF_NOTIFY_SIGNAL) || current->task_works) {
@@ -2543,8 +2650,10 @@ static int io_do_iopoll(struct io_ring_ctx *ctx, bool force_nonspin)
                /* order with io_complete_rw_iopoll(), e.g. ->result updates */
                if (!smp_load_acquire(&req->iopoll_completed))
                        break;
-               __io_cqring_fill_event(ctx, req->user_data, req->result,
-                                       io_put_rw_kbuf(req));
+               if (unlikely(req->flags & REQ_F_CQE_SKIP))
+                       continue;
+
+               __io_fill_cqe(ctx, req->user_data, req->result, io_put_kbuf(req));
                nr_events++;
        }
 
@@ -2718,9 +2827,9 @@ static bool __io_complete_rw_common(struct io_kiocb *req, long res)
        return false;
 }
 
-static void io_req_task_complete(struct io_kiocb *req, bool *locked)
+static inline void io_req_task_complete(struct io_kiocb *req, bool *locked)
 {
-       unsigned int cflags = io_put_rw_kbuf(req);
+       unsigned int cflags = io_put_kbuf(req);
        int res = req->result;
 
        if (*locked) {
@@ -2731,12 +2840,12 @@ static void io_req_task_complete(struct io_kiocb *req, bool *locked)
        }
 }
 
-static void __io_complete_rw(struct io_kiocb *req, long res, long res2,
+static void __io_complete_rw(struct io_kiocb *req, long res,
                             unsigned int issue_flags)
 {
        if (__io_complete_rw_common(req, res))
                return;
-       __io_req_complete(req, issue_flags, req->result, io_put_rw_kbuf(req));
+       __io_req_complete(req, issue_flags, req->result, io_put_kbuf(req));
 }
 
 static void io_complete_rw(struct kiocb *kiocb, long res)
@@ -2747,7 +2856,7 @@ static void io_complete_rw(struct kiocb *kiocb, long res)
                return;
        req->result = res;
        req->io_task_work.func = io_req_task_complete;
-       io_req_task_work_add(req);
+       io_req_task_work_add(req, !!(req->ctx->flags & IORING_SETUP_SQPOLL));
 }
 
 static void io_complete_rw_iopoll(struct kiocb *kiocb, long res)
@@ -2891,9 +3000,13 @@ static int io_prep_rw(struct io_kiocb *req, const struct io_uring_sqe *sqe)
                req->flags |= io_file_get_flags(file) << REQ_F_SUPPORT_NOWAIT_BIT;
 
        kiocb->ki_pos = READ_ONCE(sqe->off);
-       if (kiocb->ki_pos == -1 && !(file->f_mode & FMODE_STREAM)) {
-               req->flags |= REQ_F_CUR_POS;
-               kiocb->ki_pos = file->f_pos;
+       if (kiocb->ki_pos == -1) {
+               if (!(file->f_mode & FMODE_STREAM)) {
+                       req->flags |= REQ_F_CUR_POS;
+                       kiocb->ki_pos = file->f_pos;
+               } else {
+                       kiocb->ki_pos = 0;
+               }
        }
        kiocb->ki_flags = iocb_flags(file);
        ret = kiocb_set_rw_flags(kiocb, READ_ONCE(sqe->rw_flags));
@@ -2961,10 +3074,9 @@ static inline void io_rw_done(struct kiocb *kiocb, ssize_t ret)
        }
 }
 
-static void kiocb_done(struct kiocb *kiocb, ssize_t ret,
+static void kiocb_done(struct io_kiocb *req, ssize_t ret,
                       unsigned int issue_flags)
 {
-       struct io_kiocb *req = container_of(kiocb, struct io_kiocb, rw.kiocb);
        struct io_async_rw *io = req->async_data;
 
        /* add previously done IO, if any */
@@ -2976,28 +3088,21 @@ static void kiocb_done(struct kiocb *kiocb, ssize_t ret,
        }
 
        if (req->flags & REQ_F_CUR_POS)
-               req->file->f_pos = kiocb->ki_pos;
-       if (ret >= 0 && (kiocb->ki_complete == io_complete_rw))
-               __io_complete_rw(req, ret, 0, issue_flags);
+               req->file->f_pos = req->rw.kiocb.ki_pos;
+       if (ret >= 0 && (req->rw.kiocb.ki_complete == io_complete_rw))
+               __io_complete_rw(req, ret, issue_flags);
        else
-               io_rw_done(kiocb, ret);
+               io_rw_done(&req->rw.kiocb, ret);
 
        if (req->flags & REQ_F_REISSUE) {
                req->flags &= ~REQ_F_REISSUE;
                if (io_resubmit_prep(req)) {
                        io_req_task_queue_reissue(req);
                } else {
-                       unsigned int cflags = io_put_rw_kbuf(req);
-                       struct io_ring_ctx *ctx = req->ctx;
-
                        req_set_fail(req);
-                       if (issue_flags & IO_URING_F_UNLOCKED) {
-                               mutex_lock(&ctx->uring_lock);
-                               __io_req_complete(req, issue_flags, ret, cflags);
-                               mutex_unlock(&ctx->uring_lock);
-                       } else {
-                               __io_req_complete(req, issue_flags, ret, cflags);
-                       }
+                       req->result = ret;
+                       req->io_task_work.func = io_req_task_complete;
+                       io_req_task_work_add(req, false);
                }
        }
 }
@@ -3225,10 +3330,12 @@ static struct iovec *__io_import_iovec(int rw, struct io_kiocb *req,
        size_t sqe_len;
        ssize_t ret;
 
-       BUILD_BUG_ON(ERR_PTR(0) != NULL);
-
-       if (opcode == IORING_OP_READ_FIXED || opcode == IORING_OP_WRITE_FIXED)
-               return ERR_PTR(io_import_fixed(req, rw, iter));
+       if (opcode == IORING_OP_READ_FIXED || opcode == IORING_OP_WRITE_FIXED) {
+               ret = io_import_fixed(req, rw, iter);
+               if (ret)
+                       return ERR_PTR(ret);
+               return NULL;
+       }
 
        /* buffer index only valid with fixed read/write, or buffer select  */
        if (unlikely(req->buf_index && !(req->flags & REQ_F_BUFFER_SELECT)))
@@ -3246,15 +3353,18 @@ static struct iovec *__io_import_iovec(int rw, struct io_kiocb *req,
                }
 
                ret = import_single_range(rw, buf, sqe_len, s->fast_iov, iter);
-               return ERR_PTR(ret);
+               if (ret)
+                       return ERR_PTR(ret);
+               return NULL;
        }
 
        iovec = s->fast_iov;
        if (req->flags & REQ_F_BUFFER_SELECT) {
                ret = io_iov_buffer_select(req, iovec, issue_flags);
-               if (!ret)
-                       iov_iter_init(iter, rw, iovec, 1, iovec->iov_len);
-               return ERR_PTR(ret);
+               if (ret)
+                       return ERR_PTR(ret);
+               iov_iter_init(iter, rw, iovec, 1, iovec->iov_len);
+               return NULL;
        }
 
        ret = __import_iovec(rw, buf, sqe_len, UIO_FASTIOV, &iovec, iter,
@@ -3625,7 +3735,7 @@ static int io_read(struct io_kiocb *req, unsigned int issue_flags)
                iov_iter_restore(&s->iter, &s->iter_state);
        } while (ret > 0);
 done:
-       kiocb_done(kiocb, ret, issue_flags);
+       kiocb_done(req, ret, issue_flags);
 out_free:
        /* it's faster to check here then delegate to kfree */
        if (iovec)
@@ -3722,7 +3832,7 @@ static int io_write(struct io_kiocb *req, unsigned int issue_flags)
                if (ret2 == -EAGAIN && (req->ctx->flags & IORING_SETUP_IOPOLL))
                        goto copy_iov;
 done:
-               kiocb_done(kiocb, ret2, issue_flags);
+               kiocb_done(req, ret2, issue_flags);
        } else {
 copy_iov:
                iov_iter_restore(&s->iter, &s->iter_state);
@@ -4835,17 +4945,18 @@ static int io_sendmsg(struct io_kiocb *req, unsigned int issue_flags)
                min_ret = iov_iter_count(&kmsg->msg.msg_iter);
 
        ret = __sys_sendmsg_sock(sock, &kmsg->msg, flags);
-       if ((issue_flags & IO_URING_F_NONBLOCK) && ret == -EAGAIN)
-               return io_setup_async_msg(req, kmsg);
-       if (ret == -ERESTARTSYS)
-               ret = -EINTR;
 
+       if (ret < min_ret) {
+               if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
+                       return io_setup_async_msg(req, kmsg);
+               if (ret == -ERESTARTSYS)
+                       ret = -EINTR;
+               req_set_fail(req);
+       }
        /* fast path, check for non-NULL to avoid function call */
        if (kmsg->free_iov)
                kfree(kmsg->free_iov);
        req->flags &= ~REQ_F_NEED_CLEANUP;
-       if (ret < min_ret)
-               req_set_fail(req);
        __io_req_complete(req, issue_flags, ret, 0);
        return 0;
 }
@@ -4881,13 +4992,13 @@ static int io_send(struct io_kiocb *req, unsigned int issue_flags)
 
        msg.msg_flags = flags;
        ret = sock_sendmsg(sock, &msg);
-       if ((issue_flags & IO_URING_F_NONBLOCK) && ret == -EAGAIN)
-               return -EAGAIN;
-       if (ret == -ERESTARTSYS)
-               ret = -EINTR;
-
-       if (ret < min_ret)
+       if (ret < min_ret) {
+               if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
+                       return -EAGAIN;
+               if (ret == -ERESTARTSYS)
+                       ret = -EINTR;
                req_set_fail(req);
+       }
        __io_req_complete(req, issue_flags, ret, 0);
        return 0;
 }
@@ -4987,11 +5098,6 @@ static struct io_buffer *io_recv_buffer_select(struct io_kiocb *req,
        return io_buffer_select(req, &sr->len, sr->bgid, issue_flags);
 }
 
-static inline unsigned int io_put_recv_kbuf(struct io_kiocb *req)
-{
-       return io_put_kbuf(req, req->kbuf);
-}
-
 static int io_recvmsg_prep_async(struct io_kiocb *req)
 {
        int ret;
@@ -5029,8 +5135,7 @@ static int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
        struct socket *sock;
        struct io_buffer *kbuf;
        unsigned flags;
-       int min_ret = 0;
-       int ret, cflags = 0;
+       int ret, min_ret = 0;
        bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
 
        sock = sock_from_file(req->file);
@@ -5064,20 +5169,21 @@ static int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
 
        ret = __sys_recvmsg_sock(sock, &kmsg->msg, req->sr_msg.umsg,
                                        kmsg->uaddr, flags);
-       if (force_nonblock && ret == -EAGAIN)
-               return io_setup_async_msg(req, kmsg);
-       if (ret == -ERESTARTSYS)
-               ret = -EINTR;
+       if (ret < min_ret) {
+               if (ret == -EAGAIN && force_nonblock)
+                       return io_setup_async_msg(req, kmsg);
+               if (ret == -ERESTARTSYS)
+                       ret = -EINTR;
+               req_set_fail(req);
+       } else if ((flags & MSG_WAITALL) && (kmsg->msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
+               req_set_fail(req);
+       }
 
-       if (req->flags & REQ_F_BUFFER_SELECTED)
-               cflags = io_put_recv_kbuf(req);
        /* fast path, check for non-NULL to avoid function call */
        if (kmsg->free_iov)
                kfree(kmsg->free_iov);
        req->flags &= ~REQ_F_NEED_CLEANUP;
-       if (ret < min_ret || ((flags & MSG_WAITALL) && (kmsg->msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))))
-               req_set_fail(req);
-       __io_req_complete(req, issue_flags, ret, cflags);
+       __io_req_complete(req, issue_flags, ret, io_put_kbuf(req));
        return 0;
 }
 
@@ -5090,8 +5196,7 @@ static int io_recv(struct io_kiocb *req, unsigned int issue_flags)
        struct socket *sock;
        struct iovec iov;
        unsigned flags;
-       int min_ret = 0;
-       int ret, cflags = 0;
+       int ret, min_ret = 0;
        bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
 
        sock = sock_from_file(req->file);
@@ -5123,16 +5228,18 @@ static int io_recv(struct io_kiocb *req, unsigned int issue_flags)
                min_ret = iov_iter_count(&msg.msg_iter);
 
        ret = sock_recvmsg(sock, &msg, flags);
-       if (force_nonblock && ret == -EAGAIN)
-               return -EAGAIN;
-       if (ret == -ERESTARTSYS)
-               ret = -EINTR;
 out_free:
-       if (req->flags & REQ_F_BUFFER_SELECTED)
-               cflags = io_put_recv_kbuf(req);
-       if (ret < min_ret || ((flags & MSG_WAITALL) && (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))))
+       if (ret < min_ret) {
+               if (ret == -EAGAIN && force_nonblock)
+                       return -EAGAIN;
+               if (ret == -ERESTARTSYS)
+                       ret = -EINTR;
+               req_set_fail(req);
+       } else if ((flags & MSG_WAITALL) && (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
                req_set_fail(req);
-       __io_req_complete(req, issue_flags, ret, cflags);
+       }
+
+       __io_req_complete(req, issue_flags, ret, io_put_kbuf(req));
        return 0;
 }
 
@@ -5299,52 +5406,23 @@ struct io_poll_table {
        int error;
 };
 
-static int __io_async_wake(struct io_kiocb *req, struct io_poll_iocb *poll,
-                          __poll_t mask, io_req_tw_func_t func)
-{
-       /* for instances that support it check for an event match first: */
-       if (mask && !(mask & poll->events))
-               return 0;
-
-       trace_io_uring_task_add(req->ctx, req->opcode, req->user_data, mask);
-
-       list_del_init(&poll->wait.entry);
+#define IO_POLL_CANCEL_FLAG    BIT(31)
+#define IO_POLL_REF_MASK       ((1u << 20)-1)
 
-       req->result = mask;
-       req->io_task_work.func = func;
-
-       /*
-        * If this fails, then the task is exiting. When a task exits, the
-        * work gets canceled, so just cancel this request as well instead
-        * of executing it. We can't safely execute it anyway, as we may not
-        * have the needed state needed for it anyway.
-        */
-       io_req_task_work_add(req);
-       return 1;
+/*
+ * If refs part of ->poll_refs (see IO_POLL_REF_MASK) is 0, it's free. We can
+ * bump it and acquire ownership. It's disallowed to modify requests while not
+ * owning it, that prevents from races for enqueueing task_work's and b/w
+ * arming poll and wakeups.
+ */
+static inline bool io_poll_get_ownership(struct io_kiocb *req)
+{
+       return !(atomic_fetch_inc(&req->poll_refs) & IO_POLL_REF_MASK);
 }
 
-static bool io_poll_rewait(struct io_kiocb *req, struct io_poll_iocb *poll)
-       __acquires(&req->ctx->completion_lock)
+static void io_poll_mark_cancelled(struct io_kiocb *req)
 {
-       struct io_ring_ctx *ctx = req->ctx;
-
-       /* req->task == current here, checking PF_EXITING is safe */
-       if (unlikely(req->task->flags & PF_EXITING))
-               WRITE_ONCE(poll->canceled, true);
-
-       if (!req->result && !READ_ONCE(poll->canceled)) {
-               struct poll_table_struct pt = { ._key = poll->events };
-
-               req->result = vfs_poll(req->file, &pt) & poll->events;
-       }
-
-       spin_lock(&ctx->completion_lock);
-       if (!req->result && !READ_ONCE(poll->canceled)) {
-               add_wait_queue(poll->head, &poll->wait);
-               return true;
-       }
-
-       return false;
+       atomic_or(IO_POLL_CANCEL_FLAG, &req->poll_refs);
 }
 
 static struct io_poll_iocb *io_poll_get_double(struct io_kiocb *req)
@@ -5362,133 +5440,241 @@ static struct io_poll_iocb *io_poll_get_single(struct io_kiocb *req)
        return &req->apoll->poll;
 }
 
-static void io_poll_remove_double(struct io_kiocb *req)
-       __must_hold(&req->ctx->completion_lock)
+static void io_poll_req_insert(struct io_kiocb *req)
 {
-       struct io_poll_iocb *poll = io_poll_get_double(req);
+       struct io_ring_ctx *ctx = req->ctx;
+       struct hlist_head *list;
+
+       list = &ctx->cancel_hash[hash_long(req->user_data, ctx->cancel_hash_bits)];
+       hlist_add_head(&req->hash_node, list);
+}
 
-       lockdep_assert_held(&req->ctx->completion_lock);
+static void io_init_poll_iocb(struct io_poll_iocb *poll, __poll_t events,
+                             wait_queue_func_t wake_func)
+{
+       poll->head = NULL;
+#define IO_POLL_UNMASK (EPOLLERR|EPOLLHUP|EPOLLNVAL|EPOLLRDHUP)
+       /* mask in events that we always want/need */
+       poll->events = events | IO_POLL_UNMASK;
+       INIT_LIST_HEAD(&poll->wait.entry);
+       init_waitqueue_func_entry(&poll->wait, wake_func);
+}
 
-       if (poll && poll->head) {
-               struct wait_queue_head *head = poll->head;
+static inline void io_poll_remove_entry(struct io_poll_iocb *poll)
+{
+       struct wait_queue_head *head = smp_load_acquire(&poll->head);
 
+       if (head) {
                spin_lock_irq(&head->lock);
                list_del_init(&poll->wait.entry);
-               if (poll->wait.private)
-                       req_ref_put(req);
                poll->head = NULL;
                spin_unlock_irq(&head->lock);
        }
 }
 
-static bool __io_poll_complete(struct io_kiocb *req, __poll_t mask)
-       __must_hold(&req->ctx->completion_lock)
+static void io_poll_remove_entries(struct io_kiocb *req)
 {
-       struct io_ring_ctx *ctx = req->ctx;
-       unsigned flags = IORING_CQE_F_MORE;
-       int error;
-
-       if (READ_ONCE(req->poll.canceled)) {
-               error = -ECANCELED;
-               req->poll.events |= EPOLLONESHOT;
-       } else {
-               error = mangle_poll(mask);
-       }
-       if (req->poll.events & EPOLLONESHOT)
-               flags = 0;
-       if (!io_cqring_fill_event(ctx, req->user_data, error, flags)) {
-               req->poll.events |= EPOLLONESHOT;
-               flags = 0;
-       }
-       if (flags & IORING_CQE_F_MORE)
-               ctx->cq_extra++;
+       struct io_poll_iocb *poll = io_poll_get_single(req);
+       struct io_poll_iocb *poll_double = io_poll_get_double(req);
 
-       return !(flags & IORING_CQE_F_MORE);
+       /*
+        * While we hold the waitqueue lock and the waitqueue is nonempty,
+        * wake_up_pollfree() will wait for us.  However, taking the waitqueue
+        * lock in the first place can race with the waitqueue being freed.
+        *
+        * We solve this as eventpoll does: by taking advantage of the fact that
+        * all users of wake_up_pollfree() will RCU-delay the actual free.  If
+        * we enter rcu_read_lock() and see that the pointer to the queue is
+        * non-NULL, we can then lock it without the memory being freed out from
+        * under us.
+        *
+        * Keep holding rcu_read_lock() as long as we hold the queue lock, in
+        * case the caller deletes the entry from the queue, leaving it empty.
+        * In that case, only RCU prevents the queue memory from being freed.
+        */
+       rcu_read_lock();
+       io_poll_remove_entry(poll);
+       if (poll_double)
+               io_poll_remove_entry(poll_double);
+       rcu_read_unlock();
 }
 
-static void io_poll_task_func(struct io_kiocb *req, bool *locked)
+/*
+ * All poll tw should go through this. Checks for poll events, manages
+ * references, does rewait, etc.
+ *
+ * Returns a negative error on failure. >0 when no action require, which is
+ * either spurious wakeup or multishot CQE is served. 0 when it's done with
+ * the request, then the mask is stored in req->result.
+ */
+static int io_poll_check_events(struct io_kiocb *req)
 {
        struct io_ring_ctx *ctx = req->ctx;
-       struct io_kiocb *nxt;
+       struct io_poll_iocb *poll = io_poll_get_single(req);
+       int v;
 
-       if (io_poll_rewait(req, &req->poll)) {
-               spin_unlock(&ctx->completion_lock);
-       } else {
-               bool done;
+       /* req->task == current here, checking PF_EXITING is safe */
+       if (unlikely(req->task->flags & PF_EXITING))
+               io_poll_mark_cancelled(req);
 
-               if (req->poll.done) {
-                       spin_unlock(&ctx->completion_lock);
-                       return;
-               }
-               done = __io_poll_complete(req, req->result);
-               if (done) {
-                       io_poll_remove_double(req);
-                       hash_del(&req->hash_node);
-                       req->poll.done = true;
-               } else {
-                       req->result = 0;
-                       add_wait_queue(req->poll.head, &req->poll.wait);
-               }
-               io_commit_cqring(ctx);
-               spin_unlock(&ctx->completion_lock);
-               io_cqring_ev_posted(ctx);
+       do {
+               v = atomic_read(&req->poll_refs);
 
-               if (done) {
-                       nxt = io_put_req_find_next(req);
-                       if (nxt)
-                               io_req_task_submit(nxt, locked);
+               /* tw handler should be the owner, and so have some references */
+               if (WARN_ON_ONCE(!(v & IO_POLL_REF_MASK)))
+                       return 0;
+               if (v & IO_POLL_CANCEL_FLAG)
+                       return -ECANCELED;
+
+               if (!req->result) {
+                       struct poll_table_struct pt = { ._key = poll->events };
+
+                       req->result = vfs_poll(req->file, &pt) & poll->events;
                }
+
+               /* multishot, just fill an CQE and proceed */
+               if (req->result && !(poll->events & EPOLLONESHOT)) {
+                       __poll_t mask = mangle_poll(req->result & poll->events);
+                       bool filled;
+
+                       spin_lock(&ctx->completion_lock);
+                       filled = io_fill_cqe_aux(ctx, req->user_data, mask,
+                                                IORING_CQE_F_MORE);
+                       io_commit_cqring(ctx);
+                       spin_unlock(&ctx->completion_lock);
+                       if (unlikely(!filled))
+                               return -ECANCELED;
+                       io_cqring_ev_posted(ctx);
+               } else if (req->result) {
+                       return 0;
+               }
+
+               /*
+                * Release all references, retry if someone tried to restart
+                * task_work while we were executing it.
+                */
+       } while (atomic_sub_return(v & IO_POLL_REF_MASK, &req->poll_refs));
+
+       return 1;
+}
+
+static void io_poll_task_func(struct io_kiocb *req, bool *locked)
+{
+       struct io_ring_ctx *ctx = req->ctx;
+       int ret;
+
+       ret = io_poll_check_events(req);
+       if (ret > 0)
+               return;
+
+       if (!ret) {
+               req->result = mangle_poll(req->result & req->poll.events);
+       } else {
+               req->result = ret;
+               req_set_fail(req);
        }
+
+       io_poll_remove_entries(req);
+       spin_lock(&ctx->completion_lock);
+       hash_del(&req->hash_node);
+       __io_req_complete_post(req, req->result, 0);
+       io_commit_cqring(ctx);
+       spin_unlock(&ctx->completion_lock);
+       io_cqring_ev_posted(ctx);
 }
 
-static int io_poll_double_wake(struct wait_queue_entry *wait, unsigned mode,
-                              int sync, void *key)
+static void io_apoll_task_func(struct io_kiocb *req, bool *locked)
+{
+       struct io_ring_ctx *ctx = req->ctx;
+       int ret;
+
+       ret = io_poll_check_events(req);
+       if (ret > 0)
+               return;
+
+       io_poll_remove_entries(req);
+       spin_lock(&ctx->completion_lock);
+       hash_del(&req->hash_node);
+       spin_unlock(&ctx->completion_lock);
+
+       if (!ret)
+               io_req_task_submit(req, locked);
+       else
+               io_req_complete_failed(req, ret);
+}
+
+static void __io_poll_execute(struct io_kiocb *req, int mask)
+{
+       req->result = mask;
+       if (req->opcode == IORING_OP_POLL_ADD)
+               req->io_task_work.func = io_poll_task_func;
+       else
+               req->io_task_work.func = io_apoll_task_func;
+
+       trace_io_uring_task_add(req->ctx, req->opcode, req->user_data, mask);
+       io_req_task_work_add(req, false);
+}
+
+static inline void io_poll_execute(struct io_kiocb *req, int res)
+{
+       if (io_poll_get_ownership(req))
+               __io_poll_execute(req, res);
+}
+
+static void io_poll_cancel_req(struct io_kiocb *req)
+{
+       io_poll_mark_cancelled(req);
+       /* kick tw, which should complete the request */
+       io_poll_execute(req, 0);
+}
+
+static int io_poll_wake(struct wait_queue_entry *wait, unsigned mode, int sync,
+                       void *key)
 {
        struct io_kiocb *req = wait->private;
-       struct io_poll_iocb *poll = io_poll_get_single(req);
+       struct io_poll_iocb *poll = container_of(wait, struct io_poll_iocb,
+                                                wait);
        __poll_t mask = key_to_poll(key);
-       unsigned long flags;
 
-       /* for instances that support it check for an event match first: */
-       if (mask && !(mask & poll->events))
-               return 0;
-       if (!(poll->events & EPOLLONESHOT))
-               return poll->wait.func(&poll->wait, mode, sync, key);
+       if (unlikely(mask & POLLFREE)) {
+               io_poll_mark_cancelled(req);
+               /* we have to kick tw in case it's not already */
+               io_poll_execute(req, 0);
 
-       list_del_init(&wait->entry);
+               /*
+                * If the waitqueue is being freed early but someone is already
+                * holds ownership over it, we have to tear down the request as
+                * best we can. That means immediately removing the request from
+                * its waitqueue and preventing all further accesses to the
+                * waitqueue via the request.
+                */
+               list_del_init(&poll->wait.entry);
 
-       if (poll->head) {
-               bool done;
+               /*
+                * Careful: this *must* be the last step, since as soon
+                * as req->head is NULL'ed out, the request can be
+                * completed and freed, since aio_poll_complete_work()
+                * will no longer need to take the waitqueue lock.
+                */
+               smp_store_release(&poll->head, NULL);
+               return 1;
+       }
+
+       /* for instances that support it check for an event match first */
+       if (mask && !(mask & poll->events))
+               return 0;
 
-               spin_lock_irqsave(&poll->head->lock, flags);
-               done = list_empty(&poll->wait.entry);
-               if (!done)
+       if (io_poll_get_ownership(req)) {
+               /* optional, saves extra locking for removal in tw handler */
+               if (mask && poll->events & EPOLLONESHOT) {
                        list_del_init(&poll->wait.entry);
-               /* make sure double remove sees this as being gone */
-               wait->private = NULL;
-               spin_unlock_irqrestore(&poll->head->lock, flags);
-               if (!done) {
-                       /* use wait func handler, so it matches the rq type */
-                       poll->wait.func(&poll->wait, mode, sync, key);
+                       poll->head = NULL;
                }
+               __io_poll_execute(req, mask);
        }
-       req_ref_put(req);
        return 1;
 }
 
-static void io_init_poll_iocb(struct io_poll_iocb *poll, __poll_t events,
-                             wait_queue_func_t wake_func)
-{
-       poll->head = NULL;
-       poll->done = false;
-       poll->canceled = false;
-#define IO_POLL_UNMASK (EPOLLERR|EPOLLHUP|EPOLLNVAL|EPOLLRDHUP)
-       /* mask in events that we always want/need */
-       poll->events = events | IO_POLL_UNMASK;
-       INIT_LIST_HEAD(&poll->wait.entry);
-       init_waitqueue_func_entry(&poll->wait, wake_func);
-}
-
 static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
                            struct wait_queue_head *head,
                            struct io_poll_iocb **poll_ptr)
@@ -5501,10 +5687,10 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
         * if this happens.
         */
        if (unlikely(pt->nr_entries)) {
-               struct io_poll_iocb *poll_one = poll;
+               struct io_poll_iocb *first = poll;
 
                /* double add on the same waitqueue head, ignore */
-               if (poll_one->head == head)
+               if (first->head == head)
                        return;
                /* already have a 2nd entry, fail a third attempt */
                if (*poll_ptr) {
@@ -5513,21 +5699,13 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
                        pt->error = -EINVAL;
                        return;
                }
-               /*
-                * Can't handle multishot for double wait for now, turn it
-                * into one-shot mode.
-                */
-               if (!(poll_one->events & EPOLLONESHOT))
-                       poll_one->events |= EPOLLONESHOT;
+
                poll = kmalloc(sizeof(*poll), GFP_ATOMIC);
                if (!poll) {
                        pt->error = -ENOMEM;
                        return;
                }
-               io_init_poll_iocb(poll, poll_one->events, io_poll_double_wake);
-               req_ref_get(req);
-               poll->wait.private = req;
-
+               io_init_poll_iocb(poll, first->events, first->wait.func);
                *poll_ptr = poll;
                if (req->opcode == IORING_OP_POLL_ADD)
                        req->flags |= REQ_F_ASYNC_DATA;
@@ -5535,6 +5713,7 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
 
        pt->nr_entries++;
        poll->head = head;
+       poll->wait.private = req;
 
        if (poll->events & EPOLLEXCLUSIVE)
                add_wait_queue_exclusive(head, &poll->wait);
@@ -5542,70 +5721,24 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
                add_wait_queue(head, &poll->wait);
 }
 
-static void io_async_queue_proc(struct file *file, struct wait_queue_head *head,
+static void io_poll_queue_proc(struct file *file, struct wait_queue_head *head,
                               struct poll_table_struct *p)
 {
        struct io_poll_table *pt = container_of(p, struct io_poll_table, pt);
-       struct async_poll *apoll = pt->req->apoll;
-
-       __io_queue_proc(&apoll->poll, pt, head, &apoll->double_poll);
-}
-
-static void io_async_task_func(struct io_kiocb *req, bool *locked)
-{
-       struct async_poll *apoll = req->apoll;
-       struct io_ring_ctx *ctx = req->ctx;
-
-       trace_io_uring_task_run(req->ctx, req, req->opcode, req->user_data);
-
-       if (io_poll_rewait(req, &apoll->poll)) {
-               spin_unlock(&ctx->completion_lock);
-               return;
-       }
-
-       hash_del(&req->hash_node);
-       io_poll_remove_double(req);
-       apoll->poll.done = true;
-       spin_unlock(&ctx->completion_lock);
 
-       if (!READ_ONCE(apoll->poll.canceled))
-               io_req_task_submit(req, locked);
-       else
-               io_req_complete_failed(req, -ECANCELED);
+       __io_queue_proc(&pt->req->poll, pt, head,
+                       (struct io_poll_iocb **) &pt->req->async_data);
 }
 
-static int io_async_wake(struct wait_queue_entry *wait, unsigned mode, int sync,
-                       void *key)
-{
-       struct io_kiocb *req = wait->private;
-       struct io_poll_iocb *poll = &req->apoll->poll;
-
-       trace_io_uring_poll_wake(req->ctx, req->opcode, req->user_data,
-                                       key_to_poll(key));
-
-       return __io_async_wake(req, poll, key_to_poll(key), io_async_task_func);
-}
-
-static void io_poll_req_insert(struct io_kiocb *req)
+static int __io_arm_poll_handler(struct io_kiocb *req,
+                                struct io_poll_iocb *poll,
+                                struct io_poll_table *ipt, __poll_t mask)
 {
        struct io_ring_ctx *ctx = req->ctx;
-       struct hlist_head *list;
-
-       list = &ctx->cancel_hash[hash_long(req->user_data, ctx->cancel_hash_bits)];
-       hlist_add_head(&req->hash_node, list);
-}
-
-static __poll_t __io_arm_poll_handler(struct io_kiocb *req,
-                                     struct io_poll_iocb *poll,
-                                     struct io_poll_table *ipt, __poll_t mask,
-                                     wait_queue_func_t wake_func)
-       __acquires(&ctx->completion_lock)
-{
-       struct io_ring_ctx *ctx = req->ctx;
-       bool cancel = false;
+       int v;
 
        INIT_HLIST_NODE(&req->hash_node);
-       io_init_poll_iocb(poll, mask, wake_func);
+       io_init_poll_iocb(poll, mask, io_poll_wake);
        poll->file = req->file;
        poll->wait.private = req;
 
@@ -5614,31 +5747,54 @@ static __poll_t __io_arm_poll_handler(struct io_kiocb *req,
        ipt->error = 0;
        ipt->nr_entries = 0;
 
+       /*
+        * Take the ownership to delay any tw execution up until we're done
+        * with poll arming. see io_poll_get_ownership().
+        */
+       atomic_set(&req->poll_refs, 1);
        mask = vfs_poll(req->file, &ipt->pt) & poll->events;
-       if (unlikely(!ipt->nr_entries) && !ipt->error)
-               ipt->error = -EINVAL;
+
+       if (mask && (poll->events & EPOLLONESHOT)) {
+               io_poll_remove_entries(req);
+               /* no one else has access to the req, forget about the ref */
+               return mask;
+       }
+       if (!mask && unlikely(ipt->error || !ipt->nr_entries)) {
+               io_poll_remove_entries(req);
+               if (!ipt->error)
+                       ipt->error = -EINVAL;
+               return 0;
+       }
 
        spin_lock(&ctx->completion_lock);
-       if (ipt->error || (mask && (poll->events & EPOLLONESHOT)))
-               io_poll_remove_double(req);
-       if (likely(poll->head)) {
-               spin_lock_irq(&poll->head->lock);
-               if (unlikely(list_empty(&poll->wait.entry))) {
-                       if (ipt->error)
-                               cancel = true;
-                       ipt->error = 0;
-                       mask = 0;
-               }
-               if ((mask && (poll->events & EPOLLONESHOT)) || ipt->error)
-                       list_del_init(&poll->wait.entry);
-               else if (cancel)
-                       WRITE_ONCE(poll->canceled, true);
-               else if (!poll->done) /* actually waiting for an event */
-                       io_poll_req_insert(req);
-               spin_unlock_irq(&poll->head->lock);
+       io_poll_req_insert(req);
+       spin_unlock(&ctx->completion_lock);
+
+       if (mask) {
+               /* can't multishot if failed, just queue the event we've got */
+               if (unlikely(ipt->error || !ipt->nr_entries))
+                       poll->events |= EPOLLONESHOT;
+               __io_poll_execute(req, mask);
+               return 0;
        }
 
-       return mask;
+       /*
+        * Release ownership. If someone tried to queue a tw while it was
+        * locked, kick it off for them.
+        */
+       v = atomic_dec_return(&req->poll_refs);
+       if (unlikely(v & IO_POLL_REF_MASK))
+               __io_poll_execute(req, 0);
+       return 0;
+}
+
+static void io_async_queue_proc(struct file *file, struct wait_queue_head *head,
+                              struct poll_table_struct *p)
+{
+       struct io_poll_table *pt = container_of(p, struct io_poll_table, pt);
+       struct async_poll *apoll = pt->req->apoll;
+
+       __io_queue_proc(&apoll->poll, pt, head, &apoll->double_poll);
 }
 
 enum {
@@ -5653,7 +5809,8 @@ static int io_arm_poll_handler(struct io_kiocb *req)
        struct io_ring_ctx *ctx = req->ctx;
        struct async_poll *apoll;
        struct io_poll_table ipt;
-       __poll_t ret, mask = EPOLLONESHOT | POLLERR | POLLPRI;
+       __poll_t mask = EPOLLONESHOT | POLLERR | POLLPRI;
+       int ret;
 
        if (!def->pollin && !def->pollout)
                return IO_APOLL_ABORTED;
@@ -5678,11 +5835,8 @@ static int io_arm_poll_handler(struct io_kiocb *req)
        req->apoll = apoll;
        req->flags |= REQ_F_POLLED;
        ipt.pt._qproc = io_async_queue_proc;
-       io_req_set_refcount(req);
 
-       ret = __io_arm_poll_handler(req, &apoll->poll, &ipt, mask,
-                                       io_async_wake);
-       spin_unlock(&ctx->completion_lock);
+       ret = __io_arm_poll_handler(req, &apoll->poll, &ipt, mask);
        if (ret || ipt.error)
                return ret ? IO_APOLL_READY : IO_APOLL_ABORTED;
 
@@ -5691,43 +5845,6 @@ static int io_arm_poll_handler(struct io_kiocb *req)
        return IO_APOLL_OK;
 }
 
-static bool __io_poll_remove_one(struct io_kiocb *req,
-                                struct io_poll_iocb *poll, bool do_cancel)
-       __must_hold(&req->ctx->completion_lock)
-{
-       bool do_complete = false;
-
-       if (!poll->head)
-               return false;
-       spin_lock_irq(&poll->head->lock);
-       if (do_cancel)
-               WRITE_ONCE(poll->canceled, true);
-       if (!list_empty(&poll->wait.entry)) {
-               list_del_init(&poll->wait.entry);
-               do_complete = true;
-       }
-       spin_unlock_irq(&poll->head->lock);
-       hash_del(&req->hash_node);
-       return do_complete;
-}
-
-static bool io_poll_remove_one(struct io_kiocb *req)
-       __must_hold(&req->ctx->completion_lock)
-{
-       bool do_complete;
-
-       io_poll_remove_double(req);
-       do_complete = __io_poll_remove_one(req, io_poll_get_single(req), true);
-
-       if (do_complete) {
-               io_cqring_fill_event(req->ctx, req->user_data, -ECANCELED, 0);
-               io_commit_cqring(req->ctx);
-               req_set_fail(req);
-               io_put_req_deferred(req);
-       }
-       return do_complete;
-}
-
 /*
  * Returns true if we found and killed one or more poll requests
  */
@@ -5736,7 +5853,8 @@ static __cold bool io_poll_remove_all(struct io_ring_ctx *ctx,
 {
        struct hlist_node *tmp;
        struct io_kiocb *req;
-       int posted = 0, i;
+       bool found = false;
+       int i;
 
        spin_lock(&ctx->completion_lock);
        for (i = 0; i < (1U << ctx->cancel_hash_bits); i++) {
@@ -5744,16 +5862,14 @@ static __cold bool io_poll_remove_all(struct io_ring_ctx *ctx,
 
                list = &ctx->cancel_hash[i];
                hlist_for_each_entry_safe(req, tmp, list, hash_node) {
-                       if (io_match_task_safe(req, tsk, cancel_all))
-                               posted += io_poll_remove_one(req);
+                       if (io_match_task_safe(req, tsk, cancel_all)) {
+                               io_poll_cancel_req(req);
+                               found = true;
+                       }
                }
        }
        spin_unlock(&ctx->completion_lock);
-
-       if (posted)
-               io_cqring_ev_posted(ctx);
-
-       return posted != 0;
+       return found;
 }
 
 static struct io_kiocb *io_poll_find(struct io_ring_ctx *ctx, __u64 sqe_addr,
@@ -5774,19 +5890,26 @@ static struct io_kiocb *io_poll_find(struct io_ring_ctx *ctx, __u64 sqe_addr,
        return NULL;
 }
 
+static bool io_poll_disarm(struct io_kiocb *req)
+       __must_hold(&ctx->completion_lock)
+{
+       if (!io_poll_get_ownership(req))
+               return false;
+       io_poll_remove_entries(req);
+       hash_del(&req->hash_node);
+       return true;
+}
+
 static int io_poll_cancel(struct io_ring_ctx *ctx, __u64 sqe_addr,
                          bool poll_only)
        __must_hold(&ctx->completion_lock)
 {
-       struct io_kiocb *req;
+       struct io_kiocb *req = io_poll_find(ctx, sqe_addr, poll_only);
 
-       req = io_poll_find(ctx, sqe_addr, poll_only);
        if (!req)
                return -ENOENT;
-       if (io_poll_remove_one(req))
-               return 0;
-
-       return -EALREADY;
+       io_poll_cancel_req(req);
+       return 0;
 }
 
 static __poll_t io_poll_parse_events(const struct io_uring_sqe *sqe,
@@ -5836,23 +5959,6 @@ static int io_poll_update_prep(struct io_kiocb *req,
        return 0;
 }
 
-static int io_poll_wake(struct wait_queue_entry *wait, unsigned mode, int sync,
-                       void *key)
-{
-       struct io_kiocb *req = wait->private;
-       struct io_poll_iocb *poll = &req->poll;
-
-       return __io_async_wake(req, poll, key_to_poll(key), io_poll_task_func);
-}
-
-static void io_poll_queue_proc(struct file *file, struct wait_queue_head *head,
-                              struct poll_table_struct *p)
-{
-       struct io_poll_table *pt = container_of(p, struct io_poll_table, pt);
-
-       __io_queue_proc(&pt->req->poll, pt, head, (struct io_poll_iocb **) &pt->req->async_data);
-}
-
 static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 {
        struct io_poll_iocb *poll = &req->poll;
@@ -5865,6 +5971,8 @@ static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe
        flags = READ_ONCE(sqe->len);
        if (flags & ~IORING_POLL_ADD_MULTI)
                return -EINVAL;
+       if ((flags & IORING_POLL_ADD_MULTI) && (req->flags & REQ_F_CQE_SKIP))
+               return -EINVAL;
 
        io_req_set_refcount(req);
        poll->events = io_poll_parse_events(sqe, flags);
@@ -5874,100 +5982,60 @@ static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe
 static int io_poll_add(struct io_kiocb *req, unsigned int issue_flags)
 {
        struct io_poll_iocb *poll = &req->poll;
-       struct io_ring_ctx *ctx = req->ctx;
        struct io_poll_table ipt;
-       __poll_t mask;
-       bool done;
+       int ret;
 
        ipt.pt._qproc = io_poll_queue_proc;
 
-       mask = __io_arm_poll_handler(req, &req->poll, &ipt, poll->events,
-                                       io_poll_wake);
-
-       if (mask) { /* no async, we'd stolen it */
-               ipt.error = 0;
-               done = __io_poll_complete(req, mask);
-               io_commit_cqring(req->ctx);
-       }
-       spin_unlock(&ctx->completion_lock);
-
-       if (mask) {
-               io_cqring_ev_posted(ctx);
-               if (done)
-                       io_put_req(req);
-       }
-       return ipt.error;
+       ret = __io_arm_poll_handler(req, &req->poll, &ipt, poll->events);
+       ret = ret ?: ipt.error;
+       if (ret)
+               __io_req_complete(req, issue_flags, ret, 0);
+       return 0;
 }
 
 static int io_poll_update(struct io_kiocb *req, unsigned int issue_flags)
 {
        struct io_ring_ctx *ctx = req->ctx;
        struct io_kiocb *preq;
-       bool completing;
-       int ret;
+       int ret2, ret = 0;
+       bool locked;
 
        spin_lock(&ctx->completion_lock);
        preq = io_poll_find(ctx, req->poll_update.old_user_data, true);
-       if (!preq) {
-               ret = -ENOENT;
-               goto err;
-       }
-
-       if (!req->poll_update.update_events && !req->poll_update.update_user_data) {
-               completing = true;
-               ret = io_poll_remove_one(preq) ? 0 : -EALREADY;
-               goto err;
-       }
-
-       /*
-        * Don't allow racy completion with singleshot, as we cannot safely
-        * update those. For multishot, if we're racing with completion, just
-        * let completion re-add it.
-        */
-       completing = !__io_poll_remove_one(preq, &preq->poll, false);
-       if (completing && (preq->poll.events & EPOLLONESHOT)) {
-               ret = -EALREADY;
-               goto err;
-       }
-       /* we now have a detached poll request. reissue. */
-       ret = 0;
-err:
-       if (ret < 0) {
+       if (!preq || !io_poll_disarm(preq)) {
                spin_unlock(&ctx->completion_lock);
-               req_set_fail(req);
-               io_req_complete(req, ret);
-               return 0;
-       }
-       /* only mask one event flags, keep behavior flags */
-       if (req->poll_update.update_events) {
-               preq->poll.events &= ~0xffff;
-               preq->poll.events |= req->poll_update.events & 0xffff;
-               preq->poll.events |= IO_POLL_UNMASK;
+               ret = preq ? -EALREADY : -ENOENT;
+               goto out;
        }
-       if (req->poll_update.update_user_data)
-               preq->user_data = req->poll_update.new_user_data;
        spin_unlock(&ctx->completion_lock);
 
-       /* complete update request, we're done with it */
-       io_req_complete(req, ret);
-
-       if (!completing) {
-               ret = io_poll_add(preq, issue_flags);
-               if (ret < 0) {
-                       req_set_fail(preq);
-                       io_req_complete(preq, ret);
+       if (req->poll_update.update_events || req->poll_update.update_user_data) {
+               /* only mask one event flags, keep behavior flags */
+               if (req->poll_update.update_events) {
+                       preq->poll.events &= ~0xffff;
+                       preq->poll.events |= req->poll_update.events & 0xffff;
+                       preq->poll.events |= IO_POLL_UNMASK;
                }
-       }
-       return 0;
-}
+               if (req->poll_update.update_user_data)
+                       preq->user_data = req->poll_update.new_user_data;
 
-static void io_req_task_timeout(struct io_kiocb *req, bool *locked)
-{
-       struct io_timeout_data *data = req->async_data;
+               ret2 = io_poll_add(preq, issue_flags);
+               /* successfully updated, don't complete poll request */
+               if (!ret2)
+                       goto out;
+       }
 
-       if (!(data->flags & IORING_TIMEOUT_ETIME_SUCCESS))
+       req_set_fail(preq);
+       preq->result = -ECANCELED;
+       locked = !(issue_flags & IO_URING_F_UNLOCKED);
+       io_req_task_complete(preq, &locked);
+out:
+       if (ret < 0)
                req_set_fail(req);
-       io_req_complete_post(req, -ETIME, 0);
+       /* complete update request, we're done with it */
+       __io_req_complete(req, issue_flags, ret, 0);
+       return 0;
 }
 
 static enum hrtimer_restart io_timeout_fn(struct hrtimer *timer)
@@ -5984,8 +6052,12 @@ static enum hrtimer_restart io_timeout_fn(struct hrtimer *timer)
                atomic_read(&req->ctx->cq_timeouts) + 1);
        spin_unlock_irqrestore(&ctx->timeout_lock, flags);
 
-       req->io_task_work.func = io_req_task_timeout;
-       io_req_task_work_add(req);
+       if (!(data->flags & IORING_TIMEOUT_ETIME_SUCCESS))
+               req_set_fail(req);
+
+       req->result = -ETIME;
+       req->io_task_work.func = io_req_task_complete;
+       io_req_task_work_add(req, false);
        return HRTIMER_NORESTART;
 }
 
@@ -6022,7 +6094,7 @@ static int io_timeout_cancel(struct io_ring_ctx *ctx, __u64 user_data)
                return PTR_ERR(req);
 
        req_set_fail(req);
-       io_cqring_fill_event(ctx, req->user_data, -ECANCELED, 0);
+       io_fill_cqe_req(req, -ECANCELED, 0);
        io_put_req_deferred(req);
        return 0;
 }
@@ -6111,6 +6183,8 @@ static int io_timeout_remove_prep(struct io_kiocb *req,
                        return -EINVAL;
                if (get_timespec64(&tr->ts, u64_to_user_ptr(sqe->addr2)))
                        return -EFAULT;
+               if (tr->ts.tv_sec < 0 || tr->ts.tv_nsec < 0)
+                       return -EINVAL;
        } else if (tr->flags) {
                /* timeout removal doesn't support flags */
                return -EINVAL;
@@ -6312,16 +6386,21 @@ static int io_try_cancel_userdata(struct io_kiocb *req, u64 sqe_addr)
        WARN_ON_ONCE(!io_wq_current_is_worker() && req->task != current);
 
        ret = io_async_cancel_one(req->task->io_uring, sqe_addr, ctx);
-       if (ret != -ENOENT)
-               return ret;
+       /*
+        * Fall-through even for -EALREADY, as we may have poll armed
+        * that need unarming.
+        */
+       if (!ret)
+               return 0;
 
        spin_lock(&ctx->completion_lock);
+       ret = io_poll_cancel(ctx, sqe_addr, false);
+       if (ret != -ENOENT)
+               goto out;
+
        spin_lock_irq(&ctx->timeout_lock);
        ret = io_timeout_cancel(ctx, sqe_addr);
        spin_unlock_irq(&ctx->timeout_lock);
-       if (ret != -ENOENT)
-               goto out;
-       ret = io_poll_cancel(ctx, sqe_addr, false);
 out:
        spin_unlock(&ctx->completion_lock);
        return ret;
@@ -6540,12 +6619,15 @@ static __cold void io_drain_req(struct io_kiocb *req)
        u32 seq = io_get_sequence(req);
 
        /* Still need defer if there is pending req in defer list. */
+       spin_lock(&ctx->completion_lock);
        if (!req_need_defer(req, seq) && list_empty_careful(&ctx->defer_list)) {
+               spin_unlock(&ctx->completion_lock);
 queue:
                ctx->drain_active = false;
                io_req_task_queue(req);
                return;
        }
+       spin_unlock(&ctx->completion_lock);
 
        ret = io_req_prep_async(req);
        if (ret) {
@@ -6576,10 +6658,8 @@ fail:
 
 static void io_clean_op(struct io_kiocb *req)
 {
-       if (req->flags & REQ_F_BUFFER_SELECTED) {
-               kfree(req->kbuf);
-               req->kbuf = NULL;
-       }
+       if (req->flags & REQ_F_BUFFER_SELECTED)
+               io_put_kbuf(req);
 
        if (req->flags & REQ_F_NEED_CLEANUP) {
                switch (req->opcode) {
@@ -6961,7 +7041,7 @@ static enum hrtimer_restart io_link_timeout_fn(struct hrtimer *timer)
        spin_unlock_irqrestore(&ctx->timeout_lock, flags);
 
        req->io_task_work.func = io_req_task_link_timeout;
-       io_req_task_work_add(req);
+       io_req_task_work_add(req, false);
        return HRTIMER_NORESTART;
 }
 
@@ -7096,10 +7176,10 @@ static void io_init_req_drain(struct io_kiocb *req)
                 * If we need to drain a request in the middle of a link, drain
                 * the head request and the next request/link after the current
                 * link. Considering sequential execution of links,
-                * IOSQE_IO_DRAIN will be maintained for every request of our
+                * REQ_F_IO_DRAIN will be maintained for every request of our
                 * link.
                 */
-               head->flags |= IOSQE_IO_DRAIN | REQ_F_FORCE_ASYNC;
+               head->flags |= REQ_F_IO_DRAIN | REQ_F_FORCE_ASYNC;
                ctx->drain_next = true;
        }
 }
@@ -7132,8 +7212,13 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
                if ((sqe_flags & IOSQE_BUFFER_SELECT) &&
                    !io_op_defs[opcode].buffer_select)
                        return -EOPNOTSUPP;
-               if (sqe_flags & IOSQE_IO_DRAIN)
+               if (sqe_flags & IOSQE_CQE_SKIP_SUCCESS)
+                       ctx->drain_disabled = true;
+               if (sqe_flags & IOSQE_IO_DRAIN) {
+                       if (ctx->drain_disabled)
+                               return -EOPNOTSUPP;
                        io_init_req_drain(req);
+               }
        }
        if (unlikely(ctx->restricted || ctx->drain_active || ctx->drain_next)) {
                if (ctx->restricted && !io_check_restriction(ctx, req, sqe_flags))
@@ -7145,7 +7230,7 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
                if (unlikely(ctx->drain_next) && !ctx->submit_state.link.head) {
                        ctx->drain_next = false;
                        ctx->drain_active = true;
-                       req->flags |= IOSQE_IO_DRAIN | REQ_F_FORCE_ASYNC;
+                       req->flags |= REQ_F_IO_DRAIN | REQ_F_FORCE_ASYNC;
                }
        }
 
@@ -8259,8 +8344,7 @@ static void __io_rsrc_put_work(struct io_rsrc_node *ref_node)
 
                        io_ring_submit_lock(ctx, lock_ring);
                        spin_lock(&ctx->completion_lock);
-                       io_cqring_fill_event(ctx, prsrc->tag, 0, 0);
-                       ctx->cq_extra++;
+                       io_fill_cqe_aux(ctx, prsrc->tag, 0, 0);
                        io_commit_cqring(ctx);
                        spin_unlock(&ctx->completion_lock);
                        io_cqring_ev_posted(ctx);
@@ -8672,6 +8756,7 @@ static __cold int io_uring_alloc_task_context(struct task_struct *task,
        task->io_uring = tctx;
        spin_lock_init(&tctx->task_lock);
        INIT_WQ_LIST(&tctx->task_list);
+       INIT_WQ_LIST(&tctx->prior_task_list);
        init_task_work(&tctx->task_work, tctx_task_work);
        return 0;
 }
@@ -9810,18 +9895,6 @@ static s64 tctx_inflight(struct io_uring_task *tctx, bool tracked)
        return percpu_counter_sum(&tctx->inflight);
 }
 
-static __cold void io_uring_drop_tctx_refs(struct task_struct *task)
-{
-       struct io_uring_task *tctx = task->io_uring;
-       unsigned int refs = tctx->cached_refs;
-
-       if (refs) {
-               tctx->cached_refs = 0;
-               percpu_counter_sub(&tctx->inflight, refs);
-               put_task_struct_many(task, refs);
-       }
-}
-
 /*
  * Find any io_uring ctx that this task has registered or done IO on, and cancel
  * requests. @sqd should be not-null IFF it's an SQPOLL thread cancellation.
@@ -9879,10 +9952,14 @@ static __cold void io_uring_cancel_generic(bool cancel_all,
                        schedule();
                finish_wait(&tctx->wait, &wait);
        } while (1);
-       atomic_dec(&tctx->in_idle);
 
        io_uring_clean_tctx(tctx);
        if (cancel_all) {
+               /*
+                * We shouldn't run task_works after cancel, so just leave
+                * ->in_idle set for normal exit.
+                */
+               atomic_dec(&tctx->in_idle);
                /* for exec all current's requests should be gone, kill tctx */
                __io_uring_free(current);
        }
@@ -10160,7 +10237,7 @@ static __cold void __io_uring_show_fdinfo(struct io_ring_ctx *ctx,
         * and sq_tail and cq_head are changed by userspace. But it's ok since
         * we usually use these info when it is stuck.
         */
-       seq_printf(m, "SqMask:\t\t0x%x\n", sq_mask);
+       seq_printf(m, "SqMask:\t0x%x\n", sq_mask);
        seq_printf(m, "SqHead:\t%u\n", sq_head);
        seq_printf(m, "SqTail:\t%u\n", sq_tail);
        seq_printf(m, "CachedSqHead:\t%u\n", ctx->cached_sq_head);
@@ -10469,7 +10546,7 @@ static __cold int io_uring_create(unsigned entries, struct io_uring_params *p,
                        IORING_FEAT_CUR_PERSONALITY | IORING_FEAT_FAST_POLL |
                        IORING_FEAT_POLL_32BITS | IORING_FEAT_SQPOLL_NONFIXED |
                        IORING_FEAT_EXT_ARG | IORING_FEAT_NATIVE_WORKERS |
-                       IORING_FEAT_RSRC_TAGS;
+                       IORING_FEAT_RSRC_TAGS | IORING_FEAT_CQE_SKIP;
 
        if (copy_to_user(params, p, sizeof(*p))) {
                ret = -EFAULT;