Merge tag 'audit-pr-20200226' of git://git.kernel.org/pub/scm/linux/kernel/git/pcmoor...
[linux-2.6-microblaze.git] / tools / testing / vsock / vsock_diag_test.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * vsock_diag_test - vsock_diag.ko test suite
4  *
5  * Copyright (C) 2017 Red Hat, Inc.
6  *
7  * Author: Stefan Hajnoczi <stefanha@redhat.com>
8  */
9
10 #include <getopt.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <string.h>
14 #include <errno.h>
15 #include <unistd.h>
16 #include <sys/stat.h>
17 #include <sys/types.h>
18 #include <linux/list.h>
19 #include <linux/net.h>
20 #include <linux/netlink.h>
21 #include <linux/sock_diag.h>
22 #include <linux/vm_sockets_diag.h>
23 #include <netinet/tcp.h>
24
25 #include "timeout.h"
26 #include "control.h"
27 #include "util.h"
28
29 /* Per-socket status */
30 struct vsock_stat {
31         struct list_head list;
32         struct vsock_diag_msg msg;
33 };
34
35 static const char *sock_type_str(int type)
36 {
37         switch (type) {
38         case SOCK_DGRAM:
39                 return "DGRAM";
40         case SOCK_STREAM:
41                 return "STREAM";
42         default:
43                 return "INVALID TYPE";
44         }
45 }
46
47 static const char *sock_state_str(int state)
48 {
49         switch (state) {
50         case TCP_CLOSE:
51                 return "UNCONNECTED";
52         case TCP_SYN_SENT:
53                 return "CONNECTING";
54         case TCP_ESTABLISHED:
55                 return "CONNECTED";
56         case TCP_CLOSING:
57                 return "DISCONNECTING";
58         case TCP_LISTEN:
59                 return "LISTEN";
60         default:
61                 return "INVALID STATE";
62         }
63 }
64
65 static const char *sock_shutdown_str(int shutdown)
66 {
67         switch (shutdown) {
68         case 1:
69                 return "RCV_SHUTDOWN";
70         case 2:
71                 return "SEND_SHUTDOWN";
72         case 3:
73                 return "RCV_SHUTDOWN | SEND_SHUTDOWN";
74         default:
75                 return "0";
76         }
77 }
78
79 static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
80 {
81         if (cid == VMADDR_CID_ANY)
82                 fprintf(fp, "*:");
83         else
84                 fprintf(fp, "%u:", cid);
85
86         if (port == VMADDR_PORT_ANY)
87                 fprintf(fp, "*");
88         else
89                 fprintf(fp, "%u", port);
90 }
91
92 static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
93 {
94         print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
95         fprintf(fp, " ");
96         print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
97         fprintf(fp, " %s %s %s %u\n",
98                 sock_type_str(st->msg.vdiag_type),
99                 sock_state_str(st->msg.vdiag_state),
100                 sock_shutdown_str(st->msg.vdiag_shutdown),
101                 st->msg.vdiag_ino);
102 }
103
104 static void print_vsock_stats(FILE *fp, struct list_head *head)
105 {
106         struct vsock_stat *st;
107
108         list_for_each_entry(st, head, list)
109                 print_vsock_stat(fp, st);
110 }
111
112 static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
113 {
114         struct vsock_stat *st;
115         struct stat stat;
116
117         if (fstat(fd, &stat) < 0) {
118                 perror("fstat");
119                 exit(EXIT_FAILURE);
120         }
121
122         list_for_each_entry(st, head, list)
123                 if (st->msg.vdiag_ino == stat.st_ino)
124                         return st;
125
126         fprintf(stderr, "cannot find fd %d\n", fd);
127         exit(EXIT_FAILURE);
128 }
129
130 static void check_no_sockets(struct list_head *head)
131 {
132         if (!list_empty(head)) {
133                 fprintf(stderr, "expected no sockets\n");
134                 print_vsock_stats(stderr, head);
135                 exit(1);
136         }
137 }
138
139 static void check_num_sockets(struct list_head *head, int expected)
140 {
141         struct list_head *node;
142         int n = 0;
143
144         list_for_each(node, head)
145                 n++;
146
147         if (n != expected) {
148                 fprintf(stderr, "expected %d sockets, found %d\n",
149                         expected, n);
150                 print_vsock_stats(stderr, head);
151                 exit(EXIT_FAILURE);
152         }
153 }
154
155 static void check_socket_state(struct vsock_stat *st, __u8 state)
156 {
157         if (st->msg.vdiag_state != state) {
158                 fprintf(stderr, "expected socket state %#x, got %#x\n",
159                         state, st->msg.vdiag_state);
160                 exit(EXIT_FAILURE);
161         }
162 }
163
164 static void send_req(int fd)
165 {
166         struct sockaddr_nl nladdr = {
167                 .nl_family = AF_NETLINK,
168         };
169         struct {
170                 struct nlmsghdr nlh;
171                 struct vsock_diag_req vreq;
172         } req = {
173                 .nlh = {
174                         .nlmsg_len = sizeof(req),
175                         .nlmsg_type = SOCK_DIAG_BY_FAMILY,
176                         .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
177                 },
178                 .vreq = {
179                         .sdiag_family = AF_VSOCK,
180                         .vdiag_states = ~(__u32)0,
181                 },
182         };
183         struct iovec iov = {
184                 .iov_base = &req,
185                 .iov_len = sizeof(req),
186         };
187         struct msghdr msg = {
188                 .msg_name = &nladdr,
189                 .msg_namelen = sizeof(nladdr),
190                 .msg_iov = &iov,
191                 .msg_iovlen = 1,
192         };
193
194         for (;;) {
195                 if (sendmsg(fd, &msg, 0) < 0) {
196                         if (errno == EINTR)
197                                 continue;
198
199                         perror("sendmsg");
200                         exit(EXIT_FAILURE);
201                 }
202
203                 return;
204         }
205 }
206
207 static ssize_t recv_resp(int fd, void *buf, size_t len)
208 {
209         struct sockaddr_nl nladdr = {
210                 .nl_family = AF_NETLINK,
211         };
212         struct iovec iov = {
213                 .iov_base = buf,
214                 .iov_len = len,
215         };
216         struct msghdr msg = {
217                 .msg_name = &nladdr,
218                 .msg_namelen = sizeof(nladdr),
219                 .msg_iov = &iov,
220                 .msg_iovlen = 1,
221         };
222         ssize_t ret;
223
224         do {
225                 ret = recvmsg(fd, &msg, 0);
226         } while (ret < 0 && errno == EINTR);
227
228         if (ret < 0) {
229                 perror("recvmsg");
230                 exit(EXIT_FAILURE);
231         }
232
233         return ret;
234 }
235
236 static void add_vsock_stat(struct list_head *sockets,
237                            const struct vsock_diag_msg *resp)
238 {
239         struct vsock_stat *st;
240
241         st = malloc(sizeof(*st));
242         if (!st) {
243                 perror("malloc");
244                 exit(EXIT_FAILURE);
245         }
246
247         st->msg = *resp;
248         list_add_tail(&st->list, sockets);
249 }
250
251 /*
252  * Read vsock stats into a list.
253  */
254 static void read_vsock_stat(struct list_head *sockets)
255 {
256         long buf[8192 / sizeof(long)];
257         int fd;
258
259         fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
260         if (fd < 0) {
261                 perror("socket");
262                 exit(EXIT_FAILURE);
263         }
264
265         send_req(fd);
266
267         for (;;) {
268                 const struct nlmsghdr *h;
269                 ssize_t ret;
270
271                 ret = recv_resp(fd, buf, sizeof(buf));
272                 if (ret == 0)
273                         goto done;
274                 if (ret < sizeof(*h)) {
275                         fprintf(stderr, "short read of %zd bytes\n", ret);
276                         exit(EXIT_FAILURE);
277                 }
278
279                 h = (struct nlmsghdr *)buf;
280
281                 while (NLMSG_OK(h, ret)) {
282                         if (h->nlmsg_type == NLMSG_DONE)
283                                 goto done;
284
285                         if (h->nlmsg_type == NLMSG_ERROR) {
286                                 const struct nlmsgerr *err = NLMSG_DATA(h);
287
288                                 if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
289                                         fprintf(stderr, "NLMSG_ERROR\n");
290                                 else {
291                                         errno = -err->error;
292                                         perror("NLMSG_ERROR");
293                                 }
294
295                                 exit(EXIT_FAILURE);
296                         }
297
298                         if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
299                                 fprintf(stderr, "unexpected nlmsg_type %#x\n",
300                                         h->nlmsg_type);
301                                 exit(EXIT_FAILURE);
302                         }
303                         if (h->nlmsg_len <
304                             NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
305                                 fprintf(stderr, "short vsock_diag_msg\n");
306                                 exit(EXIT_FAILURE);
307                         }
308
309                         add_vsock_stat(sockets, NLMSG_DATA(h));
310
311                         h = NLMSG_NEXT(h, ret);
312                 }
313         }
314
315 done:
316         close(fd);
317 }
318
319 static void free_sock_stat(struct list_head *sockets)
320 {
321         struct vsock_stat *st;
322         struct vsock_stat *next;
323
324         list_for_each_entry_safe(st, next, sockets, list)
325                 free(st);
326 }
327
328 static void test_no_sockets(const struct test_opts *opts)
329 {
330         LIST_HEAD(sockets);
331
332         read_vsock_stat(&sockets);
333
334         check_no_sockets(&sockets);
335
336         free_sock_stat(&sockets);
337 }
338
339 static void test_listen_socket_server(const struct test_opts *opts)
340 {
341         union {
342                 struct sockaddr sa;
343                 struct sockaddr_vm svm;
344         } addr = {
345                 .svm = {
346                         .svm_family = AF_VSOCK,
347                         .svm_port = 1234,
348                         .svm_cid = VMADDR_CID_ANY,
349                 },
350         };
351         LIST_HEAD(sockets);
352         struct vsock_stat *st;
353         int fd;
354
355         fd = socket(AF_VSOCK, SOCK_STREAM, 0);
356
357         if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
358                 perror("bind");
359                 exit(EXIT_FAILURE);
360         }
361
362         if (listen(fd, 1) < 0) {
363                 perror("listen");
364                 exit(EXIT_FAILURE);
365         }
366
367         read_vsock_stat(&sockets);
368
369         check_num_sockets(&sockets, 1);
370         st = find_vsock_stat(&sockets, fd);
371         check_socket_state(st, TCP_LISTEN);
372
373         close(fd);
374         free_sock_stat(&sockets);
375 }
376
377 static void test_connect_client(const struct test_opts *opts)
378 {
379         int fd;
380         LIST_HEAD(sockets);
381         struct vsock_stat *st;
382
383         fd = vsock_stream_connect(opts->peer_cid, 1234);
384         if (fd < 0) {
385                 perror("connect");
386                 exit(EXIT_FAILURE);
387         }
388
389         read_vsock_stat(&sockets);
390
391         check_num_sockets(&sockets, 1);
392         st = find_vsock_stat(&sockets, fd);
393         check_socket_state(st, TCP_ESTABLISHED);
394
395         control_expectln("DONE");
396         control_writeln("DONE");
397
398         close(fd);
399         free_sock_stat(&sockets);
400 }
401
402 static void test_connect_server(const struct test_opts *opts)
403 {
404         struct vsock_stat *st;
405         LIST_HEAD(sockets);
406         int client_fd;
407
408         client_fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
409         if (client_fd < 0) {
410                 perror("accept");
411                 exit(EXIT_FAILURE);
412         }
413
414         read_vsock_stat(&sockets);
415
416         check_num_sockets(&sockets, 1);
417         st = find_vsock_stat(&sockets, client_fd);
418         check_socket_state(st, TCP_ESTABLISHED);
419
420         control_writeln("DONE");
421         control_expectln("DONE");
422
423         close(client_fd);
424         free_sock_stat(&sockets);
425 }
426
427 static struct test_case test_cases[] = {
428         {
429                 .name = "No sockets",
430                 .run_server = test_no_sockets,
431         },
432         {
433                 .name = "Listen socket",
434                 .run_server = test_listen_socket_server,
435         },
436         {
437                 .name = "Connect",
438                 .run_client = test_connect_client,
439                 .run_server = test_connect_server,
440         },
441         {},
442 };
443
444 static const char optstring[] = "";
445 static const struct option longopts[] = {
446         {
447                 .name = "control-host",
448                 .has_arg = required_argument,
449                 .val = 'H',
450         },
451         {
452                 .name = "control-port",
453                 .has_arg = required_argument,
454                 .val = 'P',
455         },
456         {
457                 .name = "mode",
458                 .has_arg = required_argument,
459                 .val = 'm',
460         },
461         {
462                 .name = "peer-cid",
463                 .has_arg = required_argument,
464                 .val = 'p',
465         },
466         {
467                 .name = "list",
468                 .has_arg = no_argument,
469                 .val = 'l',
470         },
471         {
472                 .name = "skip",
473                 .has_arg = required_argument,
474                 .val = 's',
475         },
476         {
477                 .name = "help",
478                 .has_arg = no_argument,
479                 .val = '?',
480         },
481         {},
482 };
483
484 static void usage(void)
485 {
486         fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--list] [--skip=<test_id>]\n"
487                 "\n"
488                 "  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
489                 "  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
490                 "\n"
491                 "Run vsock_diag.ko tests.  Must be launched in both\n"
492                 "guest and host.  One side must use --mode=client and\n"
493                 "the other side must use --mode=server.\n"
494                 "\n"
495                 "A TCP control socket connection is used to coordinate tests\n"
496                 "between the client and the server.  The server requires a\n"
497                 "listen address and the client requires an address to\n"
498                 "connect to.\n"
499                 "\n"
500                 "The CID of the other side must be given with --peer-cid=<cid>.\n"
501                 "\n"
502                 "Options:\n"
503                 "  --help                 This help message\n"
504                 "  --control-host <host>  Server IP address to connect to\n"
505                 "  --control-port <port>  Server port to listen on/connect to\n"
506                 "  --mode client|server   Server or client mode\n"
507                 "  --peer-cid <cid>       CID of the other side\n"
508                 "  --list                 List of tests that will be executed\n"
509                 "  --skip <test_id>       Test ID to skip;\n"
510                 "                         use multiple --skip options to skip more tests\n"
511                 );
512         exit(EXIT_FAILURE);
513 }
514
515 int main(int argc, char **argv)
516 {
517         const char *control_host = NULL;
518         const char *control_port = NULL;
519         struct test_opts opts = {
520                 .mode = TEST_MODE_UNSET,
521                 .peer_cid = VMADDR_CID_ANY,
522         };
523
524         init_signals();
525
526         for (;;) {
527                 int opt = getopt_long(argc, argv, optstring, longopts, NULL);
528
529                 if (opt == -1)
530                         break;
531
532                 switch (opt) {
533                 case 'H':
534                         control_host = optarg;
535                         break;
536                 case 'm':
537                         if (strcmp(optarg, "client") == 0)
538                                 opts.mode = TEST_MODE_CLIENT;
539                         else if (strcmp(optarg, "server") == 0)
540                                 opts.mode = TEST_MODE_SERVER;
541                         else {
542                                 fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
543                                 return EXIT_FAILURE;
544                         }
545                         break;
546                 case 'p':
547                         opts.peer_cid = parse_cid(optarg);
548                         break;
549                 case 'P':
550                         control_port = optarg;
551                         break;
552                 case 'l':
553                         list_tests(test_cases);
554                         break;
555                 case 's':
556                         skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
557                                   optarg);
558                         break;
559                 case '?':
560                 default:
561                         usage();
562                 }
563         }
564
565         if (!control_port)
566                 usage();
567         if (opts.mode == TEST_MODE_UNSET)
568                 usage();
569         if (opts.peer_cid == VMADDR_CID_ANY)
570                 usage();
571
572         if (!control_host) {
573                 if (opts.mode != TEST_MODE_SERVER)
574                         usage();
575                 control_host = "0.0.0.0";
576         }
577
578         control_init(control_host, control_port,
579                      opts.mode == TEST_MODE_SERVER);
580
581         run_tests(test_cases, &opts);
582
583         control_cleanup();
584         return EXIT_SUCCESS;
585 }