io_uring: ensure async punted sendmsg/recvmsg requests copy data
authorJens Axboe <axboe@kernel.dk>
Tue, 3 Dec 2019 01:50:25 +0000 (18:50 -0700)
committerJens Axboe <axboe@kernel.dk>
Tue, 3 Dec 2019 14:03:35 +0000 (07:03 -0700)
Just like commit f67676d160c6 for read/write requests, this one ensures
that the msghdr data is fully copied if we need to punt a recvmsg or
sendmsg system call to async context.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
fs/io_uring.c
include/linux/socket.h
net/socket.c

index 1689aea..2700382 100644 (file)
@@ -308,6 +308,13 @@ struct io_timeout {
        struct io_timeout_data          *data;
 };
 
+struct io_async_msghdr {
+       struct iovec                    fast_iov[UIO_FASTIOV];
+       struct iovec                    *iov;
+       struct sockaddr __user          *uaddr;
+       struct msghdr                   msg;
+};
+
 struct io_async_rw {
        struct iovec                    fast_iov[UIO_FASTIOV];
        struct iovec                    *iov;
@@ -319,6 +326,7 @@ struct io_async_ctx {
        struct io_uring_sqe             sqe;
        union {
                struct io_async_rw      rw;
+               struct io_async_msghdr  msg;
        };
 };
 
@@ -1991,12 +1999,25 @@ static int io_sync_file_range(struct io_kiocb *req,
        return 0;
 }
 
+static int io_sendmsg_prep(struct io_kiocb *req, struct io_async_ctx *io)
+{
 #if defined(CONFIG_NET)
-static int io_send_recvmsg(struct io_kiocb *req, const struct io_uring_sqe *sqe,
-                          struct io_kiocb **nxt, bool force_nonblock,
-                  long (*fn)(struct socket *, struct user_msghdr __user *,
-                               unsigned int))
+       const struct io_uring_sqe *sqe = req->sqe;
+       struct user_msghdr __user *msg;
+       unsigned flags;
+
+       flags = READ_ONCE(sqe->msg_flags);
+       msg = (struct user_msghdr __user *)(unsigned long) READ_ONCE(sqe->addr);
+       return sendmsg_copy_msghdr(&io->msg.msg, msg, flags, &io->msg.iov);
+#else
+       return 0;
+#endif
+}
+
+static int io_sendmsg(struct io_kiocb *req, const struct io_uring_sqe *sqe,
+                     struct io_kiocb **nxt, bool force_nonblock)
 {
+#if defined(CONFIG_NET)
        struct socket *sock;
        int ret;
 
@@ -2005,7 +2026,9 @@ static int io_send_recvmsg(struct io_kiocb *req, const struct io_uring_sqe *sqe,
 
        sock = sock_from_file(req->file, &ret);
        if (sock) {
-               struct user_msghdr __user *msg;
+               struct io_async_ctx io, *copy;
+               struct sockaddr_storage addr;
+               struct msghdr *kmsg;
                unsigned flags;
 
                flags = READ_ONCE(sqe->msg_flags);
@@ -2014,32 +2037,59 @@ static int io_send_recvmsg(struct io_kiocb *req, const struct io_uring_sqe *sqe,
                else if (force_nonblock)
                        flags |= MSG_DONTWAIT;
 
-               msg = (struct user_msghdr __user *) (unsigned long)
-                       READ_ONCE(sqe->addr);
+               if (req->io) {
+                       kmsg = &req->io->msg.msg;
+                       kmsg->msg_name = &addr;
+               } else {
+                       kmsg = &io.msg.msg;
+                       kmsg->msg_name = &addr;
+                       io.msg.iov = io.msg.fast_iov;
+                       ret = io_sendmsg_prep(req, &io);
+                       if (ret)
+                               goto out;
+               }
 
-               ret = fn(sock, msg, flags);
-               if (force_nonblock && ret == -EAGAIN)
+               ret = __sys_sendmsg_sock(sock, kmsg, flags);
+               if (force_nonblock && ret == -EAGAIN) {
+                       copy = kmalloc(sizeof(*copy), GFP_KERNEL);
+                       if (!copy) {
+                               ret = -ENOMEM;
+                               goto out;
+                       }
+                       memcpy(&copy->msg, &io.msg, sizeof(copy->msg));
+                       req->io = copy;
+                       memcpy(&req->io->sqe, req->sqe, sizeof(*req->sqe));
+                       req->sqe = &req->io->sqe;
                        return ret;
+               }
                if (ret == -ERESTARTSYS)
                        ret = -EINTR;
        }
 
+out:
        io_cqring_add_event(req, ret);
        if (ret < 0 && (req->flags & REQ_F_LINK))
                req->flags |= REQ_F_FAIL_LINK;
        io_put_req_find_next(req, nxt);
        return 0;
-}
+#else
+       return -EOPNOTSUPP;
 #endif
+}
 
-static int io_sendmsg(struct io_kiocb *req, const struct io_uring_sqe *sqe,
-                     struct io_kiocb **nxt, bool force_nonblock)
+static int io_recvmsg_prep(struct io_kiocb *req, struct io_async_ctx *io)
 {
 #if defined(CONFIG_NET)
-       return io_send_recvmsg(req, sqe, nxt, force_nonblock,
-                               __sys_sendmsg_sock);
+       const struct io_uring_sqe *sqe = req->sqe;
+       struct user_msghdr __user *msg;
+       unsigned flags;
+
+       flags = READ_ONCE(sqe->msg_flags);
+       msg = (struct user_msghdr __user *)(unsigned long) READ_ONCE(sqe->addr);
+       return recvmsg_copy_msghdr(&io->msg.msg, msg, flags, &io->msg.uaddr,
+                                       &io->msg.iov);
 #else
-       return -EOPNOTSUPP;
+       return 0;
 #endif
 }
 
@@ -2047,8 +2097,63 @@ static int io_recvmsg(struct io_kiocb *req, const struct io_uring_sqe *sqe,
                      struct io_kiocb **nxt, bool force_nonblock)
 {
 #if defined(CONFIG_NET)
-       return io_send_recvmsg(req, sqe, nxt, force_nonblock,
-                               __sys_recvmsg_sock);
+       struct socket *sock;
+       int ret;
+
+       if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL))
+               return -EINVAL;
+
+       sock = sock_from_file(req->file, &ret);
+       if (sock) {
+               struct user_msghdr __user *msg;
+               struct io_async_ctx io, *copy;
+               struct sockaddr_storage addr;
+               struct msghdr *kmsg;
+               unsigned flags;
+
+               flags = READ_ONCE(sqe->msg_flags);
+               if (flags & MSG_DONTWAIT)
+                       req->flags |= REQ_F_NOWAIT;
+               else if (force_nonblock)
+                       flags |= MSG_DONTWAIT;
+
+               msg = (struct user_msghdr __user *) (unsigned long)
+                       READ_ONCE(sqe->addr);
+               if (req->io) {
+                       kmsg = &req->io->msg.msg;
+                       kmsg->msg_name = &addr;
+               } else {
+                       kmsg = &io.msg.msg;
+                       kmsg->msg_name = &addr;
+                       io.msg.iov = io.msg.fast_iov;
+                       ret = io_recvmsg_prep(req, &io);
+                       if (ret)
+                               goto out;
+               }
+
+               ret = __sys_recvmsg_sock(sock, kmsg, msg, io.msg.uaddr, flags);
+               if (force_nonblock && ret == -EAGAIN) {
+                       copy = kmalloc(sizeof(*copy), GFP_KERNEL);
+                       if (!copy) {
+                               ret = -ENOMEM;
+                               goto out;
+                       }
+                       memcpy(copy, &io, sizeof(*copy));
+                       req->io = copy;
+                       memcpy(&req->io->sqe, req->sqe, sizeof(*req->sqe));
+                       req->sqe = &req->io->sqe;
+                       return ret;
+               }
+               if (ret == -ERESTARTSYS)
+                       ret = -EINTR;
+       }
+
+out:
+       io_cqring_add_event(req, ret);
+       if (ret < 0 && (req->flags & REQ_F_LINK))
+               req->flags |= REQ_F_FAIL_LINK;
+       io_put_req_find_next(req, nxt);
+       return 0;
 #else
        return -EOPNOTSUPP;
 #endif
@@ -2721,6 +2826,12 @@ static int io_req_defer_prep(struct io_kiocb *req, struct io_async_ctx *io)
        case IORING_OP_WRITE_FIXED:
                ret = io_write_prep(req, &iovec, &iter, true);
                break;
+       case IORING_OP_SENDMSG:
+               ret = io_sendmsg_prep(req, io);
+               break;
+       case IORING_OP_RECVMSG:
+               ret = io_recvmsg_prep(req, io);
+               break;
        default:
                req->io = io;
                return 0;
index 4bde630..903507f 100644 (file)
@@ -378,12 +378,19 @@ extern int __sys_recvmmsg(int fd, struct mmsghdr __user *mmsg,
 extern int __sys_sendmmsg(int fd, struct mmsghdr __user *mmsg,
                          unsigned int vlen, unsigned int flags,
                          bool forbid_cmsg_compat);
-extern long __sys_sendmsg_sock(struct socket *sock,
-                              struct user_msghdr __user *msg,
+extern long __sys_sendmsg_sock(struct socket *sock, struct msghdr *msg,
                               unsigned int flags);
-extern long __sys_recvmsg_sock(struct socket *sock,
-                              struct user_msghdr __user *msg,
+extern long __sys_recvmsg_sock(struct socket *sock, struct msghdr *msg,
+                              struct user_msghdr __user *umsg,
+                              struct sockaddr __user *uaddr,
                               unsigned int flags);
+extern int sendmsg_copy_msghdr(struct msghdr *msg,
+                              struct user_msghdr __user *umsg, unsigned flags,
+                              struct iovec **iov);
+extern int recvmsg_copy_msghdr(struct msghdr *msg,
+                              struct user_msghdr __user *umsg, unsigned flags,
+                              struct sockaddr __user **uaddr,
+                              struct iovec **iov);
 
 /* helpers which do the actual work for syscalls */
 extern int __sys_recvfrom(int fd, void __user *ubuf, size_t size,
index ea28cbb..0fb0820 100644 (file)
@@ -2346,9 +2346,9 @@ out:
        return err;
 }
 
-static int sendmsg_copy_msghdr(struct msghdr *msg,
-                              struct user_msghdr __user *umsg, unsigned flags,
-                              struct iovec **iov)
+int sendmsg_copy_msghdr(struct msghdr *msg,
+                       struct user_msghdr __user *umsg, unsigned flags,
+                       struct iovec **iov)
 {
        int err;
 
@@ -2390,27 +2390,14 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
 /*
  *     BSD sendmsg interface
  */
-long __sys_sendmsg_sock(struct socket *sock, struct user_msghdr __user *umsg,
+long __sys_sendmsg_sock(struct socket *sock, struct msghdr *msg,
                        unsigned int flags)
 {
-       struct iovec iovstack[UIO_FASTIOV], *iov = iovstack;
-       struct sockaddr_storage address;
-       struct msghdr msg = { .msg_name = &address };
-       ssize_t err;
-
-       err = sendmsg_copy_msghdr(&msg, umsg, flags, &iov);
-       if (err)
-               return err;
        /* disallow ancillary data requests from this path */
-       if (msg.msg_control || msg.msg_controllen) {
-               err = -EINVAL;
-               goto out;
-       }
+       if (msg->msg_control || msg->msg_controllen)
+               return -EINVAL;
 
-       err = ____sys_sendmsg(sock, &msg, flags, NULL, 0);
-out:
-       kfree(iov);
-       return err;
+       return ____sys_sendmsg(sock, msg, flags, NULL, 0);
 }
 
 long __sys_sendmsg(int fd, struct user_msghdr __user *msg, unsigned int flags,
@@ -2516,10 +2503,10 @@ SYSCALL_DEFINE4(sendmmsg, int, fd, struct mmsghdr __user *, mmsg,
        return __sys_sendmmsg(fd, mmsg, vlen, flags, true);
 }
 
-static int recvmsg_copy_msghdr(struct msghdr *msg,
-                              struct user_msghdr __user *umsg, unsigned flags,
-                              struct sockaddr __user **uaddr,
-                              struct iovec **iov)
+int recvmsg_copy_msghdr(struct msghdr *msg,
+                       struct user_msghdr __user *umsg, unsigned flags,
+                       struct sockaddr __user **uaddr,
+                       struct iovec **iov)
 {
        ssize_t err;
 
@@ -2609,28 +2596,15 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
  *     BSD recvmsg interface
  */
 
-long __sys_recvmsg_sock(struct socket *sock, struct user_msghdr __user *umsg,
-                       unsigned int flags)
+long __sys_recvmsg_sock(struct socket *sock, struct msghdr *msg,
+                       struct user_msghdr __user *umsg,
+                       struct sockaddr __user *uaddr, unsigned int flags)
 {
-       struct iovec iovstack[UIO_FASTIOV], *iov = iovstack;
-       struct sockaddr_storage address;
-       struct msghdr msg = { .msg_name = &address };
-       struct sockaddr __user *uaddr;
-       ssize_t err;
-
-       err = recvmsg_copy_msghdr(&msg, umsg, flags, &uaddr, &iov);
-       if (err)
-               return err;
        /* disallow ancillary data requests from this path */
-       if (msg.msg_control || msg.msg_controllen) {
-               err = -EINVAL;
-               goto out;
-       }
+       if (msg->msg_control || msg->msg_controllen)
+               return -EINVAL;
 
-       err = ____sys_recvmsg(sock, &msg, umsg, uaddr, flags, 0);
-out:
-       kfree(iov);
-       return err;
+       return ____sys_recvmsg(sock, msg, umsg, uaddr, flags, 0);
 }
 
 long __sys_recvmsg(int fd, struct user_msghdr __user *msg, unsigned int flags,