io_uring: fix ctx-exit io_rsrc_put_work() deadlock
[linux-2.6-microblaze.git] / fs / io_uring.c
index e55b21f..04c6d05 100644 (file)
@@ -78,6 +78,7 @@
 #include <linux/task_work.h>
 #include <linux/pagemap.h>
 #include <linux/io_uring.h>
+#include <linux/tracehook.h>
 
 #define CREATE_TRACE_POINTS
 #include <trace/events/io_uring.h>
@@ -465,7 +466,8 @@ struct io_ring_ctx {
                struct mm_struct                *mm_account;
 
                /* ctx exit and cancelation */
-               struct callback_head            *exit_task_work;
+               struct llist_head               fallback_llist;
+               struct delayed_work             fallback_work;
                struct work_struct              exit_work;
                struct list_head                tctx_list;
                struct completion               ref_comp;
@@ -784,9 +786,14 @@ struct async_poll {
        struct io_poll_iocb     *double_poll;
 };
 
+typedef void (*io_req_tw_func_t)(struct io_kiocb *req);
+
 struct io_task_work {
-       struct io_wq_work_node  node;
-       task_work_func_t        func;
+       union {
+               struct io_wq_work_node  node;
+               struct llist_node       fallback_node;
+       };
+       io_req_tw_func_t                func;
 };
 
 enum {
@@ -849,10 +856,7 @@ struct io_kiocb {
 
        /* used with ctx->iopoll_list with reads/writes */
        struct list_head                inflight_entry;
-       union {
-               struct io_task_work     io_task_work;
-               struct callback_head    task_work;
-       };
+       struct io_task_work             io_task_work;
        /* for polled requests, i.e. IORING_OP_POLL_ADD and async armed poll */
        struct hlist_node               hash_node;
        struct async_poll               *apoll;
@@ -1071,6 +1075,8 @@ static void io_submit_flush_completions(struct io_ring_ctx *ctx);
 static bool io_poll_remove_waitqs(struct io_kiocb *req);
 static int io_req_prep_async(struct io_kiocb *req);
 
+static void io_fallback_req_func(struct work_struct *unused);
+
 static struct kmem_cache *req_cachep;
 
 static const struct file_operations io_uring_fops;
@@ -1202,6 +1208,7 @@ static struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p)
        INIT_LIST_HEAD(&ctx->tctx_list);
        INIT_LIST_HEAD(&ctx->submit_state.comp.free_list);
        INIT_LIST_HEAD(&ctx->locked_free_list);
+       INIT_DELAYED_WORK(&ctx->fallback_work, io_fallback_req_func);
        return ctx;
 err:
        kfree(ctx->dummy_ubuf);
@@ -1273,8 +1280,17 @@ static void io_prep_async_link(struct io_kiocb *req)
 {
        struct io_kiocb *cur;
 
-       io_for_each_link(cur, req)
-               io_prep_async_work(cur);
+       if (req->flags & REQ_F_LINK_TIMEOUT) {
+               struct io_ring_ctx *ctx = req->ctx;
+
+               spin_lock_irq(&ctx->completion_lock);
+               io_for_each_link(cur, req)
+                       io_prep_async_work(cur);
+               spin_unlock_irq(&ctx->completion_lock);
+       } else {
+               io_for_each_link(cur, req)
+                       io_prep_async_work(cur);
+       }
 }
 
 static void io_queue_async_work(struct io_kiocb *req)
@@ -1288,6 +1304,17 @@ static void io_queue_async_work(struct io_kiocb *req)
 
        /* init ->work of the whole link before punting */
        io_prep_async_link(req);
+
+       /*
+        * Not expected to happen, but if we do have a bug where this _can_
+        * happen, catch it here and ensure the request is marked as
+        * canceled. That will make io-wq go through the usual work cancel
+        * procedure rather than attempt to run this request (or create a new
+        * worker for it).
+        */
+       if (WARN_ON_ONCE(!same_thread_group(req->task, current)))
+               req->work.flags |= IO_WQ_WORK_CANCEL;
+
        trace_io_uring_queue_async_work(ctx, io_wq_is_hashed(&req->work), req,
                                        &req->work, req->flags);
        io_wq_enqueue(tctx->io_wq, &req->work);
@@ -1473,7 +1500,8 @@ static bool __io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force)
        all_flushed = list_empty(&ctx->cq_overflow_list);
        if (all_flushed) {
                clear_bit(0, &ctx->check_cq_overflow);
-               ctx->rings->sq_flags &= ~IORING_SQ_CQ_OVERFLOW;
+               WRITE_ONCE(ctx->rings->sq_flags,
+                          ctx->rings->sq_flags & ~IORING_SQ_CQ_OVERFLOW);
        }
 
        if (posted)
@@ -1552,7 +1580,9 @@ static bool io_cqring_event_overflow(struct io_ring_ctx *ctx, u64 user_data,
        }
        if (list_empty(&ctx->cq_overflow_list)) {
                set_bit(0, &ctx->check_cq_overflow);
-               ctx->rings->sq_flags |= IORING_SQ_CQ_OVERFLOW;
+               WRITE_ONCE(ctx->rings->sq_flags,
+                          ctx->rings->sq_flags | IORING_SQ_CQ_OVERFLOW);
+
        }
        ocqe->cqe.user_data = user_data;
        ocqe->cqe.res = res;
@@ -1929,13 +1959,17 @@ static void tctx_task_work(struct callback_head *cb)
                                ctx = req->ctx;
                                percpu_ref_get(&ctx->refs);
                        }
