io_uring: always wait for sqd exited when stopping SQPOLL thread
[linux-2.6-microblaze.git] / fs / io_uring.c
index 2a3542b..62f998b 100644 (file)
@@ -406,7 +406,8 @@ struct io_ring_ctx {
 
        struct idr              io_buffer_idr;
 
-       struct idr              personality_idr;
+       struct xarray           personalities;
+       u32                     pers_next;
 
        struct {
                unsigned                cached_cq_tail;
@@ -984,6 +985,7 @@ static const struct io_op_def io_op_defs[] = {
        [IORING_OP_UNLINKAT] = {},
 };
 
+static bool io_disarm_next(struct io_kiocb *req);
 static void io_uring_del_task_file(unsigned long index);
 static void io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
                                         struct task_struct *task,
@@ -1137,7 +1139,7 @@ static struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p)
        init_completion(&ctx->ref_comp);
        init_completion(&ctx->sq_thread_comp);
        idr_init(&ctx->io_buffer_idr);
-       idr_init(&ctx->personality_idr);
+       xa_init_flags(&ctx->personalities, XA_FLAGS_ALLOC1);
        mutex_init(&ctx->uring_lock);
        init_waitqueue_head(&ctx->wait);
        spin_lock_init(&ctx->completion_lock);
@@ -1524,15 +1526,14 @@ static void io_cqring_fill_event(struct io_kiocb *req, long res)
        __io_cqring_fill_event(req, res, 0);
 }
 
