io_uring: hold 'ctx' reference around task_work queue + execute
[linux-2.6-microblaze.git] / fs / io_uring.c
index 2a3af95..99582cf 100644 (file)
@@ -898,6 +898,7 @@ static void io_put_req(struct io_kiocb *req);
 static void io_double_put_req(struct io_kiocb *req);
 static void __io_double_put_req(struct io_kiocb *req);
 static struct io_kiocb *io_prep_linked_timeout(struct io_kiocb *req);
+static void __io_queue_linked_timeout(struct io_kiocb *req);
 static void io_queue_linked_timeout(struct io_kiocb *req);
 static int __io_sqe_files_update(struct io_ring_ctx *ctx,
                                 struct io_uring_files_update *ip,
@@ -1107,10 +1108,16 @@ static void __io_commit_cqring(struct io_ring_ctx *ctx)
        }
 }
 
-static void io_req_clean_work(struct io_kiocb *req)
+/*
+ * Returns true if we need to defer file table putting. This can only happen
+ * from the error path with REQ_F_COMP_LOCKED set.
+ */
+static bool io_req_clean_work(struct io_kiocb *req)
 {
        if (!(req->flags & REQ_F_WORK_INITIALIZED))
-               return;
+               return false;
+
+       req->flags &= ~REQ_F_WORK_INITIALIZED;
 
        if (req->work.mm) {
                mmdrop(req->work.mm);
@@ -1123,6 +1130,9 @@ static void io_req_clean_work(struct io_kiocb *req)
        if (req->work.fs) {
                struct fs_struct *fs = req->work.fs;
 
+               if (req->flags & REQ_F_COMP_LOCKED)
+                       return true;
+
                spin_lock(&req->work.fs->lock);
                if (--fs->users)
                        fs = NULL;
@@ -1131,7 +1141,8 @@ static void io_req_clean_work(struct io_kiocb *req)
                        free_fs_struct(fs);
                req->work.fs = NULL;
        }
-       req->flags &= ~REQ_F_WORK_INITIALIZED;
+
+       return false;
 }
 
 static void io_prep_async_work(struct io_kiocb *req)
@@ -1179,7 +1190,7 @@ static void io_prep_async_link(struct io_kiocb *req)
                        io_prep_async_work(cur);
 }
 
-static void __io_queue_async_work(struct io_kiocb *req)
+static struct io_kiocb *__io_queue_async_work(struct io_kiocb *req)
 {
        struct io_ring_ctx *ctx = req->ctx;
        struct io_kiocb *link = io_prep_linked_timeout(req);
@@ -1187,16 +1198,19 @@ static void __io_queue_async_work(struct io_kiocb *req)
        trace_io_uring_queue_async_work(ctx, io_wq_is_hashed(&req->work), req,
                                        &req->work, req->flags);
        io_wq_enqueue(ctx->io_wq, &req->work);
-
-       if (link)
-               io_queue_linked_timeout(link);
+       return link;
 }
 
 static void io_queue_async_work(struct io_kiocb *req)
 {
+       struct io_kiocb *link;
+
        /* init ->work of the whole link before punting */
        io_prep_async_link(req);
-       __io_queue_async_work(req);
+       link = __io_queue_async_work(req);
+
+       if (link)
+               io_queue_linked_timeout(link);
 }
 
 static void io_kill_timeout(struct io_kiocb *req)
@@ -1229,12 +1243,19 @@ static void __io_queue_deferred(struct io_ring_ctx *ctx)
        do {
                struct io_defer_entry *de = list_first_entry(&ctx->defer_list,
                                                struct io_defer_entry, list);
+               struct io_kiocb *link;
 
                if (req_need_defer(de->req, de->seq))
                        break;
                list_del_init(&de->list);
                /* punt-init is done before queueing for defer */
-               __io_queue_async_work(de->req);
+               link = __io_queue_async_work(de->req);
+               if (link) {
+                       __io_queue_linked_timeout(link);
+                       /* drop submission reference */
+                       link->flags |= REQ_F_COMP_LOCKED;
+                       io_put_req(link);
+               }
                kfree(de);
        } while (!list_empty(&ctx->defer_list));
 }
