1 // SPDX-License-Identifier: GPL-2.0
18 #include <sys/sendfile.h>
20 #include <sys/socket.h>
21 #include <sys/types.h>
25 #include <netinet/in.h>
27 #include <linux/tcp.h>
32 #define IPPROTO_MPTCP 262
38 static int poll_timeout = 10 * 1000;
39 static bool listen_mode;
48 static enum cfg_mode cfg_mode = CFG_MODE_POLL;
49 static const char *cfg_host;
50 static const char *cfg_port = "12000";
51 static int cfg_sock_proto = IPPROTO_MPTCP;
52 static bool tcpulp_audit;
53 static int pf = AF_INET;
54 static int cfg_sndbuf;
55 static int cfg_rcvbuf;
57 static bool cfg_remove;
60 static void die_usage(void)
62 fprintf(stderr, "Usage: mptcp_connect [-6] [-u] [-s MPTCP|TCP] [-p port] [-m mode]"
63 "[-l] [-w sec] connect_address\n");
64 fprintf(stderr, "\t-6 use ipv6\n");
65 fprintf(stderr, "\t-t num -- set poll timeout to num\n");
66 fprintf(stderr, "\t-S num -- set SO_SNDBUF to num\n");
67 fprintf(stderr, "\t-R num -- set SO_RCVBUF to num\n");
68 fprintf(stderr, "\t-p num -- use port num\n");
69 fprintf(stderr, "\t-s [MPTCP|TCP] -- use mptcp(default) or tcp sockets\n");
70 fprintf(stderr, "\t-m [poll|mmap|sendfile] -- use poll(default)/mmap+write/sendfile\n");
71 fprintf(stderr, "\t-u -- check mptcp ulp\n");
72 fprintf(stderr, "\t-w num -- wait num sec before closing the socket\n");
76 static void handle_signal(int nr)
81 static const char *getxinfo_strerr(int err)
83 if (err == EAI_SYSTEM)
84 return strerror(errno);
86 return gai_strerror(err);
89 static void xgetnameinfo(const struct sockaddr *addr, socklen_t addrlen,
90 char *host, socklen_t hostlen,
91 char *serv, socklen_t servlen)
93 int flags = NI_NUMERICHOST | NI_NUMERICSERV;
94 int err = getnameinfo(addr, addrlen, host, hostlen, serv, servlen,
98 const char *errstr = getxinfo_strerr(err);
100 fprintf(stderr, "Fatal: getnameinfo: %s\n", errstr);
105 static void xgetaddrinfo(const char *node, const char *service,
106 const struct addrinfo *hints,
107 struct addrinfo **res)
109 int err = getaddrinfo(node, service, hints, res);
112 const char *errstr = getxinfo_strerr(err);
114 fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
115 node ? node : "", service ? service : "", errstr);
120 static void set_rcvbuf(int fd, unsigned int size)
124 err = setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &size, sizeof(size));
126 perror("set SO_RCVBUF");
131 static void set_sndbuf(int fd, unsigned int size)
135 err = setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &size, sizeof(size));
137 perror("set SO_SNDBUF");
142 static int sock_listen_mptcp(const char * const listenaddr,
143 const char * const port)
146 struct addrinfo hints = {
147 .ai_protocol = IPPROTO_TCP,
148 .ai_socktype = SOCK_STREAM,
149 .ai_flags = AI_PASSIVE | AI_NUMERICHOST
152 hints.ai_family = pf;
154 struct addrinfo *a, *addr;
157 xgetaddrinfo(listenaddr, port, &hints, &addr);
158 hints.ai_family = pf;
160 for (a = addr; a; a = a->ai_next) {
161 sock = socket(a->ai_family, a->ai_socktype, cfg_sock_proto);
165 if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
167 perror("setsockopt");
169 if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
180 fprintf(stderr, "Could not create listen socket\n");
184 if (listen(sock, 20)) {
193 static bool sock_test_tcpulp(const char * const remoteaddr,
194 const char * const port)
196 struct addrinfo hints = {
197 .ai_protocol = IPPROTO_TCP,
198 .ai_socktype = SOCK_STREAM,
200 struct addrinfo *a, *addr;
201 int sock = -1, ret = 0;
202 bool test_pass = false;
204 hints.ai_family = AF_INET;
206 xgetaddrinfo(remoteaddr, port, &hints, &addr);
207 for (a = addr; a; a = a->ai_next) {
208 sock = socket(a->ai_family, a->ai_socktype, IPPROTO_TCP);
213 ret = setsockopt(sock, IPPROTO_TCP, TCP_ULP, "mptcp",
215 if (ret == -1 && errno == EOPNOTSUPP)
223 "setsockopt(TCP_ULP) returned 0\n");
225 perror("setsockopt(TCP_ULP)");
230 static int sock_connect_mptcp(const char * const remoteaddr,
231 const char * const port, int proto)
233 struct addrinfo hints = {
234 .ai_protocol = IPPROTO_TCP,
235 .ai_socktype = SOCK_STREAM,
237 struct addrinfo *a, *addr;
240 hints.ai_family = pf;
242 xgetaddrinfo(remoteaddr, port, &hints, &addr);
243 for (a = addr; a; a = a->ai_next) {
244 sock = socket(a->ai_family, a->ai_socktype, proto);
250 if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
262 static size_t do_rnd_write(const int fd, char *buf, const size_t len)
264 static bool first = true;
268 do_w = rand() & 0xffff;
269 if (do_w == 0 || do_w > len)
272 if (cfg_join && first && do_w > 100)
275 if (cfg_remove && do_w > 50)
278 bw = write(fd, buf, do_w);
282 /* let the join handshake complete, before going on */
283 if (cfg_join && first) {
294 static size_t do_write(const int fd, char *buf, const size_t len)
298 while (offset < len) {
302 bw = write(fd, buf + offset, len - offset);
308 written = (size_t)bw;
315 static ssize_t do_rnd_read(const int fd, char *buf, const size_t len)
326 return read(fd, buf, cap);
329 static void set_nonblock(int fd)
331 int flags = fcntl(fd, F_GETFL);
336 fcntl(fd, F_SETFL, flags | O_NONBLOCK);
339 static int copyfd_io_poll(int infd, int peerfd, int outfd)
341 struct pollfd fds = {
343 .events = POLLIN | POLLOUT,
345 unsigned int woff = 0, wlen = 0;
348 set_nonblock(peerfd);
357 switch (poll(&fds, 1, poll_timeout)) {
364 fprintf(stderr, "%s: poll timed out (events: "
365 "POLLIN %u, POLLOUT %u)\n", __func__,
366 fds.events & POLLIN, fds.events & POLLOUT);
370 if (fds.revents & POLLIN) {
371 len = do_rnd_read(peerfd, rbuf, sizeof(rbuf));
373 /* no more data to receive:
374 * peer has closed its write side
376 fds.events &= ~POLLIN;
378 if ((fds.events & POLLOUT) == 0)
379 /* and nothing more to send */
382 /* Else, still have data to transmit */
383 } else if (len < 0) {
388 do_write(outfd, rbuf, len);
391 if (fds.revents & POLLOUT) {
394 wlen = read(infd, wbuf, sizeof(wbuf));
400 bw = do_rnd_write(peerfd, wbuf + woff, wlen);
406 } else if (wlen == 0) {
407 /* We have no more data to send. */
408 fds.events &= ~POLLOUT;
410 if ((fds.events & POLLIN) == 0)
411 /* ... and peer also closed already */
414 /* ... but we still receive.
415 * Close our write side, ev. give some time
416 * for address notification and/or checking
421 shutdown(peerfd, SHUT_WR);
430 if (fds.revents & (POLLERR | POLLNVAL)) {
431 fprintf(stderr, "Unexpected revents: "
432 "POLLERR/POLLNVAL(%x)\n", fds.revents);
437 /* leave some time for late join/announce */
438 if (cfg_join || cfg_remove)
445 static int do_recvfile(int infd, int outfd)
452 r = do_rnd_read(infd, buf, sizeof(buf));
454 if (write(outfd, buf, r) != r)
464 static int do_mmap(int infd, int outfd, unsigned int size)
466 char *inbuf = mmap(NULL, size, PROT_READ, MAP_SHARED, infd, 0);
467 ssize_t ret = 0, off = 0;
470 if (inbuf == MAP_FAILED) {
478 ret = write(outfd, inbuf + off, rem);
493 static int get_infd_size(int fd)
499 err = fstat(fd, &sb);
505 if ((sb.st_mode & S_IFMT) != S_IFREG) {
506 fprintf(stderr, "%s: stdin is not a regular file\n", __func__);
511 if (count > INT_MAX) {
512 fprintf(stderr, "File too large: %zu\n", count);
519 static int do_sendfile(int infd, int outfd, unsigned int count)
524 r = sendfile(outfd, infd, NULL, count);
536 static int copyfd_io_mmap(int infd, int peerfd, int outfd,
542 err = do_recvfile(peerfd, outfd);
546 err = do_mmap(infd, peerfd, size);
548 err = do_mmap(infd, peerfd, size);
552 shutdown(peerfd, SHUT_WR);
554 err = do_recvfile(peerfd, outfd);
560 static int copyfd_io_sendfile(int infd, int peerfd, int outfd,
566 err = do_recvfile(peerfd, outfd);
570 err = do_sendfile(infd, peerfd, size);
572 err = do_sendfile(infd, peerfd, size);
575 err = do_recvfile(peerfd, outfd);
581 static int copyfd_io(int infd, int peerfd, int outfd)
587 return copyfd_io_poll(infd, peerfd, outfd);
589 file_size = get_infd_size(infd);
592 return copyfd_io_mmap(infd, peerfd, outfd, file_size);
593 case CFG_MODE_SENDFILE:
594 file_size = get_infd_size(infd);
597 return copyfd_io_sendfile(infd, peerfd, outfd, file_size);
600 fprintf(stderr, "Invalid mode %d\n", cfg_mode);
606 static void check_sockaddr(int pf, struct sockaddr_storage *ss,
609 struct sockaddr_in6 *sin6;
610 struct sockaddr_in *sin;
611 socklen_t wanted_size = 0;
615 wanted_size = sizeof(*sin);
618 fprintf(stderr, "accept: something wrong: ip connection from port 0");
621 wanted_size = sizeof(*sin6);
623 if (!sin6->sin6_port)
624 fprintf(stderr, "accept: something wrong: ipv6 connection from port 0");
627 fprintf(stderr, "accept: Unknown pf %d, salen %u\n", pf, salen);
631 if (salen != wanted_size)
632 fprintf(stderr, "accept: size mismatch, got %d expected %d\n",
633 (int)salen, wanted_size);
635 if (ss->ss_family != pf)
636 fprintf(stderr, "accept: pf mismatch, expect %d, ss_family is %d\n",
637 (int)ss->ss_family, pf);
640 static void check_getpeername(int fd, struct sockaddr_storage *ss, socklen_t salen)
642 struct sockaddr_storage peerss;
643 socklen_t peersalen = sizeof(peerss);
645 if (getpeername(fd, (struct sockaddr *)&peerss, &peersalen) < 0) {
646 perror("getpeername");
650 if (peersalen != salen) {
651 fprintf(stderr, "%s: %d vs %d\n", __func__, peersalen, salen);
655 if (memcmp(ss, &peerss, peersalen)) {
656 char a[INET6_ADDRSTRLEN];
657 char b[INET6_ADDRSTRLEN];
658 char c[INET6_ADDRSTRLEN];
659 char d[INET6_ADDRSTRLEN];
661 xgetnameinfo((struct sockaddr *)ss, salen,
662 a, sizeof(a), b, sizeof(b));
664 xgetnameinfo((struct sockaddr *)&peerss, peersalen,
665 c, sizeof(c), d, sizeof(d));
667 fprintf(stderr, "%s: memcmp failure: accept %s vs peername %s, %s vs %s salen %d vs %d\n",
668 __func__, a, c, b, d, peersalen, salen);
672 static void check_getpeername_connect(int fd)
674 struct sockaddr_storage ss;
675 socklen_t salen = sizeof(ss);
676 char a[INET6_ADDRSTRLEN];
677 char b[INET6_ADDRSTRLEN];
679 if (getpeername(fd, (struct sockaddr *)&ss, &salen) < 0) {
680 perror("getpeername");
684 xgetnameinfo((struct sockaddr *)&ss, salen,
685 a, sizeof(a), b, sizeof(b));
687 if (strcmp(cfg_host, a) || strcmp(cfg_port, b))
688 fprintf(stderr, "%s: %s vs %s, %s vs %s\n", __func__,
689 cfg_host, a, cfg_port, b);
692 static void maybe_close(int fd)
694 unsigned int r = rand();
696 if (!(cfg_join || cfg_remove) && (r & 1))
700 int main_loop_s(int listensock)
702 struct sockaddr_storage ss;
707 polls.fd = listensock;
708 polls.events = POLLIN;
710 switch (poll(&polls, 1, poll_timeout)) {
715 fprintf(stderr, "%s: timed out\n", __func__);
721 remotesock = accept(listensock, (struct sockaddr *)&ss, &salen);
722 if (remotesock >= 0) {
723 maybe_close(listensock);
724 check_sockaddr(pf, &ss, salen);
725 check_getpeername(remotesock, &ss, salen);
727 return copyfd_io(0, remotesock, 1);
735 static void init_rng(void)
737 int fd = open("/dev/urandom", O_RDONLY);
741 int ret = read(fd, &foo, sizeof(foo));
755 /* listener is ready. */
756 fd = sock_connect_mptcp(cfg_host, cfg_port, cfg_sock_proto);
760 check_getpeername_connect(fd);
763 set_rcvbuf(fd, cfg_rcvbuf);
765 set_sndbuf(fd, cfg_sndbuf);
767 return copyfd_io(0, fd, 1);
770 int parse_proto(const char *proto)
772 if (!strcasecmp(proto, "MPTCP"))
773 return IPPROTO_MPTCP;
774 if (!strcasecmp(proto, "TCP"))
777 fprintf(stderr, "Unknown protocol: %s\n.", proto);
780 /* silence compiler warning */
784 int parse_mode(const char *mode)
786 if (!strcasecmp(mode, "poll"))
787 return CFG_MODE_POLL;
788 if (!strcasecmp(mode, "mmap"))
789 return CFG_MODE_MMAP;
790 if (!strcasecmp(mode, "sendfile"))
791 return CFG_MODE_SENDFILE;
793 fprintf(stderr, "Unknown test mode: %s\n", mode);
794 fprintf(stderr, "Supported modes are:\n");
795 fprintf(stderr, "\t\t\"poll\" - interleaved read/write using poll()\n");
796 fprintf(stderr, "\t\t\"mmap\" - send entire input file (mmap+write), then read response (-l will read input first)\n");
797 fprintf(stderr, "\t\t\"sendfile\" - send entire input file (sendfile), then read response (-l will read input first)\n");
801 /* silence compiler warning */
805 static int parse_int(const char *size)
811 s = strtoul(size, NULL, 0);
814 fprintf(stderr, "Invalid sndbuf size %s (%s)\n",
815 size, strerror(errno));
820 fprintf(stderr, "Invalid sndbuf size %s (%s)\n",
821 size, strerror(ERANGE));
828 static void parse_opts(int argc, char **argv)
832 while ((c = getopt(argc, argv, "6jrlp:s:hut:m:S:R:w:")) != -1) {
836 cfg_mode = CFG_MODE_POLL;
841 cfg_mode = CFG_MODE_POLL;
851 cfg_sock_proto = parse_proto(optarg);
863 poll_timeout = atoi(optarg) * 1000;
864 if (poll_timeout <= 0)
868 cfg_mode = parse_mode(optarg);
871 cfg_sndbuf = parse_int(optarg);
874 cfg_rcvbuf = parse_int(optarg);
877 cfg_wait = atoi(optarg)*1000000;
882 if (optind + 1 != argc)
884 cfg_host = argv[optind];
886 if (strchr(cfg_host, ':'))
890 int main(int argc, char *argv[])
894 signal(SIGUSR1, handle_signal);
895 parse_opts(argc, argv);
898 return sock_test_tcpulp(cfg_host, cfg_port) ? 0 : 1;
901 int fd = sock_listen_mptcp(cfg_host, cfg_port);
907 set_rcvbuf(fd, cfg_rcvbuf);
909 set_sndbuf(fd, cfg_sndbuf);
911 return main_loop_s(fd);