Merge tag 'hwmon-for-v5.8-rc6' of git://git.kernel.org/pub/scm/linux/kernel/git/groec...
[linux-2.6-microblaze.git] / tools / testing / selftests / bpf / prog_tests / bpf_tcp_ca.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2019 Facebook */
3
4 #include <linux/err.h>
5 #include <test_progs.h>
6 #include "bpf_dctcp.skel.h"
7 #include "bpf_cubic.skel.h"
8
9 #define min(a, b) ((a) < (b) ? (a) : (b))
10
11 static const unsigned int total_bytes = 10 * 1024 * 1024;
12 static const struct timeval timeo_sec = { .tv_sec = 10 };
13 static const size_t timeo_optlen = sizeof(timeo_sec);
14 static int expected_stg = 0xeB9F;
15 static int stop, duration;
16
17 static int settimeo(int fd)
18 {
19         int err;
20
21         err = setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeo_sec,
22                          timeo_optlen);
23         if (CHECK(err == -1, "setsockopt(fd, SO_RCVTIMEO)", "errno:%d\n",
24                   errno))
25                 return -1;
26
27         err = setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeo_sec,
28                          timeo_optlen);
29         if (CHECK(err == -1, "setsockopt(fd, SO_SNDTIMEO)", "errno:%d\n",
30                   errno))
31                 return -1;
32
33         return 0;
34 }
35
36 static int settcpca(int fd, const char *tcp_ca)
37 {
38         int err;
39
40         err = setsockopt(fd, IPPROTO_TCP, TCP_CONGESTION, tcp_ca, strlen(tcp_ca));
41         if (CHECK(err == -1, "setsockopt(fd, TCP_CONGESTION)", "errno:%d\n",
42                   errno))
43                 return -1;
44
45         return 0;
46 }
47
48 static void *server(void *arg)
49 {
50         int lfd = (int)(long)arg, err = 0, fd;
51         ssize_t nr_sent = 0, bytes = 0;
52         char batch[1500];
53
54         fd = accept(lfd, NULL, NULL);
55         while (fd == -1) {
56                 if (errno == EINTR)
57                         continue;
58                 err = -errno;
59                 goto done;
60         }
61
62         if (settimeo(fd)) {
63                 err = -errno;
64                 goto done;
65         }
66
67         while (bytes < total_bytes && !READ_ONCE(stop)) {
68                 nr_sent = send(fd, &batch,
69                                min(total_bytes - bytes, sizeof(batch)), 0);
70                 if (nr_sent == -1 && errno == EINTR)
71                         continue;
72                 if (nr_sent == -1) {
73                         err = -errno;
74                         break;
75                 }
76                 bytes += nr_sent;
77         }
78
79         CHECK(bytes != total_bytes, "send", "%zd != %u nr_sent:%zd errno:%d\n",
80               bytes, total_bytes, nr_sent, errno);
81
82 done:
83         if (fd != -1)
84                 close(fd);
85         if (err) {
86                 WRITE_ONCE(stop, 1);
87                 return ERR_PTR(err);
88         }
89         return NULL;
90 }
91
92 static void do_test(const char *tcp_ca, const struct bpf_map *sk_stg_map)
93 {
94         struct sockaddr_in6 sa6 = {};
95         ssize_t nr_recv = 0, bytes = 0;
96         int lfd = -1, fd = -1;
97         pthread_t srv_thread;
98         socklen_t addrlen = sizeof(sa6);
99         void *thread_ret;
100         char batch[1500];
101         int err;
102
103         WRITE_ONCE(stop, 0);
104
105         lfd = socket(AF_INET6, SOCK_STREAM, 0);
106         if (CHECK(lfd == -1, "socket", "errno:%d\n", errno))
107                 return;
108         fd = socket(AF_INET6, SOCK_STREAM, 0);
109         if (CHECK(fd == -1, "socket", "errno:%d\n", errno)) {
110                 close(lfd);
111                 return;
112         }
113
114         if (settcpca(lfd, tcp_ca) || settcpca(fd, tcp_ca) ||
115             settimeo(lfd) || settimeo(fd))
116                 goto done;
117
118         /* bind, listen and start server thread to accept */
119         sa6.sin6_family = AF_INET6;
120         sa6.sin6_addr = in6addr_loopback;
121         err = bind(lfd, (struct sockaddr *)&sa6, addrlen);
122         if (CHECK(err == -1, "bind", "errno:%d\n", errno))
123                 goto done;
124         err = getsockname(lfd, (struct sockaddr *)&sa6, &addrlen);
125         if (CHECK(err == -1, "getsockname", "errno:%d\n", errno))
126                 goto done;
127         err = listen(lfd, 1);
128         if (CHECK(err == -1, "listen", "errno:%d\n", errno))
129                 goto done;
130
131         if (sk_stg_map) {
132                 err = bpf_map_update_elem(bpf_map__fd(sk_stg_map), &fd,
133                                           &expected_stg, BPF_NOEXIST);
134                 if (CHECK(err, "bpf_map_update_elem(sk_stg_map)",
135                           "err:%d errno:%d\n", err, errno))
136                         goto done;
137         }
138
139         /* connect to server */
140         err = connect(fd, (struct sockaddr *)&sa6, addrlen);
141         if (CHECK(err == -1, "connect", "errno:%d\n", errno))
142                 goto done;
143
144         if (sk_stg_map) {
145                 int tmp_stg;
146
147                 err = bpf_map_lookup_elem(bpf_map__fd(sk_stg_map), &fd,
148                                           &tmp_stg);
149                 if (CHECK(!err || errno != ENOENT,
150                           "bpf_map_lookup_elem(sk_stg_map)",
151                           "err:%d errno:%d\n", err, errno))
152                         goto done;
153         }
154
155         err = pthread_create(&srv_thread, NULL, server, (void *)(long)lfd);
156         if (CHECK(err != 0, "pthread_create", "err:%d errno:%d\n", err, errno))
157                 goto done;
158
159         /* recv total_bytes */
160         while (bytes < total_bytes && !READ_ONCE(stop)) {
161                 nr_recv = recv(fd, &batch,
162                                min(total_bytes - bytes, sizeof(batch)), 0);
163                 if (nr_recv == -1 && errno == EINTR)
164                         continue;
165                 if (nr_recv == -1)
166                         break;
167                 bytes += nr_recv;
168         }
169
170         CHECK(bytes != total_bytes, "recv", "%zd != %u nr_recv:%zd errno:%d\n",
171               bytes, total_bytes, nr_recv, errno);
172
173         WRITE_ONCE(stop, 1);
174         pthread_join(srv_thread, &thread_ret);
175         CHECK(IS_ERR(thread_ret), "pthread_join", "thread_ret:%ld",
176               PTR_ERR(thread_ret));
177 done:
178         close(lfd);
179         close(fd);
180 }
181
182 static void test_cubic(void)
183 {
184         struct bpf_cubic *cubic_skel;
185         struct bpf_link *link;
186
187         cubic_skel = bpf_cubic__open_and_load();
188         if (CHECK(!cubic_skel, "bpf_cubic__open_and_load", "failed\n"))
189                 return;
190
191         link = bpf_map__attach_struct_ops(cubic_skel->maps.cubic);
192         if (CHECK(IS_ERR(link), "bpf_map__attach_struct_ops", "err:%ld\n",
193                   PTR_ERR(link))) {
194                 bpf_cubic__destroy(cubic_skel);
195                 return;
196         }
197
198         do_test("bpf_cubic", NULL);
199
200         bpf_link__destroy(link);
201         bpf_cubic__destroy(cubic_skel);
202 }
203
204 static void test_dctcp(void)
205 {
206         struct bpf_dctcp *dctcp_skel;
207         struct bpf_link *link;
208
209         dctcp_skel = bpf_dctcp__open_and_load();
210         if (CHECK(!dctcp_skel, "bpf_dctcp__open_and_load", "failed\n"))
211                 return;
212
213         link = bpf_map__attach_struct_ops(dctcp_skel->maps.dctcp);
214         if (CHECK(IS_ERR(link), "bpf_map__attach_struct_ops", "err:%ld\n",
215                   PTR_ERR(link))) {
216                 bpf_dctcp__destroy(dctcp_skel);
217                 return;
218         }
219
220         do_test("bpf_dctcp", dctcp_skel->maps.sk_stg_map);
221         CHECK(dctcp_skel->bss->stg_result != expected_stg,
222               "Unexpected stg_result", "stg_result (%x) != expected_stg (%x)\n",
223               dctcp_skel->bss->stg_result, expected_stg);
224
225         bpf_link__destroy(link);
226         bpf_dctcp__destroy(dctcp_skel);
227 }
228
229 void test_bpf_tcp_ca(void)
230 {
231         if (test__start_subtest("dctcp"))
232                 test_dctcp();
233         if (test__start_subtest("cubic"))
234                 test_cubic();
235 }