@@ -1533,7 +1554,7 @@ static inline void io_put_file(struct io_kiocb *req, struct file *file,
                fput(file);
 }
 
-static void io_dismantle_req(struct io_kiocb *req)
+static bool io_dismantle_req(struct io_kiocb *req)
 {
        io_clean_op(req);
 
@@ -1541,7 +1562,6 @@ static void io_dismantle_req(struct io_kiocb *req)
                kfree(req->io);
        if (req->file)
                io_put_file(req, req->file, (req->flags & REQ_F_FIXED_FILE));
-       io_req_clean_work(req);
 
        if (req->flags & REQ_F_INFLIGHT) {
                struct io_ring_ctx *ctx = req->ctx;
@@ -1553,15 +1573,15 @@ static void io_dismantle_req(struct io_kiocb *req)
                        wake_up(&ctx->inflight_wait);
                spin_unlock_irqrestore(&ctx->inflight_lock, flags);
        }
+
+       return io_req_clean_work(req);
 }
 
-static void __io_free_req(struct io_kiocb *req)
+static void __io_free_req_finish(struct io_kiocb *req)
 {
-       struct io_ring_ctx *ctx;
+       struct io_ring_ctx *ctx = req->ctx;
 
-       io_dismantle_req(req);
        __io_put_req_task(req);
-       ctx = req->ctx;
        if (likely(!io_is_fallback_req(req)))
                kmem_cache_free(req_cachep, req);
        else
@@ -1569,6 +1589,39 @@ static void __io_free_req(struct io_kiocb *req)
        percpu_ref_put(&ctx->refs);
 }
 
+static void io_req_task_file_table_put(struct callback_head *cb)
+{
+       struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work);
+       struct fs_struct *fs = req->work.fs;
+
+       spin_lock(&req->work.fs->lock);
+       if (--fs->users)
+               fs = NULL;
+       spin_unlock(&req->work.fs->lock);
+       if (fs)
+               free_fs_struct(fs);
+       req->work.fs = NULL;
+       __io_free_req_finish(req);
+}
+
+static void __io_free_req(struct io_kiocb *req)
+{
+       if (!io_dismantle_req(req)) {
+               __io_free_req_finish(req);
+       } else {
+               int ret;
+
+               init_task_work(&req->task_work, io_req_task_file_table_put);
+               ret = task_work_add(req->task, &req->task_work, TWA_RESUME);
+               if (unlikely(ret)) {
+                       struct task_struct *tsk;
+
+                       tsk = io_wq_get_task(req->ctx->io_wq);
+                       task_work_add(tsk, &req->task_work, 0);
+               }
+       }
+}
+
 static bool io_link_cancel_timeout(struct io_kiocb *req)
 {
        struct io_ring_ctx *ctx = req->ctx;
@@ -1598,6 +1651,7 @@ static bool __io_kill_linked_timeout(struct io_kiocb *req)
                return false;
 
        list_del_init(&link->link_list);
+       link->flags |= REQ_F_COMP_LOCKED;
        wake_ev = io_link_cancel_timeout(link);
        req->flags &= ~REQ_F_LINK_TIMEOUT;
        return wake_ev;
@@ -1656,6 +1710,7 @@ static void __io_fail_links(struct io_kiocb *req)
                trace_io_uring_fail_link(req, link);
 
                io_cqring_fill_event(link, -ECANCELED);
+               link->flags |= REQ_F_COMP_LOCKED;
                __io_double_put_req(link);
                req->flags &= ~REQ_F_LINK_TIMEOUT;
        }