-                       req->task_work.func(&req->task_work);
+                       req->io_task_work.func(req);
                        node = next;
                }
                if (wq_list_empty(&tctx->task_list)) {
+                       spin_lock_irq(&tctx->task_lock);
                        clear_bit(0, &tctx->task_state);
-                       if (wq_list_empty(&tctx->task_list))
+                       if (wq_list_empty(&tctx->task_list)) {
+                               spin_unlock_irq(&tctx->task_lock);
                                break;
+                       }
+                       spin_unlock_irq(&tctx->task_lock);
                        /* another tctx_task_work() is enqueued, yield */
                        if (test_and_set_bit(0, &tctx->task_state))
                                break;
@@ -1946,17 +1980,13 @@ static void tctx_task_work(struct callback_head *cb)
        ctx_flush_and_put(ctx);
 }
 
-static int io_req_task_work_add(struct io_kiocb *req)
+static void io_req_task_work_add(struct io_kiocb *req)
 {
        struct task_struct *tsk = req->task;
        struct io_uring_task *tctx = tsk->io_uring;
        enum task_work_notify_mode notify;
-       struct io_wq_work_node *node, *prev;
+       struct io_wq_work_node *node;
        unsigned long flags;
-       int ret = 0;
-
-       if (unlikely(tsk->flags & PF_EXITING))
-               return -ESRCH;
 
        WARN_ON_ONCE(!tctx);
 
@@ -1967,7 +1997,9 @@ static int io_req_task_work_add(struct io_kiocb *req)
        /* task_work already pending, we're done */
        if (test_bit(0, &tctx->task_state) ||
            test_and_set_bit(0, &tctx->task_state))
-               return 0;
+               return;
+       if (unlikely(tsk->flags & PF_EXITING))
+               goto fail;
 
        /*
         * SQPOLL kernel thread doesn't need notification, just a wakeup. For
@@ -1976,72 +2008,28 @@ static int io_req_task_work_add(struct io_kiocb *req)
         * will do the job.
         */
        notify = (req->ctx->flags & IORING_SETUP_SQPOLL) ? TWA_NONE : TWA_SIGNAL;
-
        if (!task_work_add(tsk, &tctx->task_work, notify)) {
                wake_up_process(tsk);
-               return 0;
+               return;
        }
-
-       /*
-        * Slow path - we failed, find and delete work. if the work is not
-        * in the list, it got run and we're fine.
-        */
+fail:
+       clear_bit(0, &tctx->task_state);
        spin_lock_irqsave(&tctx->task_lock, flags);
-       wq_list_for_each(node, prev, &tctx->task_list) {
-               if (&req->io_task_work.node == node) {
-                       wq_list_del(&tctx->task_list, node, prev);
-                       ret = 1;
-                       break;
-               }
-       }
+       node = tctx->task_list.first;
+       INIT_WQ_LIST(&tctx->task_list);
        spin_unlock_irqrestore(&tctx->task_lock, flags);
-       clear_bit(0, &tctx->task_state);
-       return ret;
-}
-
-static bool io_run_task_work_head(struct callback_head **work_head)
-{
-       struct callback_head *work, *next;
-       bool executed = false;
-
-       do {
-               work = xchg(work_head, NULL);
-               if (!work)
-                       break;
-
-               do {
-                       next = work->next;
-                       work->func(work);
-                       work = next;
-                       cond_resched();
-               } while (work);
-               executed = true;
-       } while (1);
-
-       return executed;
-}
-
-static void io_task_work_add_head(struct callback_head **work_head,
-                                 struct callback_head *task_work)
-{
-       struct callback_head *head;
 
-       do {
-               head = READ_ONCE(*work_head);
-               task_work->next = head;
-       } while (cmpxchg(work_head, head, task_work) != head);
-}
-
-static void io_req_task_work_add_fallback(struct io_kiocb *req,
-                                         task_work_func_t cb)
-{
-       init_task_work(&req->task_work, cb);
-       io_task_work_add_head(&req->ctx->exit_task_work, &req->task_work);
+       while (node) {
+               req = container_of(node, struct io_kiocb, io_task_work.node);
+               node = node->next;
+               if (llist_add(&req->io_task_work.fallback_node,
+                             &req->ctx->fallback_llist))
+                       schedule_delayed_work(&req->ctx->fallback_work, 1);
+       }
 }
 
