1 // SPDX-License-Identifier: GPL-2.0
18 #include <sys/sendfile.h>
20 #include <sys/socket.h>
21 #include <sys/types.h>
25 #include <netinet/in.h>
27 #include <linux/tcp.h>
32 #define IPPROTO_MPTCP 262
38 static int poll_timeout = 10 * 1000;
39 static bool listen_mode;
48 static enum cfg_mode cfg_mode = CFG_MODE_POLL;
49 static const char *cfg_host;
50 static const char *cfg_port = "12000";
51 static int cfg_sock_proto = IPPROTO_MPTCP;
52 static bool tcpulp_audit;
53 static int pf = AF_INET;
54 static int cfg_sndbuf;
55 static int cfg_rcvbuf;
59 static void die_usage(void)
61 fprintf(stderr, "Usage: mptcp_connect [-6] [-u] [-s MPTCP|TCP] [-p port] [-m mode]"
62 "[-l] [-w sec] connect_address\n");
63 fprintf(stderr, "\t-6 use ipv6\n");
64 fprintf(stderr, "\t-t num -- set poll timeout to num\n");
65 fprintf(stderr, "\t-S num -- set SO_SNDBUF to num\n");
66 fprintf(stderr, "\t-R num -- set SO_RCVBUF to num\n");
67 fprintf(stderr, "\t-p num -- use port num\n");
68 fprintf(stderr, "\t-s [MPTCP|TCP] -- use mptcp(default) or tcp sockets\n");
69 fprintf(stderr, "\t-m [poll|mmap|sendfile] -- use poll(default)/mmap+write/sendfile\n");
70 fprintf(stderr, "\t-u -- check mptcp ulp\n");
71 fprintf(stderr, "\t-w num -- wait num sec before closing the socket\n");
75 static void handle_signal(int nr)
80 static const char *getxinfo_strerr(int err)
82 if (err == EAI_SYSTEM)
83 return strerror(errno);
85 return gai_strerror(err);
88 static void xgetnameinfo(const struct sockaddr *addr, socklen_t addrlen,
89 char *host, socklen_t hostlen,
90 char *serv, socklen_t servlen)
92 int flags = NI_NUMERICHOST | NI_NUMERICSERV;
93 int err = getnameinfo(addr, addrlen, host, hostlen, serv, servlen,
97 const char *errstr = getxinfo_strerr(err);
99 fprintf(stderr, "Fatal: getnameinfo: %s\n", errstr);
104 static void xgetaddrinfo(const char *node, const char *service,
105 const struct addrinfo *hints,
106 struct addrinfo **res)
108 int err = getaddrinfo(node, service, hints, res);
111 const char *errstr = getxinfo_strerr(err);
113 fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
114 node ? node : "", service ? service : "", errstr);
119 static void set_rcvbuf(int fd, unsigned int size)
123 err = setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &size, sizeof(size));
125 perror("set SO_RCVBUF");
130 static void set_sndbuf(int fd, unsigned int size)
134 err = setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &size, sizeof(size));
136 perror("set SO_SNDBUF");
141 static int sock_listen_mptcp(const char * const listenaddr,
142 const char * const port)
145 struct addrinfo hints = {
146 .ai_protocol = IPPROTO_TCP,
147 .ai_socktype = SOCK_STREAM,
148 .ai_flags = AI_PASSIVE | AI_NUMERICHOST
151 hints.ai_family = pf;
153 struct addrinfo *a, *addr;
156 xgetaddrinfo(listenaddr, port, &hints, &addr);
157 hints.ai_family = pf;
159 for (a = addr; a; a = a->ai_next) {
160 sock = socket(a->ai_family, a->ai_socktype, cfg_sock_proto);
164 if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
166 perror("setsockopt");
168 if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
179 fprintf(stderr, "Could not create listen socket\n");
183 if (listen(sock, 20)) {
192 static bool sock_test_tcpulp(const char * const remoteaddr,
193 const char * const port)
195 struct addrinfo hints = {
196 .ai_protocol = IPPROTO_TCP,
197 .ai_socktype = SOCK_STREAM,
199 struct addrinfo *a, *addr;
200 int sock = -1, ret = 0;
201 bool test_pass = false;
203 hints.ai_family = AF_INET;
205 xgetaddrinfo(remoteaddr, port, &hints, &addr);
206 for (a = addr; a; a = a->ai_next) {
207 sock = socket(a->ai_family, a->ai_socktype, IPPROTO_TCP);
212 ret = setsockopt(sock, IPPROTO_TCP, TCP_ULP, "mptcp",
214 if (ret == -1 && errno == EOPNOTSUPP)
222 "setsockopt(TCP_ULP) returned 0\n");
224 perror("setsockopt(TCP_ULP)");
229 static int sock_connect_mptcp(const char * const remoteaddr,
230 const char * const port, int proto)
232 struct addrinfo hints = {
233 .ai_protocol = IPPROTO_TCP,
234 .ai_socktype = SOCK_STREAM,
236 struct addrinfo *a, *addr;
239 hints.ai_family = pf;
241 xgetaddrinfo(remoteaddr, port, &hints, &addr);
242 for (a = addr; a; a = a->ai_next) {
243 sock = socket(a->ai_family, a->ai_socktype, proto);
249 if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
261 static size_t do_rnd_write(const int fd, char *buf, const size_t len)
263 static bool first = true;
267 do_w = rand() & 0xffff;
268 if (do_w == 0 || do_w > len)
271 if (cfg_join && first && do_w > 100)
274 bw = write(fd, buf, do_w);
278 /* let the join handshake complete, before going on */
279 if (cfg_join && first) {
287 static size_t do_write(const int fd, char *buf, const size_t len)
291 while (offset < len) {
295 bw = write(fd, buf + offset, len - offset);
301 written = (size_t)bw;
308 static ssize_t do_rnd_read(const int fd, char *buf, const size_t len)
319 return read(fd, buf, cap);
322 static void set_nonblock(int fd)
324 int flags = fcntl(fd, F_GETFL);
329 fcntl(fd, F_SETFL, flags | O_NONBLOCK);
332 static int copyfd_io_poll(int infd, int peerfd, int outfd)
334 struct pollfd fds = {
336 .events = POLLIN | POLLOUT,
338 unsigned int woff = 0, wlen = 0;
341 set_nonblock(peerfd);
350 switch (poll(&fds, 1, poll_timeout)) {
357 fprintf(stderr, "%s: poll timed out (events: "
358 "POLLIN %u, POLLOUT %u)\n", __func__,
359 fds.events & POLLIN, fds.events & POLLOUT);
363 if (fds.revents & POLLIN) {
364 len = do_rnd_read(peerfd, rbuf, sizeof(rbuf));
366 /* no more data to receive:
367 * peer has closed its write side
369 fds.events &= ~POLLIN;
371 if ((fds.events & POLLOUT) == 0)
372 /* and nothing more to send */
375 /* Else, still have data to transmit */
376 } else if (len < 0) {
381 do_write(outfd, rbuf, len);
384 if (fds.revents & POLLOUT) {
387 wlen = read(infd, wbuf, sizeof(wbuf));
393 bw = do_rnd_write(peerfd, wbuf + woff, wlen);
399 } else if (wlen == 0) {
400 /* We have no more data to send. */
401 fds.events &= ~POLLOUT;
403 if ((fds.events & POLLIN) == 0)
404 /* ... and peer also closed already */
407 /* ... but we still receive.
408 * Close our write side, ev. give some time
409 * for address notification and/or checking
414 shutdown(peerfd, SHUT_WR);
423 if (fds.revents & (POLLERR | POLLNVAL)) {
424 fprintf(stderr, "Unexpected revents: "
425 "POLLERR/POLLNVAL(%x)\n", fds.revents);
430 /* leave some time for late join/announce */
438 static int do_recvfile(int infd, int outfd)
445 r = do_rnd_read(infd, buf, sizeof(buf));
447 if (write(outfd, buf, r) != r)
457 static int do_mmap(int infd, int outfd, unsigned int size)
459 char *inbuf = mmap(NULL, size, PROT_READ, MAP_SHARED, infd, 0);
460 ssize_t ret = 0, off = 0;
463 if (inbuf == MAP_FAILED) {
471 ret = write(outfd, inbuf + off, rem);
486 static int get_infd_size(int fd)
492 err = fstat(fd, &sb);
498 if ((sb.st_mode & S_IFMT) != S_IFREG) {
499 fprintf(stderr, "%s: stdin is not a regular file\n", __func__);
504 if (count > INT_MAX) {
505 fprintf(stderr, "File too large: %zu\n", count);
512 static int do_sendfile(int infd, int outfd, unsigned int count)
517 r = sendfile(outfd, infd, NULL, count);
529 static int copyfd_io_mmap(int infd, int peerfd, int outfd,
535 err = do_recvfile(peerfd, outfd);
539 err = do_mmap(infd, peerfd, size);
541 err = do_mmap(infd, peerfd, size);
545 shutdown(peerfd, SHUT_WR);
547 err = do_recvfile(peerfd, outfd);
553 static int copyfd_io_sendfile(int infd, int peerfd, int outfd,
559 err = do_recvfile(peerfd, outfd);
563 err = do_sendfile(infd, peerfd, size);
565 err = do_sendfile(infd, peerfd, size);
568 err = do_recvfile(peerfd, outfd);
574 static int copyfd_io(int infd, int peerfd, int outfd)
580 return copyfd_io_poll(infd, peerfd, outfd);
582 file_size = get_infd_size(infd);
585 return copyfd_io_mmap(infd, peerfd, outfd, file_size);
586 case CFG_MODE_SENDFILE:
587 file_size = get_infd_size(infd);
590 return copyfd_io_sendfile(infd, peerfd, outfd, file_size);
593 fprintf(stderr, "Invalid mode %d\n", cfg_mode);
599 static void check_sockaddr(int pf, struct sockaddr_storage *ss,
602 struct sockaddr_in6 *sin6;
603 struct sockaddr_in *sin;
604 socklen_t wanted_size = 0;
608 wanted_size = sizeof(*sin);
611 fprintf(stderr, "accept: something wrong: ip connection from port 0");
614 wanted_size = sizeof(*sin6);
616 if (!sin6->sin6_port)
617 fprintf(stderr, "accept: something wrong: ipv6 connection from port 0");
620 fprintf(stderr, "accept: Unknown pf %d, salen %u\n", pf, salen);
624 if (salen != wanted_size)
625 fprintf(stderr, "accept: size mismatch, got %d expected %d\n",
626 (int)salen, wanted_size);
628 if (ss->ss_family != pf)
629 fprintf(stderr, "accept: pf mismatch, expect %d, ss_family is %d\n",
630 (int)ss->ss_family, pf);
633 static void check_getpeername(int fd, struct sockaddr_storage *ss, socklen_t salen)
635 struct sockaddr_storage peerss;
636 socklen_t peersalen = sizeof(peerss);
638 if (getpeername(fd, (struct sockaddr *)&peerss, &peersalen) < 0) {
639 perror("getpeername");
643 if (peersalen != salen) {
644 fprintf(stderr, "%s: %d vs %d\n", __func__, peersalen, salen);
648 if (memcmp(ss, &peerss, peersalen)) {
649 char a[INET6_ADDRSTRLEN];
650 char b[INET6_ADDRSTRLEN];
651 char c[INET6_ADDRSTRLEN];
652 char d[INET6_ADDRSTRLEN];
654 xgetnameinfo((struct sockaddr *)ss, salen,
655 a, sizeof(a), b, sizeof(b));
657 xgetnameinfo((struct sockaddr *)&peerss, peersalen,
658 c, sizeof(c), d, sizeof(d));
660 fprintf(stderr, "%s: memcmp failure: accept %s vs peername %s, %s vs %s salen %d vs %d\n",
661 __func__, a, c, b, d, peersalen, salen);
665 static void check_getpeername_connect(int fd)
667 struct sockaddr_storage ss;
668 socklen_t salen = sizeof(ss);
669 char a[INET6_ADDRSTRLEN];
670 char b[INET6_ADDRSTRLEN];
672 if (getpeername(fd, (struct sockaddr *)&ss, &salen) < 0) {
673 perror("getpeername");
677 xgetnameinfo((struct sockaddr *)&ss, salen,
678 a, sizeof(a), b, sizeof(b));
680 if (strcmp(cfg_host, a) || strcmp(cfg_port, b))
681 fprintf(stderr, "%s: %s vs %s, %s vs %s\n", __func__,
682 cfg_host, a, cfg_port, b);
685 static void maybe_close(int fd)
687 unsigned int r = rand();
689 if (!cfg_join && (r & 1))
693 int main_loop_s(int listensock)
695 struct sockaddr_storage ss;
700 polls.fd = listensock;
701 polls.events = POLLIN;
703 switch (poll(&polls, 1, poll_timeout)) {
708 fprintf(stderr, "%s: timed out\n", __func__);
714 remotesock = accept(listensock, (struct sockaddr *)&ss, &salen);
715 if (remotesock >= 0) {
716 maybe_close(listensock);
717 check_sockaddr(pf, &ss, salen);
718 check_getpeername(remotesock, &ss, salen);
720 return copyfd_io(0, remotesock, 1);
728 static void init_rng(void)
730 int fd = open("/dev/urandom", O_RDONLY);
734 int ret = read(fd, &foo, sizeof(foo));
748 /* listener is ready. */
749 fd = sock_connect_mptcp(cfg_host, cfg_port, cfg_sock_proto);
753 check_getpeername_connect(fd);
756 set_rcvbuf(fd, cfg_rcvbuf);
758 set_sndbuf(fd, cfg_sndbuf);
760 return copyfd_io(0, fd, 1);
763 int parse_proto(const char *proto)
765 if (!strcasecmp(proto, "MPTCP"))
766 return IPPROTO_MPTCP;
767 if (!strcasecmp(proto, "TCP"))
770 fprintf(stderr, "Unknown protocol: %s\n.", proto);
773 /* silence compiler warning */
777 int parse_mode(const char *mode)
779 if (!strcasecmp(mode, "poll"))
780 return CFG_MODE_POLL;
781 if (!strcasecmp(mode, "mmap"))
782 return CFG_MODE_MMAP;
783 if (!strcasecmp(mode, "sendfile"))
784 return CFG_MODE_SENDFILE;
786 fprintf(stderr, "Unknown test mode: %s\n", mode);
787 fprintf(stderr, "Supported modes are:\n");
788 fprintf(stderr, "\t\t\"poll\" - interleaved read/write using poll()\n");
789 fprintf(stderr, "\t\t\"mmap\" - send entire input file (mmap+write), then read response (-l will read input first)\n");
790 fprintf(stderr, "\t\t\"sendfile\" - send entire input file (sendfile), then read response (-l will read input first)\n");
794 /* silence compiler warning */
798 static int parse_int(const char *size)
804 s = strtoul(size, NULL, 0);
807 fprintf(stderr, "Invalid sndbuf size %s (%s)\n",
808 size, strerror(errno));
813 fprintf(stderr, "Invalid sndbuf size %s (%s)\n",
814 size, strerror(ERANGE));
821 static void parse_opts(int argc, char **argv)
825 while ((c = getopt(argc, argv, "6jlp:s:hut:m:S:R:w:")) != -1) {
829 cfg_mode = CFG_MODE_POLL;
839 cfg_sock_proto = parse_proto(optarg);
851 poll_timeout = atoi(optarg) * 1000;
852 if (poll_timeout <= 0)
856 cfg_mode = parse_mode(optarg);
859 cfg_sndbuf = parse_int(optarg);
862 cfg_rcvbuf = parse_int(optarg);
865 cfg_wait = atoi(optarg)*1000000;
870 if (optind + 1 != argc)
872 cfg_host = argv[optind];
874 if (strchr(cfg_host, ':'))
878 int main(int argc, char *argv[])
882 signal(SIGUSR1, handle_signal);
883 parse_opts(argc, argv);
886 return sock_test_tcpulp(cfg_host, cfg_port) ? 0 : 1;
889 int fd = sock_listen_mptcp(cfg_host, cfg_port);
895 set_rcvbuf(fd, cfg_rcvbuf);
897 set_sndbuf(fd, cfg_sndbuf);
899 return main_loop_s(fd);