treewide: Replace GPLv2 boilerplate/reference with SPDX - rule 441
[linux-2.6-microblaze.git] / tools / testing / vsock / vsock_diag_test.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * vsock_diag_test - vsock_diag.ko test suite
4  *
5  * Copyright (C) 2017 Red Hat, Inc.
6  *
7  * Author: Stefan Hajnoczi <stefanha@redhat.com>
8  */
9
10 #include <getopt.h>
11 #include <stdio.h>
12 #include <stdbool.h>
13 #include <stdlib.h>
14 #include <string.h>
15 #include <errno.h>
16 #include <unistd.h>
17 #include <signal.h>
18 #include <sys/socket.h>
19 #include <sys/stat.h>
20 #include <sys/types.h>
21 #include <linux/list.h>
22 #include <linux/net.h>
23 #include <linux/netlink.h>
24 #include <linux/sock_diag.h>
25 #include <netinet/tcp.h>
26
27 #include "../../../include/uapi/linux/vm_sockets.h"
28 #include "../../../include/uapi/linux/vm_sockets_diag.h"
29
30 #include "timeout.h"
31 #include "control.h"
32
33 enum test_mode {
34         TEST_MODE_UNSET,
35         TEST_MODE_CLIENT,
36         TEST_MODE_SERVER
37 };
38
39 /* Per-socket status */
40 struct vsock_stat {
41         struct list_head list;
42         struct vsock_diag_msg msg;
43 };
44
45 static const char *sock_type_str(int type)
46 {
47         switch (type) {
48         case SOCK_DGRAM:
49                 return "DGRAM";
50         case SOCK_STREAM:
51                 return "STREAM";
52         default:
53                 return "INVALID TYPE";
54         }
55 }
56
57 static const char *sock_state_str(int state)
58 {
59         switch (state) {
60         case TCP_CLOSE:
61                 return "UNCONNECTED";
62         case TCP_SYN_SENT:
63                 return "CONNECTING";
64         case TCP_ESTABLISHED:
65                 return "CONNECTED";
66         case TCP_CLOSING:
67                 return "DISCONNECTING";
68         case TCP_LISTEN:
69                 return "LISTEN";
70         default:
71                 return "INVALID STATE";
72         }
73 }
74
75 static const char *sock_shutdown_str(int shutdown)
76 {
77         switch (shutdown) {
78         case 1:
79                 return "RCV_SHUTDOWN";
80         case 2:
81                 return "SEND_SHUTDOWN";
82         case 3:
83                 return "RCV_SHUTDOWN | SEND_SHUTDOWN";
84         default:
85                 return "0";
86         }
87 }
88
89 static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
90 {
91         if (cid == VMADDR_CID_ANY)
92                 fprintf(fp, "*:");
93         else
94                 fprintf(fp, "%u:", cid);
95
96         if (port == VMADDR_PORT_ANY)
97                 fprintf(fp, "*");
98         else
99                 fprintf(fp, "%u", port);
100 }
101
102 static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
103 {
104         print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
105         fprintf(fp, " ");
106         print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
107         fprintf(fp, " %s %s %s %u\n",
108                 sock_type_str(st->msg.vdiag_type),
109                 sock_state_str(st->msg.vdiag_state),
110                 sock_shutdown_str(st->msg.vdiag_shutdown),
111                 st->msg.vdiag_ino);
112 }
113
114 static void print_vsock_stats(FILE *fp, struct list_head *head)
115 {
116         struct vsock_stat *st;
117
118         list_for_each_entry(st, head, list)
119                 print_vsock_stat(fp, st);
120 }
121
122 static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
123 {
124         struct vsock_stat *st;
125         struct stat stat;
126
127         if (fstat(fd, &stat) < 0) {
128                 perror("fstat");
129                 exit(EXIT_FAILURE);
130         }
131
132         list_for_each_entry(st, head, list)
133                 if (st->msg.vdiag_ino == stat.st_ino)
134                         return st;
135
136         fprintf(stderr, "cannot find fd %d\n", fd);
137         exit(EXIT_FAILURE);
138 }
139
140 static void check_no_sockets(struct list_head *head)
141 {
142         if (!list_empty(head)) {
143                 fprintf(stderr, "expected no sockets\n");
144                 print_vsock_stats(stderr, head);
145                 exit(1);
146         }
147 }
148
149 static void check_num_sockets(struct list_head *head, int expected)
150 {
151         struct list_head *node;
152         int n = 0;
153
154         list_for_each(node, head)
155                 n++;
156
157         if (n != expected) {
158                 fprintf(stderr, "expected %d sockets, found %d\n",
159                         expected, n);
160                 print_vsock_stats(stderr, head);
161                 exit(EXIT_FAILURE);
162         }
163 }
164
165 static void check_socket_state(struct vsock_stat *st, __u8 state)
166 {
167         if (st->msg.vdiag_state != state) {
168                 fprintf(stderr, "expected socket state %#x, got %#x\n",
169                         state, st->msg.vdiag_state);
170                 exit(EXIT_FAILURE);
171         }
172 }
173
174 static void send_req(int fd)
175 {
176         struct sockaddr_nl nladdr = {
177                 .nl_family = AF_NETLINK,
178         };
179         struct {
180                 struct nlmsghdr nlh;
181                 struct vsock_diag_req vreq;
182         } req = {
183                 .nlh = {
184                         .nlmsg_len = sizeof(req),
185                         .nlmsg_type = SOCK_DIAG_BY_FAMILY,
186                         .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
187                 },
188                 .vreq = {
189                         .sdiag_family = AF_VSOCK,
190                         .vdiag_states = ~(__u32)0,
191                 },
192         };
193         struct iovec iov = {
194                 .iov_base = &req,
195                 .iov_len = sizeof(req),
196         };
197         struct msghdr msg = {
198                 .msg_name = &nladdr,
199                 .msg_namelen = sizeof(nladdr),
200                 .msg_iov = &iov,
201                 .msg_iovlen = 1,
202         };
203
204         for (;;) {
205                 if (sendmsg(fd, &msg, 0) < 0) {
206                         if (errno == EINTR)
207                                 continue;
208
209                         perror("sendmsg");
210                         exit(EXIT_FAILURE);
211                 }
212
213                 return;
214         }
215 }
216
217 static ssize_t recv_resp(int fd, void *buf, size_t len)
218 {
219         struct sockaddr_nl nladdr = {
220                 .nl_family = AF_NETLINK,
221         };
222         struct iovec iov = {
223                 .iov_base = buf,
224                 .iov_len = len,
225         };
226         struct msghdr msg = {
227                 .msg_name = &nladdr,
228                 .msg_namelen = sizeof(nladdr),
229                 .msg_iov = &iov,
230                 .msg_iovlen = 1,
231         };
232         ssize_t ret;
233
234         do {
235                 ret = recvmsg(fd, &msg, 0);
236         } while (ret < 0 && errno == EINTR);
237
238         if (ret < 0) {
239                 perror("recvmsg");
240                 exit(EXIT_FAILURE);
241         }
242
243         return ret;
244 }
245
246 static void add_vsock_stat(struct list_head *sockets,
247                            const struct vsock_diag_msg *resp)
248 {
249         struct vsock_stat *st;
250
251         st = malloc(sizeof(*st));
252         if (!st) {
253                 perror("malloc");
254                 exit(EXIT_FAILURE);
255         }
256
257         st->msg = *resp;
258         list_add_tail(&st->list, sockets);
259 }
260
261 /*
262  * Read vsock stats into a list.
263  */
264 static void read_vsock_stat(struct list_head *sockets)
265 {
266         long buf[8192 / sizeof(long)];
267         int fd;
268
269         fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
270         if (fd < 0) {
271                 perror("socket");
272                 exit(EXIT_FAILURE);
273         }
274
275         send_req(fd);
276
277         for (;;) {
278                 const struct nlmsghdr *h;
279                 ssize_t ret;
280
281                 ret = recv_resp(fd, buf, sizeof(buf));
282                 if (ret == 0)
283                         goto done;
284                 if (ret < sizeof(*h)) {
285                         fprintf(stderr, "short read of %zd bytes\n", ret);
286                         exit(EXIT_FAILURE);
287                 }
288
289                 h = (struct nlmsghdr *)buf;
290
291                 while (NLMSG_OK(h, ret)) {
292                         if (h->nlmsg_type == NLMSG_DONE)
293                                 goto done;
294
295                         if (h->nlmsg_type == NLMSG_ERROR) {
296                                 const struct nlmsgerr *err = NLMSG_DATA(h);
297
298                                 if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
299                                         fprintf(stderr, "NLMSG_ERROR\n");
300                                 else {
301                                         errno = -err->error;
302                                         perror("NLMSG_ERROR");
303                                 }
304
305                                 exit(EXIT_FAILURE);
306                         }
307
308                         if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
309                                 fprintf(stderr, "unexpected nlmsg_type %#x\n",
310                                         h->nlmsg_type);
311                                 exit(EXIT_FAILURE);
312                         }
313                         if (h->nlmsg_len <
314                             NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
315                                 fprintf(stderr, "short vsock_diag_msg\n");
316                                 exit(EXIT_FAILURE);
317                         }
318
319                         add_vsock_stat(sockets, NLMSG_DATA(h));
320
321                         h = NLMSG_NEXT(h, ret);
322                 }
323         }
324
325 done:
326         close(fd);
327 }
328
329 static void free_sock_stat(struct list_head *sockets)
330 {
331         struct vsock_stat *st;
332         struct vsock_stat *next;
333
334         list_for_each_entry_safe(st, next, sockets, list)
335                 free(st);
336 }
337
338 static void test_no_sockets(unsigned int peer_cid)
339 {
340         LIST_HEAD(sockets);
341
342         read_vsock_stat(&sockets);
343
344         check_no_sockets(&sockets);
345
346         free_sock_stat(&sockets);
347 }
348
349 static void test_listen_socket_server(unsigned int peer_cid)
350 {
351         union {
352                 struct sockaddr sa;
353                 struct sockaddr_vm svm;
354         } addr = {
355                 .svm = {
356                         .svm_family = AF_VSOCK,
357                         .svm_port = 1234,
358                         .svm_cid = VMADDR_CID_ANY,
359                 },
360         };
361         LIST_HEAD(sockets);
362         struct vsock_stat *st;
363         int fd;
364
365         fd = socket(AF_VSOCK, SOCK_STREAM, 0);
366
367         if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
368                 perror("bind");
369                 exit(EXIT_FAILURE);
370         }
371
372         if (listen(fd, 1) < 0) {
373                 perror("listen");
374                 exit(EXIT_FAILURE);
375         }
376
377         read_vsock_stat(&sockets);
378
379         check_num_sockets(&sockets, 1);
380         st = find_vsock_stat(&sockets, fd);
381         check_socket_state(st, TCP_LISTEN);
382
383         close(fd);
384         free_sock_stat(&sockets);
385 }
386
387 static void test_connect_client(unsigned int peer_cid)
388 {
389         union {
390                 struct sockaddr sa;
391                 struct sockaddr_vm svm;
392         } addr = {
393                 .svm = {
394                         .svm_family = AF_VSOCK,
395                         .svm_port = 1234,
396                         .svm_cid = peer_cid,
397                 },
398         };
399         int fd;
400         int ret;
401         LIST_HEAD(sockets);
402         struct vsock_stat *st;
403
404         control_expectln("LISTENING");
405
406         fd = socket(AF_VSOCK, SOCK_STREAM, 0);
407
408         timeout_begin(TIMEOUT);
409         do {
410                 ret = connect(fd, &addr.sa, sizeof(addr.svm));
411                 timeout_check("connect");
412         } while (ret < 0 && errno == EINTR);
413         timeout_end();
414
415         if (ret < 0) {
416                 perror("connect");
417                 exit(EXIT_FAILURE);
418         }
419
420         read_vsock_stat(&sockets);
421
422         check_num_sockets(&sockets, 1);
423         st = find_vsock_stat(&sockets, fd);
424         check_socket_state(st, TCP_ESTABLISHED);
425
426         control_expectln("DONE");
427         control_writeln("DONE");
428
429         close(fd);
430         free_sock_stat(&sockets);
431 }
432
433 static void test_connect_server(unsigned int peer_cid)
434 {
435         union {
436                 struct sockaddr sa;
437                 struct sockaddr_vm svm;
438         } addr = {
439                 .svm = {
440                         .svm_family = AF_VSOCK,
441                         .svm_port = 1234,
442                         .svm_cid = VMADDR_CID_ANY,
443                 },
444         };
445         union {
446                 struct sockaddr sa;
447                 struct sockaddr_vm svm;
448         } clientaddr;
449         socklen_t clientaddr_len = sizeof(clientaddr.svm);
450         LIST_HEAD(sockets);
451         struct vsock_stat *st;
452         int fd;
453         int client_fd;
454
455         fd = socket(AF_VSOCK, SOCK_STREAM, 0);
456
457         if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
458                 perror("bind");
459                 exit(EXIT_FAILURE);
460         }
461
462         if (listen(fd, 1) < 0) {
463                 perror("listen");
464                 exit(EXIT_FAILURE);
465         }
466
467         control_writeln("LISTENING");
468
469         timeout_begin(TIMEOUT);
470         do {
471                 client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
472                 timeout_check("accept");
473         } while (client_fd < 0 && errno == EINTR);
474         timeout_end();
475
476         if (client_fd < 0) {
477                 perror("accept");
478                 exit(EXIT_FAILURE);
479         }
480         if (clientaddr.sa.sa_family != AF_VSOCK) {
481                 fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
482                         clientaddr.sa.sa_family);
483                 exit(EXIT_FAILURE);
484         }
485         if (clientaddr.svm.svm_cid != peer_cid) {
486                 fprintf(stderr, "expected peer CID %u from accept(2), got %u\n",
487                         peer_cid, clientaddr.svm.svm_cid);
488                 exit(EXIT_FAILURE);
489         }
490
491         read_vsock_stat(&sockets);
492
493         check_num_sockets(&sockets, 2);
494         find_vsock_stat(&sockets, fd);
495         st = find_vsock_stat(&sockets, client_fd);
496         check_socket_state(st, TCP_ESTABLISHED);
497
498         control_writeln("DONE");
499         control_expectln("DONE");
500
501         close(client_fd);
502         close(fd);
503         free_sock_stat(&sockets);
504 }
505
506 static struct {
507         const char *name;
508         void (*run_client)(unsigned int peer_cid);
509         void (*run_server)(unsigned int peer_cid);
510 } test_cases[] = {
511         {
512                 .name = "No sockets",
513                 .run_server = test_no_sockets,
514         },
515         {
516                 .name = "Listen socket",
517                 .run_server = test_listen_socket_server,
518         },
519         {
520                 .name = "Connect",
521                 .run_client = test_connect_client,
522                 .run_server = test_connect_server,
523         },
524         {},
525 };
526
527 static void init_signals(void)
528 {
529         struct sigaction act = {
530                 .sa_handler = sigalrm,
531         };
532
533         sigaction(SIGALRM, &act, NULL);
534         signal(SIGPIPE, SIG_IGN);
535 }
536
537 static unsigned int parse_cid(const char *str)
538 {
539         char *endptr = NULL;
540         unsigned long int n;
541
542         errno = 0;
543         n = strtoul(str, &endptr, 10);
544         if (errno || *endptr != '\0') {
545                 fprintf(stderr, "malformed CID \"%s\"\n", str);
546                 exit(EXIT_FAILURE);
547         }
548         return n;
549 }
550
551 static const char optstring[] = "";
552 static const struct option longopts[] = {
553         {
554                 .name = "control-host",
555                 .has_arg = required_argument,
556                 .val = 'H',
557         },
558         {
559                 .name = "control-port",
560                 .has_arg = required_argument,
561                 .val = 'P',
562         },
563         {
564                 .name = "mode",
565                 .has_arg = required_argument,
566                 .val = 'm',
567         },
568         {
569                 .name = "peer-cid",
570                 .has_arg = required_argument,
571                 .val = 'p',
572         },
573         {
574                 .name = "help",
575                 .has_arg = no_argument,
576                 .val = '?',
577         },
578         {},
579 };
580
581 static void usage(void)
582 {
583         fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid>\n"
584                 "\n"
585                 "  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
586                 "  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
587                 "\n"
588                 "Run vsock_diag.ko tests.  Must be launched in both\n"
589                 "guest and host.  One side must use --mode=client and\n"
590                 "the other side must use --mode=server.\n"
591                 "\n"
592                 "A TCP control socket connection is used to coordinate tests\n"
593                 "between the client and the server.  The server requires a\n"
594                 "listen address and the client requires an address to\n"
595                 "connect to.\n"
596                 "\n"
597                 "The CID of the other side must be given with --peer-cid=<cid>.\n");
598         exit(EXIT_FAILURE);
599 }
600
601 int main(int argc, char **argv)
602 {
603         const char *control_host = NULL;
604         const char *control_port = NULL;
605         int mode = TEST_MODE_UNSET;
606         unsigned int peer_cid = VMADDR_CID_ANY;
607         int i;
608
609         init_signals();
610
611         for (;;) {
612                 int opt = getopt_long(argc, argv, optstring, longopts, NULL);
613
614                 if (opt == -1)
615                         break;
616
617                 switch (opt) {
618                 case 'H':
619                         control_host = optarg;
620                         break;
621                 case 'm':
622                         if (strcmp(optarg, "client") == 0)
623                                 mode = TEST_MODE_CLIENT;
624                         else if (strcmp(optarg, "server") == 0)
625                                 mode = TEST_MODE_SERVER;
626                         else {
627                                 fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
628                                 return EXIT_FAILURE;
629                         }
630                         break;
631                 case 'p':
632                         peer_cid = parse_cid(optarg);
633                         break;
634                 case 'P':
635                         control_port = optarg;
636                         break;
637                 case '?':
638                 default:
639                         usage();
640                 }
641         }
642
643         if (!control_port)
644                 usage();
645         if (mode == TEST_MODE_UNSET)
646                 usage();
647         if (peer_cid == VMADDR_CID_ANY)
648                 usage();
649
650         if (!control_host) {
651                 if (mode != TEST_MODE_SERVER)
652                         usage();
653                 control_host = "0.0.0.0";
654         }
655
656         control_init(control_host, control_port, mode == TEST_MODE_SERVER);
657
658         for (i = 0; test_cases[i].name; i++) {
659                 void (*run)(unsigned int peer_cid);
660
661                 printf("%s...", test_cases[i].name);
662                 fflush(stdout);
663
664                 if (mode == TEST_MODE_CLIENT)
665                         run = test_cases[i].run_client;
666                 else
667                         run = test_cases[i].run_server;
668
669                 if (run)
670                         run(peer_cid);
671
672                 printf("ok\n");
673         }
674
675         control_cleanup();
676         return EXIT_SUCCESS;
677 }