-static void io_req_task_cancel(struct callback_head *cb)
+static void io_req_task_cancel(struct io_kiocb *req)
 {
-       struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work);
        struct io_ring_ctx *ctx = req->ctx;
 
        /* ctx is guaranteed to stay alive while we hold uring_lock */
@@ -2050,41 +2038,36 @@ static void io_req_task_cancel(struct callback_head *cb)
        mutex_unlock(&ctx->uring_lock);
 }
 
-static void __io_req_task_submit(struct io_kiocb *req)
+static void io_req_task_submit(struct io_kiocb *req)
 {
        struct io_ring_ctx *ctx = req->ctx;
 
        /* ctx stays valid until unlock, even if we drop all ours ctx->refs */
        mutex_lock(&ctx->uring_lock);
-       if (!(current->flags & PF_EXITING) && !current->in_execve)
+       if (!(req->task->flags & PF_EXITING) && !req->task->in_execve)
                __io_queue_sqe(req);
        else
                io_req_complete_failed(req, -EFAULT);
        mutex_unlock(&ctx->uring_lock);
 }
 
-static void io_req_task_submit(struct callback_head *cb)
-{
-       struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work);
-
-       __io_req_task_submit(req);
-}
-
 static void io_req_task_queue_fail(struct io_kiocb *req, int ret)
 {
        req->result = ret;
-       req->task_work.func = io_req_task_cancel;
-
-       if (unlikely(io_req_task_work_add(req)))
-               io_req_task_work_add_fallback(req, io_req_task_cancel);
+       req->io_task_work.func = io_req_task_cancel;
+       io_req_task_work_add(req);
 }
 
 static void io_req_task_queue(struct io_kiocb *req)
 {
-       req->task_work.func = io_req_task_submit;
+       req->io_task_work.func = io_req_task_submit;
+       io_req_task_work_add(req);
+}
 
-       if (unlikely(io_req_task_work_add(req)))
-               io_req_task_queue_fail(req, -ECANCELED);
+static void io_req_task_queue_reissue(struct io_kiocb *req)
+{
+       req->io_task_work.func = io_queue_async_work;
+       io_req_task_work_add(req);
 }
 
 static inline void io_queue_next(struct io_kiocb *req)
@@ -2195,18 +2178,10 @@ static inline void io_put_req(struct io_kiocb *req)
                io_free_req(req);
 }
 
-static void io_put_req_deferred_cb(struct callback_head *cb)
-{
-       struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work);
-
-       io_free_req(req);
-}
-
 static void io_free_req_deferred(struct io_kiocb *req)
 {
-       req->task_work.func = io_put_req_deferred_cb;
-       if (unlikely(io_req_task_work_add(req)))
-               io_req_task_work_add_fallback(req, io_put_req_deferred_cb);
+       req->io_task_work.func = io_free_req;
+       io_req_task_work_add(req);
 }
 
 static inline void io_put_req_deferred(struct io_kiocb *req, int refs)
