io_uring: disable multishot poll for double poll add cases
[linux-2.6-microblaze.git] / fs / io_uring.c
index ffcb3ec..4803e31 100644 (file)
@@ -500,11 +500,6 @@ struct io_poll_update {
        bool                            update_user_data;
 };
 
-struct io_poll_remove {
-       struct file                     *file;
-       u64                             addr;
-};
-
 struct io_close {
        struct file                     *file;
        int                             fd;
@@ -714,7 +709,6 @@ enum {
        REQ_F_COMPLETE_INLINE_BIT,
        REQ_F_REISSUE_BIT,
        REQ_F_DONT_REISSUE_BIT,
-       REQ_F_POLL_UPDATE_BIT,
        /* keep async read/write and isreg together and in order */
        REQ_F_ASYNC_READ_BIT,
        REQ_F_ASYNC_WRITE_BIT,
@@ -762,8 +756,6 @@ enum {
        REQ_F_REISSUE           = BIT(REQ_F_REISSUE_BIT),
        /* don't attempt request reissue, see io_rw_reissue() */
        REQ_F_DONT_REISSUE      = BIT(REQ_F_DONT_REISSUE_BIT),
-       /* switches between poll and poll update */
-       REQ_F_POLL_UPDATE       = BIT(REQ_F_POLL_UPDATE_BIT),
        /* supports async reads */
        REQ_F_ASYNC_READ        = BIT(REQ_F_ASYNC_READ_BIT),
        /* supports async writes */
@@ -794,7 +786,6 @@ struct io_kiocb {
                struct io_rw            rw;
                struct io_poll_iocb     poll;
                struct io_poll_update   poll_update;
-               struct io_poll_remove   poll_remove;
                struct io_accept        accept;
                struct io_sync          sync;
                struct io_cancel        cancel;
@@ -2329,27 +2320,6 @@ static int io_do_iopoll(struct io_ring_ctx *ctx, unsigned int *nr_events,
        return ret;
 }
 
-/*
- * Poll for a minimum of 'min' events. Note that if min == 0 we consider that a
- * non-spinning poll check - we'll still enter the driver poll loop, but only
- * as a non-spinning completion check.
- */
-static int io_iopoll_getevents(struct io_ring_ctx *ctx, unsigned int *nr_events,
-                               long min)
-{
-       while (!list_empty(&ctx->iopoll_list) && !need_resched()) {
-               int ret;
-
-               ret = io_do_iopoll(ctx, nr_events, min);
-               if (ret < 0)
-                       return ret;
-               if (*nr_events >= min)
-                       return 0;
-       }
-
-       return 1;
-}
-
 /*
  * We can't just wait for polled events to come to us, we have to actively
  * find and complete them.
@@ -2393,17 +2363,16 @@ static int io_iopoll_check(struct io_ring_ctx *ctx, long min)
         * that got punted to a workqueue.
         */
        mutex_lock(&ctx->uring_lock);
+       /*
+        * Don't enter poll loop if we already have events pending.
+        * If we do, we can potentially be spinning for commands that
+        * already triggered a CQE (eg in error).
+        */
+       if (test_bit(0, &ctx->cq_check_overflow))
+               __io_cqring_overflow_flush(ctx, false);
+       if (io_cqring_events(ctx))
+               goto out;
        do {
-               /*
-                * Don't enter poll loop if we already have events pending.
-                * If we do, we can potentially be spinning for commands that
-                * already triggered a CQE (eg in error).
-                */
-               if (test_bit(0, &ctx->cq_check_overflow))
-                       __io_cqring_overflow_flush(ctx, false);
-               if (io_cqring_events(ctx))
-                       break;
-
                /*
                 * If a submit got punted to a workqueue, we can have the
                 * application entering polling for a command before it gets
@@ -2422,13 +2391,9 @@ static int io_iopoll_check(struct io_ring_ctx *ctx, long min)
                        if (list_empty(&ctx->iopoll_list))
                                break;
                }
-
-               ret = io_iopoll_getevents(ctx, &nr_events, min);
-               if (ret <= 0)
-                       break;
-               ret = 0;
-       } while (min && !nr_events && !need_resched());
-
+               ret = io_do_iopoll(ctx, &nr_events, min);
+       } while (!ret && nr_events < min && !need_resched());
+out:
        mutex_unlock(&ctx->uring_lock);
        return ret;
 }
@@ -2539,7 +2504,7 @@ static void io_complete_rw_iopoll(struct kiocb *kiocb, long res, long res2)
 /*
  * After the iocb has been issued, it's safe to be found on the poll list.
  * Adding the kiocb to the list AFTER submission ensures that we don't
- * find it from a io_iopoll_getevents() thread before the issuer is done
+ * find it from a io_do_iopoll() thread before the issuer is done
  * accessing the kiocb cookie.
  */
 static void io_iopoll_req_issued(struct io_kiocb *req, bool in_async)
@@ -5011,6 +4976,12 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
                        pt->error = -EINVAL;
                        return;
                }
+               /*
+                * Can't handle multishot for double wait for now, turn it
+                * into one-shot mode.
+                */
+               if (!(req->poll.events & EPOLLONESHOT))
+                       req->poll.events |= EPOLLONESHOT;
                /* double add on the same waitqueue head, ignore */
                if (poll->head == head)
                        return;
@@ -5275,7 +5246,8 @@ static bool io_poll_remove_all(struct io_ring_ctx *ctx, struct task_struct *tsk,
        return posted != 0;
 }
 
-static struct io_kiocb *io_poll_find(struct io_ring_ctx *ctx, __u64 sqe_addr)
+static struct io_kiocb *io_poll_find(struct io_ring_ctx *ctx, __u64 sqe_addr,
+                                    bool poll_only)
        __must_hold(&ctx->completion_lock)
 {
        struct hlist_head *list;
@@ -5285,18 +5257,20 @@ static struct io_kiocb *io_poll_find(struct io_ring_ctx *ctx, __u64 sqe_addr)
        hlist_for_each_entry(req, list, hash_node) {
                if (sqe_addr != req->user_data)
                        continue;
+               if (poll_only && req->opcode != IORING_OP_POLL_ADD)
+                       continue;
                return req;
        }
-
        return NULL;
 }
 
-static int io_poll_cancel(struct io_ring_ctx *ctx, __u64 sqe_addr)
+static int io_poll_cancel(struct io_ring_ctx *ctx, __u64 sqe_addr,
+                         bool poll_only)
        __must_hold(&ctx->completion_lock)
 {
        struct io_kiocb *req;
 
-       req = io_poll_find(ctx, sqe_addr);
+       req = io_poll_find(ctx, sqe_addr, poll_only);
        if (!req)
                return -ENOENT;
        if (io_poll_remove_one(req))
@@ -5305,35 +5279,50 @@ static int io_poll_cancel(struct io_ring_ctx *ctx, __u64 sqe_addr)
        return -EALREADY;
 }
 
-static int io_poll_remove_prep(struct io_kiocb *req,
+static __poll_t io_poll_parse_events(const struct io_uring_sqe *sqe,
+                                    unsigned int flags)
+{
+       u32 events;
+
+       events = READ_ONCE(sqe->poll32_events);
+#ifdef __BIG_ENDIAN
+       events = swahw32(events);
+#endif
+       if (!(flags & IORING_POLL_ADD_MULTI))
+               events |= EPOLLONESHOT;
+       return demangle_poll(events) | (events & (EPOLLEXCLUSIVE|EPOLLONESHOT));
+}
+
+static int io_poll_update_prep(struct io_kiocb *req,
                               const struct io_uring_sqe *sqe)
 {
+       struct io_poll_update *upd = &req->poll_update;
+       u32 flags;
+
        if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL))
                return -EINVAL;
-       if (sqe->ioprio || sqe->off || sqe->len || sqe->buf_index ||
-           sqe->poll_events)
+       if (sqe->ioprio || sqe->buf_index)
+               return -EINVAL;
+       flags = READ_ONCE(sqe->len);
+       if (flags & ~(IORING_POLL_UPDATE_EVENTS | IORING_POLL_UPDATE_USER_DATA |
+                     IORING_POLL_ADD_MULTI))
+               return -EINVAL;
+       /* meaningless without update */
+       if (flags == IORING_POLL_ADD_MULTI)
                return -EINVAL;
 
-       req->poll_remove.addr = READ_ONCE(sqe->addr);
-       return 0;
-}
+       upd->old_user_data = READ_ONCE(sqe->addr);
+       upd->update_events = flags & IORING_POLL_UPDATE_EVENTS;
+       upd->update_user_data = flags & IORING_POLL_UPDATE_USER_DATA;
 
-/*
- * Find a running poll command that matches one specified in sqe->addr,
- * and remove it if found.
- */
-static int io_poll_remove(struct io_kiocb *req, unsigned int issue_flags)
-{
-       struct io_ring_ctx *ctx = req->ctx;
-       int ret;
-
-       spin_lock_irq(&ctx->completion_lock);
-       ret = io_poll_cancel(ctx, req->poll_remove.addr);
-       spin_unlock_irq(&ctx->completion_lock);
+       upd->new_user_data = READ_ONCE(sqe->off);
+       if (!upd->update_user_data && upd->new_user_data)
+               return -EINVAL;
+       if (upd->update_events)
+               upd->events = io_poll_parse_events(sqe, flags);
+       else if (sqe->poll32_events)
+               return -EINVAL;
 
-       if (ret < 0)
-               req_set_fail_links(req);
-       __io_req_complete(req, issue_flags, ret, 0);
        return 0;
 }
 
@@ -5356,46 +5345,22 @@ static void io_poll_queue_proc(struct file *file, struct wait_queue_head *head,
 
 static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 {
-       u32 events, flags;
+       struct io_poll_iocb *poll = &req->poll;
+       u32 flags;
 
        if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL))
                return -EINVAL;
-       if (sqe->ioprio || sqe->buf_index)
+       if (sqe->ioprio || sqe->buf_index || sqe->off || sqe->addr)
                return -EINVAL;
        flags = READ_ONCE(sqe->len);
-       if (flags & ~(IORING_POLL_ADD_MULTI | IORING_POLL_UPDATE_EVENTS |
-                       IORING_POLL_UPDATE_USER_DATA))
+       if (flags & ~IORING_POLL_ADD_MULTI)
                return -EINVAL;
-       events = READ_ONCE(sqe->poll32_events);
-#ifdef __BIG_ENDIAN
-       events = swahw32(events);
-#endif
-       if (!(flags & IORING_POLL_ADD_MULTI))
-               events |= EPOLLONESHOT;
-       events = demangle_poll(events) |
-                               (events & (EPOLLEXCLUSIVE|EPOLLONESHOT));
-
-       if (flags & (IORING_POLL_UPDATE_EVENTS|IORING_POLL_UPDATE_USER_DATA)) {
-               struct io_poll_update *poll_upd = &req->poll_update;
-
-               req->flags |= REQ_F_POLL_UPDATE;
-               poll_upd->events = events;
-               poll_upd->old_user_data = READ_ONCE(sqe->addr);
-               poll_upd->update_events = flags & IORING_POLL_UPDATE_EVENTS;
-               poll_upd->update_user_data = flags & IORING_POLL_UPDATE_USER_DATA;
-               if (poll_upd->update_user_data)
-                       poll_upd->new_user_data = READ_ONCE(sqe->off);
-       } else {
-               struct io_poll_iocb *poll = &req->poll;
 
-               poll->events = events;
-               if (sqe->off || sqe->addr)
-                       return -EINVAL;
-       }
+       poll->events = io_poll_parse_events(sqe, flags);
        return 0;
 }
 
-static int __io_poll_add(struct io_kiocb *req)
+static int io_poll_add(struct io_kiocb *req, unsigned int issue_flags)
 {
        struct io_poll_iocb *poll = &req->poll;
        struct io_ring_ctx *ctx = req->ctx;
@@ -5421,7 +5386,7 @@ static int __io_poll_add(struct io_kiocb *req)
        return ipt.error;
 }
 
-static int io_poll_update(struct io_kiocb *req)
+static int io_poll_update(struct io_kiocb *req, unsigned int issue_flags)
 {
        struct io_ring_ctx *ctx = req->ctx;
        struct io_kiocb *preq;
@@ -5429,13 +5394,15 @@ static int io_poll_update(struct io_kiocb *req)
        int ret;
 
        spin_lock_irq(&ctx->completion_lock);
-       preq = io_poll_find(ctx, req->poll_update.old_user_data);
+       preq = io_poll_find(ctx, req->poll_update.old_user_data, true);
        if (!preq) {
                ret = -ENOENT;
                goto err;
-       } else if (preq->opcode != IORING_OP_POLL_ADD) {
-               /* don't allow internal poll updates */
-               ret = -EACCES;
+       }
+
+       if (!req->poll_update.update_events && !req->poll_update.update_user_data) {
+               completing = true;
+               ret = io_poll_remove_one(preq) ? 0 : -EALREADY;
                goto err;
        }
 
@@ -5466,14 +5433,13 @@ err:
        }
        if (req->poll_update.update_user_data)
                preq->user_data = req->poll_update.new_user_data;
-
        spin_unlock_irq(&ctx->completion_lock);
 
        /* complete update request, we're done with it */
        io_req_complete(req, ret);
 
        if (!completing) {
-               ret = __io_poll_add(preq);
+               ret = io_poll_add(preq, issue_flags);
                if (ret < 0) {
                        req_set_fail_links(preq);
                        io_req_complete(preq, ret);
@@ -5482,13 +5448,6 @@ err:
        return 0;
 }
 
-static int io_poll_add(struct io_kiocb *req, unsigned int issue_flags)
-{
-       if (!(req->flags & REQ_F_POLL_UPDATE))
-               return __io_poll_add(req);
-       return io_poll_update(req);
-}
-
 static enum hrtimer_restart io_timeout_fn(struct hrtimer *timer)
 {
        struct io_timeout_data *data = container_of(timer,
@@ -5765,7 +5724,7 @@ static void io_async_find_and_cancel(struct io_ring_ctx *ctx,
        ret = io_timeout_cancel(ctx, sqe_addr);
        if (ret != -ENOENT)
                goto done;
-       ret = io_poll_cancel(ctx, sqe_addr);
+       ret = io_poll_cancel(ctx, sqe_addr, false);
 done:
        if (!ret)
                ret = success_ret;
@@ -5807,7 +5766,7 @@ static int io_async_cancel(struct io_kiocb *req, unsigned int issue_flags)
        ret = io_timeout_cancel(ctx, sqe_addr);
        if (ret != -ENOENT)
                goto done;
-       ret = io_poll_cancel(ctx, sqe_addr);
+       ret = io_poll_cancel(ctx, sqe_addr, false);
        if (ret != -ENOENT)
                goto done;
        spin_unlock_irq(&ctx->completion_lock);
@@ -5893,7 +5852,7 @@ static int io_req_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
        case IORING_OP_POLL_ADD:
                return io_poll_add_prep(req, sqe);
        case IORING_OP_POLL_REMOVE:
-               return io_poll_remove_prep(req, sqe);
+               return io_poll_update_prep(req, sqe);
        case IORING_OP_FSYNC:
                return io_fsync_prep(req, sqe);
        case IORING_OP_SYNC_FILE_RANGE:
@@ -6124,7 +6083,7 @@ static int io_issue_sqe(struct io_kiocb *req, unsigned int issue_flags)
                ret = io_poll_add(req, issue_flags);
                break;
        case IORING_OP_POLL_REMOVE:
-               ret = io_poll_remove(req, issue_flags);
+               ret = io_poll_update(req, issue_flags);
                break;
        case IORING_OP_SYNC_FILE_RANGE:
                ret = io_sync_file_range(req, issue_flags);
@@ -8601,6 +8560,9 @@ static void io_ring_exit_work(struct work_struct *work)
                WARN_ON_ONCE(time_after(jiffies, timeout));
        } while (!wait_for_completion_timeout(&ctx->ref_comp, HZ/20));
 
+       init_completion(&exit.completion);
+       init_task_work(&exit.task_work, io_tctx_exit_cb);
+       exit.ctx = ctx;
        /*
         * Some may use context even when all refs and requests have been put,
         * and they are free to do so while still holding uring_lock or
@@ -8613,9 +8575,8 @@ static void io_ring_exit_work(struct work_struct *work)
 
                node = list_first_entry(&ctx->tctx_list, struct io_tctx_node,
                                        ctx_node);
-               exit.ctx = ctx;
-               init_completion(&exit.completion);
-               init_task_work(&exit.task_work, io_tctx_exit_cb);
+               /* don't spin on a single task if cancellation failed */
+               list_rotate_left(&ctx->tctx_list);
                ret = task_work_add(node->task, &exit.task_work, TWA_SIGNAL);
                if (WARN_ON_ONCE(ret))
                        continue;
@@ -8623,7 +8584,6 @@ static void io_ring_exit_work(struct work_struct *work)
 
                mutex_unlock(&ctx->uring_lock);
                wait_for_completion(&exit.completion);
-               cond_resched();
                mutex_lock(&ctx->uring_lock);
        }
        mutex_unlock(&ctx->uring_lock);