io_uring: always wait for sqd exited when stopping SQPOLL thread
[linux-2.6-microblaze.git] / fs / io_uring.c
index 84eb499..62f998b 100644 (file)
@@ -985,6 +985,7 @@ static const struct io_op_def io_op_defs[] = {
        [IORING_OP_UNLINKAT] = {},
 };
 
+static bool io_disarm_next(struct io_kiocb *req);
 static void io_uring_del_task_file(unsigned long index);
 static void io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
                                         struct task_struct *task,
@@ -1525,15 +1526,14 @@ static void io_cqring_fill_event(struct io_kiocb *req, long res)
        __io_cqring_fill_event(req, res, 0);
 }
 
-static inline void io_req_complete_post(struct io_kiocb *req, long res,
-                                       unsigned int cflags)
+static void io_req_complete_post(struct io_kiocb *req, long res,
+                                unsigned int cflags)
 {
        struct io_ring_ctx *ctx = req->ctx;
        unsigned long flags;
 
        spin_lock_irqsave(&ctx->completion_lock, flags);
        __io_cqring_fill_event(req, res, cflags);
-       io_commit_cqring(ctx);
        /*
         * If we're the last reference to this request, add to our locked
         * free_list cache.
@@ -1541,19 +1541,26 @@ static inline void io_req_complete_post(struct io_kiocb *req, long res,
        if (refcount_dec_and_test(&req->refs)) {
                struct io_comp_state *cs = &ctx->submit_state.comp;
 
+               if (req->flags & (REQ_F_LINK | REQ_F_HARDLINK)) {
+                       if (req->flags & (REQ_F_LINK_TIMEOUT | REQ_F_FAIL_LINK))
+                               io_disarm_next(req);
+                       if (req->link) {
+                               io_req_task_queue(req->link);
+                               req->link = NULL;
+                       }
+               }
                io_dismantle_req(req);
                io_put_task(req->task, 1);
                list_add(&req->compl.list, &cs->locked_free_list);
                cs->locked_free_nr++;
        } else
                req = NULL;
+       io_commit_cqring(ctx);
        spin_unlock_irqrestore(&ctx->completion_lock, flags);
-
        io_cqring_ev_posted(ctx);
-       if (req) {
-               io_queue_next(req);
+
+       if (req)
                percpu_ref_put(&ctx->refs);
-       }
 }
 
 static void io_req_complete_state(struct io_kiocb *req, long res,
@@ -1705,15 +1712,11 @@ static inline void io_remove_next_linked(struct io_kiocb *req)
        nxt->link = NULL;
 }
 
-static void io_kill_linked_timeout(struct io_kiocb *req)
+static bool io_kill_linked_timeout(struct io_kiocb *req)
+       __must_hold(&req->ctx->completion_lock)
 {
-       struct io_ring_ctx *ctx = req->ctx;
-       struct io_kiocb *link;
+       struct io_kiocb *link = req->link;
        bool cancelled = false;
-       unsigned long flags;
-
-       spin_lock_irqsave(&ctx->completion_lock, flags);
-       link = req->link;
 
        /*
         * Can happen if a linked timeout fired and link had been like
@@ -1728,50 +1731,48 @@ static void io_kill_linked_timeout(struct io_kiocb *req)
                ret = hrtimer_try_to_cancel(&io->timer);
                if (ret != -1) {
                        io_cqring_fill_event(link, -ECANCELED);
-                       io_commit_cqring(ctx);
+                       io_put_req_deferred(link, 1);
                        cancelled = true;
                }
        }
        req->flags &= ~REQ_F_LINK_TIMEOUT;
-       spin_unlock_irqrestore(&ctx->completion_lock, flags);
-
-       if (cancelled) {
-               io_cqring_ev_posted(ctx);
-               io_put_req(link);
-       }
+       return cancelled;
 }
 
-
 static void io_fail_links(struct io_kiocb *req)
+       __must_hold(&req->ctx->completion_lock)
 {
-       struct io_kiocb *link, *nxt;
-       struct io_ring_ctx *ctx = req->ctx;
-       unsigned long flags;
+       struct io_kiocb *nxt, *link = req->link;
 
-       spin_lock_irqsave(&ctx->completion_lock, flags);
-       link = req->link;
        req->link = NULL;
-
        while (link) {
                nxt = link->link;
                link->link = NULL;
 
                trace_io_uring_fail_link(req, link);
                io_cqring_fill_event(link, -ECANCELED);
-
                io_put_req_deferred(link, 2);
                link = nxt;
        }
-       io_commit_cqring(ctx);
-       spin_unlock_irqrestore(&ctx->completion_lock, flags);
+}
 
-       io_cqring_ev_posted(ctx);
+static bool io_disarm_next(struct io_kiocb *req)
+       __must_hold(&req->ctx->completion_lock)
+{
+       bool posted = false;
+
+       if (likely(req->flags & REQ_F_LINK_TIMEOUT))
+               posted = io_kill_linked_timeout(req);
+       if (unlikely(req->flags & REQ_F_FAIL_LINK)) {
+               posted |= (req->link != NULL);
+               io_fail_links(req);
+       }
+       return posted;
 }
 
 static struct io_kiocb *__io_req_find_next(struct io_kiocb *req)
 {
-       if (req->flags & REQ_F_LINK_TIMEOUT)
-               io_kill_linked_timeout(req);
+       struct io_kiocb *nxt;
 
        /*
         * If LINK is set, we have dependent requests in this chain. If we
@@ -1779,14 +1780,22 @@ static struct io_kiocb *__io_req_find_next(struct io_kiocb *req)
         * dependencies to the next request. In case of failure, fail the rest
         * of the chain.
         */
-       if (likely(!(req->flags & REQ_F_FAIL_LINK))) {
-               struct io_kiocb *nxt = req->link;
+       if (req->flags & (REQ_F_LINK_TIMEOUT | REQ_F_FAIL_LINK)) {
+               struct io_ring_ctx *ctx = req->ctx;
+               unsigned long flags;
+               bool posted;
 
-               req->link = NULL;
-               return nxt;
+               spin_lock_irqsave(&ctx->completion_lock, flags);
+               posted = io_disarm_next(req);
+               if (posted)
+                       io_commit_cqring(req->ctx);
+               spin_unlock_irqrestore(&ctx->completion_lock, flags);
+               if (posted)
+                       io_cqring_ev_posted(ctx);
        }
-       io_fail_links(req);
-       return NULL;
+       nxt = req->link;
+       req->link = NULL;
+       return nxt;
 }
 
 static inline struct io_kiocb *io_req_find_next(struct io_kiocb *req)
@@ -6318,6 +6327,9 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
        refcount_set(&req->refs, 2);
        req->task = current;
        req->result = 0;
+       req->work.list.next = NULL;
+       req->work.creds = NULL;
+       req->work.flags = 0;
 
        /* enforce forwards compatibility on users */
        if (unlikely(sqe_flags & ~SQE_VALID_FLAGS)) {
@@ -6335,17 +6347,13 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
            !io_op_defs[req->opcode].buffer_select)
                return -EOPNOTSUPP;
 
-       req->work.list.next = NULL;
        personality = READ_ONCE(sqe->personality);
        if (personality) {
                req->work.creds = xa_load(&ctx->personalities, personality);
                if (!req->work.creds)
                        return -EINVAL;
                get_cred(req->work.creds);
-       } else {
-               req->work.creds = NULL;
        }
-       req->work.flags = 0;
        state = &ctx->submit_state;
 
        /*
@@ -7071,12 +7079,9 @@ static void io_sq_thread_stop(struct io_sq_data *sqd)
        if (test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state))
                return;
        down_write(&sqd->rw_lock);
-       if (!sqd->thread) {
-               up_write(&sqd->rw_lock);
-               return;
-       }
        set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
-       wake_up_process(sqd->thread);
+       if (sqd->thread)
+               wake_up_process(sqd->thread);
        up_write(&sqd->rw_lock);
        wait_for_completion(&sqd->exited);
 }
@@ -7841,9 +7846,9 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
 
                        ret = -EINVAL;
                        if (cpu >= nr_cpu_ids)
-                               goto err;
+                               goto err_sqpoll;
                        if (!cpu_online(cpu))
-                               goto err;
+                               goto err_sqpoll;
 
                        sqd->sq_cpu = cpu;
                } else {
@@ -7854,12 +7859,11 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
                tsk = create_io_thread(io_sq_thread, sqd, NUMA_NO_NODE);
                if (IS_ERR(tsk)) {
                        ret = PTR_ERR(tsk);
-                       goto err;
+                       goto err_sqpoll;
                }
-               ret = io_uring_alloc_task_context(tsk, ctx);
-               if (ret)
-                       set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
+
                sqd->thread = tsk;
+               ret = io_uring_alloc_task_context(tsk, ctx);
                wake_up_new_task(tsk);
                if (ret)
                        goto err;
@@ -7874,6 +7878,9 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
 err:
        io_sq_thread_finish(ctx);
        return ret;
+err_sqpoll:
+       complete(&ctx->sq_data->exited);
+       goto err;
 }
 
 static inline void __io_unaccount_mem(struct user_struct *user,
@@ -9015,7 +9022,6 @@ static unsigned long io_uring_nommu_get_unmapped_area(struct file *file,
 
 static int io_sqpoll_wait_sq(struct io_ring_ctx *ctx)
 {
-       int ret = 0;
        DEFINE_WAIT(wait);
 
        do {
@@ -9029,7 +9035,7 @@ static int io_sqpoll_wait_sq(struct io_ring_ctx *ctx)
        } while (!signal_pending(current));
 
        finish_wait(&ctx->sqo_sq_wait, &wait);
-       return ret;
+       return 0;
 }
 
 static int io_get_ext_arg(unsigned flags, const void __user *argp, size_t *argsz,