io_uring: ensure async punted connect requests copy data
authorJens Axboe <axboe@kernel.dk>
Mon, 2 Dec 2019 23:28:46 +0000 (16:28 -0700)
committerJens Axboe <axboe@kernel.dk>
Tue, 3 Dec 2019 14:04:30 +0000 (07:04 -0700)
Just like commit f67676d160c6 for read/write requests, this one ensures
that the sockaddr data has been copied for IORING_OP_CONNECT if we need
to punt the request to async context.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
fs/io_uring.c
include/linux/socket.h
net/socket.c

index 2700382..5fcd89c 100644 (file)
@@ -308,6 +308,10 @@ struct io_timeout {
        struct io_timeout_data          *data;
 };
 
+struct io_async_connect {
+       struct sockaddr_storage         address;
+};
+
 struct io_async_msghdr {
        struct iovec                    fast_iov[UIO_FASTIOV];
        struct iovec                    *iov;
@@ -327,6 +331,7 @@ struct io_async_ctx {
        union {
                struct io_async_rw      rw;
                struct io_async_msghdr  msg;
+               struct io_async_connect connect;
        };
 };
 
@@ -2195,11 +2200,26 @@ static int io_accept(struct io_kiocb *req, const struct io_uring_sqe *sqe,
 #endif
 }
 
+static int io_connect_prep(struct io_kiocb *req, struct io_async_ctx *io)
+{
+#if defined(CONFIG_NET)
+       const struct io_uring_sqe *sqe = req->sqe;
+       struct sockaddr __user *addr;
+       int addr_len;
+
+       addr = (struct sockaddr __user *) (unsigned long) READ_ONCE(sqe->addr);
+       addr_len = READ_ONCE(sqe->addr2);
+       return move_addr_to_kernel(addr, addr_len, &io->connect.address);
+#else
+       return 0;
+#endif
+}
+
 static int io_connect(struct io_kiocb *req, const struct io_uring_sqe *sqe,
                      struct io_kiocb **nxt, bool force_nonblock)
 {
 #if defined(CONFIG_NET)
-       struct sockaddr __user *addr;
+       struct io_async_ctx __io, *io;
        unsigned file_flags;
        int addr_len, ret;
 
@@ -2208,15 +2228,35 @@ static int io_connect(struct io_kiocb *req, const struct io_uring_sqe *sqe,
        if (sqe->ioprio || sqe->len || sqe->buf_index || sqe->rw_flags)
                return -EINVAL;
 
-       addr = (struct sockaddr __user *) (unsigned long) READ_ONCE(sqe->addr);
        addr_len = READ_ONCE(sqe->addr2);
        file_flags = force_nonblock ? O_NONBLOCK : 0;
 
-       ret = __sys_connect_file(req->file, addr, addr_len, file_flags);
-       if (ret == -EAGAIN && force_nonblock)
+       if (req->io) {
+               io = req->io;
+       } else {
+               ret = io_connect_prep(req, &__io);
+               if (ret)
+                       goto out;
+               io = &__io;
+       }
+
+       ret = __sys_connect_file(req->file, &io->connect.address, addr_len,
+                                       file_flags);
+       if (ret == -EAGAIN && force_nonblock) {
+               io = kmalloc(sizeof(*io), GFP_KERNEL);
+               if (!io) {
+                       ret = -ENOMEM;
+                       goto out;
+               }
+               memcpy(&io->connect, &__io.connect, sizeof(io->connect));
+               req->io = io;
+               memcpy(&io->sqe, req->sqe, sizeof(*req->sqe));
+               req->sqe = &io->sqe;
                return -EAGAIN;
+       }
        if (ret == -ERESTARTSYS)
                ret = -EINTR;
+out:
        if (ret < 0 && (req->flags & REQ_F_LINK))
                req->flags |= REQ_F_FAIL_LINK;
        io_cqring_add_event(req, ret);
@@ -2832,6 +2872,9 @@ static int io_req_defer_prep(struct io_kiocb *req, struct io_async_ctx *io)
        case IORING_OP_RECVMSG:
                ret = io_recvmsg_prep(req, io);
                break;
+       case IORING_OP_CONNECT:
+               ret = io_connect_prep(req, io);
+               break;
        default:
                req->io = io;
                return 0;
index 903507f..2d23134 100644 (file)
@@ -406,9 +406,8 @@ extern int __sys_accept4(int fd, struct sockaddr __user *upeer_sockaddr,
                         int __user *upeer_addrlen, int flags);
 extern int __sys_socket(int family, int type, int protocol);
 extern int __sys_bind(int fd, struct sockaddr __user *umyaddr, int addrlen);
-extern int __sys_connect_file(struct file *file,
-                       struct sockaddr __user *uservaddr, int addrlen,
-                       int file_flags);
+extern int __sys_connect_file(struct file *file, struct sockaddr_storage *addr,
+                             int addrlen, int file_flags);
 extern int __sys_connect(int fd, struct sockaddr __user *uservaddr,
                         int addrlen);
 extern int __sys_listen(int fd, int backlog);
index 0fb0820..b343db1 100644 (file)
@@ -1826,26 +1826,22 @@ SYSCALL_DEFINE3(accept, int, fd, struct sockaddr __user *, upeer_sockaddr,
  *     include the -EINPROGRESS status for such sockets.
  */
 
-int __sys_connect_file(struct file *file, struct sockaddr __user *uservaddr,
+int __sys_connect_file(struct file *file, struct sockaddr_storage *address,
                       int addrlen, int file_flags)
 {
        struct socket *sock;
-       struct sockaddr_storage address;
        int err;
 
        sock = sock_from_file(file, &err);
        if (!sock)
                goto out;
-       err = move_addr_to_kernel(uservaddr, addrlen, &address);
-       if (err < 0)
-               goto out;
 
        err =
-           security_socket_connect(sock, (struct sockaddr *)&address, addrlen);
+           security_socket_connect(sock, (struct sockaddr *)address, addrlen);
        if (err)
                goto out;
 
-       err = sock->ops->connect(sock, (struct sockaddr *)&address, addrlen,
+       err = sock->ops->connect(sock, (struct sockaddr *)address, addrlen,
                                 sock->file->f_flags | file_flags);
 out:
        return err;
@@ -1858,7 +1854,11 @@ int __sys_connect(int fd, struct sockaddr __user *uservaddr, int addrlen)
 
        f = fdget(fd);
        if (f.file) {
-               ret = __sys_connect_file(f.file, uservaddr, addrlen, 0);
+               struct sockaddr_storage address;
+
+               ret = move_addr_to_kernel(uservaddr, addrlen, &address);
+               if (!ret)
+                       ret = __sys_connect_file(f.file, &address, addrlen, 0);
                if (f.flags)
                        fput(f.file);
        }