io_uring: split provided buffers handling into its own file
[linux-2.6-microblaze.git] / io_uring / net.c
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/kernel.h>
3 #include <linux/errno.h>
4 #include <linux/file.h>
5 #include <linux/slab.h>
6 #include <linux/net.h>
7 #include <linux/compat.h>
8 #include <net/compat.h>
9 #include <linux/io_uring.h>
10
11 #include <uapi/linux/io_uring.h>
12
13 #include "io_uring_types.h"
14 #include "io_uring.h"
15 #include "kbuf.h"
16 #include "net.h"
17
18 #if defined(CONFIG_NET)
19 struct io_shutdown {
20         struct file                     *file;
21         int                             how;
22 };
23
24 struct io_accept {
25         struct file                     *file;
26         struct sockaddr __user          *addr;
27         int __user                      *addr_len;
28         int                             flags;
29         u32                             file_slot;
30         unsigned long                   nofile;
31 };
32
33 struct io_socket {
34         struct file                     *file;
35         int                             domain;
36         int                             type;
37         int                             protocol;
38         int                             flags;
39         u32                             file_slot;
40         unsigned long                   nofile;
41 };
42
43 struct io_connect {
44         struct file                     *file;
45         struct sockaddr __user          *addr;
46         int                             addr_len;
47 };
48
49 struct io_sr_msg {
50         struct file                     *file;
51         union {
52                 struct compat_msghdr __user     *umsg_compat;
53                 struct user_msghdr __user       *umsg;
54                 void __user                     *buf;
55         };
56         int                             msg_flags;
57         size_t                          len;
58         size_t                          done_io;
59         unsigned int                    flags;
60 };
61
62 #define IO_APOLL_MULTI_POLLED (REQ_F_APOLL_MULTISHOT | REQ_F_POLLED)
63
64 int io_shutdown_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
65 {
66         struct io_shutdown *shutdown = io_kiocb_to_cmd(req);
67
68         if (unlikely(sqe->off || sqe->addr || sqe->rw_flags ||
69                      sqe->buf_index || sqe->splice_fd_in))
70                 return -EINVAL;
71
72         shutdown->how = READ_ONCE(sqe->len);
73         return 0;
74 }
75
76 int io_shutdown(struct io_kiocb *req, unsigned int issue_flags)
77 {
78         struct io_shutdown *shutdown = io_kiocb_to_cmd(req);
79         struct socket *sock;
80         int ret;
81
82         if (issue_flags & IO_URING_F_NONBLOCK)
83                 return -EAGAIN;
84
85         sock = sock_from_file(req->file);
86         if (unlikely(!sock))
87                 return -ENOTSOCK;
88
89         ret = __sys_shutdown_sock(sock, shutdown->how);
90         io_req_set_res(req, ret, 0);
91         return IOU_OK;
92 }
93
94 static bool io_net_retry(struct socket *sock, int flags)
95 {
96         if (!(flags & MSG_WAITALL))
97                 return false;
98         return sock->type == SOCK_STREAM || sock->type == SOCK_SEQPACKET;
99 }
100
101 static int io_setup_async_msg(struct io_kiocb *req,
102                               struct io_async_msghdr *kmsg)
103 {
104         struct io_async_msghdr *async_msg = req->async_data;
105
106         if (async_msg)
107                 return -EAGAIN;
108         if (io_alloc_async_data(req)) {
109                 kfree(kmsg->free_iov);
110                 return -ENOMEM;
111         }
112         async_msg = req->async_data;
113         req->flags |= REQ_F_NEED_CLEANUP;
114         memcpy(async_msg, kmsg, sizeof(*kmsg));
115         async_msg->msg.msg_name = &async_msg->addr;
116         /* if were using fast_iov, set it to the new one */
117         if (!async_msg->free_iov)
118                 async_msg->msg.msg_iter.iov = async_msg->fast_iov;
119
120         return -EAGAIN;
121 }
122
123 static int io_sendmsg_copy_hdr(struct io_kiocb *req,
124                                struct io_async_msghdr *iomsg)
125 {
126         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
127
128         iomsg->msg.msg_name = &iomsg->addr;
129         iomsg->free_iov = iomsg->fast_iov;
130         return sendmsg_copy_msghdr(&iomsg->msg, sr->umsg, sr->msg_flags,
131                                         &iomsg->free_iov);
132 }
133
134 int io_sendmsg_prep_async(struct io_kiocb *req)
135 {
136         int ret;
137
138         ret = io_sendmsg_copy_hdr(req, req->async_data);
139         if (!ret)
140                 req->flags |= REQ_F_NEED_CLEANUP;
141         return ret;
142 }
143
144 void io_sendmsg_recvmsg_cleanup(struct io_kiocb *req)
145 {
146         struct io_async_msghdr *io = req->async_data;
147
148         kfree(io->free_iov);
149 }
150
151 int io_sendmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
152 {
153         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
154
155         if (unlikely(sqe->file_index || sqe->addr2))
156                 return -EINVAL;
157
158         sr->umsg = u64_to_user_ptr(READ_ONCE(sqe->addr));
159         sr->len = READ_ONCE(sqe->len);
160         sr->flags = READ_ONCE(sqe->ioprio);
161         if (sr->flags & ~IORING_RECVSEND_POLL_FIRST)
162                 return -EINVAL;
163         sr->msg_flags = READ_ONCE(sqe->msg_flags) | MSG_NOSIGNAL;
164         if (sr->msg_flags & MSG_DONTWAIT)
165                 req->flags |= REQ_F_NOWAIT;
166
167 #ifdef CONFIG_COMPAT
168         if (req->ctx->compat)
169                 sr->msg_flags |= MSG_CMSG_COMPAT;
170 #endif
171         sr->done_io = 0;
172         return 0;
173 }
174
175 int io_sendmsg(struct io_kiocb *req, unsigned int issue_flags)
176 {
177         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
178         struct io_async_msghdr iomsg, *kmsg;
179         struct socket *sock;
180         unsigned flags;
181         int min_ret = 0;
182         int ret;
183
184         sock = sock_from_file(req->file);
185         if (unlikely(!sock))
186                 return -ENOTSOCK;
187
188         if (req_has_async_data(req)) {
189                 kmsg = req->async_data;
190         } else {
191                 ret = io_sendmsg_copy_hdr(req, &iomsg);
192                 if (ret)
193                         return ret;
194                 kmsg = &iomsg;
195         }
196
197         if (!(req->flags & REQ_F_POLLED) &&
198             (sr->flags & IORING_RECVSEND_POLL_FIRST))
199                 return io_setup_async_msg(req, kmsg);
200
201         flags = sr->msg_flags;
202         if (issue_flags & IO_URING_F_NONBLOCK)
203                 flags |= MSG_DONTWAIT;
204         if (flags & MSG_WAITALL)
205                 min_ret = iov_iter_count(&kmsg->msg.msg_iter);
206
207         ret = __sys_sendmsg_sock(sock, &kmsg->msg, flags);
208
209         if (ret < min_ret) {
210                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
211                         return io_setup_async_msg(req, kmsg);
212                 if (ret == -ERESTARTSYS)
213                         ret = -EINTR;
214                 if (ret > 0 && io_net_retry(sock, flags)) {
215                         sr->done_io += ret;
216                         req->flags |= REQ_F_PARTIAL_IO;
217                         return io_setup_async_msg(req, kmsg);
218                 }
219                 req_set_fail(req);
220         }
221         /* fast path, check for non-NULL to avoid function call */
222         if (kmsg->free_iov)
223                 kfree(kmsg->free_iov);
224         req->flags &= ~REQ_F_NEED_CLEANUP;
225         if (ret >= 0)
226                 ret += sr->done_io;
227         else if (sr->done_io)
228                 ret = sr->done_io;
229         io_req_set_res(req, ret, 0);
230         return IOU_OK;
231 }
232
233 int io_send(struct io_kiocb *req, unsigned int issue_flags)
234 {
235         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
236         struct msghdr msg;
237         struct iovec iov;
238         struct socket *sock;
239         unsigned flags;
240         int min_ret = 0;
241         int ret;
242
243         if (!(req->flags & REQ_F_POLLED) &&
244             (sr->flags & IORING_RECVSEND_POLL_FIRST))
245                 return -EAGAIN;
246
247         sock = sock_from_file(req->file);
248         if (unlikely(!sock))
249                 return -ENOTSOCK;
250
251         ret = import_single_range(WRITE, sr->buf, sr->len, &iov, &msg.msg_iter);
252         if (unlikely(ret))
253                 return ret;
254
255         msg.msg_name = NULL;
256         msg.msg_control = NULL;
257         msg.msg_controllen = 0;
258         msg.msg_namelen = 0;
259
260         flags = sr->msg_flags;
261         if (issue_flags & IO_URING_F_NONBLOCK)
262                 flags |= MSG_DONTWAIT;
263         if (flags & MSG_WAITALL)
264                 min_ret = iov_iter_count(&msg.msg_iter);
265
266         msg.msg_flags = flags;
267         ret = sock_sendmsg(sock, &msg);
268         if (ret < min_ret) {
269                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
270                         return -EAGAIN;
271                 if (ret == -ERESTARTSYS)
272                         ret = -EINTR;
273                 if (ret > 0 && io_net_retry(sock, flags)) {
274                         sr->len -= ret;
275                         sr->buf += ret;
276                         sr->done_io += ret;
277                         req->flags |= REQ_F_PARTIAL_IO;
278                         return -EAGAIN;
279                 }
280                 req_set_fail(req);
281         }
282         if (ret >= 0)
283                 ret += sr->done_io;
284         else if (sr->done_io)
285                 ret = sr->done_io;
286         io_req_set_res(req, ret, 0);
287         return IOU_OK;
288 }
289
290 static int __io_recvmsg_copy_hdr(struct io_kiocb *req,
291                                  struct io_async_msghdr *iomsg)
292 {
293         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
294         struct iovec __user *uiov;
295         size_t iov_len;
296         int ret;
297
298         ret = __copy_msghdr_from_user(&iomsg->msg, sr->umsg,
299                                         &iomsg->uaddr, &uiov, &iov_len);
300         if (ret)
301                 return ret;
302
303         if (req->flags & REQ_F_BUFFER_SELECT) {
304                 if (iov_len > 1)
305                         return -EINVAL;
306                 if (copy_from_user(iomsg->fast_iov, uiov, sizeof(*uiov)))
307                         return -EFAULT;
308                 sr->len = iomsg->fast_iov[0].iov_len;
309                 iomsg->free_iov = NULL;
310         } else {
311                 iomsg->free_iov = iomsg->fast_iov;
312                 ret = __import_iovec(READ, uiov, iov_len, UIO_FASTIOV,
313                                      &iomsg->free_iov, &iomsg->msg.msg_iter,
314                                      false);
315                 if (ret > 0)
316                         ret = 0;
317         }
318
319         return ret;
320 }
321
322 #ifdef CONFIG_COMPAT
323 static int __io_compat_recvmsg_copy_hdr(struct io_kiocb *req,
324                                         struct io_async_msghdr *iomsg)
325 {
326         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
327         struct compat_iovec __user *uiov;
328         compat_uptr_t ptr;
329         compat_size_t len;
330         int ret;
331
332         ret = __get_compat_msghdr(&iomsg->msg, sr->umsg_compat, &iomsg->uaddr,
333                                   &ptr, &len);
334         if (ret)
335                 return ret;
336
337         uiov = compat_ptr(ptr);
338         if (req->flags & REQ_F_BUFFER_SELECT) {
339                 compat_ssize_t clen;
340
341                 if (len > 1)
342                         return -EINVAL;
343                 if (!access_ok(uiov, sizeof(*uiov)))
344                         return -EFAULT;
345                 if (__get_user(clen, &uiov->iov_len))
346                         return -EFAULT;
347                 if (clen < 0)
348                         return -EINVAL;
349                 sr->len = clen;
350                 iomsg->free_iov = NULL;
351         } else {
352                 iomsg->free_iov = iomsg->fast_iov;
353                 ret = __import_iovec(READ, (struct iovec __user *)uiov, len,
354                                    UIO_FASTIOV, &iomsg->free_iov,
355                                    &iomsg->msg.msg_iter, true);
356                 if (ret < 0)
357                         return ret;
358         }
359
360         return 0;
361 }
362 #endif
363
364 static int io_recvmsg_copy_hdr(struct io_kiocb *req,
365                                struct io_async_msghdr *iomsg)
366 {
367         iomsg->msg.msg_name = &iomsg->addr;
368
369 #ifdef CONFIG_COMPAT
370         if (req->ctx->compat)
371                 return __io_compat_recvmsg_copy_hdr(req, iomsg);
372 #endif
373
374         return __io_recvmsg_copy_hdr(req, iomsg);
375 }
376
377 int io_recvmsg_prep_async(struct io_kiocb *req)
378 {
379         int ret;
380
381         ret = io_recvmsg_copy_hdr(req, req->async_data);
382         if (!ret)
383                 req->flags |= REQ_F_NEED_CLEANUP;
384         return ret;
385 }
386
387 int io_recvmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
388 {
389         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
390
391         if (unlikely(sqe->file_index || sqe->addr2))
392                 return -EINVAL;
393
394         sr->umsg = u64_to_user_ptr(READ_ONCE(sqe->addr));
395         sr->len = READ_ONCE(sqe->len);
396         sr->flags = READ_ONCE(sqe->ioprio);
397         if (sr->flags & ~IORING_RECVSEND_POLL_FIRST)
398                 return -EINVAL;
399         sr->msg_flags = READ_ONCE(sqe->msg_flags) | MSG_NOSIGNAL;
400         if (sr->msg_flags & MSG_DONTWAIT)
401                 req->flags |= REQ_F_NOWAIT;
402         if (sr->msg_flags & MSG_ERRQUEUE)
403                 req->flags |= REQ_F_CLEAR_POLLIN;
404
405 #ifdef CONFIG_COMPAT
406         if (req->ctx->compat)
407                 sr->msg_flags |= MSG_CMSG_COMPAT;
408 #endif
409         sr->done_io = 0;
410         return 0;
411 }
412
413 int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
414 {
415         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
416         struct io_async_msghdr iomsg, *kmsg;
417         struct socket *sock;
418         unsigned int cflags;
419         unsigned flags;
420         int ret, min_ret = 0;
421         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
422
423         sock = sock_from_file(req->file);
424         if (unlikely(!sock))
425                 return -ENOTSOCK;
426
427         if (req_has_async_data(req)) {
428                 kmsg = req->async_data;
429         } else {
430                 ret = io_recvmsg_copy_hdr(req, &iomsg);
431                 if (ret)
432                         return ret;
433                 kmsg = &iomsg;
434         }
435
436         if (!(req->flags & REQ_F_POLLED) &&
437             (sr->flags & IORING_RECVSEND_POLL_FIRST))
438                 return io_setup_async_msg(req, kmsg);
439
440         if (io_do_buffer_select(req)) {
441                 void __user *buf;
442
443                 buf = io_buffer_select(req, &sr->len, issue_flags);
444                 if (!buf)
445                         return -ENOBUFS;
446                 kmsg->fast_iov[0].iov_base = buf;
447                 kmsg->fast_iov[0].iov_len = sr->len;
448                 iov_iter_init(&kmsg->msg.msg_iter, READ, kmsg->fast_iov, 1,
449                                 sr->len);
450         }
451
452         flags = sr->msg_flags;
453         if (force_nonblock)
454                 flags |= MSG_DONTWAIT;
455         if (flags & MSG_WAITALL)
456                 min_ret = iov_iter_count(&kmsg->msg.msg_iter);
457
458         kmsg->msg.msg_get_inq = 1;
459         ret = __sys_recvmsg_sock(sock, &kmsg->msg, sr->umsg, kmsg->uaddr, flags);
460         if (ret < min_ret) {
461                 if (ret == -EAGAIN && force_nonblock)
462                         return io_setup_async_msg(req, kmsg);
463                 if (ret == -ERESTARTSYS)
464                         ret = -EINTR;
465                 if (ret > 0 && io_net_retry(sock, flags)) {
466                         sr->done_io += ret;
467                         req->flags |= REQ_F_PARTIAL_IO;
468                         return io_setup_async_msg(req, kmsg);
469                 }
470                 req_set_fail(req);
471         } else if ((flags & MSG_WAITALL) && (kmsg->msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
472                 req_set_fail(req);
473         }
474
475         /* fast path, check for non-NULL to avoid function call */
476         if (kmsg->free_iov)
477                 kfree(kmsg->free_iov);
478         req->flags &= ~REQ_F_NEED_CLEANUP;
479         if (ret >= 0)
480                 ret += sr->done_io;
481         else if (sr->done_io)
482                 ret = sr->done_io;
483         cflags = io_put_kbuf(req, issue_flags);
484         if (kmsg->msg.msg_inq)
485                 cflags |= IORING_CQE_F_SOCK_NONEMPTY;
486         io_req_set_res(req, ret, cflags);
487         return IOU_OK;
488 }
489
490 int io_recv(struct io_kiocb *req, unsigned int issue_flags)
491 {
492         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
493         struct msghdr msg;
494         struct socket *sock;
495         struct iovec iov;
496         unsigned int cflags;
497         unsigned flags;
498         int ret, min_ret = 0;
499         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
500
501         if (!(req->flags & REQ_F_POLLED) &&
502             (sr->flags & IORING_RECVSEND_POLL_FIRST))
503                 return -EAGAIN;
504
505         sock = sock_from_file(req->file);
506         if (unlikely(!sock))
507                 return -ENOTSOCK;
508
509         if (io_do_buffer_select(req)) {
510                 void __user *buf;
511
512                 buf = io_buffer_select(req, &sr->len, issue_flags);
513                 if (!buf)
514                         return -ENOBUFS;
515                 sr->buf = buf;
516         }
517
518         ret = import_single_range(READ, sr->buf, sr->len, &iov, &msg.msg_iter);
519         if (unlikely(ret))
520                 goto out_free;
521
522         msg.msg_name = NULL;
523         msg.msg_namelen = 0;
524         msg.msg_control = NULL;
525         msg.msg_get_inq = 1;
526         msg.msg_flags = 0;
527         msg.msg_controllen = 0;
528         msg.msg_iocb = NULL;
529
530         flags = sr->msg_flags;
531         if (force_nonblock)
532                 flags |= MSG_DONTWAIT;
533         if (flags & MSG_WAITALL)
534                 min_ret = iov_iter_count(&msg.msg_iter);
535
536         ret = sock_recvmsg(sock, &msg, flags);
537         if (ret < min_ret) {
538                 if (ret == -EAGAIN && force_nonblock)
539                         return -EAGAIN;
540                 if (ret == -ERESTARTSYS)
541                         ret = -EINTR;
542                 if (ret > 0 && io_net_retry(sock, flags)) {
543                         sr->len -= ret;
544                         sr->buf += ret;
545                         sr->done_io += ret;
546                         req->flags |= REQ_F_PARTIAL_IO;
547                         return -EAGAIN;
548                 }
549                 req_set_fail(req);
550         } else if ((flags & MSG_WAITALL) && (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
551 out_free:
552                 req_set_fail(req);
553         }
554
555         if (ret >= 0)
556                 ret += sr->done_io;
557         else if (sr->done_io)
558                 ret = sr->done_io;
559         cflags = io_put_kbuf(req, issue_flags);
560         if (msg.msg_inq)
561                 cflags |= IORING_CQE_F_SOCK_NONEMPTY;
562         io_req_set_res(req, ret, cflags);
563         return IOU_OK;
564 }
565
566 int io_accept_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
567 {
568         struct io_accept *accept = io_kiocb_to_cmd(req);
569         unsigned flags;
570
571         if (sqe->len || sqe->buf_index)
572                 return -EINVAL;
573
574         accept->addr = u64_to_user_ptr(READ_ONCE(sqe->addr));
575         accept->addr_len = u64_to_user_ptr(READ_ONCE(sqe->addr2));
576         accept->flags = READ_ONCE(sqe->accept_flags);
577         accept->nofile = rlimit(RLIMIT_NOFILE);
578         flags = READ_ONCE(sqe->ioprio);
579         if (flags & ~IORING_ACCEPT_MULTISHOT)
580                 return -EINVAL;
581
582         accept->file_slot = READ_ONCE(sqe->file_index);
583         if (accept->file_slot) {
584                 if (accept->flags & SOCK_CLOEXEC)
585                         return -EINVAL;
586                 if (flags & IORING_ACCEPT_MULTISHOT &&
587                     accept->file_slot != IORING_FILE_INDEX_ALLOC)
588                         return -EINVAL;
589         }
590         if (accept->flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK))
591                 return -EINVAL;
592         if (SOCK_NONBLOCK != O_NONBLOCK && (accept->flags & SOCK_NONBLOCK))
593                 accept->flags = (accept->flags & ~SOCK_NONBLOCK) | O_NONBLOCK;
594         if (flags & IORING_ACCEPT_MULTISHOT)
595                 req->flags |= REQ_F_APOLL_MULTISHOT;
596         return 0;
597 }
598
599 int io_accept(struct io_kiocb *req, unsigned int issue_flags)
600 {
601         struct io_ring_ctx *ctx = req->ctx;
602         struct io_accept *accept = io_kiocb_to_cmd(req);
603         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
604         unsigned int file_flags = force_nonblock ? O_NONBLOCK : 0;
605         bool fixed = !!accept->file_slot;
606         struct file *file;
607         int ret, fd;
608
609 retry:
610         if (!fixed) {
611                 fd = __get_unused_fd_flags(accept->flags, accept->nofile);
612                 if (unlikely(fd < 0))
613                         return fd;
614         }
615         file = do_accept(req->file, file_flags, accept->addr, accept->addr_len,
616                          accept->flags);
617         if (IS_ERR(file)) {
618                 if (!fixed)
619                         put_unused_fd(fd);
620                 ret = PTR_ERR(file);
621                 if (ret == -EAGAIN && force_nonblock) {
622                         /*
623                          * if it's multishot and polled, we don't need to
624                          * return EAGAIN to arm the poll infra since it
625                          * has already been done
626                          */
627                         if ((req->flags & IO_APOLL_MULTI_POLLED) ==
628                             IO_APOLL_MULTI_POLLED)
629                                 ret = IOU_ISSUE_SKIP_COMPLETE;
630                         return ret;
631                 }
632                 if (ret == -ERESTARTSYS)
633                         ret = -EINTR;
634                 req_set_fail(req);
635         } else if (!fixed) {
636                 fd_install(fd, file);
637                 ret = fd;
638         } else {
639                 ret = io_fixed_fd_install(req, issue_flags, file,
640                                                 accept->file_slot);
641         }
642
643         if (!(req->flags & REQ_F_APOLL_MULTISHOT)) {
644                 io_req_set_res(req, ret, 0);
645                 return IOU_OK;
646         }
647         if (ret >= 0) {
648                 bool filled;
649
650                 spin_lock(&ctx->completion_lock);
651                 filled = io_fill_cqe_aux(ctx, req->cqe.user_data, ret,
652                                          IORING_CQE_F_MORE);
653                 io_commit_cqring(ctx);
654                 spin_unlock(&ctx->completion_lock);
655                 if (filled) {
656                         io_cqring_ev_posted(ctx);
657                         goto retry;
658                 }
659                 ret = -ECANCELED;
660         }
661
662         return ret;
663 }
664
665 int io_socket_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
666 {
667         struct io_socket *sock = io_kiocb_to_cmd(req);
668
669         if (sqe->addr || sqe->rw_flags || sqe->buf_index)
670                 return -EINVAL;
671
672         sock->domain = READ_ONCE(sqe->fd);
673         sock->type = READ_ONCE(sqe->off);
674         sock->protocol = READ_ONCE(sqe->len);
675         sock->file_slot = READ_ONCE(sqe->file_index);
676         sock->nofile = rlimit(RLIMIT_NOFILE);
677
678         sock->flags = sock->type & ~SOCK_TYPE_MASK;
679         if (sock->file_slot && (sock->flags & SOCK_CLOEXEC))
680                 return -EINVAL;
681         if (sock->flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK))
682                 return -EINVAL;
683         return 0;
684 }
685
686 int io_socket(struct io_kiocb *req, unsigned int issue_flags)
687 {
688         struct io_socket *sock = io_kiocb_to_cmd(req);
689         bool fixed = !!sock->file_slot;
690         struct file *file;
691         int ret, fd;
692
693         if (!fixed) {
694                 fd = __get_unused_fd_flags(sock->flags, sock->nofile);
695                 if (unlikely(fd < 0))
696                         return fd;
697         }
698         file = __sys_socket_file(sock->domain, sock->type, sock->protocol);
699         if (IS_ERR(file)) {
700                 if (!fixed)
701                         put_unused_fd(fd);
702                 ret = PTR_ERR(file);
703                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
704                         return -EAGAIN;
705                 if (ret == -ERESTARTSYS)
706                         ret = -EINTR;
707                 req_set_fail(req);
708         } else if (!fixed) {
709                 fd_install(fd, file);
710                 ret = fd;
711         } else {
712                 ret = io_fixed_fd_install(req, issue_flags, file,
713                                             sock->file_slot);
714         }
715         io_req_set_res(req, ret, 0);
716         return IOU_OK;
717 }
718
719 int io_connect_prep_async(struct io_kiocb *req)
720 {
721         struct io_async_connect *io = req->async_data;
722         struct io_connect *conn = io_kiocb_to_cmd(req);
723
724         return move_addr_to_kernel(conn->addr, conn->addr_len, &io->address);
725 }
726
727 int io_connect_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
728 {
729         struct io_connect *conn = io_kiocb_to_cmd(req);
730
731         if (sqe->len || sqe->buf_index || sqe->rw_flags || sqe->splice_fd_in)
732                 return -EINVAL;
733
734         conn->addr = u64_to_user_ptr(READ_ONCE(sqe->addr));
735         conn->addr_len =  READ_ONCE(sqe->addr2);
736         return 0;
737 }
738
739 int io_connect(struct io_kiocb *req, unsigned int issue_flags)
740 {
741         struct io_connect *connect = io_kiocb_to_cmd(req);
742         struct io_async_connect __io, *io;
743         unsigned file_flags;
744         int ret;
745         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
746
747         if (req_has_async_data(req)) {
748                 io = req->async_data;
749         } else {
750                 ret = move_addr_to_kernel(connect->addr,
751                                                 connect->addr_len,
752                                                 &__io.address);
753                 if (ret)
754                         goto out;
755                 io = &__io;
756         }
757
758         file_flags = force_nonblock ? O_NONBLOCK : 0;
759
760         ret = __sys_connect_file(req->file, &io->address,
761                                         connect->addr_len, file_flags);
762         if ((ret == -EAGAIN || ret == -EINPROGRESS) && force_nonblock) {
763                 if (req_has_async_data(req))
764                         return -EAGAIN;
765                 if (io_alloc_async_data(req)) {
766                         ret = -ENOMEM;
767                         goto out;
768                 }
769                 memcpy(req->async_data, &__io, sizeof(__io));
770                 return -EAGAIN;
771         }
772         if (ret == -ERESTARTSYS)
773                 ret = -EINTR;
774 out:
775         if (ret < 0)
776                 req_set_fail(req);
777         io_req_set_res(req, ret, 0);
778         return IOU_OK;
779 }
780 #endif