Merge tag 'char-misc-5.14-rc3' of git://git.kernel.org/pub/scm/linux/kernel/git/gregk...
[linux-2.6-microblaze.git] / tools / testing / selftests / net / nettest.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* nettest - used for functional tests of networking APIs
3  *
4  * Copyright (c) 2013-2019 David Ahern <dsahern@gmail.com>. All rights reserved.
5  */
6
7 #define _GNU_SOURCE
8 #include <features.h>
9 #include <sys/types.h>
10 #include <sys/ioctl.h>
11 #include <sys/socket.h>
12 #include <sys/wait.h>
13 #include <linux/tcp.h>
14 #include <linux/udp.h>
15 #include <arpa/inet.h>
16 #include <net/if.h>
17 #include <netinet/in.h>
18 #include <netinet/ip.h>
19 #include <netdb.h>
20 #include <fcntl.h>
21 #include <libgen.h>
22 #include <limits.h>
23 #include <sched.h>
24 #include <stdarg.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <unistd.h>
29 #include <time.h>
30 #include <errno.h>
31
32 #include <linux/xfrm.h>
33 #include <linux/ipsec.h>
34 #include <linux/pfkeyv2.h>
35
36 #ifndef IPV6_UNICAST_IF
37 #define IPV6_UNICAST_IF         76
38 #endif
39 #ifndef IPV6_MULTICAST_IF
40 #define IPV6_MULTICAST_IF       17
41 #endif
42
43 #define DEFAULT_PORT 12345
44
45 #define NS_PREFIX "/run/netns/"
46
47 #ifndef MAX
48 #define MAX(a, b)  ((a) > (b) ? (a) : (b))
49 #endif
50 #ifndef MIN
51 #define MIN(a, b)  ((a) < (b) ? (a) : (b))
52 #endif
53
54 struct sock_args {
55         /* local address */
56         const char *local_addr_str;
57         const char *client_local_addr_str;
58         union {
59                 struct in_addr  in;
60                 struct in6_addr in6;
61         } local_addr;
62
63         /* remote address */
64         const char *remote_addr_str;
65         union {
66                 struct in_addr  in;
67                 struct in6_addr in6;
68         } remote_addr;
69         int scope_id;  /* remote scope; v6 send only */
70
71         struct in_addr grp;     /* multicast group */
72
73         unsigned int has_local_ip:1,
74                      has_remote_ip:1,
75                      has_grp:1,
76                      has_expected_laddr:1,
77                      has_expected_raddr:1,
78                      bind_test_only:1;
79
80         unsigned short port;
81
82         int type;      /* DGRAM, STREAM, RAW */
83         int protocol;
84         int version;   /* AF_INET/AF_INET6 */
85
86         int use_setsockopt;
87         int use_cmsg;
88         const char *dev;
89         const char *server_dev;
90         int ifindex;
91
92         const char *clientns;
93         const char *serverns;
94
95         const char *password;
96         const char *client_pw;
97         /* prefix for MD5 password */
98         const char *md5_prefix_str;
99         union {
100                 struct sockaddr_in v4;
101                 struct sockaddr_in6 v6;
102         } md5_prefix;
103         unsigned int prefix_len;
104
105         /* expected addresses and device index for connection */
106         const char *expected_dev;
107         const char *expected_server_dev;
108         int expected_ifindex;
109
110         /* local address */
111         const char *expected_laddr_str;
112         union {
113                 struct in_addr  in;
114                 struct in6_addr in6;
115         } expected_laddr;
116
117         /* remote address */
118         const char *expected_raddr_str;
119         union {
120                 struct in_addr  in;
121                 struct in6_addr in6;
122         } expected_raddr;
123
124         /* ESP in UDP encap test */
125         int use_xfrm;
126 };
127
128 static int server_mode;
129 static unsigned int prog_timeout = 5;
130 static unsigned int interactive;
131 static int iter = 1;
132 static char *msg = "Hello world!";
133 static int msglen;
134 static int quiet;
135 static int try_broadcast = 1;
136
137 static char *timestamp(char *timebuf, int buflen)
138 {
139         time_t now;
140
141         now = time(NULL);
142         if (strftime(timebuf, buflen, "%T", localtime(&now)) == 0) {
143                 memset(timebuf, 0, buflen);
144                 strncpy(timebuf, "00:00:00", buflen-1);
145         }
146
147         return timebuf;
148 }
149
150 static void log_msg(const char *format, ...)
151 {
152         char timebuf[64];
153         va_list args;
154
155         if (quiet)
156                 return;
157
158         fprintf(stdout, "%s %s:",
159                 timestamp(timebuf, sizeof(timebuf)),
160                 server_mode ? "server" : "client");
161         va_start(args, format);
162         vfprintf(stdout, format, args);
163         va_end(args);
164
165         fflush(stdout);
166 }
167
168 static void log_error(const char *format, ...)
169 {
170         char timebuf[64];
171         va_list args;
172
173         if (quiet)
174                 return;
175
176         fprintf(stderr, "%s %s:",
177                 timestamp(timebuf, sizeof(timebuf)),
178                 server_mode ? "server" : "client");
179         va_start(args, format);
180         vfprintf(stderr, format, args);
181         va_end(args);
182
183         fflush(stderr);
184 }
185
186 static void log_err_errno(const char *fmt, ...)
187 {
188         char timebuf[64];
189         va_list args;
190
191         if (quiet)
192                 return;
193
194         fprintf(stderr, "%s %s: ",
195                 timestamp(timebuf, sizeof(timebuf)),
196                 server_mode ? "server" : "client");
197         va_start(args, fmt);
198         vfprintf(stderr, fmt, args);
199         va_end(args);
200
201         fprintf(stderr, ": %d: %s\n", errno, strerror(errno));
202         fflush(stderr);
203 }
204
205 static void log_address(const char *desc, struct sockaddr *sa)
206 {
207         char addrstr[64];
208
209         if (quiet)
210                 return;
211
212         if (sa->sa_family == AF_INET) {
213                 struct sockaddr_in *s = (struct sockaddr_in *) sa;
214
215                 log_msg("%s %s:%d\n",
216                         desc,
217                         inet_ntop(AF_INET, &s->sin_addr, addrstr,
218                                   sizeof(addrstr)),
219                         ntohs(s->sin_port));
220
221         } else if (sa->sa_family == AF_INET6) {
222                 struct sockaddr_in6 *s6 = (struct sockaddr_in6 *) sa;
223
224                 log_msg("%s [%s]:%d\n",
225                         desc,
226                         inet_ntop(AF_INET6, &s6->sin6_addr, addrstr,
227                                   sizeof(addrstr)),
228                         ntohs(s6->sin6_port));
229         }
230
231         fflush(stdout);
232 }
233
234 static int switch_ns(const char *ns)
235 {
236         char path[PATH_MAX];
237         int fd, ret;
238
239         if (geteuid())
240                 log_error("warning: likely need root to set netns %s!\n", ns);
241
242         snprintf(path, sizeof(path), "%s%s", NS_PREFIX, ns);
243         fd = open(path, 0);
244         if (fd < 0) {
245                 log_err_errno("Failed to open netns path; can not switch netns");
246                 return 1;
247         }
248
249         ret = setns(fd, CLONE_NEWNET);
250         close(fd);
251
252         return ret;
253 }
254
255 static int tcp_md5sig(int sd, void *addr, socklen_t alen, struct sock_args *args)
256 {
257         int keylen = strlen(args->password);
258         struct tcp_md5sig md5sig = {};
259         int opt = TCP_MD5SIG;
260         int rc;
261
262         md5sig.tcpm_keylen = keylen;
263         memcpy(md5sig.tcpm_key, args->password, keylen);
264
265         if (args->prefix_len) {
266                 opt = TCP_MD5SIG_EXT;
267                 md5sig.tcpm_flags |= TCP_MD5SIG_FLAG_PREFIX;
268
269                 md5sig.tcpm_prefixlen = args->prefix_len;
270                 addr = &args->md5_prefix;
271         }
272         memcpy(&md5sig.tcpm_addr, addr, alen);
273
274         if (args->ifindex) {
275                 opt = TCP_MD5SIG_EXT;
276                 md5sig.tcpm_flags |= TCP_MD5SIG_FLAG_IFINDEX;
277
278                 md5sig.tcpm_ifindex = args->ifindex;
279         }
280
281         rc = setsockopt(sd, IPPROTO_TCP, opt, &md5sig, sizeof(md5sig));
282         if (rc < 0) {
283                 /* ENOENT is harmless. Returned when a password is cleared */
284                 if (errno == ENOENT)
285                         rc = 0;
286                 else
287                         log_err_errno("setsockopt(TCP_MD5SIG)");
288         }
289
290         return rc;
291 }
292
293 static int tcp_md5_remote(int sd, struct sock_args *args)
294 {
295         struct sockaddr_in sin = {
296                 .sin_family = AF_INET,
297         };
298         struct sockaddr_in6 sin6 = {
299                 .sin6_family = AF_INET6,
300         };
301         void *addr;
302         int alen;
303
304         switch (args->version) {
305         case AF_INET:
306                 sin.sin_port = htons(args->port);
307                 sin.sin_addr = args->md5_prefix.v4.sin_addr;
308                 addr = &sin;
309                 alen = sizeof(sin);
310                 break;
311         case AF_INET6:
312                 sin6.sin6_port = htons(args->port);
313                 sin6.sin6_addr = args->md5_prefix.v6.sin6_addr;
314                 addr = &sin6;
315                 alen = sizeof(sin6);
316                 break;
317         default:
318                 log_error("unknown address family\n");
319                 exit(1);
320         }
321
322         if (tcp_md5sig(sd, addr, alen, args))
323                 return -1;
324
325         return 0;
326 }
327
328 static int get_ifidx(const char *ifname)
329 {
330         struct ifreq ifdata;
331         int sd, rc;
332
333         if (!ifname || *ifname == '\0')
334                 return -1;
335
336         memset(&ifdata, 0, sizeof(ifdata));
337
338         strcpy(ifdata.ifr_name, ifname);
339
340         sd = socket(PF_INET, SOCK_DGRAM, IPPROTO_IP);
341         if (sd < 0) {
342                 log_err_errno("socket failed");
343                 return -1;
344         }
345
346         rc = ioctl(sd, SIOCGIFINDEX, (char *)&ifdata);
347         close(sd);
348         if (rc != 0) {
349                 log_err_errno("ioctl(SIOCGIFINDEX) failed");
350                 return -1;
351         }
352
353         return ifdata.ifr_ifindex;
354 }
355
356 static int bind_to_device(int sd, const char *name)
357 {
358         int rc;
359
360         rc = setsockopt(sd, SOL_SOCKET, SO_BINDTODEVICE, name, strlen(name)+1);
361         if (rc < 0)
362                 log_err_errno("setsockopt(SO_BINDTODEVICE)");
363
364         return rc;
365 }
366
367 static int get_bind_to_device(int sd, char *name, size_t len)
368 {
369         int rc;
370         socklen_t optlen = len;
371
372         name[0] = '\0';
373         rc = getsockopt(sd, SOL_SOCKET, SO_BINDTODEVICE, name, &optlen);
374         if (rc < 0)
375                 log_err_errno("setsockopt(SO_BINDTODEVICE)");
376
377         return rc;
378 }
379
380 static int check_device(int sd, struct sock_args *args)
381 {
382         int ifindex = 0;
383         char name[32];
384
385         if (get_bind_to_device(sd, name, sizeof(name)))
386                 *name = '\0';
387         else
388                 ifindex = get_ifidx(name);
389
390         log_msg("    bound to device %s/%d\n",
391                 *name ? name : "<none>", ifindex);
392
393         if (!args->expected_ifindex)
394                 return 0;
395
396         if (args->expected_ifindex != ifindex) {
397                 log_error("Device index mismatch: expected %d have %d\n",
398                           args->expected_ifindex, ifindex);
399                 return 1;
400         }
401
402         log_msg("Device index matches: expected %d have %d\n",
403                 args->expected_ifindex, ifindex);
404
405         return 0;
406 }
407
408 static int set_pktinfo_v4(int sd)
409 {
410         int one = 1;
411         int rc;
412
413         rc = setsockopt(sd, SOL_IP, IP_PKTINFO, &one, sizeof(one));
414         if (rc < 0 && rc != -ENOTSUP)
415                 log_err_errno("setsockopt(IP_PKTINFO)");
416
417         return rc;
418 }
419
420 static int set_recvpktinfo_v6(int sd)
421 {
422         int one = 1;
423         int rc;
424
425         rc = setsockopt(sd, SOL_IPV6, IPV6_RECVPKTINFO, &one, sizeof(one));
426         if (rc < 0 && rc != -ENOTSUP)
427                 log_err_errno("setsockopt(IPV6_RECVPKTINFO)");
428
429         return rc;
430 }
431
432 static int set_recverr_v4(int sd)
433 {
434         int one = 1;
435         int rc;
436
437         rc = setsockopt(sd, SOL_IP, IP_RECVERR, &one, sizeof(one));
438         if (rc < 0 && rc != -ENOTSUP)
439                 log_err_errno("setsockopt(IP_RECVERR)");
440
441         return rc;
442 }
443
444 static int set_recverr_v6(int sd)
445 {
446         int one = 1;
447         int rc;
448
449         rc = setsockopt(sd, SOL_IPV6, IPV6_RECVERR, &one, sizeof(one));
450         if (rc < 0 && rc != -ENOTSUP)
451                 log_err_errno("setsockopt(IPV6_RECVERR)");
452
453         return rc;
454 }
455
456 static int set_unicast_if(int sd, int ifindex, int version)
457 {
458         int opt = IP_UNICAST_IF;
459         int level = SOL_IP;
460         int rc;
461
462         ifindex = htonl(ifindex);
463
464         if (version == AF_INET6) {
465                 opt = IPV6_UNICAST_IF;
466                 level = SOL_IPV6;
467         }
468         rc = setsockopt(sd, level, opt, &ifindex, sizeof(ifindex));
469         if (rc < 0)
470                 log_err_errno("setsockopt(IP_UNICAST_IF)");
471
472         return rc;
473 }
474
475 static int set_multicast_if(int sd, int ifindex)
476 {
477         struct ip_mreqn mreq = { .imr_ifindex = ifindex };
478         int rc;
479
480         rc = setsockopt(sd, SOL_IP, IP_MULTICAST_IF, &mreq, sizeof(mreq));
481         if (rc < 0)
482                 log_err_errno("setsockopt(IP_MULTICAST_IF)");
483
484         return rc;
485 }
486
487 static int set_membership(int sd, uint32_t grp, uint32_t addr, int ifindex)
488 {
489         uint32_t if_addr = addr;
490         struct ip_mreqn mreq;
491         int rc;
492
493         if (addr == htonl(INADDR_ANY) && !ifindex) {
494                 log_error("Either local address or device needs to be given for multicast membership\n");
495                 return -1;
496         }
497
498         mreq.imr_multiaddr.s_addr = grp;
499         mreq.imr_address.s_addr = if_addr;
500         mreq.imr_ifindex = ifindex;
501
502         rc = setsockopt(sd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &mreq, sizeof(mreq));
503         if (rc < 0) {
504                 log_err_errno("setsockopt(IP_ADD_MEMBERSHIP)");
505                 return -1;
506         }
507
508         return 0;
509 }
510
511 static int set_broadcast(int sd)
512 {
513         unsigned int one = 1;
514         int rc = 0;
515
516         if (setsockopt(sd, SOL_SOCKET, SO_BROADCAST, &one, sizeof(one)) != 0) {
517                 log_err_errno("setsockopt(SO_BROADCAST)");
518                 rc = -1;
519         }
520
521         return rc;
522 }
523
524 static int set_reuseport(int sd)
525 {
526         unsigned int one = 1;
527         int rc = 0;
528
529         if (setsockopt(sd, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one)) != 0) {
530                 log_err_errno("setsockopt(SO_REUSEPORT)");
531                 rc = -1;
532         }
533
534         return rc;
535 }
536
537 static int set_reuseaddr(int sd)
538 {
539         unsigned int one = 1;
540         int rc = 0;
541
542         if (setsockopt(sd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) != 0) {
543                 log_err_errno("setsockopt(SO_REUSEADDR)");
544                 rc = -1;
545         }
546
547         return rc;
548 }
549
550 static int str_to_uint(const char *str, int min, int max, unsigned int *value)
551 {
552         int number;
553         char *end;
554
555         errno = 0;
556         number = (unsigned int) strtoul(str, &end, 0);
557
558         /* entire string should be consumed by conversion
559          * and value should be between min and max
560          */
561         if (((*end == '\0') || (*end == '\n')) && (end != str) &&
562             (errno != ERANGE) && (min <= number) && (number <= max)) {
563                 *value = number;
564                 return 0;
565         }
566
567         return -1;
568 }
569
570 static int resolve_devices(struct sock_args *args)
571 {
572         if (args->dev) {
573                 args->ifindex = get_ifidx(args->dev);
574                 if (args->ifindex < 0) {
575                         log_error("Invalid device name\n");
576                         return 1;
577                 }
578         }
579
580         if (args->expected_dev) {
581                 unsigned int tmp;
582
583                 if (str_to_uint(args->expected_dev, 0, INT_MAX, &tmp) == 0) {
584                         args->expected_ifindex = (int)tmp;
585                 } else {
586                         args->expected_ifindex = get_ifidx(args->expected_dev);
587                         if (args->expected_ifindex < 0) {
588                                 fprintf(stderr, "Invalid expected device\n");
589                                 return 1;
590                         }
591                 }
592         }
593
594         return 0;
595 }
596
597 static int expected_addr_match(struct sockaddr *sa, void *expected,
598                                const char *desc)
599 {
600         char addrstr[64];
601         int rc = 0;
602
603         if (sa->sa_family == AF_INET) {
604                 struct sockaddr_in *s = (struct sockaddr_in *) sa;
605                 struct in_addr *exp_in = (struct in_addr *) expected;
606
607                 if (s->sin_addr.s_addr != exp_in->s_addr) {
608                         log_error("%s address does not match expected %s\n",
609                                   desc,
610                                   inet_ntop(AF_INET, exp_in,
611                                             addrstr, sizeof(addrstr)));
612                         rc = 1;
613                 }
614         } else if (sa->sa_family == AF_INET6) {
615                 struct sockaddr_in6 *s6 = (struct sockaddr_in6 *) sa;
616                 struct in6_addr *exp_in = (struct in6_addr *) expected;
617
618                 if (memcmp(&s6->sin6_addr, exp_in, sizeof(*exp_in))) {
619                         log_error("%s address does not match expected %s\n",
620                                   desc,
621                                   inet_ntop(AF_INET6, exp_in,
622                                             addrstr, sizeof(addrstr)));
623                         rc = 1;
624                 }
625         } else {
626                 log_error("%s address does not match expected - unknown family\n",
627                           desc);
628                 rc = 1;
629         }
630
631         if (!rc)
632                 log_msg("%s address matches expected\n", desc);
633
634         return rc;
635 }
636
637 static int show_sockstat(int sd, struct sock_args *args)
638 {
639         struct sockaddr_in6 local_addr, remote_addr;
640         socklen_t alen = sizeof(local_addr);
641         struct sockaddr *sa;
642         const char *desc;
643         int rc = 0;
644
645         desc = server_mode ? "server local:" : "client local:";
646         sa = (struct sockaddr *) &local_addr;
647         if (getsockname(sd, sa, &alen) == 0) {
648                 log_address(desc, sa);
649
650                 if (args->has_expected_laddr) {
651                         rc = expected_addr_match(sa, &args->expected_laddr,
652                                                  "local");
653                 }
654         } else {
655                 log_err_errno("getsockname failed");
656         }
657
658         sa = (struct sockaddr *) &remote_addr;
659         desc = server_mode ? "server peer:" : "client peer:";
660         if (getpeername(sd, sa, &alen) == 0) {
661                 log_address(desc, sa);
662
663                 if (args->has_expected_raddr) {
664                         rc |= expected_addr_match(sa, &args->expected_raddr,
665                                                  "remote");
666                 }
667         } else {
668                 log_err_errno("getpeername failed");
669         }
670
671         return rc;
672 }
673
674 enum addr_type {
675         ADDR_TYPE_LOCAL,
676         ADDR_TYPE_REMOTE,
677         ADDR_TYPE_MCAST,
678         ADDR_TYPE_EXPECTED_LOCAL,
679         ADDR_TYPE_EXPECTED_REMOTE,
680         ADDR_TYPE_MD5_PREFIX,
681 };
682
683 static int convert_addr(struct sock_args *args, const char *_str,
684                         enum addr_type atype)
685 {
686         int pfx_len_max = args->version == AF_INET6 ? 128 : 32;
687         int family = args->version;
688         char *str, *dev, *sep;
689         struct in6_addr *in6;
690         struct in_addr  *in;
691         const char *desc;
692         void *addr;
693         int rc = 0;
694
695         str = strdup(_str);
696         if (!str)
697                 return -ENOMEM;
698
699         switch (atype) {
700         case ADDR_TYPE_LOCAL:
701                 desc = "local";
702                 addr = &args->local_addr;
703                 break;
704         case ADDR_TYPE_REMOTE:
705                 desc = "remote";
706                 addr = &args->remote_addr;
707                 break;
708         case ADDR_TYPE_MCAST:
709                 desc = "mcast grp";
710                 addr = &args->grp;
711                 break;
712         case ADDR_TYPE_EXPECTED_LOCAL:
713                 desc = "expected local";
714                 addr = &args->expected_laddr;
715                 break;
716         case ADDR_TYPE_EXPECTED_REMOTE:
717                 desc = "expected remote";
718                 addr = &args->expected_raddr;
719                 break;
720         case ADDR_TYPE_MD5_PREFIX:
721                 desc = "md5 prefix";
722                 if (family == AF_INET) {
723                         args->md5_prefix.v4.sin_family = AF_INET;
724                         addr = &args->md5_prefix.v4.sin_addr;
725                 } else if (family == AF_INET6) {
726                         args->md5_prefix.v6.sin6_family = AF_INET6;
727                         addr = &args->md5_prefix.v6.sin6_addr;
728                 } else
729                         return 1;
730
731                 sep = strchr(str, '/');
732                 if (sep) {
733                         *sep = '\0';
734                         sep++;
735                         if (str_to_uint(sep, 1, pfx_len_max,
736                                         &args->prefix_len) != 0) {
737                                 fprintf(stderr, "Invalid port\n");
738                                 return 1;
739                         }
740                 } else {
741                         args->prefix_len = 0;
742                 }
743                 break;
744         default:
745                 log_error("unknown address type\n");
746                 exit(1);
747         }
748
749         switch (family) {
750         case AF_INET:
751                 in  = (struct in_addr *) addr;
752                 if (str) {
753                         if (inet_pton(AF_INET, str, in) == 0) {
754                                 log_error("Invalid %s IP address\n", desc);
755                                 rc = -1;
756                                 goto out;
757                         }
758                 } else {
759                         in->s_addr = htonl(INADDR_ANY);
760                 }
761                 break;
762
763         case AF_INET6:
764                 dev = strchr(str, '%');
765                 if (dev) {
766                         *dev = '\0';
767                         dev++;
768                 }
769
770                 in6 = (struct in6_addr *) addr;
771                 if (str) {
772                         if (inet_pton(AF_INET6, str, in6) == 0) {
773                                 log_error("Invalid %s IPv6 address\n", desc);
774                                 rc = -1;
775                                 goto out;
776                         }
777                 } else {
778                         *in6 = in6addr_any;
779                 }
780                 if (dev) {
781                         args->scope_id = get_ifidx(dev);
782                         if (args->scope_id < 0) {
783                                 log_error("Invalid scope on %s IPv6 address\n",
784                                           desc);
785                                 rc = -1;
786                                 goto out;
787                         }
788                 }
789                 break;
790
791         default:
792                 log_error("Invalid address family\n");
793         }
794
795 out:
796         free(str);
797         return rc;
798 }
799
800 static int validate_addresses(struct sock_args *args)
801 {
802         if (args->local_addr_str &&
803             convert_addr(args, args->local_addr_str, ADDR_TYPE_LOCAL) < 0)
804                 return 1;
805
806         if (args->remote_addr_str &&
807             convert_addr(args, args->remote_addr_str, ADDR_TYPE_REMOTE) < 0)
808                 return 1;
809
810         if (args->md5_prefix_str &&
811             convert_addr(args, args->md5_prefix_str,
812                          ADDR_TYPE_MD5_PREFIX) < 0)
813                 return 1;
814
815         if (args->expected_laddr_str &&
816             convert_addr(args, args->expected_laddr_str,
817                          ADDR_TYPE_EXPECTED_LOCAL))
818                 return 1;
819
820         if (args->expected_raddr_str &&
821             convert_addr(args, args->expected_raddr_str,
822                          ADDR_TYPE_EXPECTED_REMOTE))
823                 return 1;
824
825         return 0;
826 }
827
828 static int get_index_from_cmsg(struct msghdr *m)
829 {
830         struct cmsghdr *cm;
831         int ifindex = 0;
832         char buf[64];
833
834         for (cm = (struct cmsghdr *)CMSG_FIRSTHDR(m);
835              m->msg_controllen != 0 && cm;
836              cm = (struct cmsghdr *)CMSG_NXTHDR(m, cm)) {
837
838                 if (cm->cmsg_level == SOL_IP &&
839                     cm->cmsg_type == IP_PKTINFO) {
840                         struct in_pktinfo *pi;
841
842                         pi = (struct in_pktinfo *)(CMSG_DATA(cm));
843                         inet_ntop(AF_INET, &pi->ipi_addr, buf, sizeof(buf));
844                         ifindex = pi->ipi_ifindex;
845                 } else if (cm->cmsg_level == SOL_IPV6 &&
846                            cm->cmsg_type == IPV6_PKTINFO) {
847                         struct in6_pktinfo *pi6;
848
849                         pi6 = (struct in6_pktinfo *)(CMSG_DATA(cm));
850                         inet_ntop(AF_INET6, &pi6->ipi6_addr, buf, sizeof(buf));
851                         ifindex = pi6->ipi6_ifindex;
852                 }
853         }
854
855         if (ifindex) {
856                 log_msg("    pktinfo: ifindex %d dest addr %s\n",
857                         ifindex, buf);
858         }
859         return ifindex;
860 }
861
862 static int send_msg_no_cmsg(int sd, void *addr, socklen_t alen)
863 {
864         int err;
865
866 again:
867         err = sendto(sd, msg, msglen, 0, addr, alen);
868         if (err < 0) {
869                 if (errno == EACCES && try_broadcast) {
870                         try_broadcast = 0;
871                         if (!set_broadcast(sd))
872                                 goto again;
873                         errno = EACCES;
874                 }
875
876                 log_err_errno("sendto failed");
877                 return 1;
878         }
879
880         return 0;
881 }
882
883 static int send_msg_cmsg(int sd, void *addr, socklen_t alen,
884                          int ifindex, int version)
885 {
886         unsigned char cmsgbuf[64];
887         struct iovec iov[2];
888         struct cmsghdr *cm;
889         struct msghdr m;
890         int err;
891
892         iov[0].iov_base = msg;
893         iov[0].iov_len = msglen;
894         m.msg_iov = iov;
895         m.msg_iovlen = 1;
896         m.msg_name = (caddr_t)addr;
897         m.msg_namelen = alen;
898
899         memset(cmsgbuf, 0, sizeof(cmsgbuf));
900         cm = (struct cmsghdr *)cmsgbuf;
901         m.msg_control = (caddr_t)cm;
902
903         if (version == AF_INET) {
904                 struct in_pktinfo *pi;
905
906                 cm->cmsg_level = SOL_IP;
907                 cm->cmsg_type = IP_PKTINFO;
908                 cm->cmsg_len = CMSG_LEN(sizeof(struct in_pktinfo));
909                 pi = (struct in_pktinfo *)(CMSG_DATA(cm));
910                 pi->ipi_ifindex = ifindex;
911
912                 m.msg_controllen = cm->cmsg_len;
913
914         } else if (version == AF_INET6) {
915                 struct in6_pktinfo *pi6;
916
917                 cm->cmsg_level = SOL_IPV6;
918                 cm->cmsg_type = IPV6_PKTINFO;
919                 cm->cmsg_len = CMSG_LEN(sizeof(struct in6_pktinfo));
920
921                 pi6 = (struct in6_pktinfo *)(CMSG_DATA(cm));
922                 pi6->ipi6_ifindex = ifindex;
923
924                 m.msg_controllen = cm->cmsg_len;
925         }
926
927 again:
928         err = sendmsg(sd, &m, 0);
929         if (err < 0) {
930                 if (errno == EACCES && try_broadcast) {
931                         try_broadcast = 0;
932                         if (!set_broadcast(sd))
933                                 goto again;
934                         errno = EACCES;
935                 }
936
937                 log_err_errno("sendmsg failed");
938                 return 1;
939         }
940
941         return 0;
942 }
943
944
945 static int send_msg(int sd, void *addr, socklen_t alen, struct sock_args *args)
946 {
947         if (args->type == SOCK_STREAM) {
948                 if (write(sd, msg, msglen) < 0) {
949                         log_err_errno("write failed sending msg to peer");
950                         return 1;
951                 }
952         } else if (args->ifindex && args->use_cmsg) {
953                 if (send_msg_cmsg(sd, addr, alen, args->ifindex, args->version))
954                         return 1;
955         } else {
956                 if (send_msg_no_cmsg(sd, addr, alen))
957                         return 1;
958         }
959
960         log_msg("Sent message:\n");
961         log_msg("    %.24s%s\n", msg, msglen > 24 ? " ..." : "");
962
963         return 0;
964 }
965
966 static int socket_read_dgram(int sd, struct sock_args *args)
967 {
968         unsigned char addr[sizeof(struct sockaddr_in6)];
969         struct sockaddr *sa = (struct sockaddr *) addr;
970         socklen_t alen = sizeof(addr);
971         struct iovec iov[2];
972         struct msghdr m = {
973                 .msg_name = (caddr_t)addr,
974                 .msg_namelen = alen,
975                 .msg_iov = iov,
976                 .msg_iovlen = 1,
977         };
978         unsigned char cmsgbuf[256];
979         struct cmsghdr *cm = (struct cmsghdr *)cmsgbuf;
980         char buf[16*1024];
981         int ifindex;
982         int len;
983
984         iov[0].iov_base = (caddr_t)buf;
985         iov[0].iov_len = sizeof(buf);
986
987         memset(cmsgbuf, 0, sizeof(cmsgbuf));
988         m.msg_control = (caddr_t)cm;
989         m.msg_controllen = sizeof(cmsgbuf);
990
991         len = recvmsg(sd, &m, 0);
992         if (len == 0) {
993                 log_msg("peer closed connection.\n");
994                 return 0;
995         } else if (len < 0) {
996                 log_msg("failed to read message: %d: %s\n",
997                         errno, strerror(errno));
998                 return -1;
999         }
1000
1001         buf[len] = '\0';
1002
1003         log_address("Message from:", sa);
1004         log_msg("    %.24s%s\n", buf, len > 24 ? " ..." : "");
1005
1006         ifindex = get_index_from_cmsg(&m);
1007         if (args->expected_ifindex) {
1008                 if (args->expected_ifindex != ifindex) {
1009                         log_error("Device index mismatch: expected %d have %d\n",
1010                                   args->expected_ifindex, ifindex);
1011                         return -1;
1012                 }
1013                 log_msg("Device index matches: expected %d have %d\n",
1014                         args->expected_ifindex, ifindex);
1015         }
1016
1017         if (!interactive && server_mode) {
1018                 if (sa->sa_family == AF_INET6) {
1019                         struct sockaddr_in6 *s6 = (struct sockaddr_in6 *) sa;
1020                         struct in6_addr *in6 = &s6->sin6_addr;
1021
1022                         if (IN6_IS_ADDR_V4MAPPED(in6)) {
1023                                 const uint32_t *pa = (uint32_t *) &in6->s6_addr;
1024                                 struct in_addr in4;
1025                                 struct sockaddr_in *sin;
1026
1027                                 sin = (struct sockaddr_in *) addr;
1028                                 pa += 3;
1029                                 in4.s_addr = *pa;
1030                                 sin->sin_addr = in4;
1031                                 sin->sin_family = AF_INET;
1032                                 if (send_msg_cmsg(sd, addr, alen,
1033                                                   ifindex, AF_INET) < 0)
1034                                         goto out_err;
1035                         }
1036                 }
1037 again:
1038                 iov[0].iov_len = len;
1039
1040                 if (args->version == AF_INET6) {
1041                         struct sockaddr_in6 *s6 = (struct sockaddr_in6 *) sa;
1042
1043                         if (args->dev) {
1044                                 /* avoid PKTINFO conflicts with bindtodev */
1045                                 if (sendto(sd, buf, len, 0,
1046                                            (void *) addr, alen) < 0)
1047                                         goto out_err;
1048                         } else {
1049                                 /* kernel is allowing scope_id to be set to VRF
1050                                  * index for LLA. for sends to global address
1051                                  * reset scope id
1052                                  */
1053                                 s6->sin6_scope_id = ifindex;
1054                                 if (sendmsg(sd, &m, 0) < 0)
1055                                         goto out_err;
1056                         }
1057                 } else {
1058                         int err;
1059
1060                         err = sendmsg(sd, &m, 0);
1061                         if (err < 0) {
1062                                 if (errno == EACCES && try_broadcast) {
1063                                         try_broadcast = 0;
1064                                         if (!set_broadcast(sd))
1065                                                 goto again;
1066                                         errno = EACCES;
1067                                 }
1068                                 goto out_err;
1069                         }
1070                 }
1071                 log_msg("Sent message:\n");
1072                 log_msg("    %.24s%s\n", buf, len > 24 ? " ..." : "");
1073         }
1074
1075         return 1;
1076 out_err:
1077         log_err_errno("failed to send msg to peer");
1078         return -1;
1079 }
1080
1081 static int socket_read_stream(int sd)
1082 {
1083         char buf[1024];
1084         int len;
1085
1086         len = read(sd, buf, sizeof(buf)-1);
1087         if (len == 0) {
1088                 log_msg("client closed connection.\n");
1089                 return 0;
1090         } else if (len < 0) {
1091                 log_msg("failed to read message\n");
1092                 return -1;
1093         }
1094
1095         buf[len] = '\0';
1096         log_msg("Incoming message:\n");
1097         log_msg("    %.24s%s\n", buf, len > 24 ? " ..." : "");
1098
1099         if (!interactive && server_mode) {
1100                 if (write(sd, buf, len) < 0) {
1101                         log_err_errno("failed to send buf");
1102                         return -1;
1103                 }
1104                 log_msg("Sent message:\n");
1105                 log_msg("     %.24s%s\n", buf, len > 24 ? " ..." : "");
1106         }
1107
1108         return 1;
1109 }
1110
1111 static int socket_read(int sd, struct sock_args *args)
1112 {
1113         if (args->type == SOCK_STREAM)
1114                 return socket_read_stream(sd);
1115
1116         return socket_read_dgram(sd, args);
1117 }
1118
1119 static int stdin_to_socket(int sd, int type, void *addr, socklen_t alen)
1120 {
1121         char buf[1024];
1122         int len;
1123
1124         if (fgets(buf, sizeof(buf), stdin) == NULL)
1125                 return 0;
1126
1127         len = strlen(buf);
1128         if (type == SOCK_STREAM) {
1129                 if (write(sd, buf, len) < 0) {
1130                         log_err_errno("failed to send buf");
1131                         return -1;
1132                 }
1133         } else {
1134                 int err;
1135
1136 again:
1137                 err = sendto(sd, buf, len, 0, addr, alen);
1138                 if (err < 0) {
1139                         if (errno == EACCES && try_broadcast) {
1140                                 try_broadcast = 0;
1141                                 if (!set_broadcast(sd))
1142                                         goto again;
1143                                 errno = EACCES;
1144                         }
1145                         log_err_errno("failed to send msg to peer");
1146                         return -1;
1147                 }
1148         }
1149         log_msg("Sent message:\n");
1150         log_msg("    %.24s%s\n", buf, len > 24 ? " ..." : "");
1151
1152         return 1;
1153 }
1154
1155 static void set_recv_attr(int sd, int version)
1156 {
1157         if (version == AF_INET6) {
1158                 set_recvpktinfo_v6(sd);
1159                 set_recverr_v6(sd);
1160         } else {
1161                 set_pktinfo_v4(sd);
1162                 set_recverr_v4(sd);
1163         }
1164 }
1165
1166 static int msg_loop(int client, int sd, void *addr, socklen_t alen,
1167                     struct sock_args *args)
1168 {
1169         struct timeval timeout = { .tv_sec = prog_timeout }, *ptval = NULL;
1170         fd_set rfds;
1171         int nfds;
1172         int rc;
1173
1174         if (args->type != SOCK_STREAM)
1175                 set_recv_attr(sd, args->version);
1176
1177         if (msg) {
1178                 msglen = strlen(msg);
1179
1180                 /* client sends first message */
1181                 if (client) {
1182                         if (send_msg(sd, addr, alen, args))
1183                                 return 1;
1184                 }
1185                 if (!interactive) {
1186                         ptval = &timeout;
1187                         if (!prog_timeout)
1188                                 timeout.tv_sec = 5;
1189                 }
1190         }
1191
1192         nfds = interactive ? MAX(fileno(stdin), sd)  + 1 : sd + 1;
1193         while (1) {
1194                 FD_ZERO(&rfds);
1195                 FD_SET(sd, &rfds);
1196                 if (interactive)
1197                         FD_SET(fileno(stdin), &rfds);
1198
1199                 rc = select(nfds, &rfds, NULL, NULL, ptval);
1200                 if (rc < 0) {
1201                         if (errno == EINTR)
1202                                 continue;
1203
1204                         rc = 1;
1205                         log_err_errno("select failed");
1206                         break;
1207                 } else if (rc == 0) {
1208                         log_error("Timed out waiting for response\n");
1209                         rc = 2;
1210                         break;
1211                 }
1212
1213                 if (FD_ISSET(sd, &rfds)) {
1214                         rc = socket_read(sd, args);
1215                         if (rc < 0) {
1216                                 rc = 1;
1217                                 break;
1218                         }
1219                         if (rc == 0)
1220                                 break;
1221                 }
1222
1223                 rc = 0;
1224
1225                 if (FD_ISSET(fileno(stdin), &rfds)) {
1226                         if (stdin_to_socket(sd, args->type, addr, alen) <= 0)
1227                                 break;
1228                 }
1229
1230                 if (interactive)
1231                         continue;
1232
1233                 if (iter != -1) {
1234                         --iter;
1235                         if (iter == 0)
1236                                 break;
1237                 }
1238
1239                 log_msg("Going into quiet mode\n");
1240                 quiet = 1;
1241
1242                 if (client) {
1243                         if (send_msg(sd, addr, alen, args)) {
1244                                 rc = 1;
1245                                 break;
1246                         }
1247                 }
1248         }
1249
1250         return rc;
1251 }
1252
1253 static int msock_init(struct sock_args *args, int server)
1254 {
1255         uint32_t if_addr = htonl(INADDR_ANY);
1256         struct sockaddr_in laddr = {
1257                 .sin_family = AF_INET,
1258                 .sin_port = htons(args->port),
1259         };
1260         int one = 1;
1261         int sd;
1262
1263         if (!server && args->has_local_ip)
1264                 if_addr = args->local_addr.in.s_addr;
1265
1266         sd = socket(PF_INET, SOCK_DGRAM, 0);
1267         if (sd < 0) {
1268                 log_err_errno("socket");
1269                 return -1;
1270         }
1271
1272         if (setsockopt(sd, SOL_SOCKET, SO_REUSEADDR,
1273                        (char *)&one, sizeof(one)) < 0) {
1274                 log_err_errno("Setting SO_REUSEADDR error");
1275                 goto out_err;
1276         }
1277
1278         if (setsockopt(sd, SOL_SOCKET, SO_BROADCAST,
1279                        (char *)&one, sizeof(one)) < 0)
1280                 log_err_errno("Setting SO_BROADCAST error");
1281
1282         if (args->dev && bind_to_device(sd, args->dev) != 0)
1283                 goto out_err;
1284         else if (args->use_setsockopt &&
1285                  set_multicast_if(sd, args->ifindex))
1286                 goto out_err;
1287
1288         laddr.sin_addr.s_addr = if_addr;
1289
1290         if (bind(sd, (struct sockaddr *) &laddr, sizeof(laddr)) < 0) {
1291                 log_err_errno("bind failed");
1292                 goto out_err;
1293         }
1294
1295         if (server &&
1296             set_membership(sd, args->grp.s_addr,
1297                            args->local_addr.in.s_addr, args->ifindex))
1298                 goto out_err;
1299
1300         return sd;
1301 out_err:
1302         close(sd);
1303         return -1;
1304 }
1305
1306 static int msock_server(struct sock_args *args)
1307 {
1308         return msock_init(args, 1);
1309 }
1310
1311 static int msock_client(struct sock_args *args)
1312 {
1313         return msock_init(args, 0);
1314 }
1315
1316 static int bind_socket(int sd, struct sock_args *args)
1317 {
1318         struct sockaddr_in serv_addr = {
1319                 .sin_family = AF_INET,
1320         };
1321         struct sockaddr_in6 serv6_addr = {
1322                 .sin6_family = AF_INET6,
1323         };
1324         void *addr;
1325         socklen_t alen;
1326
1327         if (!args->has_local_ip && args->type == SOCK_RAW)
1328                 return 0;
1329
1330         switch (args->version) {
1331         case AF_INET:
1332                 serv_addr.sin_port = htons(args->port);
1333                 serv_addr.sin_addr = args->local_addr.in;
1334                 addr = &serv_addr;
1335                 alen = sizeof(serv_addr);
1336                 break;
1337
1338         case AF_INET6:
1339                 serv6_addr.sin6_port = htons(args->port);
1340                 serv6_addr.sin6_addr = args->local_addr.in6;
1341                 addr = &serv6_addr;
1342                 alen = sizeof(serv6_addr);
1343                 break;
1344
1345         default:
1346                 log_error("Invalid address family\n");
1347                 return -1;
1348         }
1349
1350         if (bind(sd, addr, alen) < 0) {
1351                 log_err_errno("error binding socket");
1352                 return -1;
1353         }
1354
1355         return 0;
1356 }
1357
1358 static int config_xfrm_policy(int sd, struct sock_args *args)
1359 {
1360         struct xfrm_userpolicy_info policy = {};
1361         int type = UDP_ENCAP_ESPINUDP;
1362         int xfrm_af = IP_XFRM_POLICY;
1363         int level = SOL_IP;
1364
1365         if (args->type != SOCK_DGRAM) {
1366                 log_error("Invalid socket type. Only DGRAM could be used for XFRM\n");
1367                 return 1;
1368         }
1369
1370         policy.action = XFRM_POLICY_ALLOW;
1371         policy.sel.family = args->version;
1372         if (args->version == AF_INET6) {
1373                 xfrm_af = IPV6_XFRM_POLICY;
1374                 level = SOL_IPV6;
1375         }
1376
1377         policy.dir = XFRM_POLICY_OUT;
1378         if (setsockopt(sd, level, xfrm_af, &policy, sizeof(policy)) < 0)
1379                 return 1;
1380
1381         policy.dir = XFRM_POLICY_IN;
1382         if (setsockopt(sd, level, xfrm_af, &policy, sizeof(policy)) < 0)
1383                 return 1;
1384
1385         if (setsockopt(sd, IPPROTO_UDP, UDP_ENCAP, &type, sizeof(type)) < 0) {
1386                 log_err_errno("Failed to set xfrm encap");
1387                 return 1;
1388         }
1389
1390         return 0;
1391 }
1392
1393 static int lsock_init(struct sock_args *args)
1394 {
1395         long flags;
1396         int sd;
1397
1398         sd = socket(args->version, args->type, args->protocol);
1399         if (sd < 0) {
1400                 log_err_errno("Error opening socket");
1401                 return  -1;
1402         }
1403
1404         if (set_reuseaddr(sd) != 0)
1405                 goto err;
1406
1407         if (set_reuseport(sd) != 0)
1408                 goto err;
1409
1410         if (args->dev && bind_to_device(sd, args->dev) != 0)
1411                 goto err;
1412         else if (args->use_setsockopt &&
1413                  set_unicast_if(sd, args->ifindex, args->version))
1414                 goto err;
1415
1416         if (bind_socket(sd, args))
1417                 goto err;
1418
1419         if (args->bind_test_only)
1420                 goto out;
1421
1422         if (args->type == SOCK_STREAM && listen(sd, 1) < 0) {
1423                 log_err_errno("listen failed");
1424                 goto err;
1425         }
1426
1427         flags = fcntl(sd, F_GETFL);
1428         if ((flags < 0) || (fcntl(sd, F_SETFL, flags|O_NONBLOCK) < 0)) {
1429                 log_err_errno("Failed to set non-blocking option");
1430                 goto err;
1431         }
1432
1433         if (fcntl(sd, F_SETFD, FD_CLOEXEC) < 0)
1434                 log_err_errno("Failed to set close-on-exec flag");
1435
1436         if (args->use_xfrm && config_xfrm_policy(sd, args)) {
1437                 log_err_errno("Failed to set xfrm policy");
1438                 goto err;
1439         }
1440
1441 out:
1442         return sd;
1443
1444 err:
1445         close(sd);
1446         return -1;
1447 }
1448
1449 static void ipc_write(int fd, int message)
1450 {
1451         /* Not in both_mode, so there's no process to signal */
1452         if (fd < 0)
1453                 return;
1454
1455         if (write(fd, &message, sizeof(message)) < 0)
1456                 log_err_errno("Failed to send client status");
1457 }
1458
1459 static int do_server(struct sock_args *args, int ipc_fd)
1460 {
1461         /* ipc_fd = -1 if no parent process to signal */
1462         struct timeval timeout = { .tv_sec = prog_timeout }, *ptval = NULL;
1463         unsigned char addr[sizeof(struct sockaddr_in6)] = {};
1464         socklen_t alen = sizeof(addr);
1465         int lsd, csd = -1;
1466
1467         fd_set rfds;
1468         int rc;
1469
1470         if (args->serverns) {
1471                 if (switch_ns(args->serverns)) {
1472                         log_error("Could not set server netns to %s\n",
1473                                   args->serverns);
1474                         goto err_exit;
1475                 }
1476                 log_msg("Switched server netns\n");
1477         }
1478
1479         args->dev = args->server_dev;
1480         args->expected_dev = args->expected_server_dev;
1481         if (resolve_devices(args) || validate_addresses(args))
1482                 goto err_exit;
1483
1484         if (prog_timeout)
1485                 ptval = &timeout;
1486
1487         if (args->has_grp)
1488                 lsd = msock_server(args);
1489         else
1490                 lsd = lsock_init(args);
1491
1492         if (lsd < 0)
1493                 goto err_exit;
1494
1495         if (args->bind_test_only) {
1496                 close(lsd);
1497                 ipc_write(ipc_fd, 1);
1498                 return 0;
1499         }
1500
1501         if (args->type != SOCK_STREAM) {
1502                 ipc_write(ipc_fd, 1);
1503                 rc = msg_loop(0, lsd, (void *) addr, alen, args);
1504                 close(lsd);
1505                 return rc;
1506         }
1507
1508         if (args->password && tcp_md5_remote(lsd, args)) {
1509                 close(lsd);
1510                 goto err_exit;
1511         }
1512
1513         ipc_write(ipc_fd, 1);
1514         while (1) {
1515                 log_msg("waiting for client connection.\n");
1516                 FD_ZERO(&rfds);
1517                 FD_SET(lsd, &rfds);
1518
1519                 rc = select(lsd+1, &rfds, NULL, NULL, ptval);
1520                 if (rc == 0) {
1521                         rc = 2;
1522                         break;
1523                 }
1524
1525                 if (rc < 0) {
1526                         if (errno == EINTR)
1527                                 continue;
1528
1529                         log_err_errno("select failed");
1530                         break;
1531                 }
1532
1533                 if (FD_ISSET(lsd, &rfds)) {
1534
1535                         csd = accept(lsd, (void *) addr, &alen);
1536                         if (csd < 0) {
1537                                 log_err_errno("accept failed");
1538                                 break;
1539                         }
1540
1541                         rc = show_sockstat(csd, args);
1542                         if (rc)
1543                                 break;
1544
1545                         rc = check_device(csd, args);
1546                         if (rc)
1547                                 break;
1548                 }
1549
1550                 rc = msg_loop(0, csd, (void *) addr, alen, args);
1551                 close(csd);
1552
1553                 if (!interactive)
1554                         break;
1555         }
1556
1557         close(lsd);
1558
1559         return rc;
1560 err_exit:
1561         ipc_write(ipc_fd, 0);
1562         return 1;
1563 }
1564
1565 static int wait_for_connect(int sd)
1566 {
1567         struct timeval _tv = { .tv_sec = prog_timeout }, *tv = NULL;
1568         fd_set wfd;
1569         int val = 0, sz = sizeof(val);
1570         int rc;
1571
1572         FD_ZERO(&wfd);
1573         FD_SET(sd, &wfd);
1574
1575         if (prog_timeout)
1576                 tv = &_tv;
1577
1578         rc = select(FD_SETSIZE, NULL, &wfd, NULL, tv);
1579         if (rc == 0) {
1580                 log_error("connect timed out\n");
1581                 return -2;
1582         } else if (rc < 0) {
1583                 log_err_errno("select failed");
1584                 return -3;
1585         }
1586
1587         if (getsockopt(sd, SOL_SOCKET, SO_ERROR, &val, (socklen_t *)&sz) < 0) {
1588                 log_err_errno("getsockopt(SO_ERROR) failed");
1589                 return -4;
1590         }
1591
1592         if (val != 0) {
1593                 log_error("connect failed: %d: %s\n", val, strerror(val));
1594                 return -1;
1595         }
1596
1597         return 0;
1598 }
1599
1600 static int connectsock(void *addr, socklen_t alen, struct sock_args *args)
1601 {
1602         int sd, rc = -1;
1603         long flags;
1604
1605         sd = socket(args->version, args->type, args->protocol);
1606         if (sd < 0) {
1607                 log_err_errno("Failed to create socket");
1608                 return -1;
1609         }
1610
1611         flags = fcntl(sd, F_GETFL);
1612         if ((flags < 0) || (fcntl(sd, F_SETFL, flags|O_NONBLOCK) < 0)) {
1613                 log_err_errno("Failed to set non-blocking option");
1614                 goto err;
1615         }
1616
1617         if (set_reuseport(sd) != 0)
1618                 goto err;
1619
1620         if (args->dev && bind_to_device(sd, args->dev) != 0)
1621                 goto err;
1622         else if (args->use_setsockopt &&
1623                  set_unicast_if(sd, args->ifindex, args->version))
1624                 goto err;
1625
1626         if (args->has_local_ip && bind_socket(sd, args))
1627                 goto err;
1628
1629         if (args->type != SOCK_STREAM)
1630                 goto out;
1631
1632         if (args->password && tcp_md5sig(sd, addr, alen, args))
1633                 goto err;
1634
1635         if (args->bind_test_only)
1636                 goto out;
1637
1638         if (connect(sd, addr, alen) < 0) {
1639                 if (errno != EINPROGRESS) {
1640                         log_err_errno("Failed to connect to remote host");
1641                         rc = -1;
1642                         goto err;
1643                 }
1644                 rc = wait_for_connect(sd);
1645                 if (rc < 0)
1646                         goto err;
1647         }
1648 out:
1649         return sd;
1650
1651 err:
1652         close(sd);
1653         return rc;
1654 }
1655
1656 static int do_client(struct sock_args *args)
1657 {
1658         struct sockaddr_in sin = {
1659                 .sin_family = AF_INET,
1660         };
1661         struct sockaddr_in6 sin6 = {
1662                 .sin6_family = AF_INET6,
1663         };
1664         void *addr;
1665         int alen;
1666         int rc = 0;
1667         int sd;
1668
1669         if (!args->has_remote_ip && !args->has_grp) {
1670                 fprintf(stderr, "remote IP or multicast group not given\n");
1671                 return 1;
1672         }
1673
1674         if (args->clientns) {
1675                 if (switch_ns(args->clientns)) {
1676                         log_error("Could not set client netns to %s\n",
1677                                   args->clientns);
1678                         return 1;
1679                 }
1680                 log_msg("Switched client netns\n");
1681         }
1682
1683         args->local_addr_str = args->client_local_addr_str;
1684         if (resolve_devices(args) || validate_addresses(args))
1685                 return 1;
1686
1687         if ((args->use_setsockopt || args->use_cmsg) && !args->ifindex) {
1688                 fprintf(stderr, "Device binding not specified\n");
1689                 return 1;
1690         }
1691         if (args->use_setsockopt || args->use_cmsg)
1692                 args->dev = NULL;
1693
1694         switch (args->version) {
1695         case AF_INET:
1696                 sin.sin_port = htons(args->port);
1697                 if (args->has_grp)
1698                         sin.sin_addr = args->grp;
1699                 else
1700                         sin.sin_addr = args->remote_addr.in;
1701                 addr = &sin;
1702                 alen = sizeof(sin);
1703                 break;
1704         case AF_INET6:
1705                 sin6.sin6_port = htons(args->port);
1706                 sin6.sin6_addr = args->remote_addr.in6;
1707                 sin6.sin6_scope_id = args->scope_id;
1708                 addr = &sin6;
1709                 alen = sizeof(sin6);
1710                 break;
1711         }
1712
1713         args->password = args->client_pw;
1714
1715         if (args->has_grp)
1716                 sd = msock_client(args);
1717         else
1718                 sd = connectsock(addr, alen, args);
1719
1720         if (sd < 0)
1721                 return -sd;
1722
1723         if (args->bind_test_only)
1724                 goto out;
1725
1726         if (args->type == SOCK_STREAM) {
1727                 rc = show_sockstat(sd, args);
1728                 if (rc != 0)
1729                         goto out;
1730         }
1731
1732         rc = msg_loop(1, sd, addr, alen, args);
1733
1734 out:
1735         close(sd);
1736
1737         return rc;
1738 }
1739
1740 static char *random_msg(int len)
1741 {
1742         int i, n = 0, olen = len + 1;
1743         char *m;
1744
1745         if (len <= 0)
1746                 return NULL;
1747
1748         m = malloc(olen);
1749         if (!m)
1750                 return NULL;
1751
1752         while (len > 26) {
1753                 i = snprintf(m + n, olen - n, "%.26s",
1754                              "abcdefghijklmnopqrstuvwxyz");
1755                 n += i;
1756                 len -= i;
1757         }
1758         i = snprintf(m + n, olen - n, "%.*s", len,
1759                      "abcdefghijklmnopqrstuvwxyz");
1760         return m;
1761 }
1762
1763 static int ipc_child(int fd, struct sock_args *args)
1764 {
1765         char *outbuf, *errbuf;
1766         int rc = 1;
1767
1768         outbuf = malloc(4096);
1769         errbuf = malloc(4096);
1770         if (!outbuf || !errbuf) {
1771                 fprintf(stderr, "server: Failed to allocate buffers for stdout and stderr\n");
1772                 goto out;
1773         }
1774
1775         setbuffer(stdout, outbuf, 4096);
1776         setbuffer(stderr, errbuf, 4096);
1777
1778         server_mode = 1; /* to tell log_msg in case we are in both_mode */
1779
1780         /* when running in both mode, address validation applies
1781          * solely to client side
1782          */
1783         args->has_expected_laddr = 0;
1784         args->has_expected_raddr = 0;
1785
1786         rc = do_server(args, fd);
1787
1788 out:
1789         free(outbuf);
1790         free(errbuf);
1791
1792         return rc;
1793 }
1794
1795 static int ipc_parent(int cpid, int fd, struct sock_args *args)
1796 {
1797         int client_status;
1798         int status;
1799         int buf;
1800
1801         /* do the client-side function here in the parent process,
1802          * waiting to be told when to continue
1803          */
1804         if (read(fd, &buf, sizeof(buf)) <= 0) {
1805                 log_err_errno("Failed to read IPC status from status");
1806                 return 1;
1807         }
1808         if (!buf) {
1809                 log_error("Server failed; can not continue\n");
1810                 return 1;
1811         }
1812         log_msg("Server is ready\n");
1813
1814         client_status = do_client(args);
1815         log_msg("parent is done!\n");
1816
1817         if (kill(cpid, 0) == 0)
1818                 kill(cpid, SIGKILL);
1819
1820         wait(&status);
1821         return client_status;
1822 }
1823
1824 #define GETOPT_STR  "sr:l:c:p:t:g:P:DRn:M:X:m:d:I:BN:O:SCi6xL:0:1:2:3:Fbq"
1825
1826 static void print_usage(char *prog)
1827 {
1828         printf(
1829         "usage: %s OPTS\n"
1830         "Required:\n"
1831         "    -r addr       remote address to connect to (client mode only)\n"
1832         "    -p port       port to connect to (client mode)/listen on (server mode)\n"
1833         "                  (default: %d)\n"
1834         "    -s            server mode (default: client mode)\n"
1835         "    -t            timeout seconds (default: none)\n"
1836         "\n"
1837         "Optional:\n"
1838         "    -B            do both client and server via fork and IPC\n"
1839         "    -N ns         set client to network namespace ns (requires root)\n"
1840         "    -O ns         set server to network namespace ns (requires root)\n"
1841         "    -F            Restart server loop\n"
1842         "    -6            IPv6 (default is IPv4)\n"
1843         "    -P proto      protocol for socket: icmp, ospf (default: none)\n"
1844         "    -D|R          datagram (D) / raw (R) socket (default stream)\n"
1845         "    -l addr       local address to bind to in server mode\n"
1846         "    -c addr       local address to bind to in client mode\n"
1847         "    -x            configure XFRM policy on socket\n"
1848         "\n"
1849         "    -d dev        bind socket to given device name\n"
1850         "    -I dev        bind socket to given device name - server mode\n"
1851         "    -S            use setsockopt (IP_UNICAST_IF or IP_MULTICAST_IF)\n"
1852         "                  to set device binding\n"
1853         "    -C            use cmsg and IP_PKTINFO to specify device binding\n"
1854         "\n"
1855         "    -L len        send random message of given length\n"
1856         "    -n num        number of times to send message\n"
1857         "\n"
1858         "    -M password   use MD5 sum protection\n"
1859         "    -X password   MD5 password for client mode\n"
1860         "    -m prefix/len prefix and length to use for MD5 key\n"
1861         "    -g grp        multicast group (e.g., 239.1.1.1)\n"
1862         "    -i            interactive mode (default is echo and terminate)\n"
1863         "\n"
1864         "    -0 addr       Expected local address\n"
1865         "    -1 addr       Expected remote address\n"
1866         "    -2 dev        Expected device name (or index) to receive packet\n"
1867         "    -3 dev        Expected device name (or index) to receive packets - server mode\n"
1868         "\n"
1869         "    -b            Bind test only.\n"
1870         "    -q            Be quiet. Run test without printing anything.\n"
1871         , prog, DEFAULT_PORT);
1872 }
1873
1874 int main(int argc, char *argv[])
1875 {
1876         struct sock_args args = {
1877                 .version = AF_INET,
1878                 .type    = SOCK_STREAM,
1879                 .port    = DEFAULT_PORT,
1880         };
1881         struct protoent *pe;
1882         int both_mode = 0;
1883         unsigned int tmp;
1884         int forever = 0;
1885         int fd[2];
1886         int cpid;
1887
1888         /* process inputs */
1889         extern char *optarg;
1890         int rc = 0;
1891
1892         /*
1893          * process input args
1894          */
1895
1896         while ((rc = getopt(argc, argv, GETOPT_STR)) != -1) {
1897                 switch (rc) {
1898                 case 'B':
1899                         both_mode = 1;
1900                         break;
1901                 case 's':
1902                         server_mode = 1;
1903                         break;
1904                 case 'F':
1905                         forever = 1;
1906                         break;
1907                 case 'l':
1908                         args.has_local_ip = 1;
1909                         args.local_addr_str = optarg;
1910                         break;
1911                 case 'r':
1912                         args.has_remote_ip = 1;
1913                         args.remote_addr_str = optarg;
1914                         break;
1915                 case 'c':
1916                         args.has_local_ip = 1;
1917                         args.client_local_addr_str = optarg;
1918                         break;
1919                 case 'p':
1920                         if (str_to_uint(optarg, 1, 65535, &tmp) != 0) {
1921                                 fprintf(stderr, "Invalid port\n");
1922                                 return 1;
1923                         }
1924                         args.port = (unsigned short) tmp;
1925                         break;
1926                 case 't':
1927                         if (str_to_uint(optarg, 0, INT_MAX,
1928                                         &prog_timeout) != 0) {
1929                                 fprintf(stderr, "Invalid timeout\n");
1930                                 return 1;
1931                         }
1932                         break;
1933                 case 'D':
1934                         args.type = SOCK_DGRAM;
1935                         break;
1936                 case 'R':
1937                         args.type = SOCK_RAW;
1938                         args.port = 0;
1939                         if (!args.protocol)
1940                                 args.protocol = IPPROTO_RAW;
1941                         break;
1942                 case 'P':
1943                         pe = getprotobyname(optarg);
1944                         if (pe) {
1945                                 args.protocol = pe->p_proto;
1946                         } else {
1947                                 if (str_to_uint(optarg, 0, 0xffff, &tmp) != 0) {
1948                                         fprintf(stderr, "Invalid protocol\n");
1949                                         return 1;
1950                                 }
1951                                 args.protocol = tmp;
1952                         }
1953                         break;
1954                 case 'n':
1955                         iter = atoi(optarg);
1956                         break;
1957                 case 'N':
1958                         args.clientns = optarg;
1959                         break;
1960                 case 'O':
1961                         args.serverns = optarg;
1962                         break;
1963                 case 'L':
1964                         msg = random_msg(atoi(optarg));
1965                         break;
1966                 case 'M':
1967                         args.password = optarg;
1968                         break;
1969                 case 'X':
1970                         args.client_pw = optarg;
1971                         break;
1972                 case 'm':
1973                         args.md5_prefix_str = optarg;
1974                         break;
1975                 case 'S':
1976                         args.use_setsockopt = 1;
1977                         break;
1978                 case 'C':
1979                         args.use_cmsg = 1;
1980                         break;
1981                 case 'd':
1982                         args.dev = optarg;
1983                         break;
1984                 case 'I':
1985                         args.server_dev = optarg;
1986                         break;
1987                 case 'i':
1988                         interactive = 1;
1989                         break;
1990                 case 'g':
1991                         args.has_grp = 1;
1992                         if (convert_addr(&args, optarg, ADDR_TYPE_MCAST) < 0)
1993                                 return 1;
1994                         args.type = SOCK_DGRAM;
1995                         break;
1996                 case '6':
1997                         args.version = AF_INET6;
1998                         break;
1999                 case 'b':
2000                         args.bind_test_only = 1;
2001                         break;
2002                 case '0':
2003                         args.has_expected_laddr = 1;
2004                         args.expected_laddr_str = optarg;
2005                         break;
2006                 case '1':
2007                         args.has_expected_raddr = 1;
2008                         args.expected_raddr_str = optarg;
2009                         break;
2010                 case '2':
2011                         args.expected_dev = optarg;
2012                         break;
2013                 case '3':
2014                         args.expected_server_dev = optarg;
2015                         break;
2016                 case 'q':
2017                         quiet = 1;
2018                         break;
2019                 case 'x':
2020                         args.use_xfrm = 1;
2021                         break;
2022                 default:
2023                         print_usage(argv[0]);
2024                         return 1;
2025                 }
2026         }
2027
2028         if (args.password &&
2029             ((!args.has_remote_ip && !args.md5_prefix_str) ||
2030               args.type != SOCK_STREAM)) {
2031                 log_error("MD5 passwords apply to TCP only and require a remote ip for the password\n");
2032                 return 1;
2033         }
2034
2035         if (args.md5_prefix_str && !args.password) {
2036                 log_error("Prefix range for MD5 protection specified without a password\n");
2037                 return 1;
2038         }
2039
2040         if (iter == 0) {
2041                 fprintf(stderr, "Invalid number of messages to send\n");
2042                 return 1;
2043         }
2044
2045         if (args.type == SOCK_STREAM && !args.protocol)
2046                 args.protocol = IPPROTO_TCP;
2047         if (args.type == SOCK_DGRAM && !args.protocol)
2048                 args.protocol = IPPROTO_UDP;
2049
2050         if ((args.type == SOCK_STREAM || args.type == SOCK_DGRAM) &&
2051              args.port == 0) {
2052                 fprintf(stderr, "Invalid port number\n");
2053                 return 1;
2054         }
2055
2056         if ((both_mode || !server_mode) && !args.has_grp &&
2057             !args.has_remote_ip && !args.has_local_ip) {
2058                 fprintf(stderr,
2059                         "Local (server mode) or remote IP (client IP) required\n");
2060                 return 1;
2061         }
2062
2063         if (interactive) {
2064                 prog_timeout = 0;
2065                 msg = NULL;
2066         }
2067
2068         if (both_mode) {
2069                 if (pipe(fd) < 0) {
2070                         perror("pipe");
2071                         exit(1);
2072                 }
2073
2074                 cpid = fork();
2075                 if (cpid < 0) {
2076                         perror("fork");
2077                         exit(1);
2078                 }
2079                 if (cpid)
2080                         return ipc_parent(cpid, fd[0], &args);
2081
2082                 return ipc_child(fd[1], &args);
2083         }
2084
2085         if (server_mode) {
2086                 do {
2087                         rc = do_server(&args, -1);
2088                 } while (forever);
2089
2090                 return rc;
2091         }
2092         return do_client(&args);
2093 }