-static inline void io_req_complete_post(struct io_kiocb *req, long res,
-                                       unsigned int cflags)
+static void io_req_complete_post(struct io_kiocb *req, long res,
+                                unsigned int cflags)
 {
        struct io_ring_ctx *ctx = req->ctx;
        unsigned long flags;
 
        spin_lock_irqsave(&ctx->completion_lock, flags);
        __io_cqring_fill_event(req, res, cflags);
-       io_commit_cqring(ctx);
        /*
         * If we're the last reference to this request, add to our locked
         * free_list cache.
@@ -1540,19 +1541,26 @@ static inline void io_req_complete_post(struct io_kiocb *req, long res,
        if (refcount_dec_and_test(&req->refs)) {
                struct io_comp_state *cs = &ctx->submit_state.comp;
 
+               if (req->flags & (REQ_F_LINK | REQ_F_HARDLINK)) {
+                       if (req->flags & (REQ_F_LINK_TIMEOUT | REQ_F_FAIL_LINK))
+                               io_disarm_next(req);
+                       if (req->link) {
+                               io_req_task_queue(req->link);
+                               req->link = NULL;
+                       }
+               }
                io_dismantle_req(req);
                io_put_task(req->task, 1);
                list_add(&req->compl.list, &cs->locked_free_list);
                cs->locked_free_nr++;
        } else
                req = NULL;
+       io_commit_cqring(ctx);
        spin_unlock_irqrestore(&ctx->completion_lock, flags);
-
        io_cqring_ev_posted(ctx);
-       if (req) {
-               io_queue_next(req);
+
+       if (req)
                percpu_ref_put(&ctx->refs);
-       }
 }
 
 static void io_req_complete_state(struct io_kiocb *req, long res,
@@ -1704,15 +1712,11 @@ static inline void io_remove_next_linked(struct io_kiocb *req)
        nxt->link = NULL;
 }
 
-static void io_kill_linked_timeout(struct io_kiocb *req)
+static bool io_kill_linked_timeout(struct io_kiocb *req)
+       __must_hold(&req->ctx->completion_lock)
 {
-       struct io_ring_ctx *ctx = req->ctx;
-       struct io_kiocb *link;
+       struct io_kiocb *link = req->link;
        bool cancelled = false;
-       unsigned long flags;
-
-       spin_lock_irqsave(&ctx->completion_lock, flags);
-       link = req->link;
 
        /*
         * Can happen if a linked timeout fired and link had been like
@@ -1727,50 +1731,48 @@ static void io_kill_linked_timeout(struct io_kiocb *req)
                ret = hrtimer_try_to_cancel(&io->timer);
                if (ret != -1) {
                        io_cqring_fill_event(link, -ECANCELED);
-                       io_commit_cqring(ctx);
+                       io_put_req_deferred(link, 1);
                        cancelled = true;
                }
        }
        req->flags &= ~REQ_F_LINK_TIMEOUT;
-       spin_unlock_irqrestore(&ctx->completion_lock, flags);
-
-       if (cancelled) {
-               io_cqring_ev_posted(ctx);
-               io_put_req(link);
-       }
+       return cancelled;
 }
 
-
 static void io_fail_links(struct io_kiocb *req)
+       __must_hold(&req->ctx->completion_lock)
 {
-       struct io_kiocb *link, *nxt;
-       struct io_ring_ctx *ctx = req->ctx;
-       unsigned long flags;
+       struct io_kiocb *nxt, *link = req->link;
 
-       spin_lock_irqsave(&ctx->completion_lock, flags);
-       link = req->link;
        req->link = NULL;
-
        while (link) {
                nxt = link->link;
                link->link = NULL;
 
                trace_io_uring_fail_link(req, link);
                io_cqring_fill_event(link, -ECANCELED);
-
                io_put_req_deferred(link, 2);
                link = nxt;
        }
-       io_commit_cqring(ctx);
-       spin_unlock_irqrestore(&ctx->completion_lock, flags);
+}
 
-       io_cqring_ev_posted(ctx);
+static bool io_disarm_next(struct io_kiocb *req)
+       __must_hold(&req->ctx->completion_lock)
+{
+       bool posted = false;
+
+       if (likely(req->flags & REQ_F_LINK_TIMEOUT))
+               posted = io_kill_linked_timeout(req);
+       if (unlikely(req->flags & REQ_F_FAIL_LINK)) {
+               posted |= (req->link != NULL);
+               io_fail_links(req);
+       }
+       return posted;
 }
 
 static struct io_kiocb *__io_req_find_next(struct io_kiocb *req)
 {
-       if (req->flags & REQ_F_LINK_TIMEOUT)
-               io_kill_linked_timeout(req);
+       struct io_kiocb *nxt;
 
        /*
         * If LINK is set, we have dependent requests in this chain. If we
@@ -1778,14 +1780,22 @@ static struct io_kiocb *__io_req_find_next(struct io_kiocb *req)
         * dependencies to the next request. In case of failure, fail the rest
         * of the chain.
         */
-       if (likely(!(req->flags & REQ_F_FAIL_LINK))) {
-               struct io_kiocb *nxt = req->link;
+       if (req->flags & (REQ_F_LINK_TIMEOUT | REQ_F_FAIL_LINK)) {
+               struct io_ring_ctx *ctx = req->ctx;
+               unsigned long flags;
+               bool posted;
 
-               req->link = NULL;
-               return nxt;
+               spin_lock_irqsave(&ctx->completion_lock, flags);
+               posted = io_disarm_next(req);
+               if (posted)
+                       io_commit_cqring(req->ctx);
+               spin_unlock_irqrestore(&ctx->completion_lock, flags);
+               if (posted)
+                       io_cqring_ev_posted(ctx);
        }
-       io_fail_links(req);
-       return NULL;
+       nxt = req->link;
+       req->link = NULL;
+       return nxt;
 }
 
 static inline struct io_kiocb *io_req_find_next(struct io_kiocb *req)
@@ -5573,22 +5583,30 @@ add:
        return 0;
 }
 