@@ -1710,22 +1765,22 @@ static int io_req_task_work_add(struct io_kiocb *req, struct callback_head *cb)
 {
        struct task_struct *tsk = req->task;
        struct io_ring_ctx *ctx = req->ctx;
-       int ret, notify = TWA_RESUME;
+       int ret, notify;
 
        /*
-        * SQPOLL kernel thread doesn't need notification, just a wakeup.
-        * If we're not using an eventfd, then TWA_RESUME is always fine,
-        * as we won't have dependencies between request completions for
-        * other kernel wait conditions.
+        * SQPOLL kernel thread doesn't need notification, just a wakeup. For
+        * all other cases, use TWA_SIGNAL unconditionally to ensure we're
+        * processing task_work. There's no reliable way to tell if TWA_RESUME
+        * will do the job.
         */
-       if (ctx->flags & IORING_SETUP_SQPOLL)
-               notify = 0;
-       else if (ctx->cq_ev_fd)
+       notify = 0;
+       if (!(ctx->flags & IORING_SETUP_SQPOLL))
                notify = TWA_SIGNAL;
 
        ret = task_work_add(tsk, cb, notify);
        if (!ret)
                wake_up_process(tsk);
+
        return ret;
 }
 
@@ -1766,8 +1821,10 @@ static void __io_req_task_submit(struct io_kiocb *req)
 static void io_req_task_submit(struct callback_head *cb)
 {
        struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work);
+       struct io_ring_ctx *ctx = req->ctx;
 
        __io_req_task_submit(req);
+       percpu_ref_put(&ctx->refs);
 }
 
 static void io_req_task_queue(struct io_kiocb *req)
@@ -1775,6 +1832,7 @@ static void io_req_task_queue(struct io_kiocb *req)
        int ret;
 
        init_task_work(&req->task_work, io_req_task_submit);
+       percpu_ref_get(&req->ctx->refs);
 
        ret = io_req_task_work_add(req, &req->task_work);
        if (unlikely(ret)) {
@@ -1855,7 +1913,7 @@ static void io_req_free_batch(struct req_batch *rb, struct io_kiocb *req)
                req->flags &= ~REQ_F_TASK_PINNED;
        }
 
-       io_dismantle_req(req);
+       WARN_ON_ONCE(io_dismantle_req(req));
        rb->reqs[rb->to_free++] = req;
        if (unlikely(rb->to_free == ARRAY_SIZE(rb->reqs)))
                __io_req_free_batch_flush(req->ctx, rb);
@@ -2263,6 +2321,8 @@ static void io_rw_resubmit(struct callback_head *cb)
                refcount_inc(&req->refs);
                io_queue_async_work(req);
        }
+
+       percpu_ref_put(&ctx->refs);
 }
 #endif
 
@@ -2275,6 +2335,8 @@ static bool io_rw_reissue(struct io_kiocb *req, long res)
                return false;
 
        init_task_work(&req->task_work, io_rw_resubmit);
+       percpu_ref_get(&req->ctx->refs);
+
        ret = io_req_task_work_add(req, &req->task_work);
        if (!ret)
                return true;
@@ -2952,6 +3014,16 @@ static int io_read_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe,
        return io_rw_prep_async(req, READ, force_nonblock);
 }
 
