8b838e91cfe5b126d292993bae479ea9a4051883
[linux-2.6-microblaze.git] / tools / testing / selftests / bpf / test_sockmap.c
1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2017-2018 Covalent IO, Inc. http://covalent.io
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <sys/socket.h>
6 #include <sys/ioctl.h>
7 #include <sys/select.h>
8 #include <netinet/in.h>
9 #include <arpa/inet.h>
10 #include <unistd.h>
11 #include <string.h>
12 #include <errno.h>
13 #include <stdbool.h>
14 #include <signal.h>
15 #include <fcntl.h>
16 #include <sys/wait.h>
17 #include <time.h>
18 #include <sched.h>
19
20 #include <sys/time.h>
21 #include <sys/resource.h>
22 #include <sys/types.h>
23 #include <sys/sendfile.h>
24
25 #include <linux/netlink.h>
26 #include <linux/socket.h>
27 #include <linux/sock_diag.h>
28 #include <linux/bpf.h>
29 #include <linux/if_link.h>
30 #include <linux/tls.h>
31 #include <assert.h>
32 #include <libgen.h>
33
34 #include <getopt.h>
35
36 #include <bpf/bpf.h>
37 #include <bpf/libbpf.h>
38
39 #include "bpf_util.h"
40 #include "bpf_rlimit.h"
41 #include "cgroup_helpers.h"
42
43 int running;
44 static void running_handler(int a);
45
46 #ifndef TCP_ULP
47 # define TCP_ULP 31
48 #endif
49 #ifndef SOL_TLS
50 # define SOL_TLS 282
51 #endif
52
53 /* randomly selected ports for testing on lo */
54 #define S1_PORT 10000
55 #define S2_PORT 10001
56
57 #define BPF_SOCKMAP_FILENAME "test_sockmap_kern.o"
58 #define BPF_SOCKHASH_FILENAME "test_sockhash_kern.o"
59 #define CG_PATH "/sockmap"
60
61 /* global sockets */
62 int s1, s2, c1, c2, p1, p2;
63 int test_cnt;
64 int passed;
65 int failed;
66 int map_fd[8];
67 struct bpf_map *maps[8];
68 int prog_fd[11];
69
70 int txmsg_pass;
71 int txmsg_noisy;
72 int txmsg_redir;
73 int txmsg_redir_noisy;
74 int txmsg_drop;
75 int txmsg_apply;
76 int txmsg_cork;
77 int txmsg_start;
78 int txmsg_end;
79 int txmsg_start_push;
80 int txmsg_end_push;
81 int txmsg_start_pop;
82 int txmsg_pop;
83 int txmsg_ingress;
84 int txmsg_skb;
85 int ktls;
86 int peek_flag;
87
88 static const struct option long_options[] = {
89         {"help",        no_argument,            NULL, 'h' },
90         {"cgroup",      required_argument,      NULL, 'c' },
91         {"rate",        required_argument,      NULL, 'r' },
92         {"verbose",     no_argument,            NULL, 'v' },
93         {"iov_count",   required_argument,      NULL, 'i' },
94         {"length",      required_argument,      NULL, 'l' },
95         {"test",        required_argument,      NULL, 't' },
96         {"data_test",   no_argument,            NULL, 'd' },
97         {"txmsg",               no_argument,    &txmsg_pass,  1  },
98         {"txmsg_noisy",         no_argument,    &txmsg_noisy, 1  },
99         {"txmsg_redir",         no_argument,    &txmsg_redir, 1  },
100         {"txmsg_redir_noisy",   no_argument,    &txmsg_redir_noisy, 1},
101         {"txmsg_drop",          no_argument,    &txmsg_drop, 1 },
102         {"txmsg_apply", required_argument,      NULL, 'a'},
103         {"txmsg_cork",  required_argument,      NULL, 'k'},
104         {"txmsg_start", required_argument,      NULL, 's'},
105         {"txmsg_end",   required_argument,      NULL, 'e'},
106         {"txmsg_start_push", required_argument, NULL, 'p'},
107         {"txmsg_end_push",   required_argument, NULL, 'q'},
108         {"txmsg_start_pop",  required_argument, NULL, 'w'},
109         {"txmsg_pop",        required_argument, NULL, 'x'},
110         {"txmsg_ingress", no_argument,          &txmsg_ingress, 1 },
111         {"txmsg_skb", no_argument,              &txmsg_skb, 1 },
112         {"ktls", no_argument,                   &ktls, 1 },
113         {"peek", no_argument,                   &peek_flag, 1 },
114         {0, 0, NULL, 0 }
115 };
116
117 static void usage(char *argv[])
118 {
119         int i;
120
121         printf(" Usage: %s --cgroup <cgroup_path>\n", argv[0]);
122         printf(" options:\n");
123         for (i = 0; long_options[i].name != 0; i++) {
124                 printf(" --%-12s", long_options[i].name);
125                 if (long_options[i].flag != NULL)
126                         printf(" flag (internal value:%d)\n",
127                                 *long_options[i].flag);
128                 else
129                         printf(" -%c\n", long_options[i].val);
130         }
131         printf("\n");
132 }
133
134 char *sock_to_string(int s)
135 {
136         if (s == c1)
137                 return "client1";
138         else if (s == c2)
139                 return "client2";
140         else if (s == s1)
141                 return "server1";
142         else if (s == s2)
143                 return "server2";
144         else if (s == p1)
145                 return "peer1";
146         else if (s == p2)
147                 return "peer2";
148         else
149                 return "unknown";
150 }
151
152 static int sockmap_init_ktls(int verbose, int s)
153 {
154         struct tls12_crypto_info_aes_gcm_128 tls_tx = {
155                 .info = {
156                         .version     = TLS_1_2_VERSION,
157                         .cipher_type = TLS_CIPHER_AES_GCM_128,
158                 },
159         };
160         struct tls12_crypto_info_aes_gcm_128 tls_rx = {
161                 .info = {
162                         .version     = TLS_1_2_VERSION,
163                         .cipher_type = TLS_CIPHER_AES_GCM_128,
164                 },
165         };
166         int so_buf = 6553500;
167         int err;
168
169         err = setsockopt(s, 6, TCP_ULP, "tls", sizeof("tls"));
170         if (err) {
171                 fprintf(stderr, "setsockopt: TCP_ULP(%s) failed with error %i\n", sock_to_string(s), err);
172                 return -EINVAL;
173         }
174         err = setsockopt(s, SOL_TLS, TLS_TX, (void *)&tls_tx, sizeof(tls_tx));
175         if (err) {
176                 fprintf(stderr, "setsockopt: TLS_TX(%s) failed with error %i\n", sock_to_string(s), err);
177                 return -EINVAL;
178         }
179         err = setsockopt(s, SOL_TLS, TLS_RX, (void *)&tls_rx, sizeof(tls_rx));
180         if (err) {
181                 fprintf(stderr, "setsockopt: TLS_RX(%s) failed with error %i\n", sock_to_string(s), err);
182                 return -EINVAL;
183         }
184         err = setsockopt(s, SOL_SOCKET, SO_SNDBUF, &so_buf, sizeof(so_buf));
185         if (err) {
186                 fprintf(stderr, "setsockopt: (%s) failed sndbuf with error %i\n", sock_to_string(s), err);
187                 return -EINVAL;
188         }
189         err = setsockopt(s, SOL_SOCKET, SO_RCVBUF, &so_buf, sizeof(so_buf));
190         if (err) {
191                 fprintf(stderr, "setsockopt: (%s) failed rcvbuf with error %i\n", sock_to_string(s), err);
192                 return -EINVAL;
193         }
194
195         if (verbose)
196                 fprintf(stdout, "socket(%s) kTLS enabled\n", sock_to_string(s));
197         return 0;
198 }
199 static int sockmap_init_sockets(int verbose)
200 {
201         int i, err, one = 1;
202         struct sockaddr_in addr;
203         int *fds[4] = {&s1, &s2, &c1, &c2};
204
205         s1 = s2 = p1 = p2 = c1 = c2 = 0;
206
207         /* Init sockets */
208         for (i = 0; i < 4; i++) {
209                 *fds[i] = socket(AF_INET, SOCK_STREAM, 0);
210                 if (*fds[i] < 0) {
211                         perror("socket s1 failed()");
212                         return errno;
213                 }
214         }
215
216         /* Allow reuse */
217         for (i = 0; i < 2; i++) {
218                 err = setsockopt(*fds[i], SOL_SOCKET, SO_REUSEADDR,
219                                  (char *)&one, sizeof(one));
220                 if (err) {
221                         perror("setsockopt failed()");
222                         return errno;
223                 }
224         }
225
226         /* Non-blocking sockets */
227         for (i = 0; i < 2; i++) {
228                 err = ioctl(*fds[i], FIONBIO, (char *)&one);
229                 if (err < 0) {
230                         perror("ioctl s1 failed()");
231                         return errno;
232                 }
233         }
234
235         /* Bind server sockets */
236         memset(&addr, 0, sizeof(struct sockaddr_in));
237         addr.sin_family = AF_INET;
238         addr.sin_addr.s_addr = inet_addr("127.0.0.1");
239
240         addr.sin_port = htons(S1_PORT);
241         err = bind(s1, (struct sockaddr *)&addr, sizeof(addr));
242         if (err < 0) {
243                 perror("bind s1 failed()\n");
244                 return errno;
245         }
246
247         addr.sin_port = htons(S2_PORT);
248         err = bind(s2, (struct sockaddr *)&addr, sizeof(addr));
249         if (err < 0) {
250                 perror("bind s2 failed()\n");
251                 return errno;
252         }
253
254         /* Listen server sockets */
255         addr.sin_port = htons(S1_PORT);
256         err = listen(s1, 32);
257         if (err < 0) {
258                 perror("listen s1 failed()\n");
259                 return errno;
260         }
261
262         addr.sin_port = htons(S2_PORT);
263         err = listen(s2, 32);
264         if (err < 0) {
265                 perror("listen s1 failed()\n");
266                 return errno;
267         }
268
269         /* Initiate Connect */
270         addr.sin_port = htons(S1_PORT);
271         err = connect(c1, (struct sockaddr *)&addr, sizeof(addr));
272         if (err < 0 && errno != EINPROGRESS) {
273                 perror("connect c1 failed()\n");
274                 return errno;
275         }
276
277         addr.sin_port = htons(S2_PORT);
278         err = connect(c2, (struct sockaddr *)&addr, sizeof(addr));
279         if (err < 0 && errno != EINPROGRESS) {
280                 perror("connect c2 failed()\n");
281                 return errno;
282         } else if (err < 0) {
283                 err = 0;
284         }
285
286         /* Accept Connecrtions */
287         p1 = accept(s1, NULL, NULL);
288         if (p1 < 0) {
289                 perror("accept s1 failed()\n");
290                 return errno;
291         }
292
293         p2 = accept(s2, NULL, NULL);
294         if (p2 < 0) {
295                 perror("accept s1 failed()\n");
296                 return errno;
297         }
298
299         if (verbose) {
300                 printf("connected sockets: c1 <-> p1, c2 <-> p2\n");
301                 printf("cgroups binding: c1(%i) <-> s1(%i) - - - c2(%i) <-> s2(%i)\n",
302                         c1, s1, c2, s2);
303         }
304         return 0;
305 }
306
307 struct msg_stats {
308         size_t bytes_sent;
309         size_t bytes_recvd;
310         struct timespec start;
311         struct timespec end;
312 };
313
314 struct sockmap_options {
315         int verbose;
316         bool base;
317         bool sendpage;
318         bool data_test;
319         bool drop_expected;
320         int iov_count;
321         int iov_length;
322         int rate;
323 };
324
325 static int msg_loop_sendpage(int fd, int iov_length, int cnt,
326                              struct msg_stats *s,
327                              struct sockmap_options *opt)
328 {
329         bool drop = opt->drop_expected;
330         unsigned char k = 0;
331         FILE *file;
332         int i, fp;
333
334         file = fopen(".sendpage_tst.tmp", "w+");
335         if (!file) {
336                 perror("create file for sendpage");
337                 return 1;
338         }
339         for (i = 0; i < iov_length * cnt; i++, k++)
340                 fwrite(&k, sizeof(char), 1, file);
341         fflush(file);
342         fseek(file, 0, SEEK_SET);
343         fclose(file);
344
345         fp = open(".sendpage_tst.tmp", O_RDONLY);
346         if (fp < 0) {
347                 perror("reopen file for sendpage");
348                 return 1;
349         }
350
351         clock_gettime(CLOCK_MONOTONIC, &s->start);
352         for (i = 0; i < cnt; i++) {
353                 int sent = sendfile(fd, fp, NULL, iov_length);
354
355                 if (!drop && sent < 0) {
356                         perror("send loop error:");
357                         close(fp);
358                         return sent;
359                 } else if (drop && sent >= 0) {
360                         printf("sendpage loop error expected: %i\n", sent);
361                         close(fp);
362                         return -EIO;
363                 }
364
365                 if (sent > 0)
366                         s->bytes_sent += sent;
367         }
368         clock_gettime(CLOCK_MONOTONIC, &s->end);
369         close(fp);
370         return 0;
371 }
372
373 static void msg_free_iov(struct msghdr *msg)
374 {
375         int i;
376
377         for (i = 0; i < msg->msg_iovlen; i++)
378                 free(msg->msg_iov[i].iov_base);
379         free(msg->msg_iov);
380         msg->msg_iov = NULL;
381         msg->msg_iovlen = 0;
382 }
383
384 static int msg_alloc_iov(struct msghdr *msg,
385                          int iov_count, int iov_length,
386                          bool data, bool xmit)
387 {
388         unsigned char k = 0;
389         struct iovec *iov;
390         int i;
391
392         iov = calloc(iov_count, sizeof(struct iovec));
393         if (!iov)
394                 return errno;
395
396         for (i = 0; i < iov_count; i++) {
397                 unsigned char *d = calloc(iov_length, sizeof(char));
398
399                 if (!d) {
400                         fprintf(stderr, "iov_count %i/%i OOM\n", i, iov_count);
401                         goto unwind_iov;
402                 }
403                 iov[i].iov_base = d;
404                 iov[i].iov_len = iov_length;
405
406                 if (data && xmit) {
407                         int j;
408
409                         for (j = 0; j < iov_length; j++)
410                                 d[j] = k++;
411                 }
412         }
413
414         msg->msg_iov = iov;
415         msg->msg_iovlen = iov_count;
416
417         return 0;
418 unwind_iov:
419         for (i--; i >= 0 ; i--)
420                 free(msg->msg_iov[i].iov_base);
421         return -ENOMEM;
422 }
423
424 static int msg_verify_data(struct msghdr *msg, int size, int chunk_sz)
425 {
426         int i, j, bytes_cnt = 0;
427         unsigned char k = 0;
428
429         for (i = 0; i < msg->msg_iovlen; i++) {
430                 unsigned char *d = msg->msg_iov[i].iov_base;
431
432                 for (j = 0;
433                      j < msg->msg_iov[i].iov_len && size; j++) {
434                         if (d[j] != k++) {
435                                 fprintf(stderr,
436                                         "detected data corruption @iov[%i]:%i %02x != %02x, %02x ?= %02x\n",
437                                         i, j, d[j], k - 1, d[j+1], k);
438                                 return -EIO;
439                         }
440                         bytes_cnt++;
441                         if (bytes_cnt == chunk_sz) {
442                                 k = 0;
443                                 bytes_cnt = 0;
444                         }
445                         size--;
446                 }
447         }
448         return 0;
449 }
450
451 static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
452                     struct msg_stats *s, bool tx,
453                     struct sockmap_options *opt)
454 {
455         struct msghdr msg = {0}, msg_peek = {0};
456         int err, i, flags = MSG_NOSIGNAL;
457         bool drop = opt->drop_expected;
458         bool data = opt->data_test;
459
460         err = msg_alloc_iov(&msg, iov_count, iov_length, data, tx);
461         if (err)
462                 goto out_errno;
463         if (peek_flag) {
464                 err = msg_alloc_iov(&msg_peek, iov_count, iov_length, data, tx);
465                 if (err)
466                         goto out_errno;
467         }
468
469         if (tx) {
470                 clock_gettime(CLOCK_MONOTONIC, &s->start);
471                 for (i = 0; i < cnt; i++) {
472                         int sent = sendmsg(fd, &msg, flags);
473
474                         if (!drop && sent < 0) {
475                                 perror("send loop error:");
476                                 goto out_errno;
477                         } else if (drop && sent >= 0) {
478                                 printf("send loop error expected: %i\n", sent);
479                                 errno = -EIO;
480                                 goto out_errno;
481                         }
482                         if (sent > 0)
483                                 s->bytes_sent += sent;
484                 }
485                 clock_gettime(CLOCK_MONOTONIC, &s->end);
486         } else {
487                 int slct, recvp = 0, recv, max_fd = fd;
488                 float total_bytes, txmsg_pop_total;
489                 int fd_flags = O_NONBLOCK;
490                 struct timeval timeout;
491                 fd_set w;
492
493                 fcntl(fd, fd_flags);
494                 /* Account for pop bytes noting each iteration of apply will
495                  * call msg_pop_data helper so we need to account for this
496                  * by calculating the number of apply iterations. Note user
497                  * of the tool can create cases where no data is sent by
498                  * manipulating pop/push/pull/etc. For example txmsg_apply 1
499                  * with txmsg_pop 1 will try to apply 1B at a time but each
500                  * iteration will then pop 1B so no data will ever be sent.
501                  * This is really only useful for testing edge cases in code
502                  * paths.
503                  */
504                 total_bytes = (float)iov_count * (float)iov_length * (float)cnt;
505                 txmsg_pop_total = txmsg_pop;
506                 if (txmsg_apply)
507                         txmsg_pop_total *= (total_bytes / txmsg_apply);
508                 total_bytes -= txmsg_pop_total;
509                 err = clock_gettime(CLOCK_MONOTONIC, &s->start);
510                 if (err < 0)
511                         perror("recv start time: ");
512                 while (s->bytes_recvd < total_bytes) {
513                         if (txmsg_cork) {
514                                 timeout.tv_sec = 0;
515                                 timeout.tv_usec = 300000;
516                         } else {
517                                 timeout.tv_sec = 3;
518                                 timeout.tv_usec = 0;
519                         }
520
521                         /* FD sets */
522                         FD_ZERO(&w);
523                         FD_SET(fd, &w);
524
525                         slct = select(max_fd + 1, &w, NULL, NULL, &timeout);
526                         if (slct == -1) {
527                                 perror("select()");
528                                 clock_gettime(CLOCK_MONOTONIC, &s->end);
529                                 goto out_errno;
530                         } else if (!slct) {
531                                 if (opt->verbose)
532                                         fprintf(stderr, "unexpected timeout: recved %zu/%f pop_total %f\n", s->bytes_recvd, total_bytes, txmsg_pop_total);
533                                 errno = -EIO;
534                                 clock_gettime(CLOCK_MONOTONIC, &s->end);
535                                 goto out_errno;
536                         }
537
538                         errno = 0;
539                         if (peek_flag) {
540                                 flags |= MSG_PEEK;
541                                 recvp = recvmsg(fd, &msg_peek, flags);
542                                 if (recvp < 0) {
543                                         if (errno != EWOULDBLOCK) {
544                                                 clock_gettime(CLOCK_MONOTONIC, &s->end);
545                                                 goto out_errno;
546                                         }
547                                 }
548                                 flags = 0;
549                         }
550
551                         recv = recvmsg(fd, &msg, flags);
552                         if (recv < 0) {
553                                 if (errno != EWOULDBLOCK) {
554                                         clock_gettime(CLOCK_MONOTONIC, &s->end);
555                                         perror("recv failed()\n");
556                                         goto out_errno;
557                                 }
558                         }
559
560                         s->bytes_recvd += recv;
561
562                         if (data) {
563                                 int chunk_sz = opt->sendpage ?
564                                                 iov_length * cnt :
565                                                 iov_length * iov_count;
566
567                                 errno = msg_verify_data(&msg, recv, chunk_sz);
568                                 if (errno) {
569                                         perror("data verify msg failed\n");
570                                         goto out_errno;
571                                 }
572                                 if (recvp) {
573                                         errno = msg_verify_data(&msg_peek,
574                                                                 recvp,
575                                                                 chunk_sz);
576                                         if (errno) {
577                                                 perror("data verify msg_peek failed\n");
578                                                 goto out_errno;
579                                         }
580                                 }
581                         }
582                 }
583                 clock_gettime(CLOCK_MONOTONIC, &s->end);
584         }
585
586         msg_free_iov(&msg);
587         msg_free_iov(&msg_peek);
588         return err;
589 out_errno:
590         msg_free_iov(&msg);
591         msg_free_iov(&msg_peek);
592         return errno;
593 }
594
595 static float giga = 1000000000;
596
597 static inline float sentBps(struct msg_stats s)
598 {
599         return s.bytes_sent / (s.end.tv_sec - s.start.tv_sec);
600 }
601
602 static inline float recvdBps(struct msg_stats s)
603 {
604         return s.bytes_recvd / (s.end.tv_sec - s.start.tv_sec);
605 }
606
607 static int sendmsg_test(struct sockmap_options *opt)
608 {
609         float sent_Bps = 0, recvd_Bps = 0;
610         int rx_fd, txpid, rxpid, err = 0;
611         struct msg_stats s = {0};
612         int iov_count = opt->iov_count;
613         int iov_buf = opt->iov_length;
614         int rx_status, tx_status;
615         int cnt = opt->rate;
616
617         errno = 0;
618
619         if (opt->base)
620                 rx_fd = p1;
621         else
622                 rx_fd = p2;
623
624         if (ktls) {
625                 /* Redirecting into non-TLS socket which sends into a TLS
626                  * socket is not a valid test. So in this case lets not
627                  * enable kTLS but still run the test.
628                  */
629                 if (!txmsg_redir || (txmsg_redir && txmsg_ingress)) {
630                         err = sockmap_init_ktls(opt->verbose, rx_fd);
631                         if (err)
632                                 return err;
633                 }
634                 err = sockmap_init_ktls(opt->verbose, c1);
635                 if (err)
636                         return err;
637         }
638
639         rxpid = fork();
640         if (rxpid == 0) {
641                 if (opt->drop_expected)
642                         exit(0);
643
644                 if (opt->sendpage)
645                         iov_count = 1;
646                 err = msg_loop(rx_fd, iov_count, iov_buf,
647                                cnt, &s, false, opt);
648                 if (opt->verbose)
649                         fprintf(stderr,
650                                 "msg_loop_rx: iov_count %i iov_buf %i cnt %i err %i\n",
651                                 iov_count, iov_buf, cnt, err);
652                 if (s.end.tv_sec - s.start.tv_sec) {
653                         sent_Bps = sentBps(s);
654                         recvd_Bps = recvdBps(s);
655                 }
656                 if (opt->verbose)
657                         fprintf(stdout,
658                                 "rx_sendmsg: TX: %zuB %fB/s %fGB/s RX: %zuB %fB/s %fGB/s %s\n",
659                                 s.bytes_sent, sent_Bps, sent_Bps/giga,
660                                 s.bytes_recvd, recvd_Bps, recvd_Bps/giga,
661                                 peek_flag ? "(peek_msg)" : "");
662                 if (err && txmsg_cork)
663                         err = 0;
664                 exit(err ? 1 : 0);
665         } else if (rxpid == -1) {
666                 perror("msg_loop_rx: ");
667                 return errno;
668         }
669
670         txpid = fork();
671         if (txpid == 0) {
672                 if (opt->sendpage)
673                         err = msg_loop_sendpage(c1, iov_buf, cnt, &s, opt);
674                 else
675                         err = msg_loop(c1, iov_count, iov_buf,
676                                        cnt, &s, true, opt);
677
678                 if (err)
679                         fprintf(stderr,
680                                 "msg_loop_tx: iov_count %i iov_buf %i cnt %i err %i\n",
681                                 iov_count, iov_buf, cnt, err);
682                 if (s.end.tv_sec - s.start.tv_sec) {
683                         sent_Bps = sentBps(s);
684                         recvd_Bps = recvdBps(s);
685                 }
686                 if (opt->verbose)
687                         fprintf(stdout,
688                                 "tx_sendmsg: TX: %zuB %fB/s %f GB/s RX: %zuB %fB/s %fGB/s\n",
689                                 s.bytes_sent, sent_Bps, sent_Bps/giga,
690                                 s.bytes_recvd, recvd_Bps, recvd_Bps/giga);
691                 exit(err ? 1 : 0);
692         } else if (txpid == -1) {
693                 perror("msg_loop_tx: ");
694                 return errno;
695         }
696
697         assert(waitpid(rxpid, &rx_status, 0) == rxpid);
698         assert(waitpid(txpid, &tx_status, 0) == txpid);
699         if (WIFEXITED(rx_status)) {
700                 err = WEXITSTATUS(rx_status);
701                 if (err) {
702                         fprintf(stderr, "rx thread exited with err %d. ", err);
703                         goto out;
704                 }
705         }
706         if (WIFEXITED(tx_status)) {
707                 err = WEXITSTATUS(tx_status);
708                 if (err)
709                         fprintf(stderr, "tx thread exited with err %d. ", err);
710         }
711 out:
712         return err;
713 }
714
715 static int forever_ping_pong(int rate, struct sockmap_options *opt)
716 {
717         struct timeval timeout;
718         char buf[1024] = {0};
719         int sc;
720
721         timeout.tv_sec = 10;
722         timeout.tv_usec = 0;
723
724         /* Ping/Pong data from client to server */
725         sc = send(c1, buf, sizeof(buf), 0);
726         if (sc < 0) {
727                 perror("send failed()\n");
728                 return sc;
729         }
730
731         do {
732                 int s, rc, i, max_fd = p2;
733                 fd_set w;
734
735                 /* FD sets */
736                 FD_ZERO(&w);
737                 FD_SET(c1, &w);
738                 FD_SET(c2, &w);
739                 FD_SET(p1, &w);
740                 FD_SET(p2, &w);
741
742                 s = select(max_fd + 1, &w, NULL, NULL, &timeout);
743                 if (s == -1) {
744                         perror("select()");
745                         break;
746                 } else if (!s) {
747                         fprintf(stderr, "unexpected timeout\n");
748                         break;
749                 }
750
751                 for (i = 0; i <= max_fd && s > 0; ++i) {
752                         if (!FD_ISSET(i, &w))
753                                 continue;
754
755                         s--;
756
757                         rc = recv(i, buf, sizeof(buf), 0);
758                         if (rc < 0) {
759                                 if (errno != EWOULDBLOCK) {
760                                         perror("recv failed()\n");
761                                         return rc;
762                                 }
763                         }
764
765                         if (rc == 0) {
766                                 close(i);
767                                 break;
768                         }
769
770                         sc = send(i, buf, rc, 0);
771                         if (sc < 0) {
772                                 perror("send failed()\n");
773                                 return sc;
774                         }
775                 }
776
777                 if (rate)
778                         sleep(rate);
779
780                 if (opt->verbose) {
781                         printf(".");
782                         fflush(stdout);
783
784                 }
785         } while (running);
786
787         return 0;
788 }
789
790 enum {
791         PING_PONG,
792         SENDMSG,
793         BASE,
794         BASE_SENDPAGE,
795         SENDPAGE,
796 };
797
798 static int run_options(struct sockmap_options *options, int cg_fd,  int test)
799 {
800         int i, key, next_key, err, tx_prog_fd = -1, zero = 0;
801
802         /* If base test skip BPF setup */
803         if (test == BASE || test == BASE_SENDPAGE)
804                 goto run;
805
806         /* Attach programs to sockmap */
807         err = bpf_prog_attach(prog_fd[0], map_fd[0],
808                                 BPF_SK_SKB_STREAM_PARSER, 0);
809         if (err) {
810                 fprintf(stderr,
811                         "ERROR: bpf_prog_attach (sockmap %i->%i): %d (%s)\n",
812                         prog_fd[0], map_fd[0], err, strerror(errno));
813                 return err;
814         }
815
816         err = bpf_prog_attach(prog_fd[1], map_fd[0],
817                                 BPF_SK_SKB_STREAM_VERDICT, 0);
818         if (err) {
819                 fprintf(stderr, "ERROR: bpf_prog_attach (sockmap): %d (%s)\n",
820                         err, strerror(errno));
821                 return err;
822         }
823
824         /* Attach to cgroups */
825         err = bpf_prog_attach(prog_fd[2], cg_fd, BPF_CGROUP_SOCK_OPS, 0);
826         if (err) {
827                 fprintf(stderr, "ERROR: bpf_prog_attach (groups): %d (%s)\n",
828                         err, strerror(errno));
829                 return err;
830         }
831
832 run:
833         err = sockmap_init_sockets(options->verbose);
834         if (err) {
835                 fprintf(stderr, "ERROR: test socket failed: %d\n", err);
836                 goto out;
837         }
838
839         /* Attach txmsg program to sockmap */
840         if (txmsg_pass)
841                 tx_prog_fd = prog_fd[3];
842         else if (txmsg_noisy)
843                 tx_prog_fd = prog_fd[4];
844         else if (txmsg_redir)
845                 tx_prog_fd = prog_fd[5];
846         else if (txmsg_redir_noisy)
847                 tx_prog_fd = prog_fd[6];
848         else if (txmsg_drop)
849                 tx_prog_fd = prog_fd[9];
850         /* apply and cork must be last */
851         else if (txmsg_apply)
852                 tx_prog_fd = prog_fd[7];
853         else if (txmsg_cork)
854                 tx_prog_fd = prog_fd[8];
855         else
856                 tx_prog_fd = 0;
857
858         if (tx_prog_fd) {
859                 int redir_fd, i = 0;
860
861                 err = bpf_prog_attach(tx_prog_fd,
862                                       map_fd[1], BPF_SK_MSG_VERDICT, 0);
863                 if (err) {
864                         fprintf(stderr,
865                                 "ERROR: bpf_prog_attach (txmsg): %d (%s)\n",
866                                 err, strerror(errno));
867                         goto out;
868                 }
869
870                 err = bpf_map_update_elem(map_fd[1], &i, &c1, BPF_ANY);
871                 if (err) {
872                         fprintf(stderr,
873                                 "ERROR: bpf_map_update_elem (txmsg):  %d (%s\n",
874                                 err, strerror(errno));
875                         goto out;
876                 }
877
878                 if (txmsg_redir || txmsg_redir_noisy)
879                         redir_fd = c2;
880                 else
881                         redir_fd = c1;
882
883                 err = bpf_map_update_elem(map_fd[2], &i, &redir_fd, BPF_ANY);
884                 if (err) {
885                         fprintf(stderr,
886                                 "ERROR: bpf_map_update_elem (txmsg):  %d (%s\n",
887                                 err, strerror(errno));
888                         goto out;
889                 }
890
891                 if (txmsg_apply) {
892                         err = bpf_map_update_elem(map_fd[3],
893                                                   &i, &txmsg_apply, BPF_ANY);
894                         if (err) {
895                                 fprintf(stderr,
896                                         "ERROR: bpf_map_update_elem (apply_bytes):  %d (%s\n",
897                                         err, strerror(errno));
898                                 goto out;
899                         }
900                 }
901
902                 if (txmsg_cork) {
903                         err = bpf_map_update_elem(map_fd[4],
904                                                   &i, &txmsg_cork, BPF_ANY);
905                         if (err) {
906                                 fprintf(stderr,
907                                         "ERROR: bpf_map_update_elem (cork_bytes):  %d (%s\n",
908                                         err, strerror(errno));
909                                 goto out;
910                         }
911                 }
912
913                 if (txmsg_start) {
914                         err = bpf_map_update_elem(map_fd[5],
915                                                   &i, &txmsg_start, BPF_ANY);
916                         if (err) {
917                                 fprintf(stderr,
918                                         "ERROR: bpf_map_update_elem (txmsg_start):  %d (%s)\n",
919                                         err, strerror(errno));
920                                 goto out;
921                         }
922                 }
923
924                 if (txmsg_end) {
925                         i = 1;
926                         err = bpf_map_update_elem(map_fd[5],
927                                                   &i, &txmsg_end, BPF_ANY);
928                         if (err) {
929                                 fprintf(stderr,
930                                         "ERROR: bpf_map_update_elem (txmsg_end):  %d (%s)\n",
931                                         err, strerror(errno));
932                                 goto out;
933                         }
934                 }
935
936                 if (txmsg_start_push) {
937                         i = 2;
938                         err = bpf_map_update_elem(map_fd[5],
939                                                   &i, &txmsg_start_push, BPF_ANY);
940                         if (err) {
941                                 fprintf(stderr,
942                                         "ERROR: bpf_map_update_elem (txmsg_start_push):  %d (%s)\n",
943                                         err, strerror(errno));
944                                 goto out;
945                         }
946                 }
947
948                 if (txmsg_end_push) {
949                         i = 3;
950                         err = bpf_map_update_elem(map_fd[5],
951                                                   &i, &txmsg_end_push, BPF_ANY);
952                         if (err) {
953                                 fprintf(stderr,
954                                         "ERROR: bpf_map_update_elem %i@%i (txmsg_end_push):  %d (%s)\n",
955                                         txmsg_end_push, i, err, strerror(errno));
956                                 goto out;
957                         }
958                 }
959
960                 if (txmsg_start_pop) {
961                         i = 4;
962                         err = bpf_map_update_elem(map_fd[5],
963                                                   &i, &txmsg_start_pop, BPF_ANY);
964                         if (err) {
965                                 fprintf(stderr,
966                                         "ERROR: bpf_map_update_elem %i@%i (txmsg_start_pop):  %d (%s)\n",
967                                         txmsg_start_pop, i, err, strerror(errno));
968                                 goto out;
969                         }
970                 } else {
971                         i = 4;
972                         bpf_map_update_elem(map_fd[5],
973                                                   &i, &txmsg_start_pop, BPF_ANY);
974                 }
975
976                 if (txmsg_pop) {
977                         i = 5;
978                         err = bpf_map_update_elem(map_fd[5],
979                                                   &i, &txmsg_pop, BPF_ANY);
980                         if (err) {
981                                 fprintf(stderr,
982                                         "ERROR: bpf_map_update_elem %i@%i (txmsg_pop):  %d (%s)\n",
983                                         txmsg_pop, i, err, strerror(errno));
984                                 goto out;
985                         }
986                 } else {
987                         i = 5;
988                         bpf_map_update_elem(map_fd[5],
989                                             &i, &txmsg_pop, BPF_ANY);
990
991                 }
992
993                 if (txmsg_ingress) {
994                         int in = BPF_F_INGRESS;
995
996                         i = 0;
997                         err = bpf_map_update_elem(map_fd[6], &i, &in, BPF_ANY);
998                         if (err) {
999                                 fprintf(stderr,
1000                                         "ERROR: bpf_map_update_elem (txmsg_ingress): %d (%s)\n",
1001                                         err, strerror(errno));
1002                         }
1003                         i = 1;
1004                         err = bpf_map_update_elem(map_fd[1], &i, &p1, BPF_ANY);
1005                         if (err) {
1006                                 fprintf(stderr,
1007                                         "ERROR: bpf_map_update_elem (p1 txmsg): %d (%s)\n",
1008                                         err, strerror(errno));
1009                         }
1010                         err = bpf_map_update_elem(map_fd[2], &i, &p1, BPF_ANY);
1011                         if (err) {
1012                                 fprintf(stderr,
1013                                         "ERROR: bpf_map_update_elem (p1 redir): %d (%s)\n",
1014                                         err, strerror(errno));
1015                         }
1016
1017                         i = 2;
1018                         err = bpf_map_update_elem(map_fd[2], &i, &p2, BPF_ANY);
1019                         if (err) {
1020                                 fprintf(stderr,
1021                                         "ERROR: bpf_map_update_elem (p2 txmsg): %d (%s)\n",
1022                                         err, strerror(errno));
1023                         }
1024                 }
1025
1026                 if (txmsg_skb) {
1027                         int skb_fd = (test == SENDMSG || test == SENDPAGE) ?
1028                                         p2 : p1;
1029                         int ingress = BPF_F_INGRESS;
1030
1031                         i = 0;
1032                         err = bpf_map_update_elem(map_fd[7],
1033                                                   &i, &ingress, BPF_ANY);
1034                         if (err) {
1035                                 fprintf(stderr,
1036                                         "ERROR: bpf_map_update_elem (txmsg_ingress): %d (%s)\n",
1037                                         err, strerror(errno));
1038                         }
1039
1040                         i = 3;
1041                         err = bpf_map_update_elem(map_fd[0],
1042                                                   &i, &skb_fd, BPF_ANY);
1043                         if (err) {
1044                                 fprintf(stderr,
1045                                         "ERROR: bpf_map_update_elem (c1 sockmap): %d (%s)\n",
1046                                         err, strerror(errno));
1047                         }
1048                 }
1049         }
1050
1051         if (txmsg_drop)
1052                 options->drop_expected = true;
1053
1054         if (test == PING_PONG)
1055                 err = forever_ping_pong(options->rate, options);
1056         else if (test == SENDMSG) {
1057                 options->base = false;
1058                 options->sendpage = false;
1059                 err = sendmsg_test(options);
1060         } else if (test == SENDPAGE) {
1061                 options->base = false;
1062                 options->sendpage = true;
1063                 err = sendmsg_test(options);
1064         } else if (test == BASE) {
1065                 options->base = true;
1066                 options->sendpage = false;
1067                 err = sendmsg_test(options);
1068         } else if (test == BASE_SENDPAGE) {
1069                 options->base = true;
1070                 options->sendpage = true;
1071                 err = sendmsg_test(options);
1072         } else
1073                 fprintf(stderr, "unknown test\n");
1074 out:
1075         /* Detatch and zero all the maps */
1076         bpf_prog_detach2(prog_fd[2], cg_fd, BPF_CGROUP_SOCK_OPS);
1077         bpf_prog_detach2(prog_fd[0], map_fd[0], BPF_SK_SKB_STREAM_PARSER);
1078         bpf_prog_detach2(prog_fd[1], map_fd[0], BPF_SK_SKB_STREAM_VERDICT);
1079         if (tx_prog_fd >= 0)
1080                 bpf_prog_detach2(tx_prog_fd, map_fd[1], BPF_SK_MSG_VERDICT);
1081
1082         for (i = 0; i < 8; i++) {
1083                 key = next_key = 0;
1084                 bpf_map_update_elem(map_fd[i], &key, &zero, BPF_ANY);
1085                 while (bpf_map_get_next_key(map_fd[i], &key, &next_key) == 0) {
1086                         bpf_map_update_elem(map_fd[i], &key, &zero, BPF_ANY);
1087                         key = next_key;
1088                 }
1089         }
1090
1091         close(s1);
1092         close(s2);
1093         close(p1);
1094         close(p2);
1095         close(c1);
1096         close(c2);
1097         return err;
1098 }
1099
1100 static char *test_to_str(int test)
1101 {
1102         switch (test) {
1103         case SENDMSG:
1104                 return "sendmsg";
1105         case SENDPAGE:
1106                 return "sendpage";
1107         }
1108         return "unknown";
1109 }
1110
1111 #define OPTSTRING 60
1112 static void test_options(char *options)
1113 {
1114         char tstr[OPTSTRING];
1115
1116         memset(options, 0, OPTSTRING);
1117
1118         if (txmsg_pass)
1119                 strncat(options, "pass,", OPTSTRING);
1120         if (txmsg_noisy)
1121                 strncat(options, "pass_noisy,", OPTSTRING);
1122         if (txmsg_redir)
1123                 strncat(options, "redir,", OPTSTRING);
1124         if (txmsg_redir_noisy)
1125                 strncat(options, "redir_noisy,", OPTSTRING);
1126         if (txmsg_drop)
1127                 strncat(options, "drop,", OPTSTRING);
1128         if (txmsg_apply) {
1129                 snprintf(tstr, OPTSTRING, "apply %d,", txmsg_apply);
1130                 strncat(options, tstr, OPTSTRING);
1131         }
1132         if (txmsg_cork) {
1133                 snprintf(tstr, OPTSTRING, "cork %d,", txmsg_cork);
1134                 strncat(options, tstr, OPTSTRING);
1135         }
1136         if (txmsg_start) {
1137                 snprintf(tstr, OPTSTRING, "start %d,", txmsg_start);
1138                 strncat(options, tstr, OPTSTRING);
1139         }
1140         if (txmsg_end) {
1141                 snprintf(tstr, OPTSTRING, "end %d,", txmsg_end);
1142                 strncat(options, tstr, OPTSTRING);
1143         }
1144         if (txmsg_start_pop) {
1145                 snprintf(tstr, OPTSTRING, "pop (%d,%d),",
1146                          txmsg_start_pop, txmsg_start_pop + txmsg_pop);
1147                 strncat(options, tstr, OPTSTRING);
1148         }
1149         if (txmsg_ingress)
1150                 strncat(options, "ingress,", OPTSTRING);
1151         if (txmsg_skb)
1152                 strncat(options, "skb,", OPTSTRING);
1153         if (ktls)
1154                 strncat(options, "ktls,", OPTSTRING);
1155         if (peek_flag)
1156                 strncat(options, "peek,", OPTSTRING);
1157 }
1158
1159 static int __test_exec(int cgrp, int test, struct sockmap_options *opt)
1160 {
1161         char *options = calloc(OPTSTRING, sizeof(char));
1162         int err;
1163
1164         if (test == SENDPAGE)
1165                 opt->sendpage = true;
1166         else
1167                 opt->sendpage = false;
1168
1169         if (txmsg_drop)
1170                 opt->drop_expected = true;
1171         else
1172                 opt->drop_expected = false;
1173
1174         test_options(options);
1175
1176         fprintf(stdout,
1177                 "[TEST %i]: (%i, %i, %i, %s, %s): ",
1178                 test_cnt, opt->rate, opt->iov_count, opt->iov_length,
1179                 test_to_str(test), options);
1180         fflush(stdout);
1181         err = run_options(opt, cgrp, test);
1182         fprintf(stdout, "%s\n", !err ? "PASS" : "FAILED");
1183         test_cnt++;
1184         !err ? passed++ : failed++;
1185         free(options);
1186         return err;
1187 }
1188
1189 static int test_exec(int cgrp, struct sockmap_options *opt)
1190 {
1191         int err = __test_exec(cgrp, SENDMSG, opt);
1192
1193         if (err)
1194                 goto out;
1195
1196         err = __test_exec(cgrp, SENDPAGE, opt);
1197 out:
1198         return err;
1199 }
1200
1201 static int test_loop(int cgrp)
1202 {
1203         struct sockmap_options opt;
1204
1205         int err, i, l, r;
1206
1207         opt.verbose = 0;
1208         opt.base = false;
1209         opt.sendpage = false;
1210         opt.data_test = false;
1211         opt.drop_expected = false;
1212         opt.iov_count = 0;
1213         opt.iov_length = 0;
1214         opt.rate = 0;
1215
1216         r = 1;
1217         for (i = 1; i < 100; i += 33) {
1218                 for (l = 1; l < 100; l += 33) {
1219                         opt.rate = r;
1220                         opt.iov_count = i;
1221                         opt.iov_length = l;
1222                         err = test_exec(cgrp, &opt);
1223                         if (err)
1224                                 goto out;
1225                 }
1226         }
1227         sched_yield();
1228 out:
1229         return err;
1230 }
1231
1232 static int test_txmsg(int cgrp)
1233 {
1234         int err;
1235
1236         txmsg_pass = txmsg_noisy = txmsg_redir_noisy = txmsg_drop = 0;
1237         txmsg_apply = txmsg_cork = 0;
1238         txmsg_ingress = txmsg_skb = 0;
1239
1240         txmsg_pass = 1;
1241         err = test_loop(cgrp);
1242         txmsg_pass = 0;
1243         if (err)
1244                 goto out;
1245
1246         txmsg_redir = 1;
1247         err = test_loop(cgrp);
1248         txmsg_redir = 0;
1249         if (err)
1250                 goto out;
1251
1252         txmsg_drop = 1;
1253         err = test_loop(cgrp);
1254         txmsg_drop = 0;
1255         if (err)
1256                 goto out;
1257
1258         txmsg_redir = 1;
1259         txmsg_ingress = 1;
1260         err = test_loop(cgrp);
1261         txmsg_redir = 0;
1262         txmsg_ingress = 0;
1263         if (err)
1264                 goto out;
1265 out:
1266         txmsg_pass = 0;
1267         txmsg_redir = 0;
1268         txmsg_drop = 0;
1269         return err;
1270 }
1271
1272 static int test_send(struct sockmap_options *opt, int cgrp)
1273 {
1274         int err;
1275
1276         opt->iov_length = 1;
1277         opt->iov_count = 1;
1278         opt->rate = 1;
1279         err = test_exec(cgrp, opt);
1280         if (err)
1281                 goto out;
1282
1283         opt->iov_length = 1;
1284         opt->iov_count = 1024;
1285         opt->rate = 1;
1286         err = test_exec(cgrp, opt);
1287         if (err)
1288                 goto out;
1289
1290         opt->iov_length = 1024;
1291         opt->iov_count = 1;
1292         opt->rate = 1;
1293         err = test_exec(cgrp, opt);
1294         if (err)
1295                 goto out;
1296
1297         opt->iov_length = 1;
1298         opt->iov_count = 1;
1299         opt->rate = 512;
1300         err = test_exec(cgrp, opt);
1301         if (err)
1302                 goto out;
1303
1304         opt->iov_length = 256;
1305         opt->iov_count = 1024;
1306         opt->rate = 2;
1307         err = test_exec(cgrp, opt);
1308         if (err)
1309                 goto out;
1310
1311         opt->rate = 100;
1312         opt->iov_count = 1;
1313         opt->iov_length = 5;
1314         err = test_exec(cgrp, opt);
1315         if (err)
1316                 goto out;
1317 out:
1318         sched_yield();
1319         return err;
1320 }
1321
1322 static int test_mixed(int cgrp)
1323 {
1324         struct sockmap_options opt = {0};
1325         int err;
1326
1327         txmsg_pass = txmsg_noisy = txmsg_redir_noisy = txmsg_drop = 0;
1328         txmsg_apply = txmsg_cork = 0;
1329         txmsg_start = txmsg_end = 0;
1330         txmsg_start_push = txmsg_end_push = 0;
1331         txmsg_start_pop = txmsg_pop = 0;
1332
1333         /* Test small and large iov_count values with pass/redir/apply/cork */
1334         txmsg_pass = 1;
1335         txmsg_redir = 0;
1336         txmsg_apply = 1;
1337         txmsg_cork = 0;
1338         err = test_send(&opt, cgrp);
1339         if (err)
1340                 goto out;
1341
1342         txmsg_pass = 1;
1343         txmsg_redir = 0;
1344         txmsg_apply = 0;
1345         txmsg_cork = 1;
1346         err = test_send(&opt, cgrp);
1347         if (err)
1348                 goto out;
1349
1350         txmsg_pass = 1;
1351         txmsg_redir = 0;
1352         txmsg_apply = 1;
1353         txmsg_cork = 1;
1354         err = test_send(&opt, cgrp);
1355         if (err)
1356                 goto out;
1357
1358         txmsg_pass = 1;
1359         txmsg_redir = 0;
1360         txmsg_apply = 1024;
1361         txmsg_cork = 0;
1362         err = test_send(&opt, cgrp);
1363         if (err)
1364                 goto out;
1365
1366         txmsg_pass = 1;
1367         txmsg_redir = 0;
1368         txmsg_apply = 0;
1369         txmsg_cork = 1024;
1370         err = test_send(&opt, cgrp);
1371         if (err)
1372                 goto out;
1373
1374         txmsg_pass = 1;
1375         txmsg_redir = 0;
1376         txmsg_apply = 1024;
1377         txmsg_cork = 1024;
1378         err = test_send(&opt, cgrp);
1379         if (err)
1380                 goto out;
1381
1382         txmsg_pass = 1;
1383         txmsg_redir = 0;
1384         txmsg_cork = 4096;
1385         txmsg_apply = 4096;
1386         err = test_send(&opt, cgrp);
1387         if (err)
1388                 goto out;
1389
1390         txmsg_pass = 0;
1391         txmsg_redir = 1;
1392         txmsg_apply = 1;
1393         txmsg_cork = 0;
1394         err = test_send(&opt, cgrp);
1395         if (err)
1396                 goto out;
1397
1398         txmsg_pass = 0;
1399         txmsg_redir = 1;
1400         txmsg_apply = 0;
1401         txmsg_cork = 1;
1402         err = test_send(&opt, cgrp);
1403         if (err)
1404                 goto out;
1405
1406         txmsg_pass = 0;
1407         txmsg_redir = 1;
1408         txmsg_apply = 1024;
1409         txmsg_cork = 0;
1410         err = test_send(&opt, cgrp);
1411         if (err)
1412                 goto out;
1413
1414         txmsg_pass = 0;
1415         txmsg_redir = 1;
1416         txmsg_apply = 0;
1417         txmsg_cork = 1024;
1418         err = test_send(&opt, cgrp);
1419         if (err)
1420                 goto out;
1421
1422         txmsg_pass = 0;
1423         txmsg_redir = 1;
1424         txmsg_apply = 1024;
1425         txmsg_cork = 1024;
1426         err = test_send(&opt, cgrp);
1427         if (err)
1428                 goto out;
1429
1430         txmsg_pass = 0;
1431         txmsg_redir = 1;
1432         txmsg_cork = 4096;
1433         txmsg_apply = 4096;
1434         err = test_send(&opt, cgrp);
1435         if (err)
1436                 goto out;
1437 out:
1438         return err;
1439 }
1440
1441 static int test_start_end(int cgrp)
1442 {
1443         struct sockmap_options opt = {0};
1444         int err, i;
1445
1446         /* Test basic start/end with lots of iov_count and iov_lengths */
1447         txmsg_start = 1;
1448         txmsg_end = 2;
1449         txmsg_start_push = 1;
1450         txmsg_end_push = 2;
1451         txmsg_start_pop = 1;
1452         txmsg_pop = 1;
1453         err = test_txmsg(cgrp);
1454         if (err)
1455                 goto out;
1456
1457         /* Cut a byte of pushed data but leave reamining in place */
1458         txmsg_start = 1;
1459         txmsg_end = 2;
1460         txmsg_start_push = 1;
1461         txmsg_end_push = 3;
1462         txmsg_start_pop = 1;
1463         txmsg_pop = 1;
1464         err = test_txmsg(cgrp);
1465         if (err)
1466                 goto out;
1467
1468         /* Test start/end with cork */
1469         opt.rate = 16;
1470         opt.iov_count = 1;
1471         opt.iov_length = 100;
1472         txmsg_cork = 1600;
1473
1474         txmsg_start_pop = 0;
1475         txmsg_pop = 0;
1476
1477         for (i = 99; i <= 1600; i += 500) {
1478                 txmsg_start = 0;
1479                 txmsg_end = i;
1480                 txmsg_start_push = 0;
1481                 txmsg_end_push = i;
1482                 err = test_exec(cgrp, &opt);
1483                 if (err)
1484                         goto out;
1485         }
1486
1487         /* Test pop data in middle of cork */
1488         for (i = 99; i <= 1600; i += 500) {
1489                 txmsg_start_pop = 10;
1490                 txmsg_pop = i;
1491                 err = test_exec(cgrp, &opt);
1492                 if (err)
1493                         goto out;
1494         }
1495         txmsg_start_pop = 0;
1496         txmsg_pop = 0;
1497
1498         /* Test start/end with cork but pull data in middle */
1499         for (i = 199; i <= 1600; i += 500) {
1500                 txmsg_start = 100;
1501                 txmsg_end = i;
1502                 txmsg_start_push = 100;
1503                 txmsg_end_push = i;
1504                 err = test_exec(cgrp, &opt);
1505                 if (err)
1506                         goto out;
1507         }
1508
1509         /* Test start/end with cork pulling last sg entry */
1510         txmsg_start = 1500;
1511         txmsg_end = 1600;
1512         txmsg_start_push = 1500;
1513         txmsg_end_push = 1600;
1514         err = test_exec(cgrp, &opt);
1515         if (err)
1516                 goto out;
1517
1518         /* Test pop with cork pulling last sg entry */
1519         txmsg_start_pop = 1500;
1520         txmsg_pop = 1600;
1521         err = test_exec(cgrp, &opt);
1522         if (err)
1523                 goto out;
1524         txmsg_start_pop = 0;
1525         txmsg_pop = 0;
1526
1527         /* Test start/end pull of single byte in last page */
1528         txmsg_start = 1111;
1529         txmsg_end = 1112;
1530         txmsg_start_push = 1111;
1531         txmsg_end_push = 1112;
1532         err = test_exec(cgrp, &opt);
1533         if (err)
1534                 goto out;
1535
1536         /* Test pop of single byte in last page */
1537         txmsg_start_pop = 1111;
1538         txmsg_pop = 1112;
1539         err = test_exec(cgrp, &opt);
1540         if (err)
1541                 goto out;
1542
1543         /* Test start/end with end < start */
1544         txmsg_start = 1111;
1545         txmsg_end = 0;
1546         txmsg_start_push = 1111;
1547         txmsg_end_push = 0;
1548         err = test_exec(cgrp, &opt);
1549         if (err)
1550                 goto out;
1551
1552         /* Test start/end with end > data */
1553         txmsg_start = 0;
1554         txmsg_end = 1601;
1555         txmsg_start_push = 0;
1556         txmsg_end_push = 1601;
1557         err = test_exec(cgrp, &opt);
1558         if (err)
1559                 goto out;
1560
1561         /* Test start/end with start > data */
1562         txmsg_start = 1601;
1563         txmsg_end = 1600;
1564         txmsg_start_push = 1601;
1565         txmsg_end_push = 1600;
1566         err = test_exec(cgrp, &opt);
1567         if (err)
1568                 goto out;
1569
1570         /* Test pop with start > data */
1571         txmsg_start_pop = 1601;
1572         txmsg_pop = 1;
1573         err = test_exec(cgrp, &opt);
1574         if (err)
1575                 goto out;
1576
1577         /* Test pop with pop range > data */
1578         txmsg_start_pop = 1599;
1579         txmsg_pop = 10;
1580         err = test_exec(cgrp, &opt);
1581 out:
1582         txmsg_start = 0;
1583         txmsg_end = 0;
1584         sched_yield();
1585         return err;
1586 }
1587
1588 char *map_names[] = {
1589         "sock_map",
1590         "sock_map_txmsg",
1591         "sock_map_redir",
1592         "sock_apply_bytes",
1593         "sock_cork_bytes",
1594         "sock_bytes",
1595         "sock_redir_flags",
1596         "sock_skb_opts",
1597 };
1598
1599 int prog_attach_type[] = {
1600         BPF_SK_SKB_STREAM_PARSER,
1601         BPF_SK_SKB_STREAM_VERDICT,
1602         BPF_CGROUP_SOCK_OPS,
1603         BPF_SK_MSG_VERDICT,
1604         BPF_SK_MSG_VERDICT,
1605         BPF_SK_MSG_VERDICT,
1606         BPF_SK_MSG_VERDICT,
1607         BPF_SK_MSG_VERDICT,
1608         BPF_SK_MSG_VERDICT,
1609         BPF_SK_MSG_VERDICT,
1610 };
1611
1612 int prog_type[] = {
1613         BPF_PROG_TYPE_SK_SKB,
1614         BPF_PROG_TYPE_SK_SKB,
1615         BPF_PROG_TYPE_SOCK_OPS,
1616         BPF_PROG_TYPE_SK_MSG,
1617         BPF_PROG_TYPE_SK_MSG,
1618         BPF_PROG_TYPE_SK_MSG,
1619         BPF_PROG_TYPE_SK_MSG,
1620         BPF_PROG_TYPE_SK_MSG,
1621         BPF_PROG_TYPE_SK_MSG,
1622         BPF_PROG_TYPE_SK_MSG,
1623 };
1624
1625 static int populate_progs(char *bpf_file)
1626 {
1627         struct bpf_program *prog;
1628         struct bpf_object *obj;
1629         int i = 0;
1630         long err;
1631
1632         obj = bpf_object__open(bpf_file);
1633         err = libbpf_get_error(obj);
1634         if (err) {
1635                 char err_buf[256];
1636
1637                 libbpf_strerror(err, err_buf, sizeof(err_buf));
1638                 printf("Unable to load eBPF objects in file '%s' : %s\n",
1639                        bpf_file, err_buf);
1640                 return -1;
1641         }
1642
1643         bpf_object__for_each_program(prog, obj) {
1644                 bpf_program__set_type(prog, prog_type[i]);
1645                 bpf_program__set_expected_attach_type(prog,
1646                                                       prog_attach_type[i]);
1647                 i++;
1648         }
1649
1650         i = bpf_object__load(obj);
1651         i = 0;
1652         bpf_object__for_each_program(prog, obj) {
1653                 prog_fd[i] = bpf_program__fd(prog);
1654                 i++;
1655         }
1656
1657         for (i = 0; i < sizeof(map_fd)/sizeof(int); i++) {
1658                 maps[i] = bpf_object__find_map_by_name(obj, map_names[i]);
1659                 map_fd[i] = bpf_map__fd(maps[i]);
1660                 if (map_fd[i] < 0) {
1661                         fprintf(stderr, "load_bpf_file: (%i) %s\n",
1662                                 map_fd[i], strerror(errno));
1663                         return -1;
1664                 }
1665         }
1666
1667         return 0;
1668 }
1669
1670 static int __test_suite(int cg_fd, char *bpf_file)
1671 {
1672         int err, cleanup = cg_fd;
1673
1674         err = populate_progs(bpf_file);
1675         if (err < 0) {
1676                 fprintf(stderr, "ERROR: (%i) load bpf failed\n", err);
1677                 return err;
1678         }
1679
1680         if (cg_fd < 0) {
1681                 if (setup_cgroup_environment()) {
1682                         fprintf(stderr, "ERROR: cgroup env failed\n");
1683                         return -EINVAL;
1684                 }
1685
1686                 cg_fd = create_and_get_cgroup(CG_PATH);
1687                 if (cg_fd < 0) {
1688                         fprintf(stderr,
1689                                 "ERROR: (%i) open cg path failed: %s\n",
1690                                 cg_fd, optarg);
1691                         return cg_fd;
1692                 }
1693
1694                 if (join_cgroup(CG_PATH)) {
1695                         fprintf(stderr, "ERROR: failed to join cgroup\n");
1696                         return -EINVAL;
1697                 }
1698         }
1699
1700         /* Tests basic commands and APIs with range of iov values */
1701         txmsg_start = txmsg_end = txmsg_start_push = txmsg_end_push = 0;
1702         err = test_txmsg(cg_fd);
1703         if (err)
1704                 goto out;
1705
1706         /* Tests interesting combinations of APIs used together */
1707         err = test_mixed(cg_fd);
1708         if (err)
1709                 goto out;
1710
1711         /* Tests pull_data API using start/end API */
1712         err = test_start_end(cg_fd);
1713         if (err)
1714                 goto out;
1715
1716 out:
1717         printf("Summary: %i PASSED %i FAILED\n", passed, failed);
1718         if (cleanup < 0) {
1719                 cleanup_cgroup_environment();
1720                 close(cg_fd);
1721         }
1722         return err;
1723 }
1724
1725 static int test_suite(int cg_fd)
1726 {
1727         int err;
1728
1729         err = __test_suite(cg_fd, BPF_SOCKMAP_FILENAME);
1730         if (err)
1731                 goto out;
1732         err = __test_suite(cg_fd, BPF_SOCKHASH_FILENAME);
1733 out:
1734         if (cg_fd > -1)
1735                 close(cg_fd);
1736         return err;
1737 }
1738
1739 int main(int argc, char **argv)
1740 {
1741         int iov_count = 1, length = 1024, rate = 1;
1742         struct sockmap_options options = {0};
1743         int opt, longindex, err, cg_fd = 0;
1744         char *bpf_file = BPF_SOCKMAP_FILENAME;
1745         int test = PING_PONG;
1746
1747         if (argc < 2)
1748                 return test_suite(-1);
1749
1750         while ((opt = getopt_long(argc, argv, ":dhvc:r:i:l:t:p:q:",
1751                                   long_options, &longindex)) != -1) {
1752                 switch (opt) {
1753                 case 's':
1754                         txmsg_start = atoi(optarg);
1755                         break;
1756                 case 'e':
1757                         txmsg_end = atoi(optarg);
1758                         break;
1759                 case 'p':
1760                         txmsg_start_push = atoi(optarg);
1761                         break;
1762                 case 'q':
1763                         txmsg_end_push = atoi(optarg);
1764                         break;
1765                 case 'w':
1766                         txmsg_start_pop = atoi(optarg);
1767                         break;
1768                 case 'x':
1769                         txmsg_pop = atoi(optarg);
1770                         break;
1771                 case 'a':
1772                         txmsg_apply = atoi(optarg);
1773                         break;
1774                 case 'k':
1775                         txmsg_cork = atoi(optarg);
1776                         break;
1777                 case 'c':
1778                         cg_fd = open(optarg, O_DIRECTORY, O_RDONLY);
1779                         if (cg_fd < 0) {
1780                                 fprintf(stderr,
1781                                         "ERROR: (%i) open cg path failed: %s\n",
1782                                         cg_fd, optarg);
1783                                 return cg_fd;
1784                         }
1785                         break;
1786                 case 'r':
1787                         rate = atoi(optarg);
1788                         break;
1789                 case 'v':
1790                         options.verbose = 1;
1791                         break;
1792                 case 'i':
1793                         iov_count = atoi(optarg);
1794                         break;
1795                 case 'l':
1796                         length = atoi(optarg);
1797                         break;
1798                 case 'd':
1799                         options.data_test = true;
1800                         break;
1801                 case 't':
1802                         if (strcmp(optarg, "ping") == 0) {
1803                                 test = PING_PONG;
1804                         } else if (strcmp(optarg, "sendmsg") == 0) {
1805                                 test = SENDMSG;
1806                         } else if (strcmp(optarg, "base") == 0) {
1807                                 test = BASE;
1808                         } else if (strcmp(optarg, "base_sendpage") == 0) {
1809                                 test = BASE_SENDPAGE;
1810                         } else if (strcmp(optarg, "sendpage") == 0) {
1811                                 test = SENDPAGE;
1812                         } else {
1813                                 usage(argv);
1814                                 return -1;
1815                         }
1816                         break;
1817                 case 0:
1818                         break;
1819                 case 'h':
1820                 default:
1821                         usage(argv);
1822                         return -1;
1823                 }
1824         }
1825
1826         if (argc <= 3 && cg_fd)
1827                 return test_suite(cg_fd);
1828
1829         if (!cg_fd) {
1830                 fprintf(stderr, "%s requires cgroup option: --cgroup <path>\n",
1831                         argv[0]);
1832                 return -1;
1833         }
1834
1835         err = populate_progs(bpf_file);
1836         if (err) {
1837                 fprintf(stderr, "populate program: (%s) %s\n",
1838                         bpf_file, strerror(errno));
1839                 return 1;
1840         }
1841         running = 1;
1842
1843         /* catch SIGINT */
1844         signal(SIGINT, running_handler);
1845
1846         options.iov_count = iov_count;
1847         options.iov_length = length;
1848         options.rate = rate;
1849
1850         err = run_options(&options, cg_fd, test);
1851         close(cg_fd);
1852         return err;
1853 }
1854
1855 void running_handler(int a)
1856 {
1857         running = 0;
1858 }