Merge branch 'io_uring-5.8' into for-5.9/io_uring
[linux-2.6-microblaze.git] / fs / io_uring.c
index 3ce02a1..ff3851d 100644 (file)
@@ -593,6 +593,7 @@ enum {
 
 struct async_poll {
        struct io_poll_iocb     poll;
+       struct io_poll_iocb     *double_poll;
        struct io_wq_work       work;
 };
 
@@ -1146,6 +1147,8 @@ static void io_prep_async_work(struct io_kiocb *req)
 {
        const struct io_op_def *def = &io_op_defs[req->opcode];
 
+       io_req_init_async(req);
+
        if (req->flags & REQ_F_ISREG) {
                if (def->hash_reg_file)
                        io_wq_hash_work(&req->work, file_inode(req->file));
@@ -1337,6 +1340,7 @@ static bool io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force)
        if (cqe) {
                clear_bit(0, &ctx->sq_check_overflow);
                clear_bit(0, &ctx->cq_check_overflow);
+               ctx->rings->sq_flags &= ~IORING_SQ_CQ_OVERFLOW;
        }
        spin_unlock_irqrestore(&ctx->completion_lock, flags);
        io_cqring_ev_posted(ctx);
@@ -1374,6 +1378,7 @@ static void __io_cqring_fill_event(struct io_kiocb *req, long res, long cflags)
                if (list_empty(&ctx->cq_overflow_list)) {
                        set_bit(0, &ctx->sq_check_overflow);
                        set_bit(0, &ctx->cq_check_overflow);
+                       ctx->rings->sq_flags |= IORING_SQ_CQ_OVERFLOW;
                }
                req->flags |= REQ_F_OVERFLOW;
                refcount_inc(&req->refs);
@@ -3912,6 +3917,7 @@ static int io_sendmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
        if (req->flags & REQ_F_NEED_CLEANUP)
                return 0;
 
+       io->msg.msg.msg_name = &io->msg.addr;
        io->msg.iov = io->msg.fast_iov;
        ret = sendmsg_copy_msghdr(&io->msg.msg, sr->msg, sr->msg_flags,
                                        &io->msg.iov);
@@ -4093,6 +4099,7 @@ static int __io_compat_recvmsg_copy_hdr(struct io_kiocb *req,
 
 static int io_recvmsg_copy_hdr(struct io_kiocb *req, struct io_async_ctx *io)
 {
+       io->msg.msg.msg_name = &io->msg.addr;
        io->msg.iov = io->msg.fast_iov;
 
 #ifdef CONFIG_COMPAT
@@ -4202,10 +4209,16 @@ static int io_recvmsg(struct io_kiocb *req, bool force_nonblock,
 
                ret = __sys_recvmsg_sock(sock, &kmsg->msg, req->sr_msg.msg,
                                                kmsg->uaddr, flags);
-               if (force_nonblock && ret == -EAGAIN)
-                       return io_setup_async_msg(req, kmsg);
+               if (force_nonblock && ret == -EAGAIN) {
+                       ret = io_setup_async_msg(req, kmsg);
+                       if (ret != -EAGAIN)
+                               kfree(kbuf);
+                       return ret;
+               }
                if (ret == -ERESTARTSYS)
                        ret = -EINTR;
+               if (kbuf)
+                       kfree(kbuf);
        }
 
        if (kmsg && kmsg->iov != kmsg->fast_iov)
@@ -4492,9 +4505,9 @@ static bool io_poll_rewait(struct io_kiocb *req, struct io_poll_iocb *poll)
        return false;
 }
 
-static void io_poll_remove_double(struct io_kiocb *req)
+static void io_poll_remove_double(struct io_kiocb *req, void *data)
 {
-       struct io_poll_iocb *poll = (struct io_poll_iocb *) req->io;
+       struct io_poll_iocb *poll = data;
 
        lockdep_assert_held(&req->ctx->completion_lock);
 
@@ -4514,7 +4527,7 @@ static void io_poll_complete(struct io_kiocb *req, __poll_t mask, int error)
 {
        struct io_ring_ctx *ctx = req->ctx;
 
-       io_poll_remove_double(req);
+       io_poll_remove_double(req, req->io);
        req->poll.done = true;
        io_cqring_fill_event(req, error ? error : mangle_poll(mask));
        io_commit_cqring(ctx);
@@ -4552,21 +4565,21 @@ static int io_poll_double_wake(struct wait_queue_entry *wait, unsigned mode,
                               int sync, void *key)
 {
        struct io_kiocb *req = wait->private;
-       struct io_poll_iocb *poll = (struct io_poll_iocb *) req->io;
+       struct io_poll_iocb *poll = req->apoll->double_poll;
        __poll_t mask = key_to_poll(key);
 
        /* for instances that support it check for an event match first: */
        if (mask && !(mask & poll->events))
                return 0;
 
-       if (req->poll.head) {
+       if (poll && poll->head) {
                bool done;
 
-               spin_lock(&req->poll.head->lock);
-               done = list_empty(&req->poll.wait.entry);
+               spin_lock(&poll->head->lock);
+               done = list_empty(&poll->wait.entry);
                if (!done)
-                       list_del_init(&req->poll.wait.entry);
-               spin_unlock(&req->poll.head->lock);
+                       list_del_init(&poll->wait.entry);
+               spin_unlock(&poll->head->lock);
                if (!done)
                        __io_async_wake(req, poll, mask, io_poll_task_func);
        }
@@ -4586,7 +4599,8 @@ static void io_init_poll_iocb(struct io_poll_iocb *poll, __poll_t events,
 }
 
 static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
-                           struct wait_queue_head *head)
+                           struct wait_queue_head *head,
+                           struct io_poll_iocb **poll_ptr)
 {
        struct io_kiocb *req = pt->req;
 
@@ -4597,7 +4611,7 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
         */
        if (unlikely(poll->head)) {
                /* already have a 2nd entry, fail a third attempt */
-               if (req->io) {
+               if (*poll_ptr) {
                        pt->error = -EINVAL;
                        return;
                }
@@ -4609,7 +4623,7 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
                io_init_poll_iocb(poll, req->poll.events, io_poll_double_wake);
                refcount_inc(&req->refs);
                poll->wait.private = req;
-               req->io = (void *) poll;
+               *poll_ptr = poll;
        }
 
        pt->error = 0;
@@ -4625,8 +4639,9 @@ static void io_async_queue_proc(struct file *file, struct wait_queue_head *head,
                               struct poll_table_struct *p)
 {
        struct io_poll_table *pt = container_of(p, struct io_poll_table, pt);
+       struct async_poll *apoll = pt->req->apoll;
 
-       __io_queue_proc(&pt->req->apoll->poll, pt, head);
+       __io_queue_proc(&apoll->poll, pt, head, &apoll->double_poll);
 }
 
 static void io_async_task_func(struct callback_head *cb)
@@ -4646,6 +4661,7 @@ static void io_async_task_func(struct callback_head *cb)
        if (hash_hashed(&req->hash_node))
                hash_del(&req->hash_node);
 
+       io_poll_remove_double(req, apoll->double_poll);
        spin_unlock_irq(&ctx->completion_lock);
 
        /* restore ->work in case we need to retry again */
@@ -4657,6 +4673,7 @@ static void io_async_task_func(struct callback_head *cb)
        else
                __io_req_task_cancel(req, -ECANCELED);
 
+       kfree(apoll->double_poll);
        kfree(apoll);
 }
 
@@ -4728,7 +4745,6 @@ static bool io_arm_poll_handler(struct io_kiocb *req)
        struct async_poll *apoll;
        struct io_poll_table ipt;
        __poll_t mask, ret;
-       bool had_io;
 
        if (!req->file || !file_can_poll(req->file))
                return false;
@@ -4740,11 +4756,11 @@ static bool io_arm_poll_handler(struct io_kiocb *req)
        apoll = kmalloc(sizeof(*apoll), GFP_ATOMIC);
        if (unlikely(!apoll))
                return false;
+       apoll->double_poll = NULL;
 
        req->flags |= REQ_F_POLLED;
        if (req->flags & REQ_F_WORK_INITIALIZED)
                memcpy(&apoll->work, &req->work, sizeof(req->work));
-       had_io = req->io != NULL;
 
        io_get_req_task(req);
        req->apoll = apoll;
@@ -4762,13 +4778,11 @@ static bool io_arm_poll_handler(struct io_kiocb *req)
        ret = __io_arm_poll_handler(req, &apoll->poll, &ipt, mask,
                                        io_async_wake);
        if (ret) {
-               ipt.error = 0;
-               /* only remove double add if we did it here */
-               if (!had_io)
-                       io_poll_remove_double(req);
+               io_poll_remove_double(req, apoll->double_poll);
                spin_unlock_irq(&ctx->completion_lock);
                if (req->flags & REQ_F_WORK_INITIALIZED)
                        memcpy(&req->work, &apoll->work, sizeof(req->work));
+               kfree(apoll->double_poll);
                kfree(apoll);
                return false;
        }
@@ -4799,11 +4813,13 @@ static bool io_poll_remove_one(struct io_kiocb *req)
        bool do_complete;
 
        if (req->opcode == IORING_OP_POLL_ADD) {
-               io_poll_remove_double(req);
+               io_poll_remove_double(req, req->io);
                do_complete = __io_poll_remove_one(req, &req->poll);
        } else {
                struct async_poll *apoll = req->apoll;
 
+               io_poll_remove_double(req, apoll->double_poll);
+
                /* non-poll requests have submit ref still */
                do_complete = __io_poll_remove_one(req, &apoll->poll);
                if (do_complete) {
@@ -4816,6 +4832,7 @@ static bool io_poll_remove_one(struct io_kiocb *req)
                        if (req->flags & REQ_F_WORK_INITIALIZED)
                                memcpy(&req->work, &apoll->work,
                                       sizeof(req->work));
+                       kfree(apoll->double_poll);
                        kfree(apoll);
                }
        }
@@ -4915,7 +4932,7 @@ static void io_poll_queue_proc(struct file *file, struct wait_queue_head *head,
 {
        struct io_poll_table *pt = container_of(p, struct io_poll_table, pt);
 
-       __io_queue_proc(&pt->req->poll, pt, head);
+       __io_queue_proc(&pt->req->poll, pt, head, (struct io_poll_iocb **) &pt->req->io);
 }
 
 static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
@@ -4948,6 +4965,10 @@ static int io_poll_add(struct io_kiocb *req)
        struct io_poll_table ipt;
        __poll_t mask;
 
+       /* ->work is in union with hash_node and others */
+       io_req_work_drop_env(req);
+       req->flags &= ~REQ_F_WORK_INITIALIZED;
+
        INIT_HLIST_NODE(&req->hash_node);
        INIT_LIST_HEAD(&req->list);
        ipt.pt._qproc = io_poll_queue_proc;
@@ -5027,7 +5048,9 @@ static int io_timeout_remove_prep(struct io_kiocb *req,
 {
        if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL))
                return -EINVAL;
-       if (sqe->flags || sqe->ioprio || sqe->buf_index || sqe->len)
+       if (unlikely(req->flags & (REQ_F_FIXED_FILE | REQ_F_BUFFER_SELECT)))
+               return -EINVAL;
+       if (sqe->ioprio || sqe->buf_index || sqe->len)
                return -EINVAL;
 
        req->timeout.addr = READ_ONCE(sqe->addr);
@@ -5203,8 +5226,9 @@ static int io_async_cancel_prep(struct io_kiocb *req,
 {
        if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL))
                return -EINVAL;
-       if (sqe->flags || sqe->ioprio || sqe->off || sqe->len ||
-           sqe->cancel_flags)
+       if (unlikely(req->flags & (REQ_F_FIXED_FILE | REQ_F_BUFFER_SELECT)))
+               return -EINVAL;
+       if (sqe->ioprio || sqe->off || sqe->len || sqe->cancel_flags)
                return -EINVAL;
 
        req->cancel.addr = READ_ONCE(sqe->addr);
@@ -5222,7 +5246,9 @@ static int io_async_cancel(struct io_kiocb *req)
 static int io_files_update_prep(struct io_kiocb *req,
                                const struct io_uring_sqe *sqe)
 {
-       if (sqe->flags || sqe->ioprio || sqe->rw_flags)
+       if (unlikely(req->flags & (REQ_F_FIXED_FILE | REQ_F_BUFFER_SELECT)))
+               return -EINVAL;
+       if (sqe->ioprio || sqe->rw_flags)
                return -EINVAL;
 
        req->files_update.offset = READ_ONCE(sqe->off);
@@ -5996,6 +6022,7 @@ fail_req:
                 * Never try inline submit of IOSQE_ASYNC is set, go straight
                 * to async execution.
                 */
+               io_req_init_async(req);
                req->work.flags |= IO_WQ_WORK_CONCURRENT;
                io_queue_async_work(req);
        } else {
@@ -6372,9 +6399,9 @@ static int io_sq_thread(void *data)
                        }
 
                        /* Tell userspace we may need a wakeup call */
+                       spin_lock_irq(&ctx->completion_lock);
                        ctx->rings->sq_flags |= IORING_SQ_NEED_WAKEUP;
-                       /* make sure to read SQ tail after writing flags */
-                       smp_mb();
+                       spin_unlock_irq(&ctx->completion_lock);
 
                        to_submit = io_sqring_entries(ctx);
                        if (!to_submit || ret == -EBUSY) {
@@ -6391,13 +6418,17 @@ static int io_sq_thread(void *data)
                                schedule();
                                finish_wait(&ctx->sqo_wait, &wait);
 
+                               spin_lock_irq(&ctx->completion_lock);
                                ctx->rings->sq_flags &= ~IORING_SQ_NEED_WAKEUP;
+                               spin_unlock_irq(&ctx->completion_lock);
                                ret = 0;
                                continue;
                        }
                        finish_wait(&ctx->sqo_wait, &wait);
 
+                       spin_lock_irq(&ctx->completion_lock);
                        ctx->rings->sq_flags &= ~IORING_SQ_NEED_WAKEUP;
+                       spin_unlock_irq(&ctx->completion_lock);
                }
 
                mutex_lock(&ctx->uring_lock);
@@ -6982,6 +7013,7 @@ static int io_sqe_files_register(struct io_ring_ctx *ctx, void __user *arg,
                for (i = 0; i < nr_tables; i++)
                        kfree(ctx->file_data->table[i].files);
 
+               percpu_ref_exit(&ctx->file_data->refs);
                kfree(ctx->file_data->table);
                kfree(ctx->file_data);
                ctx->file_data = NULL;
@@ -7134,8 +7166,10 @@ static int __io_sqe_files_update(struct io_ring_ctx *ctx,
                        }
                        table->files[index] = file;
                        err = io_sqe_file_register(ctx, file, i);
-                       if (err)
+                       if (err) {
+                               fput(file);
                                break;
+                       }
                }
                nr_args--;
                done++;
@@ -7665,8 +7699,6 @@ static void io_ring_ctx_free(struct io_ring_ctx *ctx)
        io_mem_free(ctx->sq_sqes);
 
        percpu_ref_exit(&ctx->refs);
-       io_unaccount_mem(ctx, ring_pages(ctx->sq_entries, ctx->cq_entries),
-                        ACCT_LOCKED);
        free_uid(ctx->user);
        put_cred(ctx->creds);
        kfree(ctx->cancel_hash);
@@ -7748,6 +7780,15 @@ static void io_ring_ctx_wait_and_kill(struct io_ring_ctx *ctx)
                io_cqring_overflow_flush(ctx, true);
        io_iopoll_try_reap_events(ctx);
        idr_for_each(&ctx->personality_idr, io_remove_personalities, ctx);
+
+       /*
+        * Do this upfront, so we won't have a grace period where the ring
+        * is closed but resources aren't reaped yet. This can cause
+        * spurious failure in setting up a new ring.
+        */
+       io_unaccount_mem(ctx, ring_pages(ctx->sq_entries, ctx->cq_entries),
+                        ACCT_LOCKED);
+
        INIT_WORK(&ctx->exit_work, io_ring_exit_work);
        queue_work(system_wq, &ctx->exit_work);
 }
@@ -7807,6 +7848,7 @@ static void io_uring_cancel_files(struct io_ring_ctx *ctx,
                        if (list_empty(&ctx->cq_overflow_list)) {
                                clear_bit(0, &ctx->sq_check_overflow);
                                clear_bit(0, &ctx->cq_check_overflow);
+                               ctx->rings->sq_flags &= ~IORING_SQ_CQ_OVERFLOW;
                        }
                        spin_unlock_irq(&ctx->completion_lock);