+/*
+ * This is our waitqueue callback handler, registered through lock_page_async()
+ * when we initially tried to do the IO with the iocb armed our waitqueue.
+ * This gets called when the page is unlocked, and we generally expect that to
+ * happen when the page IO is completed and the page is now uptodate. This will
+ * queue a task_work based retry of the operation, attempting to copy the data
+ * again. If the latter fails because the page was NOT uptodate, then we will
+ * do a thread based blocking retry of the operation. That's the unexpected
+ * slow path.
+ */
 static int io_async_buf_func(struct wait_queue_entry *wait, unsigned mode,
                             int sync, void *arg)
 {
@@ -2965,13 +3037,11 @@ static int io_async_buf_func(struct wait_queue_entry *wait, unsigned mode,
        if (!wake_page_match(wpq, key))
                return 0;
 
-       /* Stop waking things up if the page is locked again */
-       if (test_bit(key->bit_nr, &key->page->flags))
-               return -1;
-
        list_del_init(&wait->entry);
 
        init_task_work(&req->task_work, io_req_task_submit);
+       percpu_ref_get(&req->ctx->refs);
+
        /* submit ref gets dropped, acquire a new one */
        refcount_inc(&req->refs);
        ret = io_req_task_work_add(req, &req->task_work);
@@ -3008,7 +3078,18 @@ static inline int kiocb_wait_page_queue_init(struct kiocb *kiocb,
        return -EOPNOTSUPP;
 }
 
-
+/*
+ * This controls whether a given IO request should be armed for async page
+ * based retry. If we return false here, the request is handed to the async
+ * worker threads for retry. If we're doing buffered reads on a regular file,
+ * we prepare a private wait_page_queue entry and retry the operation. This
+ * will either succeed because the page is now uptodate and unlocked, or it
+ * will register a callback when the page is unlocked at IO completion. Through
+ * that callback, io_uring uses task_work to setup a retry of the operation.
+ * That retry will attempt the buffered read again. The retry will generally
+ * succeed, or in rare cases where it fails, we then fall back to using the
+ * async worker threads for a blocking retry.
+ */
 static bool io_rw_should_retry(struct io_kiocb *req)
 {
        struct kiocb *kiocb = &req->rw.kiocb;
@@ -3049,7 +3130,10 @@ static int io_iter_do_read(struct io_kiocb *req, struct iov_iter *iter)
 {
        if (req->file->f_op->read_iter)
                return call_read_iter(req->file, &req->rw.kiocb, iter);
-       return loop_rw_iter(READ, req->file, &req->rw.kiocb, iter);
+       else if (req->file->f_op->read)
+               return loop_rw_iter(READ, req->file, &req->rw.kiocb, iter);
+       else
+               return -EINVAL;
 }
 
 static int io_read(struct io_kiocb *req, bool force_nonblock,
@@ -3186,8 +3270,10 @@ static int io_write(struct io_kiocb *req, bool force_nonblock,
 
        if (req->file->f_op->write_iter)
                ret2 = call_write_iter(req->file, kiocb, &iter);
-       else
+       else if (req->file->f_op->write)
                ret2 = loop_rw_iter(WRITE, req->file, kiocb, &iter);
+       else
+               ret2 = -EINVAL;
 
        /*
         * Raw bdev writes will return -EOPNOTSUPP for IOCB_NOWAIT. Just
@@ -4488,6 +4574,8 @@ static int __io_async_wake(struct io_kiocb *req, struct io_poll_iocb *poll,
 
        req->result = mask;
        init_task_work(&req->task_work, func);
+       percpu_ref_get(&req->ctx->refs);
+
        /*
         * If this fails, then the task is exiting. When a task exits, the
         * work gets canceled, so just cancel this request as well instead
@@ -4575,11 +4663,13 @@ static void io_poll_task_handler(struct io_kiocb *req, struct io_kiocb **nxt)
 static void io_poll_task_func(struct callback_head *cb)
 {
        struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work);
+       struct io_ring_ctx *ctx = req->ctx;
        struct io_kiocb *nxt = NULL;
 
        io_poll_task_handler(req, &nxt);
        if (nxt)
                __io_req_task_submit(nxt);
+       percpu_ref_put(&ctx->refs);
 }
 
 static int io_poll_double_wake(struct wait_queue_entry *wait, unsigned mode,
@@ -4675,6 +4765,7 @@ static void io_async_task_func(struct callback_head *cb)
 
        if (io_poll_rewait(req, &apoll->poll)) {
                spin_unlock_irq(&ctx->completion_lock);
+               percpu_ref_put(&ctx->refs);
                return;
        }
 
@@ -4690,6 +4781,7 @@ static void io_async_task_func(struct callback_head *cb)
        else
                __io_req_task_cancel(req, -ECANCELED);
 
+       percpu_ref_put(&ctx->refs);
        kfree(apoll->double_poll);
        kfree(apoll);
 }
@@ -5038,6 +5130,7 @@ static int io_timeout_cancel(struct io_ring_ctx *ctx, __u64 user_data)
                return -EALREADY;
 
        req_set_fail_links(req);
+       req->flags |= REQ_F_COMP_LOCKED;
        io_cqring_fill_event(req, -ECANCELED);
        io_put_req(req);
        return 0;
@@ -5917,15 +6010,12 @@ static enum hrtimer_restart io_link_timeout_fn(struct hrtimer *timer)
        return HRTIMER_NORESTART;
 }
 
-static void io_queue_linked_timeout(struct io_kiocb *req)
+static void __io_queue_linked_timeout(struct io_kiocb *req)
 {
-       struct io_ring_ctx *ctx = req->ctx;
-
        /*
         * If the list is now empty, then our linked request finished before
         * we got a chance to setup the timer
         */
-       spin_lock_irq(&ctx->completion_lock);
        if (!list_empty(&req->link_list)) {
                struct io_timeout_data *data = &req->io->timeout;
 
@@ -5933,6 +6023,14 @@ static void io_queue_linked_timeout(struct io_kiocb *req)
                hrtimer_start(&data->timer, timespec64_to_ktime(data->ts),
                                data->mode);
        }
+}
+
+static void io_queue_linked_timeout(struct io_kiocb *req)
+{
+       struct io_ring_ctx *ctx = req->ctx;
+
+       spin_lock_irq(&ctx->completion_lock);
+       __io_queue_linked_timeout(req);
        spin_unlock_irq(&ctx->completion_lock);
 
        /* drop submission reference */
@@ -8171,6 +8269,10 @@ static int io_allocate_scq_urings(struct io_ring_ctx *ctx,
        struct io_rings *rings;
        size_t size, sq_array_offset;
 
+       /* make sure these are sane, as we already accounted them */
+       ctx->sq_entries = p->sq_entries;
+       ctx->cq_entries = p->cq_entries;
+
        size = rings_size(p->sq_entries, p->cq_entries, &sq_array_offset);
        if (size == SIZE_MAX)
                return -EOVERFLOW;
@@ -8187,8 +8289,6 @@ static int io_allocate_scq_urings(struct io_ring_ctx *ctx,
        rings->cq_ring_entries = p->cq_entries;
        ctx->sq_mask = rings->sq_ring_mask;
        ctx->cq_mask = rings->cq_ring_mask;
-       ctx->sq_entries = rings->sq_ring_entries;
-       ctx->cq_entries = rings->cq_ring_entries;
 
        size = array_size(sizeof(struct io_uring_sqe), p->sq_entries);
        if (size == SIZE_MAX) {
@@ -8317,6 +8417,16 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p,
        ctx->user = user;
        ctx->creds = get_current_cred();
 
+       /*
+        * Account memory _before_ installing the file descriptor. Once
+        * the descriptor is installed, it can get closed at any time. Also
+        * do this before hitting the general error path, as ring freeing
+        * will un-account as well.
+        */
+       io_account_mem(ctx, ring_pages(p->sq_entries, p->cq_entries),
+                      ACCT_LOCKED);
+       ctx->limit_mem = limit_mem;
+
        ret = io_allocate_scq_urings(ctx, p);
        if (ret)
                goto err;
@@ -8353,14 +8463,6 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p,
                goto err;
        }
 
-       /*
-        * Account memory _before_ installing the file descriptor. Once
-        * the descriptor is installed, it can get closed at any time.
-        */
-       io_account_mem(ctx, ring_pages(p->sq_entries, p->cq_entries),
-                      ACCT_LOCKED);
-       ctx->limit_mem = limit_mem;
-
        /*
         * Install ring fd as the very last thing, so we don't risk someone
         * having closed it before we finish setup