Merge tag 'drm-msm-fixes-2020-06-25' of https://gitlab.freedesktop.org/drm/msm into...
[linux-2.6-microblaze.git] / tools / testing / selftests / bpf / test_sockmap.c
1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2017-2018 Covalent IO, Inc. http://covalent.io
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <sys/socket.h>
6 #include <sys/ioctl.h>
7 #include <sys/select.h>
8 #include <netinet/in.h>
9 #include <arpa/inet.h>
10 #include <unistd.h>
11 #include <string.h>
12 #include <errno.h>
13 #include <stdbool.h>
14 #include <signal.h>
15 #include <fcntl.h>
16 #include <sys/wait.h>
17 #include <time.h>
18 #include <sched.h>
19
20 #include <sys/time.h>
21 #include <sys/resource.h>
22 #include <sys/types.h>
23 #include <sys/sendfile.h>
24
25 #include <linux/netlink.h>
26 #include <linux/socket.h>
27 #include <linux/sock_diag.h>
28 #include <linux/bpf.h>
29 #include <linux/if_link.h>
30 #include <linux/tls.h>
31 #include <assert.h>
32 #include <libgen.h>
33
34 #include <getopt.h>
35
36 #include <bpf/bpf.h>
37 #include <bpf/libbpf.h>
38
39 #include "bpf_util.h"
40 #include "bpf_rlimit.h"
41 #include "cgroup_helpers.h"
42
43 int running;
44 static void running_handler(int a);
45
46 #ifndef TCP_ULP
47 # define TCP_ULP 31
48 #endif
49 #ifndef SOL_TLS
50 # define SOL_TLS 282
51 #endif
52
53 /* randomly selected ports for testing on lo */
54 #define S1_PORT 10000
55 #define S2_PORT 10001
56
57 #define BPF_SOCKMAP_FILENAME  "test_sockmap_kern.o"
58 #define BPF_SOCKHASH_FILENAME "test_sockhash_kern.o"
59 #define CG_PATH "/sockmap"
60
61 /* global sockets */
62 int s1, s2, c1, c2, p1, p2;
63 int test_cnt;
64 int passed;
65 int failed;
66 int map_fd[9];
67 struct bpf_map *maps[9];
68 int prog_fd[11];
69
70 int txmsg_pass;
71 int txmsg_redir;
72 int txmsg_drop;
73 int txmsg_apply;
74 int txmsg_cork;
75 int txmsg_start;
76 int txmsg_end;
77 int txmsg_start_push;
78 int txmsg_end_push;
79 int txmsg_start_pop;
80 int txmsg_pop;
81 int txmsg_ingress;
82 int txmsg_redir_skb;
83 int txmsg_ktls_skb;
84 int txmsg_ktls_skb_drop;
85 int txmsg_ktls_skb_redir;
86 int ktls;
87 int peek_flag;
88
89 static const struct option long_options[] = {
90         {"help",        no_argument,            NULL, 'h' },
91         {"cgroup",      required_argument,      NULL, 'c' },
92         {"rate",        required_argument,      NULL, 'r' },
93         {"verbose",     optional_argument,      NULL, 'v' },
94         {"iov_count",   required_argument,      NULL, 'i' },
95         {"length",      required_argument,      NULL, 'l' },
96         {"test",        required_argument,      NULL, 't' },
97         {"data_test",   no_argument,            NULL, 'd' },
98         {"txmsg",               no_argument,    &txmsg_pass,  1  },
99         {"txmsg_redir",         no_argument,    &txmsg_redir, 1  },
100         {"txmsg_drop",          no_argument,    &txmsg_drop, 1 },
101         {"txmsg_apply", required_argument,      NULL, 'a'},
102         {"txmsg_cork",  required_argument,      NULL, 'k'},
103         {"txmsg_start", required_argument,      NULL, 's'},
104         {"txmsg_end",   required_argument,      NULL, 'e'},
105         {"txmsg_start_push", required_argument, NULL, 'p'},
106         {"txmsg_end_push",   required_argument, NULL, 'q'},
107         {"txmsg_start_pop",  required_argument, NULL, 'w'},
108         {"txmsg_pop",        required_argument, NULL, 'x'},
109         {"txmsg_ingress", no_argument,          &txmsg_ingress, 1 },
110         {"txmsg_redir_skb", no_argument,        &txmsg_redir_skb, 1 },
111         {"ktls", no_argument,                   &ktls, 1 },
112         {"peek", no_argument,                   &peek_flag, 1 },
113         {"whitelist", required_argument,        NULL, 'n' },
114         {"blacklist", required_argument,        NULL, 'b' },
115         {0, 0, NULL, 0 }
116 };
117
118 struct test_env {
119         const char *type;
120         const char *subtest;
121         const char *prepend;
122
123         int test_num;
124         int subtest_num;
125
126         int succ_cnt;
127         int fail_cnt;
128         int fail_last;
129 };
130
131 struct test_env env;
132
133 struct sockmap_options {
134         int verbose;
135         bool base;
136         bool sendpage;
137         bool data_test;
138         bool drop_expected;
139         int iov_count;
140         int iov_length;
141         int rate;
142         char *map;
143         char *whitelist;
144         char *blacklist;
145         char *prepend;
146 };
147
148 struct _test {
149         char *title;
150         void (*tester)(int cg_fd, struct sockmap_options *opt);
151 };
152
153 static void test_start(void)
154 {
155         env.subtest_num++;
156 }
157
158 static void test_fail(void)
159 {
160         env.fail_cnt++;
161 }
162
163 static void test_pass(void)
164 {
165         env.succ_cnt++;
166 }
167
168 static void test_reset(void)
169 {
170         txmsg_start = txmsg_end = 0;
171         txmsg_start_pop = txmsg_pop = 0;
172         txmsg_start_push = txmsg_end_push = 0;
173         txmsg_pass = txmsg_drop = txmsg_redir = 0;
174         txmsg_apply = txmsg_cork = 0;
175         txmsg_ingress = txmsg_redir_skb = 0;
176         txmsg_ktls_skb = txmsg_ktls_skb_drop = txmsg_ktls_skb_redir = 0;
177 }
178
179 static int test_start_subtest(const struct _test *t, struct sockmap_options *o)
180 {
181         env.type = o->map;
182         env.subtest = t->title;
183         env.prepend = o->prepend;
184         env.test_num++;
185         env.subtest_num = 0;
186         env.fail_last = env.fail_cnt;
187         test_reset();
188         return 0;
189 }
190
191 static void test_end_subtest(void)
192 {
193         int error = env.fail_cnt - env.fail_last;
194         int type = strcmp(env.type, BPF_SOCKMAP_FILENAME);
195
196         if (!error)
197                 test_pass();
198
199         fprintf(stdout, "#%2d/%2d %8s:%s:%s:%s\n",
200                 env.test_num, env.subtest_num,
201                 !type ? "sockmap" : "sockhash",
202                 env.prepend ? : "",
203                 env.subtest, error ? "FAIL" : "OK");
204 }
205
206 static void test_print_results(void)
207 {
208         fprintf(stdout, "Pass: %d Fail: %d\n",
209                 env.succ_cnt, env.fail_cnt);
210 }
211
212 static void usage(char *argv[])
213 {
214         int i;
215
216         printf(" Usage: %s --cgroup <cgroup_path>\n", argv[0]);
217         printf(" options:\n");
218         for (i = 0; long_options[i].name != 0; i++) {
219                 printf(" --%-12s", long_options[i].name);
220                 if (long_options[i].flag != NULL)
221                         printf(" flag (internal value:%d)\n",
222                                 *long_options[i].flag);
223                 else
224                         printf(" -%c\n", long_options[i].val);
225         }
226         printf("\n");
227 }
228
229 char *sock_to_string(int s)
230 {
231         if (s == c1)
232                 return "client1";
233         else if (s == c2)
234                 return "client2";
235         else if (s == s1)
236                 return "server1";
237         else if (s == s2)
238                 return "server2";
239         else if (s == p1)
240                 return "peer1";
241         else if (s == p2)
242                 return "peer2";
243         else
244                 return "unknown";
245 }
246
247 static int sockmap_init_ktls(int verbose, int s)
248 {
249         struct tls12_crypto_info_aes_gcm_128 tls_tx = {
250                 .info = {
251                         .version     = TLS_1_2_VERSION,
252                         .cipher_type = TLS_CIPHER_AES_GCM_128,
253                 },
254         };
255         struct tls12_crypto_info_aes_gcm_128 tls_rx = {
256                 .info = {
257                         .version     = TLS_1_2_VERSION,
258                         .cipher_type = TLS_CIPHER_AES_GCM_128,
259                 },
260         };
261         int so_buf = 6553500;
262         int err;
263
264         err = setsockopt(s, 6, TCP_ULP, "tls", sizeof("tls"));
265         if (err) {
266                 fprintf(stderr, "setsockopt: TCP_ULP(%s) failed with error %i\n", sock_to_string(s), err);
267                 return -EINVAL;
268         }
269         err = setsockopt(s, SOL_TLS, TLS_TX, (void *)&tls_tx, sizeof(tls_tx));
270         if (err) {
271                 fprintf(stderr, "setsockopt: TLS_TX(%s) failed with error %i\n", sock_to_string(s), err);
272                 return -EINVAL;
273         }
274         err = setsockopt(s, SOL_TLS, TLS_RX, (void *)&tls_rx, sizeof(tls_rx));
275         if (err) {
276                 fprintf(stderr, "setsockopt: TLS_RX(%s) failed with error %i\n", sock_to_string(s), err);
277                 return -EINVAL;
278         }
279         err = setsockopt(s, SOL_SOCKET, SO_SNDBUF, &so_buf, sizeof(so_buf));
280         if (err) {
281                 fprintf(stderr, "setsockopt: (%s) failed sndbuf with error %i\n", sock_to_string(s), err);
282                 return -EINVAL;
283         }
284         err = setsockopt(s, SOL_SOCKET, SO_RCVBUF, &so_buf, sizeof(so_buf));
285         if (err) {
286                 fprintf(stderr, "setsockopt: (%s) failed rcvbuf with error %i\n", sock_to_string(s), err);
287                 return -EINVAL;
288         }
289
290         if (verbose)
291                 fprintf(stdout, "socket(%s) kTLS enabled\n", sock_to_string(s));
292         return 0;
293 }
294 static int sockmap_init_sockets(int verbose)
295 {
296         int i, err, one = 1;
297         struct sockaddr_in addr;
298         int *fds[4] = {&s1, &s2, &c1, &c2};
299
300         s1 = s2 = p1 = p2 = c1 = c2 = 0;
301
302         /* Init sockets */
303         for (i = 0; i < 4; i++) {
304                 *fds[i] = socket(AF_INET, SOCK_STREAM, 0);
305                 if (*fds[i] < 0) {
306                         perror("socket s1 failed()");
307                         return errno;
308                 }
309         }
310
311         /* Allow reuse */
312         for (i = 0; i < 2; i++) {
313                 err = setsockopt(*fds[i], SOL_SOCKET, SO_REUSEADDR,
314                                  (char *)&one, sizeof(one));
315                 if (err) {
316                         perror("setsockopt failed()");
317                         return errno;
318                 }
319         }
320
321         /* Non-blocking sockets */
322         for (i = 0; i < 2; i++) {
323                 err = ioctl(*fds[i], FIONBIO, (char *)&one);
324                 if (err < 0) {
325                         perror("ioctl s1 failed()");
326                         return errno;
327                 }
328         }
329
330         /* Bind server sockets */
331         memset(&addr, 0, sizeof(struct sockaddr_in));
332         addr.sin_family = AF_INET;
333         addr.sin_addr.s_addr = inet_addr("127.0.0.1");
334
335         addr.sin_port = htons(S1_PORT);
336         err = bind(s1, (struct sockaddr *)&addr, sizeof(addr));
337         if (err < 0) {
338                 perror("bind s1 failed()");
339                 return errno;
340         }
341
342         addr.sin_port = htons(S2_PORT);
343         err = bind(s2, (struct sockaddr *)&addr, sizeof(addr));
344         if (err < 0) {
345                 perror("bind s2 failed()");
346                 return errno;
347         }
348
349         /* Listen server sockets */
350         addr.sin_port = htons(S1_PORT);
351         err = listen(s1, 32);
352         if (err < 0) {
353                 perror("listen s1 failed()");
354                 return errno;
355         }
356
357         addr.sin_port = htons(S2_PORT);
358         err = listen(s2, 32);
359         if (err < 0) {
360                 perror("listen s1 failed()");
361                 return errno;
362         }
363
364         /* Initiate Connect */
365         addr.sin_port = htons(S1_PORT);
366         err = connect(c1, (struct sockaddr *)&addr, sizeof(addr));
367         if (err < 0 && errno != EINPROGRESS) {
368                 perror("connect c1 failed()");
369                 return errno;
370         }
371
372         addr.sin_port = htons(S2_PORT);
373         err = connect(c2, (struct sockaddr *)&addr, sizeof(addr));
374         if (err < 0 && errno != EINPROGRESS) {
375                 perror("connect c2 failed()");
376                 return errno;
377         } else if (err < 0) {
378                 err = 0;
379         }
380
381         /* Accept Connecrtions */
382         p1 = accept(s1, NULL, NULL);
383         if (p1 < 0) {
384                 perror("accept s1 failed()");
385                 return errno;
386         }
387
388         p2 = accept(s2, NULL, NULL);
389         if (p2 < 0) {
390                 perror("accept s1 failed()");
391                 return errno;
392         }
393
394         if (verbose > 1) {
395                 printf("connected sockets: c1 <-> p1, c2 <-> p2\n");
396                 printf("cgroups binding: c1(%i) <-> s1(%i) - - - c2(%i) <-> s2(%i)\n",
397                         c1, s1, c2, s2);
398         }
399         return 0;
400 }
401
402 struct msg_stats {
403         size_t bytes_sent;
404         size_t bytes_recvd;
405         struct timespec start;
406         struct timespec end;
407 };
408
409 static int msg_loop_sendpage(int fd, int iov_length, int cnt,
410                              struct msg_stats *s,
411                              struct sockmap_options *opt)
412 {
413         bool drop = opt->drop_expected;
414         unsigned char k = 0;
415         FILE *file;
416         int i, fp;
417
418         file = tmpfile();
419         if (!file) {
420                 perror("create file for sendpage");
421                 return 1;
422         }
423         for (i = 0; i < iov_length * cnt; i++, k++)
424                 fwrite(&k, sizeof(char), 1, file);
425         fflush(file);
426         fseek(file, 0, SEEK_SET);
427
428         fp = fileno(file);
429
430         clock_gettime(CLOCK_MONOTONIC, &s->start);
431         for (i = 0; i < cnt; i++) {
432                 int sent;
433
434                 errno = 0;
435                 sent = sendfile(fd, fp, NULL, iov_length);
436
437                 if (!drop && sent < 0) {
438                         perror("sendpage loop error");
439                         fclose(file);
440                         return sent;
441                 } else if (drop && sent >= 0) {
442                         printf("sendpage loop error expected: %i errno %i\n",
443                                sent, errno);
444                         fclose(file);
445                         return -EIO;
446                 }
447
448                 if (sent > 0)
449                         s->bytes_sent += sent;
450         }
451         clock_gettime(CLOCK_MONOTONIC, &s->end);
452         fclose(file);
453         return 0;
454 }
455
456 static void msg_free_iov(struct msghdr *msg)
457 {
458         int i;
459
460         for (i = 0; i < msg->msg_iovlen; i++)
461                 free(msg->msg_iov[i].iov_base);
462         free(msg->msg_iov);
463         msg->msg_iov = NULL;
464         msg->msg_iovlen = 0;
465 }
466
467 static int msg_alloc_iov(struct msghdr *msg,
468                          int iov_count, int iov_length,
469                          bool data, bool xmit)
470 {
471         unsigned char k = 0;
472         struct iovec *iov;
473         int i;
474
475         iov = calloc(iov_count, sizeof(struct iovec));
476         if (!iov)
477                 return errno;
478
479         for (i = 0; i < iov_count; i++) {
480                 unsigned char *d = calloc(iov_length, sizeof(char));
481
482                 if (!d) {
483                         fprintf(stderr, "iov_count %i/%i OOM\n", i, iov_count);
484                         goto unwind_iov;
485                 }
486                 iov[i].iov_base = d;
487                 iov[i].iov_len = iov_length;
488
489                 if (data && xmit) {
490                         int j;
491
492                         for (j = 0; j < iov_length; j++)
493                                 d[j] = k++;
494                 }
495         }
496
497         msg->msg_iov = iov;
498         msg->msg_iovlen = iov_count;
499
500         return 0;
501 unwind_iov:
502         for (i--; i >= 0 ; i--)
503                 free(msg->msg_iov[i].iov_base);
504         return -ENOMEM;
505 }
506
507 static int msg_verify_data(struct msghdr *msg, int size, int chunk_sz)
508 {
509         int i, j = 0, bytes_cnt = 0;
510         unsigned char k = 0;
511
512         for (i = 0; i < msg->msg_iovlen; i++) {
513                 unsigned char *d = msg->msg_iov[i].iov_base;
514
515                 /* Special case test for skb ingress + ktls */
516                 if (i == 0 && txmsg_ktls_skb) {
517                         if (msg->msg_iov[i].iov_len < 4)
518                                 return -EIO;
519                         if (txmsg_ktls_skb_redir) {
520                                 if (memcmp(&d[13], "PASS", 4) != 0) {
521                                         fprintf(stderr,
522                                                 "detected redirect ktls_skb data error with skb ingress update @iov[%i]:%i \"%02x %02x %02x %02x\" != \"PASS\"\n", i, 0, d[13], d[14], d[15], d[16]);
523                                         return -EIO;
524                                 }
525                                 d[13] = 0;
526                                 d[14] = 1;
527                                 d[15] = 2;
528                                 d[16] = 3;
529                                 j = 13;
530                         } else if (txmsg_ktls_skb) {
531                                 if (memcmp(d, "PASS", 4) != 0) {
532                                         fprintf(stderr,
533                                                 "detected ktls_skb data error with skb ingress update @iov[%i]:%i \"%02x %02x %02x %02x\" != \"PASS\"\n", i, 0, d[0], d[1], d[2], d[3]);
534                                         return -EIO;
535                                 }
536                                 d[0] = 0;
537                                 d[1] = 1;
538                                 d[2] = 2;
539                                 d[3] = 3;
540                         }
541                 }
542
543                 for (; j < msg->msg_iov[i].iov_len && size; j++) {
544                         if (d[j] != k++) {
545                                 fprintf(stderr,
546                                         "detected data corruption @iov[%i]:%i %02x != %02x, %02x ?= %02x\n",
547                                         i, j, d[j], k - 1, d[j+1], k);
548                                 return -EIO;
549                         }
550                         bytes_cnt++;
551                         if (bytes_cnt == chunk_sz) {
552                                 k = 0;
553                                 bytes_cnt = 0;
554                         }
555                         size--;
556                 }
557         }
558         return 0;
559 }
560
561 static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
562                     struct msg_stats *s, bool tx,
563                     struct sockmap_options *opt)
564 {
565         struct msghdr msg = {0}, msg_peek = {0};
566         int err, i, flags = MSG_NOSIGNAL;
567         bool drop = opt->drop_expected;
568         bool data = opt->data_test;
569
570         err = msg_alloc_iov(&msg, iov_count, iov_length, data, tx);
571         if (err)
572                 goto out_errno;
573         if (peek_flag) {
574                 err = msg_alloc_iov(&msg_peek, iov_count, iov_length, data, tx);
575                 if (err)
576                         goto out_errno;
577         }
578
579         if (tx) {
580                 clock_gettime(CLOCK_MONOTONIC, &s->start);
581                 for (i = 0; i < cnt; i++) {
582                         int sent;
583
584                         errno = 0;
585                         sent = sendmsg(fd, &msg, flags);
586
587                         if (!drop && sent < 0) {
588                                 perror("sendmsg loop error");
589                                 goto out_errno;
590                         } else if (drop && sent >= 0) {
591                                 fprintf(stderr,
592                                         "sendmsg loop error expected: %i errno %i\n",
593                                         sent, errno);
594                                 errno = -EIO;
595                                 goto out_errno;
596                         }
597                         if (sent > 0)
598                                 s->bytes_sent += sent;
599                 }
600                 clock_gettime(CLOCK_MONOTONIC, &s->end);
601         } else {
602                 int slct, recvp = 0, recv, max_fd = fd;
603                 float total_bytes, txmsg_pop_total;
604                 int fd_flags = O_NONBLOCK;
605                 struct timeval timeout;
606                 fd_set w;
607
608                 fcntl(fd, fd_flags);
609                 /* Account for pop bytes noting each iteration of apply will
610                  * call msg_pop_data helper so we need to account for this
611                  * by calculating the number of apply iterations. Note user
612                  * of the tool can create cases where no data is sent by
613                  * manipulating pop/push/pull/etc. For example txmsg_apply 1
614                  * with txmsg_pop 1 will try to apply 1B at a time but each
615                  * iteration will then pop 1B so no data will ever be sent.
616                  * This is really only useful for testing edge cases in code
617                  * paths.
618                  */
619                 total_bytes = (float)iov_count * (float)iov_length * (float)cnt;
620                 if (txmsg_apply)
621                         txmsg_pop_total = txmsg_pop * (total_bytes / txmsg_apply);
622                 else
623                         txmsg_pop_total = txmsg_pop * cnt;
624                 total_bytes -= txmsg_pop_total;
625                 err = clock_gettime(CLOCK_MONOTONIC, &s->start);
626                 if (err < 0)
627                         perror("recv start time");
628                 while (s->bytes_recvd < total_bytes) {
629                         if (txmsg_cork) {
630                                 timeout.tv_sec = 0;
631                                 timeout.tv_usec = 300000;
632                         } else {
633                                 timeout.tv_sec = 3;
634                                 timeout.tv_usec = 0;
635                         }
636
637                         /* FD sets */
638                         FD_ZERO(&w);
639                         FD_SET(fd, &w);
640
641                         slct = select(max_fd + 1, &w, NULL, NULL, &timeout);
642                         if (slct == -1) {
643                                 perror("select()");
644                                 clock_gettime(CLOCK_MONOTONIC, &s->end);
645                                 goto out_errno;
646                         } else if (!slct) {
647                                 if (opt->verbose)
648                                         fprintf(stderr, "unexpected timeout: recved %zu/%f pop_total %f\n", s->bytes_recvd, total_bytes, txmsg_pop_total);
649                                 errno = -EIO;
650                                 clock_gettime(CLOCK_MONOTONIC, &s->end);
651                                 goto out_errno;
652                         }
653
654                         errno = 0;
655                         if (peek_flag) {
656                                 flags |= MSG_PEEK;
657                                 recvp = recvmsg(fd, &msg_peek, flags);
658                                 if (recvp < 0) {
659                                         if (errno != EWOULDBLOCK) {
660                                                 clock_gettime(CLOCK_MONOTONIC, &s->end);
661                                                 goto out_errno;
662                                         }
663                                 }
664                                 flags = 0;
665                         }
666
667                         recv = recvmsg(fd, &msg, flags);
668                         if (recv < 0) {
669                                 if (errno != EWOULDBLOCK) {
670                                         clock_gettime(CLOCK_MONOTONIC, &s->end);
671                                         perror("recv failed()");
672                                         goto out_errno;
673                                 }
674                         }
675
676                         s->bytes_recvd += recv;
677
678                         if (data) {
679                                 int chunk_sz = opt->sendpage ?
680                                                 iov_length * cnt :
681                                                 iov_length * iov_count;
682
683                                 errno = msg_verify_data(&msg, recv, chunk_sz);
684                                 if (errno) {
685                                         perror("data verify msg failed");
686                                         goto out_errno;
687                                 }
688                                 if (recvp) {
689                                         errno = msg_verify_data(&msg_peek,
690                                                                 recvp,
691                                                                 chunk_sz);
692                                         if (errno) {
693                                                 perror("data verify msg_peek failed");
694                                                 goto out_errno;
695                                         }
696                                 }
697                         }
698                 }
699                 clock_gettime(CLOCK_MONOTONIC, &s->end);
700         }
701
702         msg_free_iov(&msg);
703         msg_free_iov(&msg_peek);
704         return err;
705 out_errno:
706         msg_free_iov(&msg);
707         msg_free_iov(&msg_peek);
708         return errno;
709 }
710
711 static float giga = 1000000000;
712
713 static inline float sentBps(struct msg_stats s)
714 {
715         return s.bytes_sent / (s.end.tv_sec - s.start.tv_sec);
716 }
717
718 static inline float recvdBps(struct msg_stats s)
719 {
720         return s.bytes_recvd / (s.end.tv_sec - s.start.tv_sec);
721 }
722
723 static int sendmsg_test(struct sockmap_options *opt)
724 {
725         float sent_Bps = 0, recvd_Bps = 0;
726         int rx_fd, txpid, rxpid, err = 0;
727         struct msg_stats s = {0};
728         int iov_count = opt->iov_count;
729         int iov_buf = opt->iov_length;
730         int rx_status, tx_status;
731         int cnt = opt->rate;
732
733         errno = 0;
734
735         if (opt->base)
736                 rx_fd = p1;
737         else
738                 rx_fd = p2;
739
740         if (ktls) {
741                 /* Redirecting into non-TLS socket which sends into a TLS
742                  * socket is not a valid test. So in this case lets not
743                  * enable kTLS but still run the test.
744                  */
745                 if (!txmsg_redir || (txmsg_redir && txmsg_ingress)) {
746                         err = sockmap_init_ktls(opt->verbose, rx_fd);
747                         if (err)
748                                 return err;
749                 }
750                 err = sockmap_init_ktls(opt->verbose, c1);
751                 if (err)
752                         return err;
753         }
754
755         rxpid = fork();
756         if (rxpid == 0) {
757                 iov_buf -= (txmsg_pop - txmsg_start_pop + 1);
758                 if (opt->drop_expected || txmsg_ktls_skb_drop)
759                         _exit(0);
760
761                 if (!iov_buf) /* zero bytes sent case */
762                         _exit(0);
763
764                 if (opt->sendpage)
765                         iov_count = 1;
766                 err = msg_loop(rx_fd, iov_count, iov_buf,
767                                cnt, &s, false, opt);
768                 if (opt->verbose > 1)
769                         fprintf(stderr,
770                                 "msg_loop_rx: iov_count %i iov_buf %i cnt %i err %i\n",
771                                 iov_count, iov_buf, cnt, err);
772                 if (s.end.tv_sec - s.start.tv_sec) {
773                         sent_Bps = sentBps(s);
774                         recvd_Bps = recvdBps(s);
775                 }
776                 if (opt->verbose > 1)
777                         fprintf(stdout,
778                                 "rx_sendmsg: TX: %zuB %fB/s %fGB/s RX: %zuB %fB/s %fGB/s %s\n",
779                                 s.bytes_sent, sent_Bps, sent_Bps/giga,
780                                 s.bytes_recvd, recvd_Bps, recvd_Bps/giga,
781                                 peek_flag ? "(peek_msg)" : "");
782                 if (err && txmsg_cork)
783                         err = 0;
784                 exit(err ? 1 : 0);
785         } else if (rxpid == -1) {
786                 perror("msg_loop_rx");
787                 return errno;
788         }
789
790         txpid = fork();
791         if (txpid == 0) {
792                 if (opt->sendpage)
793                         err = msg_loop_sendpage(c1, iov_buf, cnt, &s, opt);
794                 else
795                         err = msg_loop(c1, iov_count, iov_buf,
796                                        cnt, &s, true, opt);
797
798                 if (err)
799                         fprintf(stderr,
800                                 "msg_loop_tx: iov_count %i iov_buf %i cnt %i err %i\n",
801                                 iov_count, iov_buf, cnt, err);
802                 if (s.end.tv_sec - s.start.tv_sec) {
803                         sent_Bps = sentBps(s);
804                         recvd_Bps = recvdBps(s);
805                 }
806                 if (opt->verbose > 1)
807                         fprintf(stdout,
808                                 "tx_sendmsg: TX: %zuB %fB/s %f GB/s RX: %zuB %fB/s %fGB/s\n",
809                                 s.bytes_sent, sent_Bps, sent_Bps/giga,
810                                 s.bytes_recvd, recvd_Bps, recvd_Bps/giga);
811                 exit(err ? 1 : 0);
812         } else if (txpid == -1) {
813                 perror("msg_loop_tx");
814                 return errno;
815         }
816
817         assert(waitpid(rxpid, &rx_status, 0) == rxpid);
818         assert(waitpid(txpid, &tx_status, 0) == txpid);
819         if (WIFEXITED(rx_status)) {
820                 err = WEXITSTATUS(rx_status);
821                 if (err) {
822                         fprintf(stderr, "rx thread exited with err %d.\n", err);
823                         goto out;
824                 }
825         }
826         if (WIFEXITED(tx_status)) {
827                 err = WEXITSTATUS(tx_status);
828                 if (err)
829                         fprintf(stderr, "tx thread exited with err %d.\n", err);
830         }
831 out:
832         return err;
833 }
834
835 static int forever_ping_pong(int rate, struct sockmap_options *opt)
836 {
837         struct timeval timeout;
838         char buf[1024] = {0};
839         int sc;
840
841         timeout.tv_sec = 10;
842         timeout.tv_usec = 0;
843
844         /* Ping/Pong data from client to server */
845         sc = send(c1, buf, sizeof(buf), 0);
846         if (sc < 0) {
847                 perror("send failed()");
848                 return sc;
849         }
850
851         do {
852                 int s, rc, i, max_fd = p2;
853                 fd_set w;
854
855                 /* FD sets */
856                 FD_ZERO(&w);
857                 FD_SET(c1, &w);
858                 FD_SET(c2, &w);
859                 FD_SET(p1, &w);
860                 FD_SET(p2, &w);
861
862                 s = select(max_fd + 1, &w, NULL, NULL, &timeout);
863                 if (s == -1) {
864                         perror("select()");
865                         break;
866                 } else if (!s) {
867                         fprintf(stderr, "unexpected timeout\n");
868                         break;
869                 }
870
871                 for (i = 0; i <= max_fd && s > 0; ++i) {
872                         if (!FD_ISSET(i, &w))
873                                 continue;
874
875                         s--;
876
877                         rc = recv(i, buf, sizeof(buf), 0);
878                         if (rc < 0) {
879                                 if (errno != EWOULDBLOCK) {
880                                         perror("recv failed()");
881                                         return rc;
882                                 }
883                         }
884
885                         if (rc == 0) {
886                                 close(i);
887                                 break;
888                         }
889
890                         sc = send(i, buf, rc, 0);
891                         if (sc < 0) {
892                                 perror("send failed()");
893                                 return sc;
894                         }
895                 }
896
897                 if (rate)
898                         sleep(rate);
899
900                 if (opt->verbose) {
901                         printf(".");
902                         fflush(stdout);
903
904                 }
905         } while (running);
906
907         return 0;
908 }
909
910 enum {
911         SELFTESTS,
912         PING_PONG,
913         SENDMSG,
914         BASE,
915         BASE_SENDPAGE,
916         SENDPAGE,
917 };
918
919 static int run_options(struct sockmap_options *options, int cg_fd,  int test)
920 {
921         int i, key, next_key, err, tx_prog_fd = -1, zero = 0;
922
923         /* If base test skip BPF setup */
924         if (test == BASE || test == BASE_SENDPAGE)
925                 goto run;
926
927         /* Attach programs to sockmap */
928         err = bpf_prog_attach(prog_fd[0], map_fd[0],
929                                 BPF_SK_SKB_STREAM_PARSER, 0);
930         if (err) {
931                 fprintf(stderr,
932                         "ERROR: bpf_prog_attach (sockmap %i->%i): %d (%s)\n",
933                         prog_fd[0], map_fd[0], err, strerror(errno));
934                 return err;
935         }
936
937         err = bpf_prog_attach(prog_fd[1], map_fd[0],
938                                 BPF_SK_SKB_STREAM_VERDICT, 0);
939         if (err) {
940                 fprintf(stderr, "ERROR: bpf_prog_attach (sockmap): %d (%s)\n",
941                         err, strerror(errno));
942                 return err;
943         }
944
945         /* Attach programs to TLS sockmap */
946         if (txmsg_ktls_skb) {
947                 err = bpf_prog_attach(prog_fd[0], map_fd[8],
948                                         BPF_SK_SKB_STREAM_PARSER, 0);
949                 if (err) {
950                         fprintf(stderr,
951                                 "ERROR: bpf_prog_attach (TLS sockmap %i->%i): %d (%s)\n",
952                                 prog_fd[0], map_fd[8], err, strerror(errno));
953                         return err;
954                 }
955
956                 err = bpf_prog_attach(prog_fd[2], map_fd[8],
957                                       BPF_SK_SKB_STREAM_VERDICT, 0);
958                 if (err) {
959                         fprintf(stderr, "ERROR: bpf_prog_attach (TLS sockmap): %d (%s)\n",
960                                 err, strerror(errno));
961                         return err;
962                 }
963         }
964
965         /* Attach to cgroups */
966         err = bpf_prog_attach(prog_fd[3], cg_fd, BPF_CGROUP_SOCK_OPS, 0);
967         if (err) {
968                 fprintf(stderr, "ERROR: bpf_prog_attach (groups): %d (%s)\n",
969                         err, strerror(errno));
970                 return err;
971         }
972
973 run:
974         err = sockmap_init_sockets(options->verbose);
975         if (err) {
976                 fprintf(stderr, "ERROR: test socket failed: %d\n", err);
977                 goto out;
978         }
979
980         /* Attach txmsg program to sockmap */
981         if (txmsg_pass)
982                 tx_prog_fd = prog_fd[4];
983         else if (txmsg_redir)
984                 tx_prog_fd = prog_fd[5];
985         else if (txmsg_apply)
986                 tx_prog_fd = prog_fd[6];
987         else if (txmsg_cork)
988                 tx_prog_fd = prog_fd[7];
989         else if (txmsg_drop)
990                 tx_prog_fd = prog_fd[8];
991         else
992                 tx_prog_fd = 0;
993
994         if (tx_prog_fd) {
995                 int redir_fd, i = 0;
996
997                 err = bpf_prog_attach(tx_prog_fd,
998                                       map_fd[1], BPF_SK_MSG_VERDICT, 0);
999                 if (err) {
1000                         fprintf(stderr,
1001                                 "ERROR: bpf_prog_attach (txmsg): %d (%s)\n",
1002                                 err, strerror(errno));
1003                         goto out;
1004                 }
1005
1006                 err = bpf_map_update_elem(map_fd[1], &i, &c1, BPF_ANY);
1007                 if (err) {
1008                         fprintf(stderr,
1009                                 "ERROR: bpf_map_update_elem (txmsg):  %d (%s\n",
1010                                 err, strerror(errno));
1011                         goto out;
1012                 }
1013
1014                 if (txmsg_redir)
1015                         redir_fd = c2;
1016                 else
1017                         redir_fd = c1;
1018
1019                 err = bpf_map_update_elem(map_fd[2], &i, &redir_fd, BPF_ANY);
1020                 if (err) {
1021                         fprintf(stderr,
1022                                 "ERROR: bpf_map_update_elem (txmsg):  %d (%s\n",
1023                                 err, strerror(errno));
1024                         goto out;
1025                 }
1026
1027                 if (txmsg_apply) {
1028                         err = bpf_map_update_elem(map_fd[3],
1029                                                   &i, &txmsg_apply, BPF_ANY);
1030                         if (err) {
1031                                 fprintf(stderr,
1032                                         "ERROR: bpf_map_update_elem (apply_bytes):  %d (%s\n",
1033                                         err, strerror(errno));
1034                                 goto out;
1035                         }
1036                 }
1037
1038                 if (txmsg_cork) {
1039                         err = bpf_map_update_elem(map_fd[4],
1040                                                   &i, &txmsg_cork, BPF_ANY);
1041                         if (err) {
1042                                 fprintf(stderr,
1043                                         "ERROR: bpf_map_update_elem (cork_bytes):  %d (%s\n",
1044                                         err, strerror(errno));
1045                                 goto out;
1046                         }
1047                 }
1048
1049                 if (txmsg_start) {
1050                         err = bpf_map_update_elem(map_fd[5],
1051                                                   &i, &txmsg_start, BPF_ANY);
1052                         if (err) {
1053                                 fprintf(stderr,
1054                                         "ERROR: bpf_map_update_elem (txmsg_start):  %d (%s)\n",
1055                                         err, strerror(errno));
1056                                 goto out;
1057                         }
1058                 }
1059
1060                 if (txmsg_end) {
1061                         i = 1;
1062                         err = bpf_map_update_elem(map_fd[5],
1063                                                   &i, &txmsg_end, BPF_ANY);
1064                         if (err) {
1065                                 fprintf(stderr,
1066                                         "ERROR: bpf_map_update_elem (txmsg_end):  %d (%s)\n",
1067                                         err, strerror(errno));
1068                                 goto out;
1069                         }
1070                 }
1071
1072                 if (txmsg_start_push) {
1073                         i = 2;
1074                         err = bpf_map_update_elem(map_fd[5],
1075                                                   &i, &txmsg_start_push, BPF_ANY);
1076                         if (err) {
1077                                 fprintf(stderr,
1078                                         "ERROR: bpf_map_update_elem (txmsg_start_push):  %d (%s)\n",
1079                                         err, strerror(errno));
1080                                 goto out;
1081                         }
1082                 }
1083
1084                 if (txmsg_end_push) {
1085                         i = 3;
1086                         err = bpf_map_update_elem(map_fd[5],
1087                                                   &i, &txmsg_end_push, BPF_ANY);
1088                         if (err) {
1089                                 fprintf(stderr,
1090                                         "ERROR: bpf_map_update_elem %i@%i (txmsg_end_push):  %d (%s)\n",
1091                                         txmsg_end_push, i, err, strerror(errno));
1092                                 goto out;
1093                         }
1094                 }
1095
1096                 if (txmsg_start_pop) {
1097                         i = 4;
1098                         err = bpf_map_update_elem(map_fd[5],
1099                                                   &i, &txmsg_start_pop, BPF_ANY);
1100                         if (err) {
1101                                 fprintf(stderr,
1102                                         "ERROR: bpf_map_update_elem %i@%i (txmsg_start_pop):  %d (%s)\n",
1103                                         txmsg_start_pop, i, err, strerror(errno));
1104                                 goto out;
1105                         }
1106                 } else {
1107                         i = 4;
1108                         bpf_map_update_elem(map_fd[5],
1109                                                   &i, &txmsg_start_pop, BPF_ANY);
1110                 }
1111
1112                 if (txmsg_pop) {
1113                         i = 5;
1114                         err = bpf_map_update_elem(map_fd[5],
1115                                                   &i, &txmsg_pop, BPF_ANY);
1116                         if (err) {
1117                                 fprintf(stderr,
1118                                         "ERROR: bpf_map_update_elem %i@%i (txmsg_pop):  %d (%s)\n",
1119                                         txmsg_pop, i, err, strerror(errno));
1120                                 goto out;
1121                         }
1122                 } else {
1123                         i = 5;
1124                         bpf_map_update_elem(map_fd[5],
1125                                             &i, &txmsg_pop, BPF_ANY);
1126
1127                 }
1128
1129                 if (txmsg_ingress) {
1130                         int in = BPF_F_INGRESS;
1131
1132                         i = 0;
1133                         err = bpf_map_update_elem(map_fd[6], &i, &in, BPF_ANY);
1134                         if (err) {
1135                                 fprintf(stderr,
1136                                         "ERROR: bpf_map_update_elem (txmsg_ingress): %d (%s)\n",
1137                                         err, strerror(errno));
1138                         }
1139                         i = 1;
1140                         err = bpf_map_update_elem(map_fd[1], &i, &p1, BPF_ANY);
1141                         if (err) {
1142                                 fprintf(stderr,
1143                                         "ERROR: bpf_map_update_elem (p1 txmsg): %d (%s)\n",
1144                                         err, strerror(errno));
1145                         }
1146                         err = bpf_map_update_elem(map_fd[2], &i, &p1, BPF_ANY);
1147                         if (err) {
1148                                 fprintf(stderr,
1149                                         "ERROR: bpf_map_update_elem (p1 redir): %d (%s)\n",
1150                                         err, strerror(errno));
1151                         }
1152
1153                         i = 2;
1154                         err = bpf_map_update_elem(map_fd[2], &i, &p2, BPF_ANY);
1155                         if (err) {
1156                                 fprintf(stderr,
1157                                         "ERROR: bpf_map_update_elem (p2 txmsg): %d (%s)\n",
1158                                         err, strerror(errno));
1159                         }
1160                 }
1161
1162                 if (txmsg_ktls_skb) {
1163                         int ingress = BPF_F_INGRESS;
1164
1165                         i = 0;
1166                         err = bpf_map_update_elem(map_fd[8], &i, &p2, BPF_ANY);
1167                         if (err) {
1168                                 fprintf(stderr,
1169                                         "ERROR: bpf_map_update_elem (c1 sockmap): %d (%s)\n",
1170                                         err, strerror(errno));
1171                         }
1172
1173                         if (txmsg_ktls_skb_redir) {
1174                                 i = 1;
1175                                 err = bpf_map_update_elem(map_fd[7],
1176                                                           &i, &ingress, BPF_ANY);
1177                                 if (err) {
1178                                         fprintf(stderr,
1179                                                 "ERROR: bpf_map_update_elem (txmsg_ingress): %d (%s)\n",
1180                                                 err, strerror(errno));
1181                                 }
1182                         }
1183
1184                         if (txmsg_ktls_skb_drop) {
1185                                 i = 1;
1186                                 err = bpf_map_update_elem(map_fd[7], &i, &i, BPF_ANY);
1187                         }
1188                 }
1189
1190                 if (txmsg_redir_skb) {
1191                         int skb_fd = (test == SENDMSG || test == SENDPAGE) ?
1192                                         p2 : p1;
1193                         int ingress = BPF_F_INGRESS;
1194
1195                         i = 0;
1196                         err = bpf_map_update_elem(map_fd[7],
1197                                                   &i, &ingress, BPF_ANY);
1198                         if (err) {
1199                                 fprintf(stderr,
1200                                         "ERROR: bpf_map_update_elem (txmsg_ingress): %d (%s)\n",
1201                                         err, strerror(errno));
1202                         }
1203
1204                         i = 3;
1205                         err = bpf_map_update_elem(map_fd[0], &i, &skb_fd, BPF_ANY);
1206                         if (err) {
1207                                 fprintf(stderr,
1208                                         "ERROR: bpf_map_update_elem (c1 sockmap): %d (%s)\n",
1209                                         err, strerror(errno));
1210                         }
1211                 }
1212         }
1213
1214         if (txmsg_drop)
1215                 options->drop_expected = true;
1216
1217         if (test == PING_PONG)
1218                 err = forever_ping_pong(options->rate, options);
1219         else if (test == SENDMSG) {
1220                 options->base = false;
1221                 options->sendpage = false;
1222                 err = sendmsg_test(options);
1223         } else if (test == SENDPAGE) {
1224                 options->base = false;
1225                 options->sendpage = true;
1226                 err = sendmsg_test(options);
1227         } else if (test == BASE) {
1228                 options->base = true;
1229                 options->sendpage = false;
1230                 err = sendmsg_test(options);
1231         } else if (test == BASE_SENDPAGE) {
1232                 options->base = true;
1233                 options->sendpage = true;
1234                 err = sendmsg_test(options);
1235         } else
1236                 fprintf(stderr, "unknown test\n");
1237 out:
1238         /* Detatch and zero all the maps */
1239         bpf_prog_detach2(prog_fd[3], cg_fd, BPF_CGROUP_SOCK_OPS);
1240         bpf_prog_detach2(prog_fd[0], map_fd[0], BPF_SK_SKB_STREAM_PARSER);
1241         bpf_prog_detach2(prog_fd[1], map_fd[0], BPF_SK_SKB_STREAM_VERDICT);
1242         bpf_prog_detach2(prog_fd[0], map_fd[8], BPF_SK_SKB_STREAM_PARSER);
1243         bpf_prog_detach2(prog_fd[2], map_fd[8], BPF_SK_SKB_STREAM_VERDICT);
1244
1245         if (tx_prog_fd >= 0)
1246                 bpf_prog_detach2(tx_prog_fd, map_fd[1], BPF_SK_MSG_VERDICT);
1247
1248         for (i = 0; i < 8; i++) {
1249                 key = next_key = 0;
1250                 bpf_map_update_elem(map_fd[i], &key, &zero, BPF_ANY);
1251                 while (bpf_map_get_next_key(map_fd[i], &key, &next_key) == 0) {
1252                         bpf_map_update_elem(map_fd[i], &key, &zero, BPF_ANY);
1253                         key = next_key;
1254                 }
1255         }
1256
1257         close(s1);
1258         close(s2);
1259         close(p1);
1260         close(p2);
1261         close(c1);
1262         close(c2);
1263         return err;
1264 }
1265
1266 static char *test_to_str(int test)
1267 {
1268         switch (test) {
1269         case SENDMSG:
1270                 return "sendmsg";
1271         case SENDPAGE:
1272                 return "sendpage";
1273         }
1274         return "unknown";
1275 }
1276
1277 #define OPTSTRING 60
1278 static void test_options(char *options)
1279 {
1280         char tstr[OPTSTRING];
1281
1282         memset(options, 0, OPTSTRING);
1283
1284         if (txmsg_pass)
1285                 strncat(options, "pass,", OPTSTRING);
1286         if (txmsg_redir)
1287                 strncat(options, "redir,", OPTSTRING);
1288         if (txmsg_drop)
1289                 strncat(options, "drop,", OPTSTRING);
1290         if (txmsg_apply) {
1291                 snprintf(tstr, OPTSTRING, "apply %d,", txmsg_apply);
1292                 strncat(options, tstr, OPTSTRING);
1293         }
1294         if (txmsg_cork) {
1295                 snprintf(tstr, OPTSTRING, "cork %d,", txmsg_cork);
1296                 strncat(options, tstr, OPTSTRING);
1297         }
1298         if (txmsg_start) {
1299                 snprintf(tstr, OPTSTRING, "start %d,", txmsg_start);
1300                 strncat(options, tstr, OPTSTRING);
1301         }
1302         if (txmsg_end) {
1303                 snprintf(tstr, OPTSTRING, "end %d,", txmsg_end);
1304                 strncat(options, tstr, OPTSTRING);
1305         }
1306         if (txmsg_start_pop) {
1307                 snprintf(tstr, OPTSTRING, "pop (%d,%d),",
1308                          txmsg_start_pop, txmsg_start_pop + txmsg_pop);
1309                 strncat(options, tstr, OPTSTRING);
1310         }
1311         if (txmsg_ingress)
1312                 strncat(options, "ingress,", OPTSTRING);
1313         if (txmsg_redir_skb)
1314                 strncat(options, "redir_skb,", OPTSTRING);
1315         if (txmsg_ktls_skb)
1316                 strncat(options, "ktls_skb,", OPTSTRING);
1317         if (ktls)
1318                 strncat(options, "ktls,", OPTSTRING);
1319         if (peek_flag)
1320                 strncat(options, "peek,", OPTSTRING);
1321 }
1322
1323 static int __test_exec(int cgrp, int test, struct sockmap_options *opt)
1324 {
1325         char *options = calloc(OPTSTRING, sizeof(char));
1326         int err;
1327
1328         if (test == SENDPAGE)
1329                 opt->sendpage = true;
1330         else
1331                 opt->sendpage = false;
1332
1333         if (txmsg_drop)
1334                 opt->drop_expected = true;
1335         else
1336                 opt->drop_expected = false;
1337
1338         test_options(options);
1339
1340         if (opt->verbose) {
1341                 fprintf(stdout,
1342                         " [TEST %i]: (%i, %i, %i, %s, %s): ",
1343                         test_cnt, opt->rate, opt->iov_count, opt->iov_length,
1344                         test_to_str(test), options);
1345                 fflush(stdout);
1346         }
1347         err = run_options(opt, cgrp, test);
1348         if (opt->verbose)
1349                 fprintf(stdout, " %s\n", !err ? "PASS" : "FAILED");
1350         test_cnt++;
1351         !err ? passed++ : failed++;
1352         free(options);
1353         return err;
1354 }
1355
1356 static void test_exec(int cgrp, struct sockmap_options *opt)
1357 {
1358         int type = strcmp(opt->map, BPF_SOCKMAP_FILENAME);
1359         int err;
1360
1361         if (type == 0) {
1362                 test_start();
1363                 err = __test_exec(cgrp, SENDMSG, opt);
1364                 if (err)
1365                         test_fail();
1366         } else {
1367                 test_start();
1368                 err = __test_exec(cgrp, SENDPAGE, opt);
1369                 if (err)
1370                         test_fail();
1371         }
1372 }
1373
1374 static void test_send_one(struct sockmap_options *opt, int cgrp)
1375 {
1376         opt->iov_length = 1;
1377         opt->iov_count = 1;
1378         opt->rate = 1;
1379         test_exec(cgrp, opt);
1380
1381         opt->iov_length = 1;
1382         opt->iov_count = 1024;
1383         opt->rate = 1;
1384         test_exec(cgrp, opt);
1385
1386         opt->iov_length = 1024;
1387         opt->iov_count = 1;
1388         opt->rate = 1;
1389         test_exec(cgrp, opt);
1390
1391 }
1392
1393 static void test_send_many(struct sockmap_options *opt, int cgrp)
1394 {
1395         opt->iov_length = 3;
1396         opt->iov_count = 1;
1397         opt->rate = 512;
1398         test_exec(cgrp, opt);
1399
1400         opt->rate = 100;
1401         opt->iov_count = 1;
1402         opt->iov_length = 5;
1403         test_exec(cgrp, opt);
1404 }
1405
1406 static void test_send_large(struct sockmap_options *opt, int cgrp)
1407 {
1408         opt->iov_length = 256;
1409         opt->iov_count = 1024;
1410         opt->rate = 2;
1411         test_exec(cgrp, opt);
1412 }
1413
1414 static void test_send(struct sockmap_options *opt, int cgrp)
1415 {
1416         test_send_one(opt, cgrp);
1417         test_send_many(opt, cgrp);
1418         test_send_large(opt, cgrp);
1419         sched_yield();
1420 }
1421
1422 static void test_txmsg_pass(int cgrp, struct sockmap_options *opt)
1423 {
1424         /* Test small and large iov_count values with pass/redir/apply/cork */
1425         txmsg_pass = 1;
1426         test_send(opt, cgrp);
1427 }
1428
1429 static void test_txmsg_redir(int cgrp, struct sockmap_options *opt)
1430 {
1431         txmsg_redir = 1;
1432         test_send(opt, cgrp);
1433 }
1434
1435 static void test_txmsg_drop(int cgrp, struct sockmap_options *opt)
1436 {
1437         txmsg_drop = 1;
1438         test_send(opt, cgrp);
1439 }
1440
1441 static void test_txmsg_ingress_redir(int cgrp, struct sockmap_options *opt)
1442 {
1443         txmsg_pass = txmsg_drop = 0;
1444         txmsg_ingress = txmsg_redir = 1;
1445         test_send(opt, cgrp);
1446 }
1447
1448 static void test_txmsg_skb(int cgrp, struct sockmap_options *opt)
1449 {
1450         bool data = opt->data_test;
1451         int k = ktls;
1452
1453         opt->data_test = true;
1454         ktls = 1;
1455
1456         txmsg_pass = txmsg_drop = 0;
1457         txmsg_ingress = txmsg_redir = 0;
1458         txmsg_ktls_skb = 1;
1459         txmsg_pass = 1;
1460
1461         /* Using data verification so ensure iov layout is
1462          * expected from test receiver side. e.g. has enough
1463          * bytes to write test code.
1464          */
1465         opt->iov_length = 100;
1466         opt->iov_count = 1;
1467         opt->rate = 1;
1468         test_exec(cgrp, opt);
1469
1470         txmsg_ktls_skb_drop = 1;
1471         test_exec(cgrp, opt);
1472
1473         txmsg_ktls_skb_drop = 0;
1474         txmsg_ktls_skb_redir = 1;
1475         test_exec(cgrp, opt);
1476
1477         opt->data_test = data;
1478         ktls = k;
1479 }
1480
1481
1482 /* Test cork with hung data. This tests poor usage patterns where
1483  * cork can leave data on the ring if user program is buggy and
1484  * doesn't flush them somehow. They do take some time however
1485  * because they wait for a timeout. Test pass, redir and cork with
1486  * apply logic. Use cork size of 4097 with send_large to avoid
1487  * aligning cork size with send size.
1488  */
1489 static void test_txmsg_cork_hangs(int cgrp, struct sockmap_options *opt)
1490 {
1491         txmsg_pass = 1;
1492         txmsg_redir = 0;
1493         txmsg_cork = 4097;
1494         txmsg_apply = 4097;
1495         test_send_large(opt, cgrp);
1496
1497         txmsg_pass = 0;
1498         txmsg_redir = 1;
1499         txmsg_apply = 0;
1500         txmsg_cork = 4097;
1501         test_send_large(opt, cgrp);
1502
1503         txmsg_pass = 0;
1504         txmsg_redir = 1;
1505         txmsg_apply = 4097;
1506         txmsg_cork = 4097;
1507         test_send_large(opt, cgrp);
1508 }
1509
1510 static void test_txmsg_pull(int cgrp, struct sockmap_options *opt)
1511 {
1512         /* Test basic start/end */
1513         txmsg_start = 1;
1514         txmsg_end = 2;
1515         test_send(opt, cgrp);
1516
1517         /* Test >4k pull */
1518         txmsg_start = 4096;
1519         txmsg_end = 9182;
1520         test_send_large(opt, cgrp);
1521
1522         /* Test pull + redirect */
1523         txmsg_redir = 0;
1524         txmsg_start = 1;
1525         txmsg_end = 2;
1526         test_send(opt, cgrp);
1527
1528         /* Test pull + cork */
1529         txmsg_redir = 0;
1530         txmsg_cork = 512;
1531         txmsg_start = 1;
1532         txmsg_end = 2;
1533         test_send_many(opt, cgrp);
1534
1535         /* Test pull + cork + redirect */
1536         txmsg_redir = 1;
1537         txmsg_cork = 512;
1538         txmsg_start = 1;
1539         txmsg_end = 2;
1540         test_send_many(opt, cgrp);
1541 }
1542
1543 static void test_txmsg_pop(int cgrp, struct sockmap_options *opt)
1544 {
1545         /* Test basic pop */
1546         txmsg_start_pop = 1;
1547         txmsg_pop = 2;
1548         test_send_many(opt, cgrp);
1549
1550         /* Test pop with >4k */
1551         txmsg_start_pop = 4096;
1552         txmsg_pop = 4096;
1553         test_send_large(opt, cgrp);
1554
1555         /* Test pop + redirect */
1556         txmsg_redir = 1;
1557         txmsg_start_pop = 1;
1558         txmsg_pop = 2;
1559         test_send_many(opt, cgrp);
1560
1561         /* Test pop + cork */
1562         txmsg_redir = 0;
1563         txmsg_cork = 512;
1564         txmsg_start_pop = 1;
1565         txmsg_pop = 2;
1566         test_send_many(opt, cgrp);
1567
1568         /* Test pop + redirect + cork */
1569         txmsg_redir = 1;
1570         txmsg_cork = 4;
1571         txmsg_start_pop = 1;
1572         txmsg_pop = 2;
1573         test_send_many(opt, cgrp);
1574 }
1575
1576 static void test_txmsg_push(int cgrp, struct sockmap_options *opt)
1577 {
1578         /* Test basic push */
1579         txmsg_start_push = 1;
1580         txmsg_end_push = 1;
1581         test_send(opt, cgrp);
1582
1583         /* Test push 4kB >4k */
1584         txmsg_start_push = 4096;
1585         txmsg_end_push = 4096;
1586         test_send_large(opt, cgrp);
1587
1588         /* Test push + redirect */
1589         txmsg_redir = 1;
1590         txmsg_start_push = 1;
1591         txmsg_end_push = 2;
1592         test_send_many(opt, cgrp);
1593
1594         /* Test push + cork */
1595         txmsg_redir = 0;
1596         txmsg_cork = 512;
1597         txmsg_start_push = 1;
1598         txmsg_end_push = 2;
1599         test_send_many(opt, cgrp);
1600 }
1601
1602 static void test_txmsg_push_pop(int cgrp, struct sockmap_options *opt)
1603 {
1604         txmsg_start_push = 1;
1605         txmsg_end_push = 10;
1606         txmsg_start_pop = 5;
1607         txmsg_pop = 4;
1608         test_send_large(opt, cgrp);
1609 }
1610
1611 static void test_txmsg_apply(int cgrp, struct sockmap_options *opt)
1612 {
1613         txmsg_pass = 1;
1614         txmsg_redir = 0;
1615         txmsg_apply = 1;
1616         txmsg_cork = 0;
1617         test_send_one(opt, cgrp);
1618
1619         txmsg_pass = 0;
1620         txmsg_redir = 1;
1621         txmsg_apply = 1;
1622         txmsg_cork = 0;
1623         test_send_one(opt, cgrp);
1624
1625         txmsg_pass = 1;
1626         txmsg_redir = 0;
1627         txmsg_apply = 1024;
1628         txmsg_cork = 0;
1629         test_send_large(opt, cgrp);
1630
1631         txmsg_pass = 0;
1632         txmsg_redir = 1;
1633         txmsg_apply = 1024;
1634         txmsg_cork = 0;
1635         test_send_large(opt, cgrp);
1636 }
1637
1638 static void test_txmsg_cork(int cgrp, struct sockmap_options *opt)
1639 {
1640         txmsg_pass = 1;
1641         txmsg_redir = 0;
1642         txmsg_apply = 0;
1643         txmsg_cork = 1;
1644         test_send(opt, cgrp);
1645
1646         txmsg_pass = 1;
1647         txmsg_redir = 0;
1648         txmsg_apply = 1;
1649         txmsg_cork = 1;
1650         test_send(opt, cgrp);
1651 }
1652
1653 char *map_names[] = {
1654         "sock_map",
1655         "sock_map_txmsg",
1656         "sock_map_redir",
1657         "sock_apply_bytes",
1658         "sock_cork_bytes",
1659         "sock_bytes",
1660         "sock_redir_flags",
1661         "sock_skb_opts",
1662         "tls_sock_map",
1663 };
1664
1665 int prog_attach_type[] = {
1666         BPF_SK_SKB_STREAM_PARSER,
1667         BPF_SK_SKB_STREAM_VERDICT,
1668         BPF_SK_SKB_STREAM_VERDICT,
1669         BPF_CGROUP_SOCK_OPS,
1670         BPF_SK_MSG_VERDICT,
1671         BPF_SK_MSG_VERDICT,
1672         BPF_SK_MSG_VERDICT,
1673         BPF_SK_MSG_VERDICT,
1674         BPF_SK_MSG_VERDICT,
1675         BPF_SK_MSG_VERDICT,
1676         BPF_SK_MSG_VERDICT,
1677 };
1678
1679 int prog_type[] = {
1680         BPF_PROG_TYPE_SK_SKB,
1681         BPF_PROG_TYPE_SK_SKB,
1682         BPF_PROG_TYPE_SK_SKB,
1683         BPF_PROG_TYPE_SOCK_OPS,
1684         BPF_PROG_TYPE_SK_MSG,
1685         BPF_PROG_TYPE_SK_MSG,
1686         BPF_PROG_TYPE_SK_MSG,
1687         BPF_PROG_TYPE_SK_MSG,
1688         BPF_PROG_TYPE_SK_MSG,
1689         BPF_PROG_TYPE_SK_MSG,
1690         BPF_PROG_TYPE_SK_MSG,
1691 };
1692
1693 static int populate_progs(char *bpf_file)
1694 {
1695         struct bpf_program *prog;
1696         struct bpf_object *obj;
1697         int i = 0;
1698         long err;
1699
1700         obj = bpf_object__open(bpf_file);
1701         err = libbpf_get_error(obj);
1702         if (err) {
1703                 char err_buf[256];
1704
1705                 libbpf_strerror(err, err_buf, sizeof(err_buf));
1706                 printf("Unable to load eBPF objects in file '%s' : %s\n",
1707                        bpf_file, err_buf);
1708                 return -1;
1709         }
1710
1711         bpf_object__for_each_program(prog, obj) {
1712                 bpf_program__set_type(prog, prog_type[i]);
1713                 bpf_program__set_expected_attach_type(prog,
1714                                                       prog_attach_type[i]);
1715                 i++;
1716         }
1717
1718         i = bpf_object__load(obj);
1719         i = 0;
1720         bpf_object__for_each_program(prog, obj) {
1721                 prog_fd[i] = bpf_program__fd(prog);
1722                 i++;
1723         }
1724
1725         for (i = 0; i < sizeof(map_fd)/sizeof(int); i++) {
1726                 maps[i] = bpf_object__find_map_by_name(obj, map_names[i]);
1727                 map_fd[i] = bpf_map__fd(maps[i]);
1728                 if (map_fd[i] < 0) {
1729                         fprintf(stderr, "load_bpf_file: (%i) %s\n",
1730                                 map_fd[i], strerror(errno));
1731                         return -1;
1732                 }
1733         }
1734
1735         return 0;
1736 }
1737
1738 struct _test test[] = {
1739         {"txmsg test passthrough", test_txmsg_pass},
1740         {"txmsg test redirect", test_txmsg_redir},
1741         {"txmsg test drop", test_txmsg_drop},
1742         {"txmsg test ingress redirect", test_txmsg_ingress_redir},
1743         {"txmsg test skb", test_txmsg_skb},
1744         {"txmsg test apply", test_txmsg_apply},
1745         {"txmsg test cork", test_txmsg_cork},
1746         {"txmsg test hanging corks", test_txmsg_cork_hangs},
1747         {"txmsg test push_data", test_txmsg_push},
1748         {"txmsg test pull-data", test_txmsg_pull},
1749         {"txmsg test pop-data", test_txmsg_pop},
1750         {"txmsg test push/pop data", test_txmsg_push_pop},
1751 };
1752
1753 static int check_whitelist(struct _test *t, struct sockmap_options *opt)
1754 {
1755         char *entry, *ptr;
1756
1757         if (!opt->whitelist)
1758                 return 0;
1759         ptr = strdup(opt->whitelist);
1760         if (!ptr)
1761                 return -ENOMEM;
1762         entry = strtok(ptr, ",");
1763         while (entry) {
1764                 if ((opt->prepend && strstr(opt->prepend, entry) != 0) ||
1765                     strstr(opt->map, entry) != 0 ||
1766                     strstr(t->title, entry) != 0)
1767                         return 0;
1768                 entry = strtok(NULL, ",");
1769         }
1770         return -EINVAL;
1771 }
1772
1773 static int check_blacklist(struct _test *t, struct sockmap_options *opt)
1774 {
1775         char *entry, *ptr;
1776
1777         if (!opt->blacklist)
1778                 return -EINVAL;
1779         ptr = strdup(opt->blacklist);
1780         if (!ptr)
1781                 return -ENOMEM;
1782         entry = strtok(ptr, ",");
1783         while (entry) {
1784                 if ((opt->prepend && strstr(opt->prepend, entry) != 0) ||
1785                     strstr(opt->map, entry) != 0 ||
1786                     strstr(t->title, entry) != 0)
1787                         return 0;
1788                 entry = strtok(NULL, ",");
1789         }
1790         return -EINVAL;
1791 }
1792
1793 static int __test_selftests(int cg_fd, struct sockmap_options *opt)
1794 {
1795         int i, err;
1796
1797         err = populate_progs(opt->map);
1798         if (err < 0) {
1799                 fprintf(stderr, "ERROR: (%i) load bpf failed\n", err);
1800                 return err;
1801         }
1802
1803         /* Tests basic commands and APIs */
1804         for (i = 0; i < sizeof(test)/sizeof(struct _test); i++) {
1805                 struct _test t = test[i];
1806
1807                 if (check_whitelist(&t, opt) != 0)
1808                         continue;
1809                 if (check_blacklist(&t, opt) == 0)
1810                         continue;
1811
1812                 test_start_subtest(&t, opt);
1813                 t.tester(cg_fd, opt);
1814                 test_end_subtest();
1815         }
1816
1817         return err;
1818 }
1819
1820 static void test_selftests_sockmap(int cg_fd, struct sockmap_options *opt)
1821 {
1822         opt->map = BPF_SOCKMAP_FILENAME;
1823         __test_selftests(cg_fd, opt);
1824 }
1825
1826 static void test_selftests_sockhash(int cg_fd, struct sockmap_options *opt)
1827 {
1828         opt->map = BPF_SOCKHASH_FILENAME;
1829         __test_selftests(cg_fd, opt);
1830 }
1831
1832 static void test_selftests_ktls(int cg_fd, struct sockmap_options *opt)
1833 {
1834         opt->map = BPF_SOCKHASH_FILENAME;
1835         opt->prepend = "ktls";
1836         ktls = 1;
1837         __test_selftests(cg_fd, opt);
1838         ktls = 0;
1839 }
1840
1841 static int test_selftest(int cg_fd, struct sockmap_options *opt)
1842 {
1843
1844         test_selftests_sockmap(cg_fd, opt);
1845         test_selftests_sockhash(cg_fd, opt);
1846         test_selftests_ktls(cg_fd, opt);
1847         test_print_results();
1848         return 0;
1849 }
1850
1851 int main(int argc, char **argv)
1852 {
1853         int iov_count = 1, length = 1024, rate = 1;
1854         struct sockmap_options options = {0};
1855         int opt, longindex, err, cg_fd = 0;
1856         char *bpf_file = BPF_SOCKMAP_FILENAME;
1857         int test = SELFTESTS;
1858         bool cg_created = 0;
1859
1860         while ((opt = getopt_long(argc, argv, ":dhv:c:r:i:l:t:p:q:n:b:",
1861                                   long_options, &longindex)) != -1) {
1862                 switch (opt) {
1863                 case 's':
1864                         txmsg_start = atoi(optarg);
1865                         break;
1866                 case 'e':
1867                         txmsg_end = atoi(optarg);
1868                         break;
1869                 case 'p':
1870                         txmsg_start_push = atoi(optarg);
1871                         break;
1872                 case 'q':
1873                         txmsg_end_push = atoi(optarg);
1874                         break;
1875                 case 'w':
1876                         txmsg_start_pop = atoi(optarg);
1877                         break;
1878                 case 'x':
1879                         txmsg_pop = atoi(optarg);
1880                         break;
1881                 case 'a':
1882                         txmsg_apply = atoi(optarg);
1883                         break;
1884                 case 'k':
1885                         txmsg_cork = atoi(optarg);
1886                         break;
1887                 case 'c':
1888                         cg_fd = open(optarg, O_DIRECTORY, O_RDONLY);
1889                         if (cg_fd < 0) {
1890                                 fprintf(stderr,
1891                                         "ERROR: (%i) open cg path failed: %s\n",
1892                                         cg_fd, optarg);
1893                                 return cg_fd;
1894                         }
1895                         break;
1896                 case 'r':
1897                         rate = atoi(optarg);
1898                         break;
1899                 case 'v':
1900                         options.verbose = 1;
1901                         if (optarg)
1902                                 options.verbose = atoi(optarg);
1903                         break;
1904                 case 'i':
1905                         iov_count = atoi(optarg);
1906                         break;
1907                 case 'l':
1908                         length = atoi(optarg);
1909                         break;
1910                 case 'd':
1911                         options.data_test = true;
1912                         break;
1913                 case 't':
1914                         if (strcmp(optarg, "ping") == 0) {
1915                                 test = PING_PONG;
1916                         } else if (strcmp(optarg, "sendmsg") == 0) {
1917                                 test = SENDMSG;
1918                         } else if (strcmp(optarg, "base") == 0) {
1919                                 test = BASE;
1920                         } else if (strcmp(optarg, "base_sendpage") == 0) {
1921                                 test = BASE_SENDPAGE;
1922                         } else if (strcmp(optarg, "sendpage") == 0) {
1923                                 test = SENDPAGE;
1924                         } else {
1925                                 usage(argv);
1926                                 return -1;
1927                         }
1928                         break;
1929                 case 'n':
1930                         options.whitelist = strdup(optarg);
1931                         if (!options.whitelist)
1932                                 return -ENOMEM;
1933                         break;
1934                 case 'b':
1935                         options.blacklist = strdup(optarg);
1936                         if (!options.blacklist)
1937                                 return -ENOMEM;
1938                 case 0:
1939                         break;
1940                 case 'h':
1941                 default:
1942                         usage(argv);
1943                         return -1;
1944                 }
1945         }
1946
1947         if (!cg_fd) {
1948                 if (setup_cgroup_environment()) {
1949                         fprintf(stderr, "ERROR: cgroup env failed\n");
1950                         return -EINVAL;
1951                 }
1952
1953                 cg_fd = create_and_get_cgroup(CG_PATH);
1954                 if (cg_fd < 0) {
1955                         fprintf(stderr,
1956                                 "ERROR: (%i) open cg path failed: %s\n",
1957                                 cg_fd, strerror(errno));
1958                         return cg_fd;
1959                 }
1960
1961                 if (join_cgroup(CG_PATH)) {
1962                         fprintf(stderr, "ERROR: failed to join cgroup\n");
1963                         return -EINVAL;
1964                 }
1965                 cg_created = 1;
1966         }
1967
1968         if (test == SELFTESTS) {
1969                 err = test_selftest(cg_fd, &options);
1970                 goto out;
1971         }
1972
1973         err = populate_progs(bpf_file);
1974         if (err) {
1975                 fprintf(stderr, "populate program: (%s) %s\n",
1976                         bpf_file, strerror(errno));
1977                 return 1;
1978         }
1979         running = 1;
1980
1981         /* catch SIGINT */
1982         signal(SIGINT, running_handler);
1983
1984         options.iov_count = iov_count;
1985         options.iov_length = length;
1986         options.rate = rate;
1987
1988         err = run_options(&options, cg_fd, test);
1989 out:
1990         if (options.whitelist)
1991                 free(options.whitelist);
1992         if (options.blacklist)
1993                 free(options.blacklist);
1994         if (cg_created)
1995                 cleanup_cgroup_environment();
1996         close(cg_fd);
1997         return err;
1998 }
1999
2000 void running_handler(int a)
2001 {
2002         running = 0;
2003 }