power: supply: core: Fix parsing of battery chemistry/technology
[linux-2.6-microblaze.git] / tools / testing / selftests / net / ipsec.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * ipsec.c - Check xfrm on veth inside a net-ns.
4  * Copyright (c) 2018 Dmitry Safonov
5  */
6
7 #define _GNU_SOURCE
8
9 #include <arpa/inet.h>
10 #include <asm/types.h>
11 #include <errno.h>
12 #include <fcntl.h>
13 #include <limits.h>
14 #include <linux/limits.h>
15 #include <linux/netlink.h>
16 #include <linux/random.h>
17 #include <linux/rtnetlink.h>
18 #include <linux/veth.h>
19 #include <linux/xfrm.h>
20 #include <netinet/in.h>
21 #include <net/if.h>
22 #include <sched.h>
23 #include <stdbool.h>
24 #include <stdint.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <sys/mman.h>
29 #include <sys/socket.h>
30 #include <sys/stat.h>
31 #include <sys/syscall.h>
32 #include <sys/types.h>
33 #include <sys/wait.h>
34 #include <time.h>
35 #include <unistd.h>
36
37 #include "../kselftest.h"
38
39 #define printk(fmt, ...)                                                \
40         ksft_print_msg("%d[%u] " fmt "\n", getpid(), __LINE__, ##__VA_ARGS__)
41
42 #define pr_err(fmt, ...)        printk(fmt ": %m", ##__VA_ARGS__)
43
44 #define ARRAY_SIZE(arr) (sizeof(arr) / sizeof((arr)[0]))
45 #define BUILD_BUG_ON(condition) ((void)sizeof(char[1 - 2*!!(condition)]))
46
47 #define IPV4_STR_SZ     16      /* xxx.xxx.xxx.xxx is longest + \0 */
48 #define MAX_PAYLOAD     2048
49 #define XFRM_ALGO_KEY_BUF_SIZE  512
50 #define MAX_PROCESSES   (1 << 14) /* /16 mask divided by /30 subnets */
51 #define INADDR_A        ((in_addr_t) 0x0a000000) /* 10.0.0.0 */
52 #define INADDR_B        ((in_addr_t) 0xc0a80000) /* 192.168.0.0 */
53
54 /* /30 mask for one veth connection */
55 #define PREFIX_LEN      30
56 #define child_ip(nr)    (4*nr + 1)
57 #define grchild_ip(nr)  (4*nr + 2)
58
59 #define VETH_FMT        "ktst-%d"
60 #define VETH_LEN        12
61
62 static int nsfd_parent  = -1;
63 static int nsfd_childa  = -1;
64 static int nsfd_childb  = -1;
65 static long page_size;
66
67 /*
68  * ksft_cnt is static in kselftest, so isn't shared with children.
69  * We have to send a test result back to parent and count there.
70  * results_fd is a pipe with test feedback from children.
71  */
72 static int results_fd[2];
73
74 const unsigned int ping_delay_nsec      = 50 * 1000 * 1000;
75 const unsigned int ping_timeout         = 300;
76 const unsigned int ping_count           = 100;
77 const unsigned int ping_success         = 80;
78
79 static void randomize_buffer(void *buf, size_t buflen)
80 {
81         int *p = (int *)buf;
82         size_t words = buflen / sizeof(int);
83         size_t leftover = buflen % sizeof(int);
84
85         if (!buflen)
86                 return;
87
88         while (words--)
89                 *p++ = rand();
90
91         if (leftover) {
92                 int tmp = rand();
93
94                 memcpy(buf + buflen - leftover, &tmp, leftover);
95         }
96
97         return;
98 }
99
100 static int unshare_open(void)
101 {
102         const char *netns_path = "/proc/self/ns/net";
103         int fd;
104
105         if (unshare(CLONE_NEWNET) != 0) {
106                 pr_err("unshare()");
107                 return -1;
108         }
109
110         fd = open(netns_path, O_RDONLY);
111         if (fd <= 0) {
112                 pr_err("open(%s)", netns_path);
113                 return -1;
114         }
115
116         return fd;
117 }
118
119 static int switch_ns(int fd)
120 {
121         if (setns(fd, CLONE_NEWNET)) {
122                 pr_err("setns()");
123                 return -1;
124         }
125         return 0;
126 }
127
128 /*
129  * Running the test inside a new parent net namespace to bother less
130  * about cleanup on error-path.
131  */
132 static int init_namespaces(void)
133 {
134         nsfd_parent = unshare_open();
135         if (nsfd_parent <= 0)
136                 return -1;
137
138         nsfd_childa = unshare_open();
139         if (nsfd_childa <= 0)
140                 return -1;
141
142         if (switch_ns(nsfd_parent))
143                 return -1;
144
145         nsfd_childb = unshare_open();
146         if (nsfd_childb <= 0)
147                 return -1;
148
149         if (switch_ns(nsfd_parent))
150                 return -1;
151         return 0;
152 }
153
154 static int netlink_sock(int *sock, uint32_t *seq_nr, int proto)
155 {
156         if (*sock > 0) {
157                 seq_nr++;
158                 return 0;
159         }
160
161         *sock = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, proto);
162         if (*sock <= 0) {
163                 pr_err("socket(AF_NETLINK)");
164                 return -1;
165         }
166
167         randomize_buffer(seq_nr, sizeof(*seq_nr));
168
169         return 0;
170 }
171
172 static inline struct rtattr *rtattr_hdr(struct nlmsghdr *nh)
173 {
174         return (struct rtattr *)((char *)(nh) + RTA_ALIGN((nh)->nlmsg_len));
175 }
176
177 static int rtattr_pack(struct nlmsghdr *nh, size_t req_sz,
178                 unsigned short rta_type, const void *payload, size_t size)
179 {
180         /* NLMSG_ALIGNTO == RTA_ALIGNTO, nlmsg_len already aligned */
181         struct rtattr *attr = rtattr_hdr(nh);
182         size_t nl_size = RTA_ALIGN(nh->nlmsg_len) + RTA_LENGTH(size);
183
184         if (req_sz < nl_size) {
185                 printk("req buf is too small: %zu < %zu", req_sz, nl_size);
186                 return -1;
187         }
188         nh->nlmsg_len = nl_size;
189
190         attr->rta_len = RTA_LENGTH(size);
191         attr->rta_type = rta_type;
192         memcpy(RTA_DATA(attr), payload, size);
193
194         return 0;
195 }
196
197 static struct rtattr *_rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
198                 unsigned short rta_type, const void *payload, size_t size)
199 {
200         struct rtattr *ret = rtattr_hdr(nh);
201
202         if (rtattr_pack(nh, req_sz, rta_type, payload, size))
203                 return 0;
204
205         return ret;
206 }
207
208 static inline struct rtattr *rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
209                 unsigned short rta_type)
210 {
211         return _rtattr_begin(nh, req_sz, rta_type, 0, 0);
212 }
213
214 static inline void rtattr_end(struct nlmsghdr *nh, struct rtattr *attr)
215 {
216         char *nlmsg_end = (char *)nh + nh->nlmsg_len;
217
218         attr->rta_len = nlmsg_end - (char *)attr;
219 }
220
221 static int veth_pack_peerb(struct nlmsghdr *nh, size_t req_sz,
222                 const char *peer, int ns)
223 {
224         struct ifinfomsg pi;
225         struct rtattr *peer_attr;
226
227         memset(&pi, 0, sizeof(pi));
228         pi.ifi_family   = AF_UNSPEC;
229         pi.ifi_change   = 0xFFFFFFFF;
230
231         peer_attr = _rtattr_begin(nh, req_sz, VETH_INFO_PEER, &pi, sizeof(pi));
232         if (!peer_attr)
233                 return -1;
234
235         if (rtattr_pack(nh, req_sz, IFLA_IFNAME, peer, strlen(peer)))
236                 return -1;
237
238         if (rtattr_pack(nh, req_sz, IFLA_NET_NS_FD, &ns, sizeof(ns)))
239                 return -1;
240
241         rtattr_end(nh, peer_attr);
242
243         return 0;
244 }
245
246 static int netlink_check_answer(int sock)
247 {
248         struct nlmsgerror {
249                 struct nlmsghdr hdr;
250                 int error;
251                 struct nlmsghdr orig_msg;
252         } answer;
253
254         if (recv(sock, &answer, sizeof(answer), 0) < 0) {
255                 pr_err("recv()");
256                 return -1;
257         } else if (answer.hdr.nlmsg_type != NLMSG_ERROR) {
258                 printk("expected NLMSG_ERROR, got %d", (int)answer.hdr.nlmsg_type);
259                 return -1;
260         } else if (answer.error) {
261                 printk("NLMSG_ERROR: %d: %s",
262                         answer.error, strerror(-answer.error));
263                 return answer.error;
264         }
265
266         return 0;
267 }
268
269 static int veth_add(int sock, uint32_t seq, const char *peera, int ns_a,
270                 const char *peerb, int ns_b)
271 {
272         uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
273         struct {
274                 struct nlmsghdr         nh;
275                 struct ifinfomsg        info;
276                 char                    attrbuf[MAX_PAYLOAD];
277         } req;
278         const char veth_type[] = "veth";
279         struct rtattr *link_info, *info_data;
280
281         memset(&req, 0, sizeof(req));
282         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.info));
283         req.nh.nlmsg_type       = RTM_NEWLINK;
284         req.nh.nlmsg_flags      = flags;
285         req.nh.nlmsg_seq        = seq;
286         req.info.ifi_family     = AF_UNSPEC;
287         req.info.ifi_change     = 0xFFFFFFFF;
288
289         if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, peera, strlen(peera)))
290                 return -1;
291
292         if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, &ns_a, sizeof(ns_a)))
293                 return -1;
294
295         link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO);
296         if (!link_info)
297                 return -1;
298
299         if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, veth_type, sizeof(veth_type)))
300                 return -1;
301
302         info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA);
303         if (!info_data)
304                 return -1;
305
306         if (veth_pack_peerb(&req.nh, sizeof(req), peerb, ns_b))
307                 return -1;
308
309         rtattr_end(&req.nh, info_data);
310         rtattr_end(&req.nh, link_info);
311
312         if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
313                 pr_err("send()");
314                 return -1;
315         }
316         return netlink_check_answer(sock);
317 }
318
319 static int ip4_addr_set(int sock, uint32_t seq, const char *intf,
320                 struct in_addr addr, uint8_t prefix)
321 {
322         uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
323         struct {
324                 struct nlmsghdr         nh;
325                 struct ifaddrmsg        info;
326                 char                    attrbuf[MAX_PAYLOAD];
327         } req;
328
329         memset(&req, 0, sizeof(req));
330         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.info));
331         req.nh.nlmsg_type       = RTM_NEWADDR;
332         req.nh.nlmsg_flags      = flags;
333         req.nh.nlmsg_seq        = seq;
334         req.info.ifa_family     = AF_INET;
335         req.info.ifa_prefixlen  = prefix;
336         req.info.ifa_index      = if_nametoindex(intf);
337
338 #ifdef DEBUG
339         {
340                 char addr_str[IPV4_STR_SZ] = {};
341
342                 strncpy(addr_str, inet_ntoa(addr), IPV4_STR_SZ - 1);
343
344                 printk("ip addr set %s", addr_str);
345         }
346 #endif
347
348         if (rtattr_pack(&req.nh, sizeof(req), IFA_LOCAL, &addr, sizeof(addr)))
349                 return -1;
350
351         if (rtattr_pack(&req.nh, sizeof(req), IFA_ADDRESS, &addr, sizeof(addr)))
352                 return -1;
353
354         if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
355                 pr_err("send()");
356                 return -1;
357         }
358         return netlink_check_answer(sock);
359 }
360
361 static int link_set_up(int sock, uint32_t seq, const char *intf)
362 {
363         struct {
364                 struct nlmsghdr         nh;
365                 struct ifinfomsg        info;
366                 char                    attrbuf[MAX_PAYLOAD];
367         } req;
368
369         memset(&req, 0, sizeof(req));
370         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.info));
371         req.nh.nlmsg_type       = RTM_NEWLINK;
372         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
373         req.nh.nlmsg_seq        = seq;
374         req.info.ifi_family     = AF_UNSPEC;
375         req.info.ifi_change     = 0xFFFFFFFF;
376         req.info.ifi_index      = if_nametoindex(intf);
377         req.info.ifi_flags      = IFF_UP;
378         req.info.ifi_change     = IFF_UP;
379
380         if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
381                 pr_err("send()");
382                 return -1;
383         }
384         return netlink_check_answer(sock);
385 }
386
387 static int ip4_route_set(int sock, uint32_t seq, const char *intf,
388                 struct in_addr src, struct in_addr dst)
389 {
390         struct {
391                 struct nlmsghdr nh;
392                 struct rtmsg    rt;
393                 char            attrbuf[MAX_PAYLOAD];
394         } req;
395         unsigned int index = if_nametoindex(intf);
396
397         memset(&req, 0, sizeof(req));
398         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.rt));
399         req.nh.nlmsg_type       = RTM_NEWROUTE;
400         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
401         req.nh.nlmsg_seq        = seq;
402         req.rt.rtm_family       = AF_INET;
403         req.rt.rtm_dst_len      = 32;
404         req.rt.rtm_table        = RT_TABLE_MAIN;
405         req.rt.rtm_protocol     = RTPROT_BOOT;
406         req.rt.rtm_scope        = RT_SCOPE_LINK;
407         req.rt.rtm_type         = RTN_UNICAST;
408
409         if (rtattr_pack(&req.nh, sizeof(req), RTA_DST, &dst, sizeof(dst)))
410                 return -1;
411
412         if (rtattr_pack(&req.nh, sizeof(req), RTA_PREFSRC, &src, sizeof(src)))
413                 return -1;
414
415         if (rtattr_pack(&req.nh, sizeof(req), RTA_OIF, &index, sizeof(index)))
416                 return -1;
417
418         if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
419                 pr_err("send()");
420                 return -1;
421         }
422
423         return netlink_check_answer(sock);
424 }
425
426 static int tunnel_set_route(int route_sock, uint32_t *route_seq, char *veth,
427                 struct in_addr tunsrc, struct in_addr tundst)
428 {
429         if (ip4_addr_set(route_sock, (*route_seq)++, "lo",
430                         tunsrc, PREFIX_LEN)) {
431                 printk("Failed to set ipv4 addr");
432                 return -1;
433         }
434
435         if (ip4_route_set(route_sock, (*route_seq)++, veth, tunsrc, tundst)) {
436                 printk("Failed to set ipv4 route");
437                 return -1;
438         }
439
440         return 0;
441 }
442
443 static int init_child(int nsfd, char *veth, unsigned int src, unsigned int dst)
444 {
445         struct in_addr intsrc = inet_makeaddr(INADDR_B, src);
446         struct in_addr tunsrc = inet_makeaddr(INADDR_A, src);
447         struct in_addr tundst = inet_makeaddr(INADDR_A, dst);
448         int route_sock = -1, ret = -1;
449         uint32_t route_seq;
450
451         if (switch_ns(nsfd))
452                 return -1;
453
454         if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) {
455                 printk("Failed to open netlink route socket in child");
456                 return -1;
457         }
458
459         if (ip4_addr_set(route_sock, route_seq++, veth, intsrc, PREFIX_LEN)) {
460                 printk("Failed to set ipv4 addr");
461                 goto err;
462         }
463
464         if (link_set_up(route_sock, route_seq++, veth)) {
465                 printk("Failed to bring up %s", veth);
466                 goto err;
467         }
468
469         if (tunnel_set_route(route_sock, &route_seq, veth, tunsrc, tundst)) {
470                 printk("Failed to add tunnel route on %s", veth);
471                 goto err;
472         }
473         ret = 0;
474
475 err:
476         close(route_sock);
477         return ret;
478 }
479
480 #define ALGO_LEN        64
481 enum desc_type {
482         CREATE_TUNNEL   = 0,
483         ALLOCATE_SPI,
484         MONITOR_ACQUIRE,
485         EXPIRE_STATE,
486         EXPIRE_POLICY,
487 };
488 const char *desc_name[] = {
489         "create tunnel",
490         "alloc spi",
491         "monitor acquire",
492         "expire state",
493         "expire policy"
494 };
495 struct xfrm_desc {
496         enum desc_type  type;
497         uint8_t         proto;
498         char            a_algo[ALGO_LEN];
499         char            e_algo[ALGO_LEN];
500         char            c_algo[ALGO_LEN];
501         char            ae_algo[ALGO_LEN];
502         unsigned int    icv_len;
503         /* unsigned key_len; */
504 };
505
506 enum msg_type {
507         MSG_ACK         = 0,
508         MSG_EXIT,
509         MSG_PING,
510         MSG_XFRM_PREPARE,
511         MSG_XFRM_ADD,
512         MSG_XFRM_DEL,
513         MSG_XFRM_CLEANUP,
514 };
515
516 struct test_desc {
517         enum msg_type type;
518         union {
519                 struct {
520                         in_addr_t reply_ip;
521                         unsigned int port;
522                 } ping;
523                 struct xfrm_desc xfrm_desc;
524         } body;
525 };
526
527 struct test_result {
528         struct xfrm_desc desc;
529         unsigned int res;
530 };
531
532 static void write_test_result(unsigned int res, struct xfrm_desc *d)
533 {
534         struct test_result tr = {};
535         ssize_t ret;
536
537         tr.desc = *d;
538         tr.res = res;
539
540         ret = write(results_fd[1], &tr, sizeof(tr));
541         if (ret != sizeof(tr))
542                 pr_err("Failed to write the result in pipe %zd", ret);
543 }
544
545 static void write_msg(int fd, struct test_desc *msg, bool exit_of_fail)
546 {
547         ssize_t bytes = write(fd, msg, sizeof(*msg));
548
549         /* Make sure that write/read is atomic to a pipe */
550         BUILD_BUG_ON(sizeof(struct test_desc) > PIPE_BUF);
551
552         if (bytes < 0) {
553                 pr_err("write()");
554                 if (exit_of_fail)
555                         exit(KSFT_FAIL);
556         }
557         if (bytes != sizeof(*msg)) {
558                 pr_err("sent part of the message %zd/%zu", bytes, sizeof(*msg));
559                 if (exit_of_fail)
560                         exit(KSFT_FAIL);
561         }
562 }
563
564 static void read_msg(int fd, struct test_desc *msg, bool exit_of_fail)
565 {
566         ssize_t bytes = read(fd, msg, sizeof(*msg));
567
568         if (bytes < 0) {
569                 pr_err("read()");
570                 if (exit_of_fail)
571                         exit(KSFT_FAIL);
572         }
573         if (bytes != sizeof(*msg)) {
574                 pr_err("got incomplete message %zd/%zu", bytes, sizeof(*msg));
575                 if (exit_of_fail)
576                         exit(KSFT_FAIL);
577         }
578 }
579
580 static int udp_ping_init(struct in_addr listen_ip, unsigned int u_timeout,
581                 unsigned int *server_port, int sock[2])
582 {
583         struct sockaddr_in server;
584         struct timeval t = { .tv_sec = 0, .tv_usec = u_timeout };
585         socklen_t s_len = sizeof(server);
586
587         sock[0] = socket(AF_INET, SOCK_DGRAM, 0);
588         if (sock[0] < 0) {
589                 pr_err("socket()");
590                 return -1;
591         }
592
593         server.sin_family       = AF_INET;
594         server.sin_port         = 0;
595         memcpy(&server.sin_addr.s_addr, &listen_ip, sizeof(struct in_addr));
596
597         if (bind(sock[0], (struct sockaddr *)&server, s_len)) {
598                 pr_err("bind()");
599                 goto err_close_server;
600         }
601
602         if (getsockname(sock[0], (struct sockaddr *)&server, &s_len)) {
603                 pr_err("getsockname()");
604                 goto err_close_server;
605         }
606
607         *server_port = ntohs(server.sin_port);
608
609         if (setsockopt(sock[0], SOL_SOCKET, SO_RCVTIMEO, (const char *)&t, sizeof t)) {
610                 pr_err("setsockopt()");
611                 goto err_close_server;
612         }
613
614         sock[1] = socket(AF_INET, SOCK_DGRAM, 0);
615         if (sock[1] < 0) {
616                 pr_err("socket()");
617                 goto err_close_server;
618         }
619
620         return 0;
621
622 err_close_server:
623         close(sock[0]);
624         return -1;
625 }
626
627 static int udp_ping_send(int sock[2], in_addr_t dest_ip, unsigned int port,
628                 char *buf, size_t buf_len)
629 {
630         struct sockaddr_in server;
631         const struct sockaddr *dest_addr = (struct sockaddr *)&server;
632         char *sock_buf[buf_len];
633         ssize_t r_bytes, s_bytes;
634
635         server.sin_family       = AF_INET;
636         server.sin_port         = htons(port);
637         server.sin_addr.s_addr  = dest_ip;
638
639         s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
640         if (s_bytes < 0) {
641                 pr_err("sendto()");
642                 return -1;
643         } else if (s_bytes != buf_len) {
644                 printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
645                 return -1;
646         }
647
648         r_bytes = recv(sock[0], sock_buf, buf_len, 0);
649         if (r_bytes < 0) {
650                 if (errno != EAGAIN)
651                         pr_err("recv()");
652                 return -1;
653         } else if (r_bytes == 0) { /* EOF */
654                 printk("EOF on reply to ping");
655                 return -1;
656         } else if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
657                 printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
658                 return -1;
659         }
660
661         return 0;
662 }
663
664 static int udp_ping_reply(int sock[2], in_addr_t dest_ip, unsigned int port,
665                 char *buf, size_t buf_len)
666 {
667         struct sockaddr_in server;
668         const struct sockaddr *dest_addr = (struct sockaddr *)&server;
669         char *sock_buf[buf_len];
670         ssize_t r_bytes, s_bytes;
671
672         server.sin_family       = AF_INET;
673         server.sin_port         = htons(port);
674         server.sin_addr.s_addr  = dest_ip;
675
676         r_bytes = recv(sock[0], sock_buf, buf_len, 0);
677         if (r_bytes < 0) {
678                 if (errno != EAGAIN)
679                         pr_err("recv()");
680                 return -1;
681         }
682         if (r_bytes == 0) { /* EOF */
683                 printk("EOF on reply to ping");
684                 return -1;
685         }
686         if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
687                 printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
688                 return -1;
689         }
690
691         s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
692         if (s_bytes < 0) {
693                 pr_err("sendto()");
694                 return -1;
695         } else if (s_bytes != buf_len) {
696                 printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
697                 return -1;
698         }
699
700         return 0;
701 }
702
703 typedef int (*ping_f)(int sock[2], in_addr_t dest_ip, unsigned int port,
704                 char *buf, size_t buf_len);
705 static int do_ping(int cmd_fd, char *buf, size_t buf_len, struct in_addr from,
706                 bool init_side, int d_port, in_addr_t to, ping_f func)
707 {
708         struct test_desc msg;
709         unsigned int s_port, i, ping_succeeded = 0;
710         int ping_sock[2];
711         char to_str[IPV4_STR_SZ] = {}, from_str[IPV4_STR_SZ] = {};
712
713         if (udp_ping_init(from, ping_timeout, &s_port, ping_sock)) {
714                 printk("Failed to init ping");
715                 return -1;
716         }
717
718         memset(&msg, 0, sizeof(msg));
719         msg.type                = MSG_PING;
720         msg.body.ping.port      = s_port;
721         memcpy(&msg.body.ping.reply_ip, &from, sizeof(from));
722
723         write_msg(cmd_fd, &msg, 0);
724         if (init_side) {
725                 /* The other end sends ip to ping */
726                 read_msg(cmd_fd, &msg, 0);
727                 if (msg.type != MSG_PING)
728                         return -1;
729                 to = msg.body.ping.reply_ip;
730                 d_port = msg.body.ping.port;
731         }
732
733         for (i = 0; i < ping_count ; i++) {
734                 struct timespec sleep_time = {
735                         .tv_sec = 0,
736                         .tv_nsec = ping_delay_nsec,
737                 };
738
739                 ping_succeeded += !func(ping_sock, to, d_port, buf, page_size);
740                 nanosleep(&sleep_time, 0);
741         }
742
743         close(ping_sock[0]);
744         close(ping_sock[1]);
745
746         strncpy(to_str, inet_ntoa(*(struct in_addr *)&to), IPV4_STR_SZ - 1);
747         strncpy(from_str, inet_ntoa(from), IPV4_STR_SZ - 1);
748
749         if (ping_succeeded < ping_success) {
750                 printk("ping (%s) %s->%s failed %u/%u times",
751                         init_side ? "send" : "reply", from_str, to_str,
752                         ping_count - ping_succeeded, ping_count);
753                 return -1;
754         }
755
756 #ifdef DEBUG
757         printk("ping (%s) %s->%s succeeded %u/%u times",
758                 init_side ? "send" : "reply", from_str, to_str,
759                 ping_succeeded, ping_count);
760 #endif
761
762         return 0;
763 }
764
765 static int xfrm_fill_key(char *name, char *buf,
766                 size_t buf_len, unsigned int *key_len)
767 {
768         /* TODO: use set/map instead */
769         if (strncmp(name, "digest_null", ALGO_LEN) == 0)
770                 *key_len = 0;
771         else if (strncmp(name, "ecb(cipher_null)", ALGO_LEN) == 0)
772                 *key_len = 0;
773         else if (strncmp(name, "cbc(des)", ALGO_LEN) == 0)
774                 *key_len = 64;
775         else if (strncmp(name, "hmac(md5)", ALGO_LEN) == 0)
776                 *key_len = 128;
777         else if (strncmp(name, "cmac(aes)", ALGO_LEN) == 0)
778                 *key_len = 128;
779         else if (strncmp(name, "xcbc(aes)", ALGO_LEN) == 0)
780                 *key_len = 128;
781         else if (strncmp(name, "cbc(cast5)", ALGO_LEN) == 0)
782                 *key_len = 128;
783         else if (strncmp(name, "cbc(serpent)", ALGO_LEN) == 0)
784                 *key_len = 128;
785         else if (strncmp(name, "hmac(sha1)", ALGO_LEN) == 0)
786                 *key_len = 160;
787         else if (strncmp(name, "hmac(rmd160)", ALGO_LEN) == 0)
788                 *key_len = 160;
789         else if (strncmp(name, "cbc(des3_ede)", ALGO_LEN) == 0)
790                 *key_len = 192;
791         else if (strncmp(name, "hmac(sha256)", ALGO_LEN) == 0)
792                 *key_len = 256;
793         else if (strncmp(name, "cbc(aes)", ALGO_LEN) == 0)
794                 *key_len = 256;
795         else if (strncmp(name, "cbc(camellia)", ALGO_LEN) == 0)
796                 *key_len = 256;
797         else if (strncmp(name, "cbc(twofish)", ALGO_LEN) == 0)
798                 *key_len = 256;
799         else if (strncmp(name, "rfc3686(ctr(aes))", ALGO_LEN) == 0)
800                 *key_len = 288;
801         else if (strncmp(name, "hmac(sha384)", ALGO_LEN) == 0)
802                 *key_len = 384;
803         else if (strncmp(name, "cbc(blowfish)", ALGO_LEN) == 0)
804                 *key_len = 448;
805         else if (strncmp(name, "hmac(sha512)", ALGO_LEN) == 0)
806                 *key_len = 512;
807         else if (strncmp(name, "rfc4106(gcm(aes))-128", ALGO_LEN) == 0)
808                 *key_len = 160;
809         else if (strncmp(name, "rfc4543(gcm(aes))-128", ALGO_LEN) == 0)
810                 *key_len = 160;
811         else if (strncmp(name, "rfc4309(ccm(aes))-128", ALGO_LEN) == 0)
812                 *key_len = 152;
813         else if (strncmp(name, "rfc4106(gcm(aes))-192", ALGO_LEN) == 0)
814                 *key_len = 224;
815         else if (strncmp(name, "rfc4543(gcm(aes))-192", ALGO_LEN) == 0)
816                 *key_len = 224;
817         else if (strncmp(name, "rfc4309(ccm(aes))-192", ALGO_LEN) == 0)
818                 *key_len = 216;
819         else if (strncmp(name, "rfc4106(gcm(aes))-256", ALGO_LEN) == 0)
820                 *key_len = 288;
821         else if (strncmp(name, "rfc4543(gcm(aes))-256", ALGO_LEN) == 0)
822                 *key_len = 288;
823         else if (strncmp(name, "rfc4309(ccm(aes))-256", ALGO_LEN) == 0)
824                 *key_len = 280;
825         else if (strncmp(name, "rfc7539(chacha20,poly1305)-128", ALGO_LEN) == 0)
826                 *key_len = 0;
827
828         if (*key_len > buf_len) {
829                 printk("Can't pack a key - too big for buffer");
830                 return -1;
831         }
832
833         randomize_buffer(buf, *key_len);
834
835         return 0;
836 }
837
838 static int xfrm_state_pack_algo(struct nlmsghdr *nh, size_t req_sz,
839                 struct xfrm_desc *desc)
840 {
841         struct {
842                 union {
843                         struct xfrm_algo        alg;
844                         struct xfrm_algo_aead   aead;
845                         struct xfrm_algo_auth   auth;
846                 } u;
847                 char buf[XFRM_ALGO_KEY_BUF_SIZE];
848         } alg = {};
849         size_t alen, elen, clen, aelen;
850         unsigned short type;
851
852         alen = strlen(desc->a_algo);
853         elen = strlen(desc->e_algo);
854         clen = strlen(desc->c_algo);
855         aelen = strlen(desc->ae_algo);
856
857         /* Verify desc */
858         switch (desc->proto) {
859         case IPPROTO_AH:
860                 if (!alen || elen || clen || aelen) {
861                         printk("BUG: buggy ah desc");
862                         return -1;
863                 }
864                 strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN - 1);
865                 if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
866                                 sizeof(alg.buf), &alg.u.alg.alg_key_len))
867                         return -1;
868                 type = XFRMA_ALG_AUTH;
869                 break;
870         case IPPROTO_COMP:
871                 if (!clen || elen || alen || aelen) {
872                         printk("BUG: buggy comp desc");
873                         return -1;
874                 }
875                 strncpy(alg.u.alg.alg_name, desc->c_algo, ALGO_LEN - 1);
876                 if (xfrm_fill_key(desc->c_algo, alg.u.alg.alg_key,
877                                 sizeof(alg.buf), &alg.u.alg.alg_key_len))
878                         return -1;
879                 type = XFRMA_ALG_COMP;
880                 break;
881         case IPPROTO_ESP:
882                 if (!((alen && elen) ^ aelen) || clen) {
883                         printk("BUG: buggy esp desc");
884                         return -1;
885                 }
886                 if (aelen) {
887                         alg.u.aead.alg_icv_len = desc->icv_len;
888                         strncpy(alg.u.aead.alg_name, desc->ae_algo, ALGO_LEN - 1);
889                         if (xfrm_fill_key(desc->ae_algo, alg.u.aead.alg_key,
890                                                 sizeof(alg.buf), &alg.u.aead.alg_key_len))
891                                 return -1;
892                         type = XFRMA_ALG_AEAD;
893                 } else {
894
895                         strncpy(alg.u.alg.alg_name, desc->e_algo, ALGO_LEN - 1);
896                         type = XFRMA_ALG_CRYPT;
897                         if (xfrm_fill_key(desc->e_algo, alg.u.alg.alg_key,
898                                                 sizeof(alg.buf), &alg.u.alg.alg_key_len))
899                                 return -1;
900                         if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
901                                 return -1;
902
903                         strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN);
904                         type = XFRMA_ALG_AUTH;
905                         if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
906                                                 sizeof(alg.buf), &alg.u.alg.alg_key_len))
907                                 return -1;
908                 }
909                 break;
910         default:
911                 printk("BUG: unknown proto in desc");
912                 return -1;
913         }
914
915         if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
916                 return -1;
917
918         return 0;
919 }
920
921 static inline uint32_t gen_spi(struct in_addr src)
922 {
923         return htonl(inet_lnaof(src));
924 }
925
926 static int xfrm_state_add(int xfrm_sock, uint32_t seq, uint32_t spi,
927                 struct in_addr src, struct in_addr dst,
928                 struct xfrm_desc *desc)
929 {
930         struct {
931                 struct nlmsghdr         nh;
932                 struct xfrm_usersa_info info;
933                 char                    attrbuf[MAX_PAYLOAD];
934         } req;
935
936         memset(&req, 0, sizeof(req));
937         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.info));
938         req.nh.nlmsg_type       = XFRM_MSG_NEWSA;
939         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
940         req.nh.nlmsg_seq        = seq;
941
942         /* Fill selector. */
943         memcpy(&req.info.sel.daddr, &dst, sizeof(dst));
944         memcpy(&req.info.sel.saddr, &src, sizeof(src));
945         req.info.sel.family             = AF_INET;
946         req.info.sel.prefixlen_d        = PREFIX_LEN;
947         req.info.sel.prefixlen_s        = PREFIX_LEN;
948
949         /* Fill id */
950         memcpy(&req.info.id.daddr, &dst, sizeof(dst));
951         /* Note: zero-spi cannot be deleted */
952         req.info.id.spi = spi;
953         req.info.id.proto       = desc->proto;
954
955         memcpy(&req.info.saddr, &src, sizeof(src));
956
957         /* Fill lifteme_cfg */
958         req.info.lft.soft_byte_limit    = XFRM_INF;
959         req.info.lft.hard_byte_limit    = XFRM_INF;
960         req.info.lft.soft_packet_limit  = XFRM_INF;
961         req.info.lft.hard_packet_limit  = XFRM_INF;
962
963         req.info.family         = AF_INET;
964         req.info.mode           = XFRM_MODE_TUNNEL;
965
966         if (xfrm_state_pack_algo(&req.nh, sizeof(req), desc))
967                 return -1;
968
969         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
970                 pr_err("send()");
971                 return -1;
972         }
973
974         return netlink_check_answer(xfrm_sock);
975 }
976
977 static bool xfrm_usersa_found(struct xfrm_usersa_info *info, uint32_t spi,
978                 struct in_addr src, struct in_addr dst,
979                 struct xfrm_desc *desc)
980 {
981         if (memcmp(&info->sel.daddr, &dst, sizeof(dst)))
982                 return false;
983
984         if (memcmp(&info->sel.saddr, &src, sizeof(src)))
985                 return false;
986
987         if (info->sel.family != AF_INET                                 ||
988                         info->sel.prefixlen_d != PREFIX_LEN             ||
989                         info->sel.prefixlen_s != PREFIX_LEN)
990                 return false;
991
992         if (info->id.spi != spi || info->id.proto != desc->proto)
993                 return false;
994
995         if (memcmp(&info->id.daddr, &dst, sizeof(dst)))
996                 return false;
997
998         if (memcmp(&info->saddr, &src, sizeof(src)))
999                 return false;
1000
1001         if (info->lft.soft_byte_limit != XFRM_INF                       ||
1002                         info->lft.hard_byte_limit != XFRM_INF           ||
1003                         info->lft.soft_packet_limit != XFRM_INF         ||
1004                         info->lft.hard_packet_limit != XFRM_INF)
1005                 return false;
1006
1007         if (info->family != AF_INET || info->mode != XFRM_MODE_TUNNEL)
1008                 return false;
1009
1010         /* XXX: check xfrm algo, see xfrm_state_pack_algo(). */
1011
1012         return true;
1013 }
1014
1015 static int xfrm_state_check(int xfrm_sock, uint32_t seq, uint32_t spi,
1016                 struct in_addr src, struct in_addr dst,
1017                 struct xfrm_desc *desc)
1018 {
1019         struct {
1020                 struct nlmsghdr         nh;
1021                 char                    attrbuf[MAX_PAYLOAD];
1022         } req;
1023         struct {
1024                 struct nlmsghdr         nh;
1025                 union {
1026                         struct xfrm_usersa_info info;
1027                         int error;
1028                 };
1029                 char                    attrbuf[MAX_PAYLOAD];
1030         } answer;
1031         struct xfrm_address_filter filter = {};
1032         bool found = false;
1033
1034
1035         memset(&req, 0, sizeof(req));
1036         req.nh.nlmsg_len        = NLMSG_LENGTH(0);
1037         req.nh.nlmsg_type       = XFRM_MSG_GETSA;
1038         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_DUMP;
1039         req.nh.nlmsg_seq        = seq;
1040
1041         /*
1042          * Add dump filter by source address as there may be other tunnels
1043          * in this netns (if tests run in parallel).
1044          */
1045         filter.family = AF_INET;
1046         filter.splen = 0x1f;    /* 0xffffffff mask see addr_match() */
1047         memcpy(&filter.saddr, &src, sizeof(src));
1048         if (rtattr_pack(&req.nh, sizeof(req), XFRMA_ADDRESS_FILTER,
1049                                 &filter, sizeof(filter)))
1050                 return -1;
1051
1052         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1053                 pr_err("send()");
1054                 return -1;
1055         }
1056
1057         while (1) {
1058                 if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1059                         pr_err("recv()");
1060                         return -1;
1061                 }
1062                 if (answer.nh.nlmsg_type == NLMSG_ERROR) {
1063                         printk("NLMSG_ERROR: %d: %s",
1064                                 answer.error, strerror(-answer.error));
1065                         return -1;
1066                 } else if (answer.nh.nlmsg_type == NLMSG_DONE) {
1067                         if (found)
1068                                 return 0;
1069                         printk("didn't find allocated xfrm state in dump");
1070                         return -1;
1071                 } else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1072                         if (xfrm_usersa_found(&answer.info, spi, src, dst, desc))
1073                                 found = true;
1074                 }
1075         }
1076 }
1077
1078 static int xfrm_set(int xfrm_sock, uint32_t *seq,
1079                 struct in_addr src, struct in_addr dst,
1080                 struct in_addr tunsrc, struct in_addr tundst,
1081                 struct xfrm_desc *desc)
1082 {
1083         int err;
1084
1085         err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1086         if (err) {
1087                 printk("Failed to add xfrm state");
1088                 return -1;
1089         }
1090
1091         err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1092         if (err) {
1093                 printk("Failed to add xfrm state");
1094                 return -1;
1095         }
1096
1097         /* Check dumps for XFRM_MSG_GETSA */
1098         err = xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1099         err |= xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1100         if (err) {
1101                 printk("Failed to check xfrm state");
1102                 return -1;
1103         }
1104
1105         return 0;
1106 }
1107
1108 static int xfrm_policy_add(int xfrm_sock, uint32_t seq, uint32_t spi,
1109                 struct in_addr src, struct in_addr dst, uint8_t dir,
1110                 struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1111 {
1112         struct {
1113                 struct nlmsghdr                 nh;
1114                 struct xfrm_userpolicy_info     info;
1115                 char                            attrbuf[MAX_PAYLOAD];
1116         } req;
1117         struct xfrm_user_tmpl tmpl;
1118
1119         memset(&req, 0, sizeof(req));
1120         memset(&tmpl, 0, sizeof(tmpl));
1121         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.info));
1122         req.nh.nlmsg_type       = XFRM_MSG_NEWPOLICY;
1123         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1124         req.nh.nlmsg_seq        = seq;
1125
1126         /* Fill selector. */
1127         memcpy(&req.info.sel.daddr, &dst, sizeof(tundst));
1128         memcpy(&req.info.sel.saddr, &src, sizeof(tunsrc));
1129         req.info.sel.family             = AF_INET;
1130         req.info.sel.prefixlen_d        = PREFIX_LEN;
1131         req.info.sel.prefixlen_s        = PREFIX_LEN;
1132
1133         /* Fill lifteme_cfg */
1134         req.info.lft.soft_byte_limit    = XFRM_INF;
1135         req.info.lft.hard_byte_limit    = XFRM_INF;
1136         req.info.lft.soft_packet_limit  = XFRM_INF;
1137         req.info.lft.hard_packet_limit  = XFRM_INF;
1138
1139         req.info.dir = dir;
1140
1141         /* Fill tmpl */
1142         memcpy(&tmpl.id.daddr, &dst, sizeof(dst));
1143         /* Note: zero-spi cannot be deleted */
1144         tmpl.id.spi = spi;
1145         tmpl.id.proto   = proto;
1146         tmpl.family     = AF_INET;
1147         memcpy(&tmpl.saddr, &src, sizeof(src));
1148         tmpl.mode       = XFRM_MODE_TUNNEL;
1149         tmpl.aalgos = (~(uint32_t)0);
1150         tmpl.ealgos = (~(uint32_t)0);
1151         tmpl.calgos = (~(uint32_t)0);
1152
1153         if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &tmpl, sizeof(tmpl)))
1154                 return -1;
1155
1156         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1157                 pr_err("send()");
1158                 return -1;
1159         }
1160
1161         return netlink_check_answer(xfrm_sock);
1162 }
1163
1164 static int xfrm_prepare(int xfrm_sock, uint32_t *seq,
1165                 struct in_addr src, struct in_addr dst,
1166                 struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1167 {
1168         if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1169                                 XFRM_POLICY_OUT, tunsrc, tundst, proto)) {
1170                 printk("Failed to add xfrm policy");
1171                 return -1;
1172         }
1173
1174         if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src,
1175                                 XFRM_POLICY_IN, tunsrc, tundst, proto)) {
1176                 printk("Failed to add xfrm policy");
1177                 return -1;
1178         }
1179
1180         return 0;
1181 }
1182
1183 static int xfrm_policy_del(int xfrm_sock, uint32_t seq,
1184                 struct in_addr src, struct in_addr dst, uint8_t dir,
1185                 struct in_addr tunsrc, struct in_addr tundst)
1186 {
1187         struct {
1188                 struct nlmsghdr                 nh;
1189                 struct xfrm_userpolicy_id       id;
1190                 char                            attrbuf[MAX_PAYLOAD];
1191         } req;
1192
1193         memset(&req, 0, sizeof(req));
1194         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.id));
1195         req.nh.nlmsg_type       = XFRM_MSG_DELPOLICY;
1196         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1197         req.nh.nlmsg_seq        = seq;
1198
1199         /* Fill id */
1200         memcpy(&req.id.sel.daddr, &dst, sizeof(tundst));
1201         memcpy(&req.id.sel.saddr, &src, sizeof(tunsrc));
1202         req.id.sel.family               = AF_INET;
1203         req.id.sel.prefixlen_d          = PREFIX_LEN;
1204         req.id.sel.prefixlen_s          = PREFIX_LEN;
1205         req.id.dir = dir;
1206
1207         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1208                 pr_err("send()");
1209                 return -1;
1210         }
1211
1212         return netlink_check_answer(xfrm_sock);
1213 }
1214
1215 static int xfrm_cleanup(int xfrm_sock, uint32_t *seq,
1216                 struct in_addr src, struct in_addr dst,
1217                 struct in_addr tunsrc, struct in_addr tundst)
1218 {
1219         if (xfrm_policy_del(xfrm_sock, (*seq)++, src, dst,
1220                                 XFRM_POLICY_OUT, tunsrc, tundst)) {
1221                 printk("Failed to add xfrm policy");
1222                 return -1;
1223         }
1224
1225         if (xfrm_policy_del(xfrm_sock, (*seq)++, dst, src,
1226                                 XFRM_POLICY_IN, tunsrc, tundst)) {
1227                 printk("Failed to add xfrm policy");
1228                 return -1;
1229         }
1230
1231         return 0;
1232 }
1233
1234 static int xfrm_state_del(int xfrm_sock, uint32_t seq, uint32_t spi,
1235                 struct in_addr src, struct in_addr dst, uint8_t proto)
1236 {
1237         struct {
1238                 struct nlmsghdr         nh;
1239                 struct xfrm_usersa_id   id;
1240                 char                    attrbuf[MAX_PAYLOAD];
1241         } req;
1242         xfrm_address_t saddr = {};
1243
1244         memset(&req, 0, sizeof(req));
1245         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.id));
1246         req.nh.nlmsg_type       = XFRM_MSG_DELSA;
1247         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1248         req.nh.nlmsg_seq        = seq;
1249
1250         memcpy(&req.id.daddr, &dst, sizeof(dst));
1251         req.id.family           = AF_INET;
1252         req.id.proto            = proto;
1253         /* Note: zero-spi cannot be deleted */
1254         req.id.spi = spi;
1255
1256         memcpy(&saddr, &src, sizeof(src));
1257         if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SRCADDR, &saddr, sizeof(saddr)))
1258                 return -1;
1259
1260         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1261                 pr_err("send()");
1262                 return -1;
1263         }
1264
1265         return netlink_check_answer(xfrm_sock);
1266 }
1267
1268 static int xfrm_delete(int xfrm_sock, uint32_t *seq,
1269                 struct in_addr src, struct in_addr dst,
1270                 struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1271 {
1272         if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), src, dst, proto)) {
1273                 printk("Failed to remove xfrm state");
1274                 return -1;
1275         }
1276
1277         if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), dst, src, proto)) {
1278                 printk("Failed to remove xfrm state");
1279                 return -1;
1280         }
1281
1282         return 0;
1283 }
1284
1285 static int xfrm_state_allocspi(int xfrm_sock, uint32_t *seq,
1286                 uint32_t spi, uint8_t proto)
1287 {
1288         struct {
1289                 struct nlmsghdr                 nh;
1290                 struct xfrm_userspi_info        spi;
1291         } req;
1292         struct {
1293                 struct nlmsghdr                 nh;
1294                 union {
1295                         struct xfrm_usersa_info info;
1296                         int error;
1297                 };
1298         } answer;
1299
1300         memset(&req, 0, sizeof(req));
1301         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.spi));
1302         req.nh.nlmsg_type       = XFRM_MSG_ALLOCSPI;
1303         req.nh.nlmsg_flags      = NLM_F_REQUEST;
1304         req.nh.nlmsg_seq        = (*seq)++;
1305
1306         req.spi.info.family     = AF_INET;
1307         req.spi.min             = spi;
1308         req.spi.max             = spi;
1309         req.spi.info.id.proto   = proto;
1310
1311         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1312                 pr_err("send()");
1313                 return KSFT_FAIL;
1314         }
1315
1316         if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1317                 pr_err("recv()");
1318                 return KSFT_FAIL;
1319         } else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1320                 uint32_t new_spi = htonl(answer.info.id.spi);
1321
1322                 if (new_spi != spi) {
1323                         printk("allocated spi is different from requested: %#x != %#x",
1324                                         new_spi, spi);
1325                         return KSFT_FAIL;
1326                 }
1327                 return KSFT_PASS;
1328         } else if (answer.nh.nlmsg_type != NLMSG_ERROR) {
1329                 printk("expected NLMSG_ERROR, got %d", (int)answer.nh.nlmsg_type);
1330                 return KSFT_FAIL;
1331         }
1332
1333         printk("NLMSG_ERROR: %d: %s", answer.error, strerror(-answer.error));
1334         return (answer.error) ? KSFT_FAIL : KSFT_PASS;
1335 }
1336
1337 static int netlink_sock_bind(int *sock, uint32_t *seq, int proto, uint32_t groups)
1338 {
1339         struct sockaddr_nl snl = {};
1340         socklen_t addr_len;
1341         int ret = -1;
1342
1343         snl.nl_family = AF_NETLINK;
1344         snl.nl_groups = groups;
1345
1346         if (netlink_sock(sock, seq, proto)) {
1347                 printk("Failed to open xfrm netlink socket");
1348                 return -1;
1349         }
1350
1351         if (bind(*sock, (struct sockaddr *)&snl, sizeof(snl)) < 0) {
1352                 pr_err("bind()");
1353                 goto out_close;
1354         }
1355
1356         addr_len = sizeof(snl);
1357         if (getsockname(*sock, (struct sockaddr *)&snl, &addr_len) < 0) {
1358                 pr_err("getsockname()");
1359                 goto out_close;
1360         }
1361         if (addr_len != sizeof(snl)) {
1362                 printk("Wrong address length %d", addr_len);
1363                 goto out_close;
1364         }
1365         if (snl.nl_family != AF_NETLINK) {
1366                 printk("Wrong address family %d", snl.nl_family);
1367                 goto out_close;
1368         }
1369         return 0;
1370
1371 out_close:
1372         close(*sock);
1373         return ret;
1374 }
1375
1376 static int xfrm_monitor_acquire(int xfrm_sock, uint32_t *seq, unsigned int nr)
1377 {
1378         struct {
1379                 struct nlmsghdr nh;
1380                 union {
1381                         struct xfrm_user_acquire acq;
1382                         int error;
1383                 };
1384                 char attrbuf[MAX_PAYLOAD];
1385         } req;
1386         struct xfrm_user_tmpl xfrm_tmpl = {};
1387         int xfrm_listen = -1, ret = KSFT_FAIL;
1388         uint32_t seq_listen;
1389
1390         if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_ACQUIRE))
1391                 return KSFT_FAIL;
1392
1393         memset(&req, 0, sizeof(req));
1394         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.acq));
1395         req.nh.nlmsg_type       = XFRM_MSG_ACQUIRE;
1396         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1397         req.nh.nlmsg_seq        = (*seq)++;
1398
1399         req.acq.policy.sel.family       = AF_INET;
1400         req.acq.aalgos  = 0xfeed;
1401         req.acq.ealgos  = 0xbaad;
1402         req.acq.calgos  = 0xbabe;
1403
1404         xfrm_tmpl.family = AF_INET;
1405         xfrm_tmpl.id.proto = IPPROTO_ESP;
1406         if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &xfrm_tmpl, sizeof(xfrm_tmpl)))
1407                 goto out_close;
1408
1409         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1410                 pr_err("send()");
1411                 goto out_close;
1412         }
1413
1414         if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1415                 pr_err("recv()");
1416                 goto out_close;
1417         } else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1418                 printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1419                 goto out_close;
1420         }
1421
1422         if (req.error) {
1423                 printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1424                 ret = req.error;
1425                 goto out_close;
1426         }
1427
1428         if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1429                 pr_err("recv()");
1430                 goto out_close;
1431         }
1432
1433         if (req.acq.aalgos != 0xfeed || req.acq.ealgos != 0xbaad
1434                         || req.acq.calgos != 0xbabe) {
1435                 printk("xfrm_user_acquire has changed  %x %x %x",
1436                                 req.acq.aalgos, req.acq.ealgos, req.acq.calgos);
1437                 goto out_close;
1438         }
1439
1440         ret = KSFT_PASS;
1441 out_close:
1442         close(xfrm_listen);
1443         return ret;
1444 }
1445
1446 static int xfrm_expire_state(int xfrm_sock, uint32_t *seq,
1447                 unsigned int nr, struct xfrm_desc *desc)
1448 {
1449         struct {
1450                 struct nlmsghdr nh;
1451                 union {
1452                         struct xfrm_user_expire expire;
1453                         int error;
1454                 };
1455         } req;
1456         struct in_addr src, dst;
1457         int xfrm_listen = -1, ret = KSFT_FAIL;
1458         uint32_t seq_listen;
1459
1460         src = inet_makeaddr(INADDR_B, child_ip(nr));
1461         dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1462
1463         if (xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc)) {
1464                 printk("Failed to add xfrm state");
1465                 return KSFT_FAIL;
1466         }
1467
1468         if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1469                 return KSFT_FAIL;
1470
1471         memset(&req, 0, sizeof(req));
1472         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.expire));
1473         req.nh.nlmsg_type       = XFRM_MSG_EXPIRE;
1474         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1475         req.nh.nlmsg_seq        = (*seq)++;
1476
1477         memcpy(&req.expire.state.id.daddr, &dst, sizeof(dst));
1478         req.expire.state.id.spi         = gen_spi(src);
1479         req.expire.state.id.proto       = desc->proto;
1480         req.expire.state.family         = AF_INET;
1481         req.expire.hard                 = 0xff;
1482
1483         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1484                 pr_err("send()");
1485                 goto out_close;
1486         }
1487
1488         if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1489                 pr_err("recv()");
1490                 goto out_close;
1491         } else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1492                 printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1493                 goto out_close;
1494         }
1495
1496         if (req.error) {
1497                 printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1498                 ret = req.error;
1499                 goto out_close;
1500         }
1501
1502         if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1503                 pr_err("recv()");
1504                 goto out_close;
1505         }
1506
1507         if (req.expire.hard != 0x1) {
1508                 printk("expire.hard is not set: %x", req.expire.hard);
1509                 goto out_close;
1510         }
1511
1512         ret = KSFT_PASS;
1513 out_close:
1514         close(xfrm_listen);
1515         return ret;
1516 }
1517
1518 static int xfrm_expire_policy(int xfrm_sock, uint32_t *seq,
1519                 unsigned int nr, struct xfrm_desc *desc)
1520 {
1521         struct {
1522                 struct nlmsghdr nh;
1523                 union {
1524                         struct xfrm_user_polexpire expire;
1525                         int error;
1526                 };
1527         } req;
1528         struct in_addr src, dst, tunsrc, tundst;
1529         int xfrm_listen = -1, ret = KSFT_FAIL;
1530         uint32_t seq_listen;
1531
1532         src = inet_makeaddr(INADDR_B, child_ip(nr));
1533         dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1534         tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1535         tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1536
1537         if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1538                                 XFRM_POLICY_OUT, tunsrc, tundst, desc->proto)) {
1539                 printk("Failed to add xfrm policy");
1540                 return KSFT_FAIL;
1541         }
1542
1543         if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1544                 return KSFT_FAIL;
1545
1546         memset(&req, 0, sizeof(req));
1547         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.expire));
1548         req.nh.nlmsg_type       = XFRM_MSG_POLEXPIRE;
1549         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1550         req.nh.nlmsg_seq        = (*seq)++;
1551
1552         /* Fill selector. */
1553         memcpy(&req.expire.pol.sel.daddr, &dst, sizeof(tundst));
1554         memcpy(&req.expire.pol.sel.saddr, &src, sizeof(tunsrc));
1555         req.expire.pol.sel.family       = AF_INET;
1556         req.expire.pol.sel.prefixlen_d  = PREFIX_LEN;
1557         req.expire.pol.sel.prefixlen_s  = PREFIX_LEN;
1558         req.expire.pol.dir              = XFRM_POLICY_OUT;
1559         req.expire.hard                 = 0xff;
1560
1561         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1562                 pr_err("send()");
1563                 goto out_close;
1564         }
1565
1566         if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1567                 pr_err("recv()");
1568                 goto out_close;
1569         } else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1570                 printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1571                 goto out_close;
1572         }
1573
1574         if (req.error) {
1575                 printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1576                 ret = req.error;
1577                 goto out_close;
1578         }
1579
1580         if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1581                 pr_err("recv()");
1582                 goto out_close;
1583         }
1584
1585         if (req.expire.hard != 0x1) {
1586                 printk("expire.hard is not set: %x", req.expire.hard);
1587                 goto out_close;
1588         }
1589
1590         ret = KSFT_PASS;
1591 out_close:
1592         close(xfrm_listen);
1593         return ret;
1594 }
1595
1596 static int child_serv(int xfrm_sock, uint32_t *seq,
1597                 unsigned int nr, int cmd_fd, void *buf, struct xfrm_desc *desc)
1598 {
1599         struct in_addr src, dst, tunsrc, tundst;
1600         struct test_desc msg;
1601         int ret = KSFT_FAIL;
1602
1603         src = inet_makeaddr(INADDR_B, child_ip(nr));
1604         dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1605         tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1606         tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1607
1608         /* UDP pinging without xfrm */
1609         if (do_ping(cmd_fd, buf, page_size, src, true, 0, 0, udp_ping_send)) {
1610                 printk("ping failed before setting xfrm");
1611                 return KSFT_FAIL;
1612         }
1613
1614         memset(&msg, 0, sizeof(msg));
1615         msg.type = MSG_XFRM_PREPARE;
1616         memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1617         write_msg(cmd_fd, &msg, 1);
1618
1619         if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1620                 printk("failed to prepare xfrm");
1621                 goto cleanup;
1622         }
1623
1624         memset(&msg, 0, sizeof(msg));
1625         msg.type = MSG_XFRM_ADD;
1626         memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1627         write_msg(cmd_fd, &msg, 1);
1628         if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1629                 printk("failed to set xfrm");
1630                 goto delete;
1631         }
1632
1633         /* UDP pinging with xfrm tunnel */
1634         if (do_ping(cmd_fd, buf, page_size, tunsrc,
1635                                 true, 0, 0, udp_ping_send)) {
1636                 printk("ping failed for xfrm");
1637                 goto delete;
1638         }
1639
1640         ret = KSFT_PASS;
1641 delete:
1642         /* xfrm delete */
1643         memset(&msg, 0, sizeof(msg));
1644         msg.type = MSG_XFRM_DEL;
1645         memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1646         write_msg(cmd_fd, &msg, 1);
1647
1648         if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1649                 printk("failed ping to remove xfrm");
1650                 ret = KSFT_FAIL;
1651         }
1652
1653 cleanup:
1654         memset(&msg, 0, sizeof(msg));
1655         msg.type = MSG_XFRM_CLEANUP;
1656         memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1657         write_msg(cmd_fd, &msg, 1);
1658         if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1659                 printk("failed ping to cleanup xfrm");
1660                 ret = KSFT_FAIL;
1661         }
1662         return ret;
1663 }
1664
1665 static int child_f(unsigned int nr, int test_desc_fd, int cmd_fd, void *buf)
1666 {
1667         struct xfrm_desc desc;
1668         struct test_desc msg;
1669         int xfrm_sock = -1;
1670         uint32_t seq;
1671
1672         if (switch_ns(nsfd_childa))
1673                 exit(KSFT_FAIL);
1674
1675         if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1676                 printk("Failed to open xfrm netlink socket");
1677                 exit(KSFT_FAIL);
1678         }
1679
1680         /* Check that seq sock is ready, just for sure. */
1681         memset(&msg, 0, sizeof(msg));
1682         msg.type = MSG_ACK;
1683         write_msg(cmd_fd, &msg, 1);
1684         read_msg(cmd_fd, &msg, 1);
1685         if (msg.type != MSG_ACK) {
1686                 printk("Ack failed");
1687                 exit(KSFT_FAIL);
1688         }
1689
1690         for (;;) {
1691                 ssize_t received = read(test_desc_fd, &desc, sizeof(desc));
1692                 int ret;
1693
1694                 if (received == 0) /* EOF */
1695                         break;
1696
1697                 if (received != sizeof(desc)) {
1698                         pr_err("read() returned %zd", received);
1699                         exit(KSFT_FAIL);
1700                 }
1701
1702                 switch (desc.type) {
1703                 case CREATE_TUNNEL:
1704                         ret = child_serv(xfrm_sock, &seq, nr,
1705                                          cmd_fd, buf, &desc);
1706                         break;
1707                 case ALLOCATE_SPI:
1708                         ret = xfrm_state_allocspi(xfrm_sock, &seq,
1709                                                   -1, desc.proto);
1710                         break;
1711                 case MONITOR_ACQUIRE:
1712                         ret = xfrm_monitor_acquire(xfrm_sock, &seq, nr);
1713                         break;
1714                 case EXPIRE_STATE:
1715                         ret = xfrm_expire_state(xfrm_sock, &seq, nr, &desc);
1716                         break;
1717                 case EXPIRE_POLICY:
1718                         ret = xfrm_expire_policy(xfrm_sock, &seq, nr, &desc);
1719                         break;
1720                 default:
1721                         printk("Unknown desc type %d", desc.type);
1722                         exit(KSFT_FAIL);
1723                 }
1724                 write_test_result(ret, &desc);
1725         }
1726
1727         close(xfrm_sock);
1728
1729         msg.type = MSG_EXIT;
1730         write_msg(cmd_fd, &msg, 1);
1731         exit(KSFT_PASS);
1732 }
1733
1734 static void grand_child_serv(unsigned int nr, int cmd_fd, void *buf,
1735                 struct test_desc *msg, int xfrm_sock, uint32_t *seq)
1736 {
1737         struct in_addr src, dst, tunsrc, tundst;
1738         bool tun_reply;
1739         struct xfrm_desc *desc = &msg->body.xfrm_desc;
1740
1741         src = inet_makeaddr(INADDR_B, grchild_ip(nr));
1742         dst = inet_makeaddr(INADDR_B, child_ip(nr));
1743         tunsrc = inet_makeaddr(INADDR_A, grchild_ip(nr));
1744         tundst = inet_makeaddr(INADDR_A, child_ip(nr));
1745
1746         switch (msg->type) {
1747         case MSG_EXIT:
1748                 exit(KSFT_PASS);
1749         case MSG_ACK:
1750                 write_msg(cmd_fd, msg, 1);
1751                 break;
1752         case MSG_PING:
1753                 tun_reply = memcmp(&dst, &msg->body.ping.reply_ip, sizeof(in_addr_t));
1754                 /* UDP pinging without xfrm */
1755                 if (do_ping(cmd_fd, buf, page_size, tun_reply ? tunsrc : src,
1756                                 false, msg->body.ping.port,
1757                                 msg->body.ping.reply_ip, udp_ping_reply)) {
1758                         printk("ping failed before setting xfrm");
1759                 }
1760                 break;
1761         case MSG_XFRM_PREPARE:
1762                 if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst,
1763                                         desc->proto)) {
1764                         xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1765                         printk("failed to prepare xfrm");
1766                 }
1767                 break;
1768         case MSG_XFRM_ADD:
1769                 if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1770                         xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1771                         printk("failed to set xfrm");
1772                 }
1773                 break;
1774         case MSG_XFRM_DEL:
1775                 if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst,
1776                                         desc->proto)) {
1777                         xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1778                         printk("failed to remove xfrm");
1779                 }
1780                 break;
1781         case MSG_XFRM_CLEANUP:
1782                 if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1783                         printk("failed to cleanup xfrm");
1784                 }
1785                 break;
1786         default:
1787                 printk("got unknown msg type %d", msg->type);
1788         }
1789 }
1790
1791 static int grand_child_f(unsigned int nr, int cmd_fd, void *buf)
1792 {
1793         struct test_desc msg;
1794         int xfrm_sock = -1;
1795         uint32_t seq;
1796
1797         if (switch_ns(nsfd_childb))
1798                 exit(KSFT_FAIL);
1799
1800         if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1801                 printk("Failed to open xfrm netlink socket");
1802                 exit(KSFT_FAIL);
1803         }
1804
1805         do {
1806                 read_msg(cmd_fd, &msg, 1);
1807                 grand_child_serv(nr, cmd_fd, buf, &msg, xfrm_sock, &seq);
1808         } while (1);
1809
1810         close(xfrm_sock);
1811         exit(KSFT_FAIL);
1812 }
1813
1814 static int start_child(unsigned int nr, char *veth, int test_desc_fd[2])
1815 {
1816         int cmd_sock[2];
1817         void *data_map;
1818         pid_t child;
1819
1820         if (init_child(nsfd_childa, veth, child_ip(nr), grchild_ip(nr)))
1821                 return -1;
1822
1823         if (init_child(nsfd_childb, veth, grchild_ip(nr), child_ip(nr)))
1824                 return -1;
1825
1826         child = fork();
1827         if (child < 0) {
1828                 pr_err("fork()");
1829                 return -1;
1830         } else if (child) {
1831                 /* in parent - selftest */
1832                 return switch_ns(nsfd_parent);
1833         }
1834
1835         if (close(test_desc_fd[1])) {
1836                 pr_err("close()");
1837                 return -1;
1838         }
1839
1840         /* child */
1841         data_map = mmap(0, page_size, PROT_READ | PROT_WRITE,
1842                         MAP_SHARED | MAP_ANONYMOUS, -1, 0);
1843         if (data_map == MAP_FAILED) {
1844                 pr_err("mmap()");
1845                 return -1;
1846         }
1847
1848         randomize_buffer(data_map, page_size);
1849
1850         if (socketpair(PF_LOCAL, SOCK_SEQPACKET, 0, cmd_sock)) {
1851                 pr_err("socketpair()");
1852                 return -1;
1853         }
1854
1855         child = fork();
1856         if (child < 0) {
1857                 pr_err("fork()");
1858                 return -1;
1859         } else if (child) {
1860                 if (close(cmd_sock[0])) {
1861                         pr_err("close()");
1862                         return -1;
1863                 }
1864                 return child_f(nr, test_desc_fd[0], cmd_sock[1], data_map);
1865         }
1866         if (close(cmd_sock[1])) {
1867                 pr_err("close()");
1868                 return -1;
1869         }
1870         return grand_child_f(nr, cmd_sock[0], data_map);
1871 }
1872
1873 static void exit_usage(char **argv)
1874 {
1875         printk("Usage: %s [nr_process]", argv[0]);
1876         exit(KSFT_FAIL);
1877 }
1878
1879 static int __write_desc(int test_desc_fd, struct xfrm_desc *desc)
1880 {
1881         ssize_t ret;
1882
1883         ret = write(test_desc_fd, desc, sizeof(*desc));
1884
1885         if (ret == sizeof(*desc))
1886                 return 0;
1887
1888         pr_err("Writing test's desc failed %ld", ret);
1889
1890         return -1;
1891 }
1892
1893 static int write_desc(int proto, int test_desc_fd,
1894                 char *a, char *e, char *c, char *ae)
1895 {
1896         struct xfrm_desc desc = {};
1897
1898         desc.type = CREATE_TUNNEL;
1899         desc.proto = proto;
1900
1901         if (a)
1902                 strncpy(desc.a_algo, a, ALGO_LEN - 1);
1903         if (e)
1904                 strncpy(desc.e_algo, e, ALGO_LEN - 1);
1905         if (c)
1906                 strncpy(desc.c_algo, c, ALGO_LEN - 1);
1907         if (ae)
1908                 strncpy(desc.ae_algo, ae, ALGO_LEN - 1);
1909
1910         return __write_desc(test_desc_fd, &desc);
1911 }
1912
1913 int proto_list[] = { IPPROTO_AH, IPPROTO_COMP, IPPROTO_ESP };
1914 char *ah_list[] = {
1915         "digest_null", "hmac(md5)", "hmac(sha1)", "hmac(sha256)",
1916         "hmac(sha384)", "hmac(sha512)", "hmac(rmd160)",
1917         "xcbc(aes)", "cmac(aes)"
1918 };
1919 char *comp_list[] = {
1920         "deflate",
1921 #if 0
1922         /* No compression backend realization */
1923         "lzs", "lzjh"
1924 #endif
1925 };
1926 char *e_list[] = {
1927         "ecb(cipher_null)", "cbc(des)", "cbc(des3_ede)", "cbc(cast5)",
1928         "cbc(blowfish)", "cbc(aes)", "cbc(serpent)", "cbc(camellia)",
1929         "cbc(twofish)", "rfc3686(ctr(aes))"
1930 };
1931 char *ae_list[] = {
1932 #if 0
1933         /* not implemented */
1934         "rfc4106(gcm(aes))", "rfc4309(ccm(aes))", "rfc4543(gcm(aes))",
1935         "rfc7539esp(chacha20,poly1305)"
1936 #endif
1937 };
1938
1939 const unsigned int proto_plan = ARRAY_SIZE(ah_list) + ARRAY_SIZE(comp_list) \
1940                                 + (ARRAY_SIZE(ah_list) * ARRAY_SIZE(e_list)) \
1941                                 + ARRAY_SIZE(ae_list);
1942
1943 static int write_proto_plan(int fd, int proto)
1944 {
1945         unsigned int i;
1946
1947         switch (proto) {
1948         case IPPROTO_AH:
1949                 for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
1950                         if (write_desc(proto, fd, ah_list[i], 0, 0, 0))
1951                                 return -1;
1952                 }
1953                 break;
1954         case IPPROTO_COMP:
1955                 for (i = 0; i < ARRAY_SIZE(comp_list); i++) {
1956                         if (write_desc(proto, fd, 0, 0, comp_list[i], 0))
1957                                 return -1;
1958                 }
1959                 break;
1960         case IPPROTO_ESP:
1961                 for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
1962                         int j;
1963
1964                         for (j = 0; j < ARRAY_SIZE(e_list); j++) {
1965                                 if (write_desc(proto, fd, ah_list[i],
1966                                                         e_list[j], 0, 0))
1967                                         return -1;
1968                         }
1969                 }
1970                 for (i = 0; i < ARRAY_SIZE(ae_list); i++) {
1971                         if (write_desc(proto, fd, 0, 0, 0, ae_list[i]))
1972                                 return -1;
1973                 }
1974                 break;
1975         default:
1976                 printk("BUG: Specified unknown proto %d", proto);
1977                 return -1;
1978         }
1979
1980         return 0;
1981 }
1982
1983 /*
1984  * Some structures in xfrm uapi header differ in size between
1985  * 64-bit and 32-bit ABI:
1986  *
1987  *             32-bit UABI               |            64-bit UABI
1988  *  -------------------------------------|-------------------------------------
1989  *   sizeof(xfrm_usersa_info)     = 220  |  sizeof(xfrm_usersa_info)     = 224
1990  *   sizeof(xfrm_userpolicy_info) = 164  |  sizeof(xfrm_userpolicy_info) = 168
1991  *   sizeof(xfrm_userspi_info)    = 228  |  sizeof(xfrm_userspi_info)    = 232
1992  *   sizeof(xfrm_user_acquire)    = 276  |  sizeof(xfrm_user_acquire)    = 280
1993  *   sizeof(xfrm_user_expire)     = 224  |  sizeof(xfrm_user_expire)     = 232
1994  *   sizeof(xfrm_user_polexpire)  = 168  |  sizeof(xfrm_user_polexpire)  = 176
1995  *
1996  * Check the affected by the UABI difference structures.
1997  */
1998 const unsigned int compat_plan = 4;
1999 static int write_compat_struct_tests(int test_desc_fd)
2000 {
2001         struct xfrm_desc desc = {};
2002
2003         desc.type = ALLOCATE_SPI;
2004         desc.proto = IPPROTO_AH;
2005         strncpy(desc.a_algo, ah_list[0], ALGO_LEN - 1);
2006
2007         if (__write_desc(test_desc_fd, &desc))
2008                 return -1;
2009
2010         desc.type = MONITOR_ACQUIRE;
2011         if (__write_desc(test_desc_fd, &desc))
2012                 return -1;
2013
2014         desc.type = EXPIRE_STATE;
2015         if (__write_desc(test_desc_fd, &desc))
2016                 return -1;
2017
2018         desc.type = EXPIRE_POLICY;
2019         if (__write_desc(test_desc_fd, &desc))
2020                 return -1;
2021
2022         return 0;
2023 }
2024
2025 static int write_test_plan(int test_desc_fd)
2026 {
2027         unsigned int i;
2028         pid_t child;
2029
2030         child = fork();
2031         if (child < 0) {
2032                 pr_err("fork()");
2033                 return -1;
2034         }
2035         if (child) {
2036                 if (close(test_desc_fd))
2037                         printk("close(): %m");
2038                 return 0;
2039         }
2040
2041         if (write_compat_struct_tests(test_desc_fd))
2042                 exit(KSFT_FAIL);
2043
2044         for (i = 0; i < ARRAY_SIZE(proto_list); i++) {
2045                 if (write_proto_plan(test_desc_fd, proto_list[i]))
2046                         exit(KSFT_FAIL);
2047         }
2048
2049         exit(KSFT_PASS);
2050 }
2051
2052 static int children_cleanup(void)
2053 {
2054         unsigned ret = KSFT_PASS;
2055
2056         while (1) {
2057                 int status;
2058                 pid_t p = wait(&status);
2059
2060                 if ((p < 0) && errno == ECHILD)
2061                         break;
2062
2063                 if (p < 0) {
2064                         pr_err("wait()");
2065                         return KSFT_FAIL;
2066                 }
2067
2068                 if (!WIFEXITED(status)) {
2069                         ret = KSFT_FAIL;
2070                         continue;
2071                 }
2072
2073                 if (WEXITSTATUS(status) == KSFT_FAIL)
2074                         ret = KSFT_FAIL;
2075         }
2076
2077         return ret;
2078 }
2079
2080 typedef void (*print_res)(const char *, ...);
2081
2082 static int check_results(void)
2083 {
2084         struct test_result tr = {};
2085         struct xfrm_desc *d = &tr.desc;
2086         int ret = KSFT_PASS;
2087
2088         while (1) {
2089                 ssize_t received = read(results_fd[0], &tr, sizeof(tr));
2090                 print_res result;
2091
2092                 if (received == 0) /* EOF */
2093                         break;
2094
2095                 if (received != sizeof(tr)) {
2096                         pr_err("read() returned %zd", received);
2097                         return KSFT_FAIL;
2098                 }
2099
2100                 switch (tr.res) {
2101                 case KSFT_PASS:
2102                         result = ksft_test_result_pass;
2103                         break;
2104                 case KSFT_FAIL:
2105                 default:
2106                         result = ksft_test_result_fail;
2107                         ret = KSFT_FAIL;
2108                 }
2109
2110                 result(" %s: [%u, '%s', '%s', '%s', '%s', %u]\n",
2111                        desc_name[d->type], (unsigned int)d->proto, d->a_algo,
2112                        d->e_algo, d->c_algo, d->ae_algo, d->icv_len);
2113         }
2114
2115         return ret;
2116 }
2117
2118 int main(int argc, char **argv)
2119 {
2120         unsigned int nr_process = 1;
2121         int route_sock = -1, ret = KSFT_SKIP;
2122         int test_desc_fd[2];
2123         uint32_t route_seq;
2124         unsigned int i;
2125
2126         if (argc > 2)
2127                 exit_usage(argv);
2128
2129         if (argc > 1) {
2130                 char *endptr;
2131
2132                 errno = 0;
2133                 nr_process = strtol(argv[1], &endptr, 10);
2134                 if ((errno == ERANGE && (nr_process == LONG_MAX || nr_process == LONG_MIN))
2135                                 || (errno != 0 && nr_process == 0)
2136                                 || (endptr == argv[1]) || (*endptr != '\0')) {
2137                         printk("Failed to parse [nr_process]");
2138                         exit_usage(argv);
2139                 }
2140
2141                 if (nr_process > MAX_PROCESSES || !nr_process) {
2142                         printk("nr_process should be between [1; %u]",
2143                                         MAX_PROCESSES);
2144                         exit_usage(argv);
2145                 }
2146         }
2147
2148         srand(time(NULL));
2149         page_size = sysconf(_SC_PAGESIZE);
2150         if (page_size < 1)
2151                 ksft_exit_skip("sysconf(): %m\n");
2152
2153         if (pipe2(test_desc_fd, O_DIRECT) < 0)
2154                 ksft_exit_skip("pipe(): %m\n");
2155
2156         if (pipe2(results_fd, O_DIRECT) < 0)
2157                 ksft_exit_skip("pipe(): %m\n");
2158
2159         if (init_namespaces())
2160                 ksft_exit_skip("Failed to create namespaces\n");
2161
2162         if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
2163                 ksft_exit_skip("Failed to open netlink route socket\n");
2164
2165         for (i = 0; i < nr_process; i++) {
2166                 char veth[VETH_LEN];
2167
2168                 snprintf(veth, VETH_LEN, VETH_FMT, i);
2169
2170                 if (veth_add(route_sock, route_seq++, veth, nsfd_childa, veth, nsfd_childb)) {
2171                         close(route_sock);
2172                         ksft_exit_fail_msg("Failed to create veth device");
2173                 }
2174
2175                 if (start_child(i, veth, test_desc_fd)) {
2176                         close(route_sock);
2177                         ksft_exit_fail_msg("Child %u failed to start", i);
2178                 }
2179         }
2180
2181         if (close(route_sock) || close(test_desc_fd[0]) || close(results_fd[1]))
2182                 ksft_exit_fail_msg("close(): %m");
2183
2184         ksft_set_plan(proto_plan + compat_plan);
2185
2186         if (write_test_plan(test_desc_fd[1]))
2187                 ksft_exit_fail_msg("Failed to write test plan to pipe");
2188
2189         ret = check_results();
2190
2191         if (children_cleanup() == KSFT_FAIL)
2192                 exit(KSFT_FAIL);
2193
2194         exit(ret);
2195 }