+struct io_cancel_data {
+       struct io_ring_ctx *ctx;
+       u64 user_data;
+};
+
 static bool io_cancel_cb(struct io_wq_work *work, void *data)
 {
        struct io_kiocb *req = container_of(work, struct io_kiocb, work);
+       struct io_cancel_data *cd = data;
 
-       return req->user_data == (unsigned long) data;
+       return req->ctx == cd->ctx && req->user_data == cd->user_data;
 }
 
-static int io_async_cancel_one(struct io_uring_task *tctx, void *sqe_addr)
+static int io_async_cancel_one(struct io_uring_task *tctx, u64 user_data,
+                              struct io_ring_ctx *ctx)
 {
+       struct io_cancel_data data = { .ctx = ctx, .user_data = user_data, };
        enum io_wq_cancel cancel_ret;
        int ret = 0;
 
-       if (!tctx->io_wq)
+       if (!tctx || !tctx->io_wq)
                return -ENOENT;
 
-       cancel_ret = io_wq_cancel_cb(tctx->io_wq, io_cancel_cb, sqe_addr, false);
+       cancel_ret = io_wq_cancel_cb(tctx->io_wq, io_cancel_cb, &data, false);
        switch (cancel_ret) {
        case IO_WQ_CANCEL_OK:
                ret = 0;
@@ -5611,8 +5629,7 @@ static void io_async_find_and_cancel(struct io_ring_ctx *ctx,
        unsigned long flags;
        int ret;
 
-       ret = io_async_cancel_one(req->task->io_uring,
-                                       (void *) (unsigned long) sqe_addr);
+       ret = io_async_cancel_one(req->task->io_uring, sqe_addr, ctx);
        if (ret != -ENOENT) {
                spin_lock_irqsave(&ctx->completion_lock, flags);
                goto done;
@@ -6310,6 +6327,9 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
        refcount_set(&req->refs, 2);
        req->task = current;
        req->result = 0;
+       req->work.list.next = NULL;
+       req->work.creds = NULL;
+       req->work.flags = 0;
 
        /* enforce forwards compatibility on users */
        if (unlikely(sqe_flags & ~SQE_VALID_FLAGS)) {
@@ -6327,17 +6347,13 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
            !io_op_defs[req->opcode].buffer_select)
                return -EOPNOTSUPP;
 
-       req->work.list.next = NULL;
        personality = READ_ONCE(sqe->personality);
        if (personality) {
-               req->work.creds = idr_find(&ctx->personality_idr, personality);
+               req->work.creds = xa_load(&ctx->personalities, personality);
                if (!req->work.creds)
                        return -EINVAL;
                get_cred(req->work.creds);
-       } else {
-               req->work.creds = NULL;
        }
-       req->work.flags = 0;
        state = &ctx->submit_state;
 
        /*
@@ -6599,7 +6615,8 @@ static int __io_sq_thread(struct io_ring_ctx *ctx, bool cap_entries)
                if (!list_empty(&ctx->iopoll_list))
                        io_do_iopoll(ctx, &nr_events, 0);
 
-               if (to_submit && likely(!percpu_ref_is_dying(&ctx->refs)))
+               if (to_submit && likely(!percpu_ref_is_dying(&ctx->refs)) &&
+                   !(ctx->flags & IORING_SETUP_R_DISABLED))
                        ret = io_submit_sqes(ctx, to_submit);
                mutex_unlock(&ctx->uring_lock);
        }
@@ -7062,12 +7079,9 @@ static void io_sq_thread_stop(struct io_sq_data *sqd)
        if (test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state))
                return;
        down_write(&sqd->rw_lock);
-       if (!sqd->thread) {
-               up_write(&sqd->rw_lock);
-               return;
-       }
        set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
-       wake_up_process(sqd->thread);
+       if (sqd->thread)
+               wake_up_process(sqd->thread);
        up_write(&sqd->rw_lock);
        wait_for_completion(&sqd->exited);
 }
@@ -7832,9 +7846,9 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
 
                        ret = -EINVAL;
                        if (cpu >= nr_cpu_ids)
-                               goto err;
+                               goto err_sqpoll;
                        if (!cpu_online(cpu))
-                               goto err;
+                               goto err_sqpoll;
 
                        sqd->sq_cpu = cpu;
                } else {
@@ -7845,15 +7859,15 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
                tsk = create_io_thread(io_sq_thread, sqd, NUMA_NO_NODE);
                if (IS_ERR(tsk)) {
                        ret = PTR_ERR(tsk);
-                       goto err;
+                       goto err_sqpoll;
                }
-               ret = io_uring_alloc_task_context(tsk, ctx);
-               if (ret)
-                       set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
+
                sqd->thread = tsk;
+               ret = io_uring_alloc_task_context(tsk, ctx);
                wake_up_new_task(tsk);
                if (ret)
                        goto err;
+               complete(&sqd->startup);
        } else if (p->flags & IORING_SETUP_SQ_AFF) {
                /* Can't have SQ_AFF without SQPOLL */
                ret = -EINVAL;
@@ -7864,15 +7878,9 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
 err:
        io_sq_thread_finish(ctx);
        return ret;
-}
-
-static void io_sq_offload_start(struct io_ring_ctx *ctx)
-{
-       struct io_sq_data *sqd = ctx->sq_data;
-
-       ctx->flags &= ~IORING_SETUP_R_DISABLED;
-       if (ctx->flags & IORING_SETUP_SQPOLL)
-               complete(&sqd->startup);
+err_sqpoll:
+       complete(&ctx->sq_data->exited);
+       goto err;
 }
 
 static inline void __io_unaccount_mem(struct user_struct *user,
@@ -8355,7 +8363,6 @@ static void io_ring_ctx_free(struct io_ring_ctx *ctx)
        mutex_unlock(&ctx->uring_lock);
        io_eventfd_unregister(ctx);
        io_destroy_buffers(ctx);
-       idr_destroy(&ctx->personality_idr);
 
 #if defined(CONFIG_UNIX)
        if (ctx->ring_sock) {
@@ -8420,7 +8427,7 @@ static int io_unregister_personality(struct io_ring_ctx *ctx, unsigned id)
 {
        const struct cred *creds;
 
-       creds = idr_remove(&ctx->personality_idr, id);
+       creds = xa_erase(&ctx->personalities, id);
        if (creds) {
                put_cred(creds);
                return 0;
@@ -8429,14 +8436,6 @@ static int io_unregister_personality(struct io_ring_ctx *ctx, unsigned id)
        return -EINVAL;
 }
 
-static int io_remove_personalities(int id, void *p, void *data)
-{
-       struct io_ring_ctx *ctx = data;
-
-       io_unregister_personality(ctx, id);
-       return 0;
-}
-
 static bool io_run_ctx_fallback(struct io_ring_ctx *ctx)
 {
        struct callback_head *work, *next;
@@ -8526,13 +8525,17 @@ static void io_ring_exit_work(struct work_struct *work)
 
 static void io_ring_ctx_wait_and_kill(struct io_ring_ctx *ctx)
 {
+       unsigned long index;
+       struct creds *creds;
+
        mutex_lock(&ctx->uring_lock);
        percpu_ref_kill(&ctx->refs);
        /* if force is set, the ring is going away. always drop after that */
        ctx->cq_overflow_flushed = 1;
        if (ctx->rings)
                __io_cqring_overflow_flush(ctx, true, NULL, NULL);
-       idr_for_each(&ctx->personality_idr, io_remove_personalities, ctx);
+       xa_for_each(&ctx->personalities, index, creds)
+               io_unregister_personality(ctx, index);
        mutex_unlock(&ctx->uring_lock);
 
        io_kill_timeouts(ctx, NULL, NULL);
@@ -8735,11 +8738,6 @@ static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
        struct task_struct *task = current;
 
        if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sq_data) {
-               /* never started, nothing to cancel */
-               if (ctx->flags & IORING_SETUP_R_DISABLED) {
-                       io_sq_offload_start(ctx);
-                       return;
-               }
                io_sq_thread_park(ctx->sq_data);
                task = ctx->sq_data->thread;
                if (task)
@@ -9024,7 +9022,6 @@ static unsigned long io_uring_nommu_get_unmapped_area(struct file *file,
 
 static int io_sqpoll_wait_sq(struct io_ring_ctx *ctx)
 {
-       int ret = 0;
        DEFINE_WAIT(wait);
 
        do {
@@ -9038,7 +9035,7 @@ static int io_sqpoll_wait_sq(struct io_ring_ctx *ctx)
        } while (!signal_pending(current));
 
        finish_wait(&ctx->sqo_sq_wait, &wait);
-       return ret;
+       return 0;
 }
 
 static int io_get_ext_arg(unsigned flags, const void __user *argp, size_t *argsz,
@@ -9167,10 +9164,9 @@ out_fput:
 }
 
 #ifdef CONFIG_PROC_FS
-static int io_uring_show_cred(int id, void *p, void *data)
+static int io_uring_show_cred(struct seq_file *m, unsigned int id,
+               const struct cred *cred)
 {
-       const struct cred *cred = p;
-       struct seq_file *m = data;
        struct user_namespace *uns = seq_user_ns(m);
        struct group_info *gi;
        kernel_cap_t cap;
@@ -9238,9 +9234,13 @@ static void __io_uring_show_fdinfo(struct io_ring_ctx *ctx, struct seq_file *m)
                seq_printf(m, "%5u: 0x%llx/%u\n", i, buf->ubuf,
                                                (unsigned int) buf->len);
        }
-       if (has_lock && !idr_is_empty(&ctx->personality_idr)) {
+       if (has_lock && !xa_empty(&ctx->personalities)) {
+               unsigned long index;
+               const struct cred *cred;
+
                seq_printf(m, "Personalities:\n");
-               idr_for_each(&ctx->personality_idr, io_uring_show_cred, m);
+               xa_for_each(&ctx->personalities, index, cred)
+                       io_uring_show_cred(m, index, cred);
        }
        seq_printf(m, "PollList:\n");
        spin_lock_irq(&ctx->completion_lock);
@@ -9442,9 +9442,6 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p,
        if (ret)
                goto err;
 
-       if (!(p->flags & IORING_SETUP_R_DISABLED))
-               io_sq_offload_start(ctx);
-
        memset(&p->sq_off, 0, sizeof(p->sq_off));
        p->sq_off.head = offsetof(struct io_rings, sq.head);
        p->sq_off.tail = offsetof(struct io_rings, sq.tail);
@@ -9572,14 +9569,16 @@ out:
 static int io_register_personality(struct io_ring_ctx *ctx)
 {
        const struct cred *creds;
+       u32 id;
        int ret;
 
        creds = get_current_cred();
 
-       ret = idr_alloc_cyclic(&ctx->personality_idr, (void *) creds, 1,
-                               USHRT_MAX, GFP_KERNEL);
-       if (ret < 0)
-               put_cred(creds);
+       ret = xa_alloc_cyclic(&ctx->personalities, &id, (void *)creds,
+                       XA_LIMIT(0, USHRT_MAX), &ctx->pers_next, GFP_KERNEL);
+       if (!ret)
+               return id;
+       put_cred(creds);
        return ret;
 }
 
@@ -9661,7 +9660,9 @@ static int io_register_enable_rings(struct io_ring_ctx *ctx)
        if (ctx->restrictions.registered)
                ctx->restricted = 1;
 
-       io_sq_offload_start(ctx);
+       ctx->flags &= ~IORING_SETUP_R_DISABLED;
+       if (ctx->sq_data && wq_has_sleeper(&ctx->sq_data->wait))
+               wake_up(&ctx->sq_data->wait);
        return 0;
 }