@@ -2251,9 +2226,9 @@ static inline unsigned int io_put_rw_kbuf(struct io_kiocb *req)
 
 static inline bool io_run_task_work(void)
 {
-       if (current->task_works) {
+       if (test_thread_flag(TIF_NOTIFY_SIGNAL) || current->task_works) {
                __set_current_state(TASK_RUNNING);
-               task_work_run();
+               tracehook_notify_signal();
                return true;
        }
 
@@ -2264,7 +2239,7 @@ static inline bool io_run_task_work(void)
  * Find and free completed poll iocbs
  */
 static void io_iopoll_complete(struct io_ring_ctx *ctx, unsigned int *nr_events,
-                              struct list_head *done)
+                              struct list_head *done, bool resubmit)
 {
        struct req_batch rb;
        struct io_kiocb *req;
@@ -2279,11 +2254,11 @@ static void io_iopoll_complete(struct io_ring_ctx *ctx, unsigned int *nr_events,
                req = list_first_entry(done, struct io_kiocb, inflight_entry);
                list_del(&req->inflight_entry);
 
-               if (READ_ONCE(req->result) == -EAGAIN &&
+               if (READ_ONCE(req->result) == -EAGAIN && resubmit &&
                    !(req->flags & REQ_F_DONT_REISSUE)) {
                        req->iopoll_completed = 0;
                        req_ref_get(req);
-                       io_queue_async_work(req);
+                       io_req_task_queue_reissue(req);
                        continue;
                }
 
@@ -2303,7 +2278,7 @@ static void io_iopoll_complete(struct io_ring_ctx *ctx, unsigned int *nr_events,
 }
 
 static int io_do_iopoll(struct io_ring_ctx *ctx, unsigned int *nr_events,
-                       long min)
+                       long min, bool resubmit)
 {
        struct io_kiocb *req, *tmp;
        LIST_HEAD(done);
@@ -2346,7 +2321,7 @@ static int io_do_iopoll(struct io_ring_ctx *ctx, unsigned int *nr_events,
        }
 
        if (!list_empty(&done))
-               io_iopoll_complete(ctx, nr_events, &done);
+               io_iopoll_complete(ctx, nr_events, &done, resubmit);
 
        return ret;
 }
@@ -2364,7 +2339,7 @@ static void io_iopoll_try_reap_events(struct io_ring_ctx *ctx)
        while (!list_empty(&ctx->iopoll_list)) {
                unsigned int nr_events = 0;
 
-               io_do_iopoll(ctx, &nr_events, 0);
+               io_do_iopoll(ctx, &nr_events, 0, false);
 
                /* let it sleep and repeat later if can't complete a request */
                if (nr_events == 0)
@@ -2415,14 +2390,18 @@ static int io_iopoll_check(struct io_ring_ctx *ctx, long min)
                 * very same mutex.
                 */
                if (list_empty(&ctx->iopoll_list)) {
+                       u32 tail = ctx->cached_cq_tail;
+
                        mutex_unlock(&ctx->uring_lock);
                        io_run_task_work();
                        mutex_lock(&ctx->uring_lock);
 
-                       if (list_empty(&ctx->iopoll_list))
+                       /* some requests don't go through iopoll_list */
+                       if (tail != ctx->cached_cq_tail ||
+                           list_empty(&ctx->iopoll_list))
                                break;
                }
-               ret = io_do_iopoll(ctx, &nr_events, min);
+               ret = io_do_iopoll(ctx, &nr_events, min, true);
        } while (!ret && nr_events < min && !need_resched());
 out:
        mutex_unlock(&ctx->uring_lock);
@@ -2472,6 +2451,12 @@ static bool io_rw_should_reissue(struct io_kiocb *req)
         */
        if (percpu_ref_is_dying(&ctx->refs))
                return false;
+       /*
+        * Play it safe and assume not safe to re-import and reissue if we're
+        * not in the original thread group (or in task context).
+        */
+       if (!same_thread_group(req->task, current) || !in_task())
+               return false;
        return true;
 }
 #else
@@ -2485,6 +2470,17 @@ static bool io_rw_should_reissue(struct io_kiocb *req)
 }
 #endif
 
+static void io_fallback_req_func(struct work_struct *work)
+{
+       struct io_ring_ctx *ctx = container_of(work, struct io_ring_ctx,
+                                               fallback_work.work);
+       struct llist_node *node = llist_del_all(&ctx->fallback_llist);
+       struct io_kiocb *req, *tmp;
+
+       llist_for_each_entry_safe(req, tmp, node, io_task_work.fallback_node)
+               req->io_task_work.func(req);
+}
+
 static void __io_complete_rw(struct io_kiocb *req, long res, long res2,
                             unsigned int issue_flags)
 {
@@ -2791,7 +2787,7 @@ static void kiocb_done(struct kiocb *kiocb, ssize_t ret,
                req->flags &= ~REQ_F_REISSUE;
                if (io_resubmit_prep(req)) {
                        req_ref_get(req);
-                       io_queue_async_work(req);
+                       io_req_task_queue_reissue(req);
                } else {
                        int cflags = 0;
 
@@ -4846,14 +4842,13 @@ IO_NETOP_FN(recv);
 struct io_poll_table {
        struct poll_table_struct pt;
        struct io_kiocb *req;
+       int nr_entries;
        int error;
 };
 
 static int __io_async_wake(struct io_kiocb *req, struct io_poll_iocb *poll,
-                          __poll_t mask, task_work_func_t func)
+                          __poll_t mask, io_req_tw_func_t func)
 {
-       int ret;
-
        /* for instances that support it check for an event match first: */
        if (mask && !(mask & poll->events))
                return 0;
@@ -4863,7 +4858,7 @@ static int __io_async_wake(struct io_kiocb *req, struct io_poll_iocb *poll,
        list_del_init(&poll->wait.entry);
 
        req->result = mask;
-       req->task_work.func = func;
+       req->io_task_work.func = func;
 
        /*
         * If this fails, then the task is exiting. When a task exits, the
@@ -4871,11 +4866,7 @@ static int __io_async_wake(struct io_kiocb *req, struct io_poll_iocb *poll,
         * of executing it. We can't safely execute it anyway, as we may not
         * have the needed state needed for it anyway.
         */
-       ret = io_req_task_work_add(req);
-       if (unlikely(ret)) {
-               WRITE_ONCE(poll->canceled, true);
-               io_req_task_work_add_fallback(req, func);
-       }
+       io_req_task_work_add(req);
        return 1;
 }
 
@@ -4884,6 +4875,9 @@ static bool io_poll_rewait(struct io_kiocb *req, struct io_poll_iocb *poll)
 {
        struct io_ring_ctx *ctx = req->ctx;
 
+       if (unlikely(req->task->flags & PF_EXITING))
+               WRITE_ONCE(poll->canceled, true);
+
        if (!req->result && !READ_ONCE(poll->canceled)) {
                struct poll_table_struct pt = { ._key = poll->events };
 
@@ -4949,7 +4943,6 @@ static bool io_poll_complete(struct io_kiocb *req, __poll_t mask)
        if (req->poll.events & EPOLLONESHOT)
                flags = 0;
        if (!io_cqring_fill_event(ctx, req->user_data, error, flags)) {
-               io_poll_remove_waitqs(req);
                req->poll.done = true;
                flags = 0;
        }
@@ -4960,9 +4953,8 @@ static bool io_poll_complete(struct io_kiocb *req, __poll_t mask)
        return !(flags & IORING_CQE_F_MORE);
 }
 
-static void io_poll_task_func(struct callback_head *cb)
+static void io_poll_task_func(struct io_kiocb *req)
 {
-       struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work);
        struct io_ring_ctx *ctx = req->ctx;
        struct io_kiocb *nxt;
 
@@ -4973,6 +4965,7 @@ static void io_poll_task_func(struct callback_head *cb)
 
                done = io_poll_complete(req, req->result);
                if (done) {
+                       io_poll_remove_double(req);
                        hash_del(&req->hash_node);
                } else {
                        req->result = 0;
@@ -4984,7 +4977,7 @@ static void io_poll_task_func(struct callback_head *cb)
                if (done) {
                        nxt = io_put_req_find_next(req);
                        if (nxt)
-                               __io_req_task_submit(nxt);
+                               io_req_task_submit(nxt);
                }
        }
 }
@@ -5004,7 +4997,7 @@ static int io_poll_double_wake(struct wait_queue_entry *wait, unsigned mode,
 
        list_del_init(&wait->entry);
 
-       if (poll && poll->head) {
+       if (poll->head) {
                bool done;
 
                spin_lock(&poll->head->lock);
@@ -5043,11 +5036,11 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
        struct io_kiocb *req = pt->req;
 
        /*
-        * If poll->head is already set, it's because the file being polled
-        * uses multiple waitqueues for poll handling (eg one for read, one
-        * for write). Setup a separate io_poll_iocb if this happens.
+        * The file being polled uses multiple waitqueues for poll handling
+        * (e.g. one for read, one for write). Setup a separate io_poll_iocb
+        * if this happens.
         */
-       if (unlikely(poll->head)) {
+       if (unlikely(pt->nr_entries)) {
                struct io_poll_iocb *poll_one = poll;
 
                /* already have a 2nd entry, fail a third attempt */
@@ -5075,7 +5068,7 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
                *poll_ptr = poll;
        }
 
-       pt->error = 0;
+       pt->nr_entries++;
        poll->head = head;
 
        if (poll->events & EPOLLEXCLUSIVE)
@@ -5093,9 +5086,8 @@ static void io_async_queue_proc(struct file *file, struct wait_queue_head *head,
        __io_queue_proc(&apoll->poll, pt, head, &apoll->double_poll);
 }
 
-static void io_async_task_func(struct callback_head *cb)
+static void io_async_task_func(struct io_kiocb *req)
 {
-       struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work);
        struct async_poll *apoll = req->apoll;
        struct io_ring_ctx *ctx = req->ctx;
 
@@ -5111,7 +5103,7 @@ static void io_async_task_func(struct callback_head *cb)
        spin_unlock_irq(&ctx->completion_lock);
 
        if (!READ_ONCE(apoll->poll.canceled))
-               __io_req_task_submit(req);
+               io_req_task_submit(req);
        else
                io_req_complete_failed(req, -ECANCELED);
 }
@@ -5153,11 +5145,16 @@ static __poll_t __io_arm_poll_handler(struct io_kiocb *req,
 
        ipt->pt._key = mask;
        ipt->req = req;
-       ipt->error = -EINVAL;
+       ipt->error = 0;
+       ipt->nr_entries = 0;
 
        mask = vfs_poll(req->file, &ipt->pt) & poll->events;
+       if (unlikely(!ipt->nr_entries) && !ipt->error)
+               ipt->error = -EINVAL;
 
        spin_lock_irq(&ctx->completion_lock);
+       if (ipt->error || (mask && (poll->events & EPOLLONESHOT)))
+               io_poll_remove_double(req);
        if (likely(poll->head)) {
                spin_lock(&poll->head->lock);
                if (unlikely(list_empty(&poll->wait.entry))) {
@@ -5228,7 +5225,6 @@ static int io_arm_poll_handler(struct io_kiocb *req)
        ret = __io_arm_poll_handler(req, &apoll->poll, &ipt, mask,
                                        io_async_wake);
        if (ret || ipt.error) {
-               io_poll_remove_double(req);
                spin_unlock_irq(&ctx->completion_lock);
                if (ret)
                        return IO_APOLL_READY;
@@ -6068,10 +6064,12 @@ static bool io_drain_req(struct io_kiocb *req)
 
        ret = io_req_prep_async(req);
        if (ret)
-               return ret;
+               goto fail;
        io_prep_async_link(req);
        de = kmalloc(sizeof(*de), GFP_KERNEL);
        if (!de) {
+               ret = -ENOMEM;
+fail:
                io_req_complete_failed(req, ret);
                return true;
        }
@@ -6809,14 +6807,16 @@ static inline void io_ring_set_wakeup_flag(struct io_ring_ctx *ctx)
 {
        /* Tell userspace we may need a wakeup call */
        spin_lock_irq(&ctx->completion_lock);
-       ctx->rings->sq_flags |= IORING_SQ_NEED_WAKEUP;
+       WRITE_ONCE(ctx->rings->sq_flags,
+                  ctx->rings->sq_flags | IORING_SQ_NEED_WAKEUP);
        spin_unlock_irq(&ctx->completion_lock);
 }
 
 static inline void io_ring_clear_wakeup_flag(struct io_ring_ctx *ctx)
 {
        spin_lock_irq(&ctx->completion_lock);
-       ctx->rings->sq_flags &= ~IORING_SQ_NEED_WAKEUP;
+       WRITE_ONCE(ctx->rings->sq_flags,
+                  ctx->rings->sq_flags & ~IORING_SQ_NEED_WAKEUP);
        spin_unlock_irq(&ctx->completion_lock);
 }
 
@@ -6839,7 +6839,7 @@ static int __io_sq_thread(struct io_ring_ctx *ctx, bool cap_entries)
 
                mutex_lock(&ctx->uring_lock);
                if (!list_empty(&ctx->iopoll_list))
-                       io_do_iopoll(ctx, &nr_events, 0);
+                       io_do_iopoll(ctx, &nr_events, 0, true);
 
                /*
                 * Don't submit if refs are dying, good for io_uring_register(),
@@ -7138,16 +7138,6 @@ static void **io_alloc_page_table(size_t size)
        return table;
 }
 
-static inline void io_rsrc_ref_lock(struct io_ring_ctx *ctx)
-{
-       spin_lock_bh(&ctx->rsrc_ref_lock);
-}
-
-static inline void io_rsrc_ref_unlock(struct io_ring_ctx *ctx)
-{
-       spin_unlock_bh(&ctx->rsrc_ref_lock);
-}
-
 static void io_rsrc_node_destroy(struct io_rsrc_node *ref_node)
 {
        percpu_ref_exit(&ref_node->refs);
@@ -7164,9 +7154,9 @@ static void io_rsrc_node_switch(struct io_ring_ctx *ctx,
                struct io_rsrc_node *rsrc_node = ctx->rsrc_node;
 
                rsrc_node->rsrc_data = data_to_kill;
-               io_rsrc_ref_lock(ctx);
+               spin_lock_irq(&ctx->rsrc_ref_lock);
                list_add_tail(&rsrc_node->node, &ctx->rsrc_ref_list);
-               io_rsrc_ref_unlock(ctx);
+               spin_unlock_irq(&ctx->rsrc_ref_lock);
 
                atomic_inc(&data_to_kill->refs);
                percpu_ref_kill(&rsrc_node->refs);
@@ -7205,17 +7195,19 @@ static int io_rsrc_ref_quiesce(struct io_rsrc_data *data, struct io_ring_ctx *ct
                /* kill initial ref, already quiesced if zero */
                if (atomic_dec_and_test(&data->refs))
                        break;
+               mutex_unlock(&ctx->uring_lock);
                flush_delayed_work(&ctx->rsrc_put_work);
                ret = wait_for_completion_interruptible(&data->done);
-               if (!ret)
+               if (!ret) {
+                       mutex_lock(&ctx->uring_lock);
                        break;
+               }
 
                atomic_inc(&data->refs);
                /* wait for all works potentially completing data->done */
                flush_delayed_work(&ctx->rsrc_put_work);
                reinit_completion(&data->done);
 
-               mutex_unlock(&ctx->uring_lock);
                ret = io_run_task_work_sig();
                mutex_lock(&ctx->uring_lock);
        } while (ret >= 0);
@@ -7674,9 +7666,10 @@ static void io_rsrc_node_ref_zero(struct percpu_ref *ref)
 {
        struct io_rsrc_node *node = container_of(ref, struct io_rsrc_node, refs);
        struct io_ring_ctx *ctx = node->rsrc_data->ctx;
+       unsigned long flags;
        bool first_add = false;
 
-       io_rsrc_ref_lock(ctx);
+       spin_lock_irqsave(&ctx->rsrc_ref_lock, flags);
        node->done = true;
 
        while (!list_empty(&ctx->rsrc_ref_list)) {
@@ -7688,7 +7681,7 @@ static void io_rsrc_node_ref_zero(struct percpu_ref *ref)
                list_del(&node->node);
                first_add |= llist_add(&node->llist, &ctx->rsrc_put_llist);
        }
-       io_rsrc_ref_unlock(ctx);
+       spin_unlock_irqrestore(&ctx->rsrc_ref_lock, flags);
 
        if (first_add)
                mod_delayed_work(system_wq, &ctx->rsrc_put_work, HZ);
@@ -7946,15 +7939,19 @@ static struct io_wq *io_init_wq_offload(struct io_ring_ctx *ctx,
        struct io_wq_data data;
        unsigned int concurrency;
 
+       mutex_lock(&ctx->uring_lock);
        hash = ctx->hash_map;
        if (!hash) {
                hash = kzalloc(sizeof(*hash), GFP_KERNEL);
-               if (!hash)
+               if (!hash) {
+                       mutex_unlock(&ctx->uring_lock);
                        return ERR_PTR(-ENOMEM);
+               }
                refcount_set(&hash->refs, 1);
                init_waitqueue_head(&hash->wait);
                ctx->hash_map = hash;
        }
+       mutex_unlock(&ctx->uring_lock);
 
        data.hash = hash;
        data.task = task;
@@ -8028,9 +8025,11 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
                f = fdget(p->wq_fd);
                if (!f.file)
                        return -ENXIO;
-               fdput(f);
-               if (f.file->f_op != &io_uring_fops)
+               if (f.file->f_op != &io_uring_fops) {
+                       fdput(f);
                        return -EINVAL;
+               }
+               fdput(f);
        }
        if (ctx->flags & IORING_SETUP_SQPOLL) {
                struct task_struct *tsk;
@@ -8653,13 +8652,10 @@ static void io_req_caches_free(struct io_ring_ctx *ctx)
        mutex_unlock(&ctx->uring_lock);
 }
 
-static bool io_wait_rsrc_data(struct io_rsrc_data *data)
+static void io_wait_rsrc_data(struct io_rsrc_data *data)
 {
-       if (!data)
-               return false;
-       if (!atomic_dec_and_test(&data->refs))
+       if (data && !atomic_dec_and_test(&data->refs))
                wait_for_completion(&data->done);
-       return true;
 }
 
 static void io_ring_ctx_free(struct io_ring_ctx *ctx)
@@ -8671,10 +8667,14 @@ static void io_ring_ctx_free(struct io_ring_ctx *ctx)
                ctx->mm_account = NULL;
        }
 
+       /* __io_rsrc_put_work() may need uring_lock to progress, wait w/o it */
+       io_wait_rsrc_data(ctx->buf_data);
+       io_wait_rsrc_data(ctx->file_data);
+
        mutex_lock(&ctx->uring_lock);
-       if (io_wait_rsrc_data(ctx->buf_data))
+       if (ctx->buf_data)
                __io_sqe_buffers_unregister(ctx);
-       if (io_wait_rsrc_data(ctx->file_data))
+       if (ctx->file_data)
                __io_sqe_files_unregister(ctx);
        if (ctx->rings)
                __io_cqring_overflow_flush(ctx, true);
@@ -8767,11 +8767,6 @@ static int io_unregister_personality(struct io_ring_ctx *ctx, unsigned id)
        return -EINVAL;
 }
 
-static inline bool io_run_ctx_fallback(struct io_ring_ctx *ctx)
-{
-       return io_run_task_work_head(&ctx->exit_task_work);
-}
-
 struct io_tctx_exit {
        struct callback_head            task_work;
        struct completion               completion;
@@ -8837,7 +8832,7 @@ static void io_ring_exit_work(struct work_struct *work)
        /*
         * Some may use context even when all refs and requests have been put,
         * and they are free to do so while still holding uring_lock or
-        * completion_lock, see __io_req_task_submit(). Apart from other work,
+        * completion_lock, see io_req_task_submit(). Apart from other work,
         * this lock/unlock section also waits them to finish.
         */
        mutex_lock(&ctx->uring_lock);
@@ -9036,7 +9031,6 @@ static void io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
                ret |= io_kill_timeouts(ctx, task, cancel_all);
                if (task)
                        ret |= io_run_task_work();
-               ret |= io_run_ctx_fallback(ctx);
                if (!ret)
                        break;
                cond_resched();