io_uring: initialise msghdr::msg_ubuf
[linux-2.6-microblaze.git] / io_uring / net.c
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/kernel.h>
3 #include <linux/errno.h>
4 #include <linux/file.h>
5 #include <linux/slab.h>
6 #include <linux/net.h>
7 #include <linux/compat.h>
8 #include <net/compat.h>
9 #include <linux/io_uring.h>
10
11 #include <uapi/linux/io_uring.h>
12
13 #include "io_uring.h"
14 #include "kbuf.h"
15 #include "alloc_cache.h"
16 #include "net.h"
17
18 #if defined(CONFIG_NET)
19 struct io_shutdown {
20         struct file                     *file;
21         int                             how;
22 };
23
24 struct io_accept {
25         struct file                     *file;
26         struct sockaddr __user          *addr;
27         int __user                      *addr_len;
28         int                             flags;
29         u32                             file_slot;
30         unsigned long                   nofile;
31 };
32
33 struct io_socket {
34         struct file                     *file;
35         int                             domain;
36         int                             type;
37         int                             protocol;
38         int                             flags;
39         u32                             file_slot;
40         unsigned long                   nofile;
41 };
42
43 struct io_connect {
44         struct file                     *file;
45         struct sockaddr __user          *addr;
46         int                             addr_len;
47 };
48
49 struct io_sr_msg {
50         struct file                     *file;
51         union {
52                 struct compat_msghdr __user     *umsg_compat;
53                 struct user_msghdr __user       *umsg;
54                 void __user                     *buf;
55         };
56         int                             msg_flags;
57         size_t                          len;
58         size_t                          done_io;
59         unsigned int                    flags;
60 };
61
62 #define IO_APOLL_MULTI_POLLED (REQ_F_APOLL_MULTISHOT | REQ_F_POLLED)
63
64 int io_shutdown_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
65 {
66         struct io_shutdown *shutdown = io_kiocb_to_cmd(req);
67
68         if (unlikely(sqe->off || sqe->addr || sqe->rw_flags ||
69                      sqe->buf_index || sqe->splice_fd_in))
70                 return -EINVAL;
71
72         shutdown->how = READ_ONCE(sqe->len);
73         return 0;
74 }
75
76 int io_shutdown(struct io_kiocb *req, unsigned int issue_flags)
77 {
78         struct io_shutdown *shutdown = io_kiocb_to_cmd(req);
79         struct socket *sock;
80         int ret;
81
82         if (issue_flags & IO_URING_F_NONBLOCK)
83                 return -EAGAIN;
84
85         sock = sock_from_file(req->file);
86         if (unlikely(!sock))
87                 return -ENOTSOCK;
88
89         ret = __sys_shutdown_sock(sock, shutdown->how);
90         io_req_set_res(req, ret, 0);
91         return IOU_OK;
92 }
93
94 static bool io_net_retry(struct socket *sock, int flags)
95 {
96         if (!(flags & MSG_WAITALL))
97                 return false;
98         return sock->type == SOCK_STREAM || sock->type == SOCK_SEQPACKET;
99 }
100
101 static void io_netmsg_recycle(struct io_kiocb *req, unsigned int issue_flags)
102 {
103         struct io_async_msghdr *hdr = req->async_data;
104
105         if (!hdr || issue_flags & IO_URING_F_UNLOCKED)
106                 return;
107
108         /* Let normal cleanup path reap it if we fail adding to the cache */
109         if (io_alloc_cache_put(&req->ctx->netmsg_cache, &hdr->cache)) {
110                 req->async_data = NULL;
111                 req->flags &= ~REQ_F_ASYNC_DATA;
112         }
113 }
114
115 static struct io_async_msghdr *io_recvmsg_alloc_async(struct io_kiocb *req,
116                                                       unsigned int issue_flags)
117 {
118         struct io_ring_ctx *ctx = req->ctx;
119         struct io_cache_entry *entry;
120
121         if (!(issue_flags & IO_URING_F_UNLOCKED) &&
122             (entry = io_alloc_cache_get(&ctx->netmsg_cache)) != NULL) {
123                 struct io_async_msghdr *hdr;
124
125                 hdr = container_of(entry, struct io_async_msghdr, cache);
126                 req->flags |= REQ_F_ASYNC_DATA;
127                 req->async_data = hdr;
128                 return hdr;
129         }
130
131         if (!io_alloc_async_data(req))
132                 return req->async_data;
133
134         return NULL;
135 }
136
137 static int io_setup_async_msg(struct io_kiocb *req,
138                               struct io_async_msghdr *kmsg,
139                               unsigned int issue_flags)
140 {
141         struct io_async_msghdr *async_msg = req->async_data;
142
143         if (async_msg)
144                 return -EAGAIN;
145         async_msg = io_recvmsg_alloc_async(req, issue_flags);
146         if (!async_msg) {
147                 kfree(kmsg->free_iov);
148                 return -ENOMEM;
149         }
150         req->flags |= REQ_F_NEED_CLEANUP;
151         memcpy(async_msg, kmsg, sizeof(*kmsg));
152         async_msg->msg.msg_name = &async_msg->addr;
153         /* if were using fast_iov, set it to the new one */
154         if (!async_msg->free_iov)
155                 async_msg->msg.msg_iter.iov = async_msg->fast_iov;
156
157         return -EAGAIN;
158 }
159
160 static int io_sendmsg_copy_hdr(struct io_kiocb *req,
161                                struct io_async_msghdr *iomsg)
162 {
163         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
164
165         iomsg->msg.msg_name = &iomsg->addr;
166         iomsg->free_iov = iomsg->fast_iov;
167         return sendmsg_copy_msghdr(&iomsg->msg, sr->umsg, sr->msg_flags,
168                                         &iomsg->free_iov);
169 }
170
171 int io_sendmsg_prep_async(struct io_kiocb *req)
172 {
173         int ret;
174
175         ret = io_sendmsg_copy_hdr(req, req->async_data);
176         if (!ret)
177                 req->flags |= REQ_F_NEED_CLEANUP;
178         return ret;
179 }
180
181 void io_sendmsg_recvmsg_cleanup(struct io_kiocb *req)
182 {
183         struct io_async_msghdr *io = req->async_data;
184
185         kfree(io->free_iov);
186 }
187
188 int io_sendmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
189 {
190         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
191
192         if (unlikely(sqe->file_index || sqe->addr2))
193                 return -EINVAL;
194
195         sr->umsg = u64_to_user_ptr(READ_ONCE(sqe->addr));
196         sr->len = READ_ONCE(sqe->len);
197         sr->flags = READ_ONCE(sqe->ioprio);
198         if (sr->flags & ~IORING_RECVSEND_POLL_FIRST)
199                 return -EINVAL;
200         sr->msg_flags = READ_ONCE(sqe->msg_flags) | MSG_NOSIGNAL;
201         if (sr->msg_flags & MSG_DONTWAIT)
202                 req->flags |= REQ_F_NOWAIT;
203
204 #ifdef CONFIG_COMPAT
205         if (req->ctx->compat)
206                 sr->msg_flags |= MSG_CMSG_COMPAT;
207 #endif
208         sr->done_io = 0;
209         return 0;
210 }
211
212 int io_sendmsg(struct io_kiocb *req, unsigned int issue_flags)
213 {
214         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
215         struct io_async_msghdr iomsg, *kmsg;
216         struct socket *sock;
217         unsigned flags;
218         int min_ret = 0;
219         int ret;
220
221         sock = sock_from_file(req->file);
222         if (unlikely(!sock))
223                 return -ENOTSOCK;
224
225         if (req_has_async_data(req)) {
226                 kmsg = req->async_data;
227         } else {
228                 ret = io_sendmsg_copy_hdr(req, &iomsg);
229                 if (ret)
230                         return ret;
231                 kmsg = &iomsg;
232         }
233
234         if (!(req->flags & REQ_F_POLLED) &&
235             (sr->flags & IORING_RECVSEND_POLL_FIRST))
236                 return io_setup_async_msg(req, kmsg, issue_flags);
237
238         flags = sr->msg_flags;
239         if (issue_flags & IO_URING_F_NONBLOCK)
240                 flags |= MSG_DONTWAIT;
241         if (flags & MSG_WAITALL)
242                 min_ret = iov_iter_count(&kmsg->msg.msg_iter);
243
244         ret = __sys_sendmsg_sock(sock, &kmsg->msg, flags);
245
246         if (ret < min_ret) {
247                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
248                         return io_setup_async_msg(req, kmsg, issue_flags);
249                 if (ret == -ERESTARTSYS)
250                         ret = -EINTR;
251                 if (ret > 0 && io_net_retry(sock, flags)) {
252                         sr->done_io += ret;
253                         req->flags |= REQ_F_PARTIAL_IO;
254                         return io_setup_async_msg(req, kmsg, issue_flags);
255                 }
256                 req_set_fail(req);
257         }
258         /* fast path, check for non-NULL to avoid function call */
259         if (kmsg->free_iov)
260                 kfree(kmsg->free_iov);
261         req->flags &= ~REQ_F_NEED_CLEANUP;
262         io_netmsg_recycle(req, issue_flags);
263         if (ret >= 0)
264                 ret += sr->done_io;
265         else if (sr->done_io)
266                 ret = sr->done_io;
267         io_req_set_res(req, ret, 0);
268         return IOU_OK;
269 }
270
271 int io_send(struct io_kiocb *req, unsigned int issue_flags)
272 {
273         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
274         struct msghdr msg;
275         struct iovec iov;
276         struct socket *sock;
277         unsigned flags;
278         int min_ret = 0;
279         int ret;
280
281         if (!(req->flags & REQ_F_POLLED) &&
282             (sr->flags & IORING_RECVSEND_POLL_FIRST))
283                 return -EAGAIN;
284
285         sock = sock_from_file(req->file);
286         if (unlikely(!sock))
287                 return -ENOTSOCK;
288
289         ret = import_single_range(WRITE, sr->buf, sr->len, &iov, &msg.msg_iter);
290         if (unlikely(ret))
291                 return ret;
292
293         msg.msg_name = NULL;
294         msg.msg_control = NULL;
295         msg.msg_controllen = 0;
296         msg.msg_namelen = 0;
297         msg.msg_ubuf = NULL;
298
299         flags = sr->msg_flags;
300         if (issue_flags & IO_URING_F_NONBLOCK)
301                 flags |= MSG_DONTWAIT;
302         if (flags & MSG_WAITALL)
303                 min_ret = iov_iter_count(&msg.msg_iter);
304
305         msg.msg_flags = flags;
306         ret = sock_sendmsg(sock, &msg);
307         if (ret < min_ret) {
308                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
309                         return -EAGAIN;
310                 if (ret == -ERESTARTSYS)
311                         ret = -EINTR;
312                 if (ret > 0 && io_net_retry(sock, flags)) {
313                         sr->len -= ret;
314                         sr->buf += ret;
315                         sr->done_io += ret;
316                         req->flags |= REQ_F_PARTIAL_IO;
317                         return -EAGAIN;
318                 }
319                 req_set_fail(req);
320         }
321         if (ret >= 0)
322                 ret += sr->done_io;
323         else if (sr->done_io)
324                 ret = sr->done_io;
325         io_req_set_res(req, ret, 0);
326         return IOU_OK;
327 }
328
329 static bool io_recvmsg_multishot_overflow(struct io_async_msghdr *iomsg)
330 {
331         int hdr;
332
333         if (iomsg->namelen < 0)
334                 return true;
335         if (check_add_overflow((int)sizeof(struct io_uring_recvmsg_out),
336                                iomsg->namelen, &hdr))
337                 return true;
338         if (check_add_overflow(hdr, (int)iomsg->controllen, &hdr))
339                 return true;
340
341         return false;
342 }
343
344 static int __io_recvmsg_copy_hdr(struct io_kiocb *req,
345                                  struct io_async_msghdr *iomsg)
346 {
347         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
348         struct user_msghdr msg;
349         int ret;
350
351         if (copy_from_user(&msg, sr->umsg, sizeof(*sr->umsg)))
352                 return -EFAULT;
353
354         ret = __copy_msghdr(&iomsg->msg, &msg, &iomsg->uaddr);
355         if (ret)
356                 return ret;
357
358         if (req->flags & REQ_F_BUFFER_SELECT) {
359                 if (msg.msg_iovlen == 0) {
360                         sr->len = iomsg->fast_iov[0].iov_len = 0;
361                         iomsg->fast_iov[0].iov_base = NULL;
362                         iomsg->free_iov = NULL;
363                 } else if (msg.msg_iovlen > 1) {
364                         return -EINVAL;
365                 } else {
366                         if (copy_from_user(iomsg->fast_iov, msg.msg_iov, sizeof(*msg.msg_iov)))
367                                 return -EFAULT;
368                         sr->len = iomsg->fast_iov[0].iov_len;
369                         iomsg->free_iov = NULL;
370                 }
371
372                 if (req->flags & REQ_F_APOLL_MULTISHOT) {
373                         iomsg->namelen = msg.msg_namelen;
374                         iomsg->controllen = msg.msg_controllen;
375                         if (io_recvmsg_multishot_overflow(iomsg))
376                                 return -EOVERFLOW;
377                 }
378         } else {
379                 iomsg->free_iov = iomsg->fast_iov;
380                 ret = __import_iovec(READ, msg.msg_iov, msg.msg_iovlen, UIO_FASTIOV,
381                                      &iomsg->free_iov, &iomsg->msg.msg_iter,
382                                      false);
383                 if (ret > 0)
384                         ret = 0;
385         }
386
387         return ret;
388 }
389
390 #ifdef CONFIG_COMPAT
391 static int __io_compat_recvmsg_copy_hdr(struct io_kiocb *req,
392                                         struct io_async_msghdr *iomsg)
393 {
394         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
395         struct compat_msghdr msg;
396         struct compat_iovec __user *uiov;
397         int ret;
398
399         if (copy_from_user(&msg, sr->umsg_compat, sizeof(msg)))
400                 return -EFAULT;
401
402         ret = __get_compat_msghdr(&iomsg->msg, &msg, &iomsg->uaddr);
403         if (ret)
404                 return ret;
405
406         uiov = compat_ptr(msg.msg_iov);
407         if (req->flags & REQ_F_BUFFER_SELECT) {
408                 compat_ssize_t clen;
409
410                 if (msg.msg_iovlen == 0) {
411                         sr->len = 0;
412                         iomsg->free_iov = NULL;
413                 } else if (msg.msg_iovlen > 1) {
414                         return -EINVAL;
415                 } else {
416                         if (!access_ok(uiov, sizeof(*uiov)))
417                                 return -EFAULT;
418                         if (__get_user(clen, &uiov->iov_len))
419                                 return -EFAULT;
420                         if (clen < 0)
421                                 return -EINVAL;
422                         sr->len = clen;
423                         iomsg->free_iov = NULL;
424                 }
425
426                 if (req->flags & REQ_F_APOLL_MULTISHOT) {
427                         iomsg->namelen = msg.msg_namelen;
428                         iomsg->controllen = msg.msg_controllen;
429                         if (io_recvmsg_multishot_overflow(iomsg))
430                                 return -EOVERFLOW;
431                 }
432         } else {
433                 iomsg->free_iov = iomsg->fast_iov;
434                 ret = __import_iovec(READ, (struct iovec __user *)uiov, msg.msg_iovlen,
435                                    UIO_FASTIOV, &iomsg->free_iov,
436                                    &iomsg->msg.msg_iter, true);
437                 if (ret < 0)
438                         return ret;
439         }
440
441         return 0;
442 }
443 #endif
444
445 static int io_recvmsg_copy_hdr(struct io_kiocb *req,
446                                struct io_async_msghdr *iomsg)
447 {
448         iomsg->msg.msg_name = &iomsg->addr;
449
450 #ifdef CONFIG_COMPAT
451         if (req->ctx->compat)
452                 return __io_compat_recvmsg_copy_hdr(req, iomsg);
453 #endif
454
455         return __io_recvmsg_copy_hdr(req, iomsg);
456 }
457
458 int io_recvmsg_prep_async(struct io_kiocb *req)
459 {
460         int ret;
461
462         ret = io_recvmsg_copy_hdr(req, req->async_data);
463         if (!ret)
464                 req->flags |= REQ_F_NEED_CLEANUP;
465         return ret;
466 }
467
468 #define RECVMSG_FLAGS (IORING_RECVSEND_POLL_FIRST | IORING_RECV_MULTISHOT)
469
470 int io_recvmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
471 {
472         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
473
474         if (unlikely(sqe->file_index || sqe->addr2))
475                 return -EINVAL;
476
477         sr->umsg = u64_to_user_ptr(READ_ONCE(sqe->addr));
478         sr->len = READ_ONCE(sqe->len);
479         sr->flags = READ_ONCE(sqe->ioprio);
480         if (sr->flags & ~(RECVMSG_FLAGS))
481                 return -EINVAL;
482         sr->msg_flags = READ_ONCE(sqe->msg_flags) | MSG_NOSIGNAL;
483         if (sr->msg_flags & MSG_DONTWAIT)
484                 req->flags |= REQ_F_NOWAIT;
485         if (sr->msg_flags & MSG_ERRQUEUE)
486                 req->flags |= REQ_F_CLEAR_POLLIN;
487         if (sr->flags & IORING_RECV_MULTISHOT) {
488                 if (!(req->flags & REQ_F_BUFFER_SELECT))
489                         return -EINVAL;
490                 if (sr->msg_flags & MSG_WAITALL)
491                         return -EINVAL;
492                 if (req->opcode == IORING_OP_RECV && sr->len)
493                         return -EINVAL;
494                 req->flags |= REQ_F_APOLL_MULTISHOT;
495         }
496
497 #ifdef CONFIG_COMPAT
498         if (req->ctx->compat)
499                 sr->msg_flags |= MSG_CMSG_COMPAT;
500 #endif
501         sr->done_io = 0;
502         return 0;
503 }
504
505 static inline void io_recv_prep_retry(struct io_kiocb *req)
506 {
507         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
508
509         sr->done_io = 0;
510         sr->len = 0; /* get from the provided buffer */
511 }
512
513 /*
514  * Finishes io_recv and io_recvmsg.
515  *
516  * Returns true if it is actually finished, or false if it should run
517  * again (for multishot).
518  */
519 static inline bool io_recv_finish(struct io_kiocb *req, int *ret,
520                                   unsigned int cflags, bool mshot_finished)
521 {
522         if (!(req->flags & REQ_F_APOLL_MULTISHOT)) {
523                 io_req_set_res(req, *ret, cflags);
524                 *ret = IOU_OK;
525                 return true;
526         }
527
528         if (!mshot_finished) {
529                 if (io_post_aux_cqe(req->ctx, req->cqe.user_data, *ret,
530                                     cflags | IORING_CQE_F_MORE, false)) {
531                         io_recv_prep_retry(req);
532                         return false;
533                 }
534                 /*
535                  * Otherwise stop multishot but use the current result.
536                  * Probably will end up going into overflow, but this means
537                  * we cannot trust the ordering anymore
538                  */
539         }
540
541         io_req_set_res(req, *ret, cflags);
542
543         if (req->flags & REQ_F_POLLED)
544                 *ret = IOU_STOP_MULTISHOT;
545         else
546                 *ret = IOU_OK;
547         return true;
548 }
549
550 static int io_recvmsg_prep_multishot(struct io_async_msghdr *kmsg,
551                                      struct io_sr_msg *sr, void __user **buf,
552                                      size_t *len)
553 {
554         unsigned long ubuf = (unsigned long) *buf;
555         unsigned long hdr;
556
557         hdr = sizeof(struct io_uring_recvmsg_out) + kmsg->namelen +
558                 kmsg->controllen;
559         if (*len < hdr)
560                 return -EFAULT;
561
562         if (kmsg->controllen) {
563                 unsigned long control = ubuf + hdr - kmsg->controllen;
564
565                 kmsg->msg.msg_control_user = (void *) control;
566                 kmsg->msg.msg_controllen = kmsg->controllen;
567         }
568
569         sr->buf = *buf; /* stash for later copy */
570         *buf = (void *) (ubuf + hdr);
571         kmsg->payloadlen = *len = *len - hdr;
572         return 0;
573 }
574
575 struct io_recvmsg_multishot_hdr {
576         struct io_uring_recvmsg_out msg;
577         struct sockaddr_storage addr;
578 };
579
580 static int io_recvmsg_multishot(struct socket *sock, struct io_sr_msg *io,
581                                 struct io_async_msghdr *kmsg,
582                                 unsigned int flags, bool *finished)
583 {
584         int err;
585         int copy_len;
586         struct io_recvmsg_multishot_hdr hdr;
587
588         if (kmsg->namelen)
589                 kmsg->msg.msg_name = &hdr.addr;
590         kmsg->msg.msg_flags = flags & (MSG_CMSG_CLOEXEC|MSG_CMSG_COMPAT);
591         kmsg->msg.msg_namelen = 0;
592
593         if (sock->file->f_flags & O_NONBLOCK)
594                 flags |= MSG_DONTWAIT;
595
596         err = sock_recvmsg(sock, &kmsg->msg, flags);
597         *finished = err <= 0;
598         if (err < 0)
599                 return err;
600
601         hdr.msg = (struct io_uring_recvmsg_out) {
602                 .controllen = kmsg->controllen - kmsg->msg.msg_controllen,
603                 .flags = kmsg->msg.msg_flags & ~MSG_CMSG_COMPAT
604         };
605
606         hdr.msg.payloadlen = err;
607         if (err > kmsg->payloadlen)
608                 err = kmsg->payloadlen;
609
610         copy_len = sizeof(struct io_uring_recvmsg_out);
611         if (kmsg->msg.msg_namelen > kmsg->namelen)
612                 copy_len += kmsg->namelen;
613         else
614                 copy_len += kmsg->msg.msg_namelen;
615
616         /*
617          *      "fromlen shall refer to the value before truncation.."
618          *                      1003.1g
619          */
620         hdr.msg.namelen = kmsg->msg.msg_namelen;
621
622         /* ensure that there is no gap between hdr and sockaddr_storage */
623         BUILD_BUG_ON(offsetof(struct io_recvmsg_multishot_hdr, addr) !=
624                      sizeof(struct io_uring_recvmsg_out));
625         if (copy_to_user(io->buf, &hdr, copy_len)) {
626                 *finished = true;
627                 return -EFAULT;
628         }
629
630         return sizeof(struct io_uring_recvmsg_out) + kmsg->namelen +
631                         kmsg->controllen + err;
632 }
633
634 int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
635 {
636         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
637         struct io_async_msghdr iomsg, *kmsg;
638         struct socket *sock;
639         unsigned int cflags;
640         unsigned flags;
641         int ret, min_ret = 0;
642         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
643         bool mshot_finished = true;
644
645         sock = sock_from_file(req->file);
646         if (unlikely(!sock))
647                 return -ENOTSOCK;
648
649         if (req_has_async_data(req)) {
650                 kmsg = req->async_data;
651         } else {
652                 ret = io_recvmsg_copy_hdr(req, &iomsg);
653                 if (ret)
654                         return ret;
655                 kmsg = &iomsg;
656         }
657
658         if (!(req->flags & REQ_F_POLLED) &&
659             (sr->flags & IORING_RECVSEND_POLL_FIRST))
660                 return io_setup_async_msg(req, kmsg, issue_flags);
661
662 retry_multishot:
663         if (io_do_buffer_select(req)) {
664                 void __user *buf;
665                 size_t len = sr->len;
666
667                 buf = io_buffer_select(req, &len, issue_flags);
668                 if (!buf)
669                         return -ENOBUFS;
670
671                 if (req->flags & REQ_F_APOLL_MULTISHOT) {
672                         ret = io_recvmsg_prep_multishot(kmsg, sr, &buf, &len);
673                         if (ret) {
674                                 io_kbuf_recycle(req, issue_flags);
675                                 return ret;
676                         }
677                 }
678
679                 kmsg->fast_iov[0].iov_base = buf;
680                 kmsg->fast_iov[0].iov_len = len;
681                 iov_iter_init(&kmsg->msg.msg_iter, READ, kmsg->fast_iov, 1,
682                                 len);
683         }
684
685         flags = sr->msg_flags;
686         if (force_nonblock)
687                 flags |= MSG_DONTWAIT;
688         if (flags & MSG_WAITALL)
689                 min_ret = iov_iter_count(&kmsg->msg.msg_iter);
690
691         kmsg->msg.msg_get_inq = 1;
692         if (req->flags & REQ_F_APOLL_MULTISHOT)
693                 ret = io_recvmsg_multishot(sock, sr, kmsg, flags,
694                                            &mshot_finished);
695         else
696                 ret = __sys_recvmsg_sock(sock, &kmsg->msg, sr->umsg,
697                                          kmsg->uaddr, flags);
698
699         if (ret < min_ret) {
700                 if (ret == -EAGAIN && force_nonblock) {
701                         ret = io_setup_async_msg(req, kmsg, issue_flags);
702                         if (ret == -EAGAIN && (req->flags & IO_APOLL_MULTI_POLLED) ==
703                                                IO_APOLL_MULTI_POLLED) {
704                                 io_kbuf_recycle(req, issue_flags);
705                                 return IOU_ISSUE_SKIP_COMPLETE;
706                         }
707                         return ret;
708                 }
709                 if (ret == -ERESTARTSYS)
710                         ret = -EINTR;
711                 if (ret > 0 && io_net_retry(sock, flags)) {
712                         sr->done_io += ret;
713                         req->flags |= REQ_F_PARTIAL_IO;
714                         return io_setup_async_msg(req, kmsg, issue_flags);
715                 }
716                 req_set_fail(req);
717         } else if ((flags & MSG_WAITALL) && (kmsg->msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
718                 req_set_fail(req);
719         }
720
721         if (ret > 0)
722                 ret += sr->done_io;
723         else if (sr->done_io)
724                 ret = sr->done_io;
725         else
726                 io_kbuf_recycle(req, issue_flags);
727
728         cflags = io_put_kbuf(req, issue_flags);
729         if (kmsg->msg.msg_inq)
730                 cflags |= IORING_CQE_F_SOCK_NONEMPTY;
731
732         if (!io_recv_finish(req, &ret, cflags, mshot_finished))
733                 goto retry_multishot;
734
735         if (mshot_finished) {
736                 io_netmsg_recycle(req, issue_flags);
737                 /* fast path, check for non-NULL to avoid function call */
738                 if (kmsg->free_iov)
739                         kfree(kmsg->free_iov);
740                 req->flags &= ~REQ_F_NEED_CLEANUP;
741         }
742
743         return ret;
744 }
745
746 int io_recv(struct io_kiocb *req, unsigned int issue_flags)
747 {
748         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
749         struct msghdr msg;
750         struct socket *sock;
751         struct iovec iov;
752         unsigned int cflags;
753         unsigned flags;
754         int ret, min_ret = 0;
755         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
756         size_t len = sr->len;
757
758         if (!(req->flags & REQ_F_POLLED) &&
759             (sr->flags & IORING_RECVSEND_POLL_FIRST))
760                 return -EAGAIN;
761
762         sock = sock_from_file(req->file);
763         if (unlikely(!sock))
764                 return -ENOTSOCK;
765
766 retry_multishot:
767         if (io_do_buffer_select(req)) {
768                 void __user *buf;
769
770                 buf = io_buffer_select(req, &len, issue_flags);
771                 if (!buf)
772                         return -ENOBUFS;
773                 sr->buf = buf;
774         }
775
776         ret = import_single_range(READ, sr->buf, len, &iov, &msg.msg_iter);
777         if (unlikely(ret))
778                 goto out_free;
779
780         msg.msg_name = NULL;
781         msg.msg_namelen = 0;
782         msg.msg_control = NULL;
783         msg.msg_get_inq = 1;
784         msg.msg_flags = 0;
785         msg.msg_controllen = 0;
786         msg.msg_iocb = NULL;
787         msg.msg_ubuf = NULL;
788
789         flags = sr->msg_flags;
790         if (force_nonblock)
791                 flags |= MSG_DONTWAIT;
792         if (flags & MSG_WAITALL)
793                 min_ret = iov_iter_count(&msg.msg_iter);
794
795         ret = sock_recvmsg(sock, &msg, flags);
796         if (ret < min_ret) {
797                 if (ret == -EAGAIN && force_nonblock) {
798                         if ((req->flags & IO_APOLL_MULTI_POLLED) == IO_APOLL_MULTI_POLLED) {
799                                 io_kbuf_recycle(req, issue_flags);
800                                 return IOU_ISSUE_SKIP_COMPLETE;
801                         }
802
803                         return -EAGAIN;
804                 }
805                 if (ret == -ERESTARTSYS)
806                         ret = -EINTR;
807                 if (ret > 0 && io_net_retry(sock, flags)) {
808                         sr->len -= ret;
809                         sr->buf += ret;
810                         sr->done_io += ret;
811                         req->flags |= REQ_F_PARTIAL_IO;
812                         return -EAGAIN;
813                 }
814                 req_set_fail(req);
815         } else if ((flags & MSG_WAITALL) && (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
816 out_free:
817                 req_set_fail(req);
818         }
819
820         if (ret > 0)
821                 ret += sr->done_io;
822         else if (sr->done_io)
823                 ret = sr->done_io;
824         else
825                 io_kbuf_recycle(req, issue_flags);
826
827         cflags = io_put_kbuf(req, issue_flags);
828         if (msg.msg_inq)
829                 cflags |= IORING_CQE_F_SOCK_NONEMPTY;
830
831         if (!io_recv_finish(req, &ret, cflags, ret <= 0))
832                 goto retry_multishot;
833
834         return ret;
835 }
836
837 int io_accept_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
838 {
839         struct io_accept *accept = io_kiocb_to_cmd(req);
840         unsigned flags;
841
842         if (sqe->len || sqe->buf_index)
843                 return -EINVAL;
844
845         accept->addr = u64_to_user_ptr(READ_ONCE(sqe->addr));
846         accept->addr_len = u64_to_user_ptr(READ_ONCE(sqe->addr2));
847         accept->flags = READ_ONCE(sqe->accept_flags);
848         accept->nofile = rlimit(RLIMIT_NOFILE);
849         flags = READ_ONCE(sqe->ioprio);
850         if (flags & ~IORING_ACCEPT_MULTISHOT)
851                 return -EINVAL;
852
853         accept->file_slot = READ_ONCE(sqe->file_index);
854         if (accept->file_slot) {
855                 if (accept->flags & SOCK_CLOEXEC)
856                         return -EINVAL;
857                 if (flags & IORING_ACCEPT_MULTISHOT &&
858                     accept->file_slot != IORING_FILE_INDEX_ALLOC)
859                         return -EINVAL;
860         }
861         if (accept->flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK))
862                 return -EINVAL;
863         if (SOCK_NONBLOCK != O_NONBLOCK && (accept->flags & SOCK_NONBLOCK))
864                 accept->flags = (accept->flags & ~SOCK_NONBLOCK) | O_NONBLOCK;
865         if (flags & IORING_ACCEPT_MULTISHOT)
866                 req->flags |= REQ_F_APOLL_MULTISHOT;
867         return 0;
868 }
869
870 int io_accept(struct io_kiocb *req, unsigned int issue_flags)
871 {
872         struct io_ring_ctx *ctx = req->ctx;
873         struct io_accept *accept = io_kiocb_to_cmd(req);
874         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
875         unsigned int file_flags = force_nonblock ? O_NONBLOCK : 0;
876         bool fixed = !!accept->file_slot;
877         struct file *file;
878         int ret, fd;
879
880 retry:
881         if (!fixed) {
882                 fd = __get_unused_fd_flags(accept->flags, accept->nofile);
883                 if (unlikely(fd < 0))
884                         return fd;
885         }
886         file = do_accept(req->file, file_flags, accept->addr, accept->addr_len,
887                          accept->flags);
888         if (IS_ERR(file)) {
889                 if (!fixed)
890                         put_unused_fd(fd);
891                 ret = PTR_ERR(file);
892                 if (ret == -EAGAIN && force_nonblock) {
893                         /*
894                          * if it's multishot and polled, we don't need to
895                          * return EAGAIN to arm the poll infra since it
896                          * has already been done
897                          */
898                         if ((req->flags & IO_APOLL_MULTI_POLLED) ==
899                             IO_APOLL_MULTI_POLLED)
900                                 ret = IOU_ISSUE_SKIP_COMPLETE;
901                         return ret;
902                 }
903                 if (ret == -ERESTARTSYS)
904                         ret = -EINTR;
905                 req_set_fail(req);
906         } else if (!fixed) {
907                 fd_install(fd, file);
908                 ret = fd;
909         } else {
910                 ret = io_fixed_fd_install(req, issue_flags, file,
911                                                 accept->file_slot);
912         }
913
914         if (!(req->flags & REQ_F_APOLL_MULTISHOT)) {
915                 io_req_set_res(req, ret, 0);
916                 return IOU_OK;
917         }
918
919         if (ret >= 0 &&
920             io_post_aux_cqe(ctx, req->cqe.user_data, ret, IORING_CQE_F_MORE, false))
921                 goto retry;
922
923         io_req_set_res(req, ret, 0);
924         if (req->flags & REQ_F_POLLED)
925                 return IOU_STOP_MULTISHOT;
926         return IOU_OK;
927 }
928
929 int io_socket_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
930 {
931         struct io_socket *sock = io_kiocb_to_cmd(req);
932
933         if (sqe->addr || sqe->rw_flags || sqe->buf_index)
934                 return -EINVAL;
935
936         sock->domain = READ_ONCE(sqe->fd);
937         sock->type = READ_ONCE(sqe->off);
938         sock->protocol = READ_ONCE(sqe->len);
939         sock->file_slot = READ_ONCE(sqe->file_index);
940         sock->nofile = rlimit(RLIMIT_NOFILE);
941
942         sock->flags = sock->type & ~SOCK_TYPE_MASK;
943         if (sock->file_slot && (sock->flags & SOCK_CLOEXEC))
944                 return -EINVAL;
945         if (sock->flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK))
946                 return -EINVAL;
947         return 0;
948 }
949
950 int io_socket(struct io_kiocb *req, unsigned int issue_flags)
951 {
952         struct io_socket *sock = io_kiocb_to_cmd(req);
953         bool fixed = !!sock->file_slot;
954         struct file *file;
955         int ret, fd;
956
957         if (!fixed) {
958                 fd = __get_unused_fd_flags(sock->flags, sock->nofile);
959                 if (unlikely(fd < 0))
960                         return fd;
961         }
962         file = __sys_socket_file(sock->domain, sock->type, sock->protocol);
963         if (IS_ERR(file)) {
964                 if (!fixed)
965                         put_unused_fd(fd);
966                 ret = PTR_ERR(file);
967                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
968                         return -EAGAIN;
969                 if (ret == -ERESTARTSYS)
970                         ret = -EINTR;
971                 req_set_fail(req);
972         } else if (!fixed) {
973                 fd_install(fd, file);
974                 ret = fd;
975         } else {
976                 ret = io_fixed_fd_install(req, issue_flags, file,
977                                             sock->file_slot);
978         }
979         io_req_set_res(req, ret, 0);
980         return IOU_OK;
981 }
982
983 int io_connect_prep_async(struct io_kiocb *req)
984 {
985         struct io_async_connect *io = req->async_data;
986         struct io_connect *conn = io_kiocb_to_cmd(req);
987
988         return move_addr_to_kernel(conn->addr, conn->addr_len, &io->address);
989 }
990
991 int io_connect_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
992 {
993         struct io_connect *conn = io_kiocb_to_cmd(req);
994
995         if (sqe->len || sqe->buf_index || sqe->rw_flags || sqe->splice_fd_in)
996                 return -EINVAL;
997
998         conn->addr = u64_to_user_ptr(READ_ONCE(sqe->addr));
999         conn->addr_len =  READ_ONCE(sqe->addr2);
1000         return 0;
1001 }
1002
1003 int io_connect(struct io_kiocb *req, unsigned int issue_flags)
1004 {
1005         struct io_connect *connect = io_kiocb_to_cmd(req);
1006         struct io_async_connect __io, *io;
1007         unsigned file_flags;
1008         int ret;
1009         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
1010
1011         if (req_has_async_data(req)) {
1012                 io = req->async_data;
1013         } else {
1014                 ret = move_addr_to_kernel(connect->addr,
1015                                                 connect->addr_len,
1016                                                 &__io.address);
1017                 if (ret)
1018                         goto out;
1019                 io = &__io;
1020         }
1021
1022         file_flags = force_nonblock ? O_NONBLOCK : 0;
1023
1024         ret = __sys_connect_file(req->file, &io->address,
1025                                         connect->addr_len, file_flags);
1026         if ((ret == -EAGAIN || ret == -EINPROGRESS) && force_nonblock) {
1027                 if (req_has_async_data(req))
1028                         return -EAGAIN;
1029                 if (io_alloc_async_data(req)) {
1030                         ret = -ENOMEM;
1031                         goto out;
1032                 }
1033                 memcpy(req->async_data, &__io, sizeof(__io));
1034                 return -EAGAIN;
1035         }
1036         if (ret == -ERESTARTSYS)
1037                 ret = -EINTR;
1038 out:
1039         if (ret < 0)
1040                 req_set_fail(req);
1041         io_req_set_res(req, ret, 0);
1042         return IOU_OK;
1043 }
1044
1045 void io_netmsg_cache_free(struct io_cache_entry *entry)
1046 {
1047         kfree(container_of(entry, struct io_async_msghdr, cache));
1048 }
1049 #endif