selftests: tls: skip cmsg_to_pipe tests with TLS=n
[linux-2.6-microblaze.git] / tools / testing / selftests / net / tls.c
1 // SPDX-License-Identifier: GPL-2.0
2
3 #define _GNU_SOURCE
4
5 #include <arpa/inet.h>
6 #include <errno.h>
7 #include <error.h>
8 #include <fcntl.h>
9 #include <poll.h>
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include <unistd.h>
13
14 #include <linux/tls.h>
15 #include <linux/tcp.h>
16 #include <linux/socket.h>
17
18 #include <sys/types.h>
19 #include <sys/sendfile.h>
20 #include <sys/socket.h>
21 #include <sys/stat.h>
22
23 #include "../kselftest_harness.h"
24
25 #define TLS_PAYLOAD_MAX_LEN 16384
26 #define SOL_TLS 282
27
28 struct tls_crypto_info_keys {
29         union {
30                 struct tls12_crypto_info_aes_gcm_128 aes128;
31                 struct tls12_crypto_info_chacha20_poly1305 chacha20;
32                 struct tls12_crypto_info_sm4_gcm sm4gcm;
33                 struct tls12_crypto_info_sm4_ccm sm4ccm;
34                 struct tls12_crypto_info_aes_ccm_128 aesccm128;
35                 struct tls12_crypto_info_aes_gcm_256 aesgcm256;
36         };
37         size_t len;
38 };
39
40 static void tls_crypto_info_init(uint16_t tls_version, uint16_t cipher_type,
41                                  struct tls_crypto_info_keys *tls12)
42 {
43         memset(tls12, 0, sizeof(*tls12));
44
45         switch (cipher_type) {
46         case TLS_CIPHER_CHACHA20_POLY1305:
47                 tls12->len = sizeof(struct tls12_crypto_info_chacha20_poly1305);
48                 tls12->chacha20.info.version = tls_version;
49                 tls12->chacha20.info.cipher_type = cipher_type;
50                 break;
51         case TLS_CIPHER_AES_GCM_128:
52                 tls12->len = sizeof(struct tls12_crypto_info_aes_gcm_128);
53                 tls12->aes128.info.version = tls_version;
54                 tls12->aes128.info.cipher_type = cipher_type;
55                 break;
56         case TLS_CIPHER_SM4_GCM:
57                 tls12->len = sizeof(struct tls12_crypto_info_sm4_gcm);
58                 tls12->sm4gcm.info.version = tls_version;
59                 tls12->sm4gcm.info.cipher_type = cipher_type;
60                 break;
61         case TLS_CIPHER_SM4_CCM:
62                 tls12->len = sizeof(struct tls12_crypto_info_sm4_ccm);
63                 tls12->sm4ccm.info.version = tls_version;
64                 tls12->sm4ccm.info.cipher_type = cipher_type;
65                 break;
66         case TLS_CIPHER_AES_CCM_128:
67                 tls12->len = sizeof(struct tls12_crypto_info_aes_ccm_128);
68                 tls12->aesccm128.info.version = tls_version;
69                 tls12->aesccm128.info.cipher_type = cipher_type;
70                 break;
71         case TLS_CIPHER_AES_GCM_256:
72                 tls12->len = sizeof(struct tls12_crypto_info_aes_gcm_256);
73                 tls12->aesgcm256.info.version = tls_version;
74                 tls12->aesgcm256.info.cipher_type = cipher_type;
75                 break;
76         default:
77                 break;
78         }
79 }
80
81 static void memrnd(void *s, size_t n)
82 {
83         int *dword = s;
84         char *byte;
85
86         for (; n >= 4; n -= 4)
87                 *dword++ = rand();
88         byte = (void *)dword;
89         while (n--)
90                 *byte++ = rand();
91 }
92
93 static void ulp_sock_pair(struct __test_metadata *_metadata,
94                           int *fd, int *cfd, bool *notls)
95 {
96         struct sockaddr_in addr;
97         socklen_t len;
98         int sfd, ret;
99
100         *notls = false;
101         len = sizeof(addr);
102
103         addr.sin_family = AF_INET;
104         addr.sin_addr.s_addr = htonl(INADDR_ANY);
105         addr.sin_port = 0;
106
107         *fd = socket(AF_INET, SOCK_STREAM, 0);
108         sfd = socket(AF_INET, SOCK_STREAM, 0);
109
110         ret = bind(sfd, &addr, sizeof(addr));
111         ASSERT_EQ(ret, 0);
112         ret = listen(sfd, 10);
113         ASSERT_EQ(ret, 0);
114
115         ret = getsockname(sfd, &addr, &len);
116         ASSERT_EQ(ret, 0);
117
118         ret = connect(*fd, &addr, sizeof(addr));
119         ASSERT_EQ(ret, 0);
120
121         *cfd = accept(sfd, &addr, &len);
122         ASSERT_GE(*cfd, 0);
123
124         close(sfd);
125
126         ret = setsockopt(*fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
127         if (ret != 0) {
128                 ASSERT_EQ(errno, ENOENT);
129                 *notls = true;
130                 printf("Failure setting TCP_ULP, testing without tls\n");
131                 return;
132         }
133
134         ret = setsockopt(*cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
135         ASSERT_EQ(ret, 0);
136 }
137
138 /* Produce a basic cmsg */
139 static int tls_send_cmsg(int fd, unsigned char record_type,
140                          void *data, size_t len, int flags)
141 {
142         char cbuf[CMSG_SPACE(sizeof(char))];
143         int cmsg_len = sizeof(char);
144         struct cmsghdr *cmsg;
145         struct msghdr msg;
146         struct iovec vec;
147
148         vec.iov_base = data;
149         vec.iov_len = len;
150         memset(&msg, 0, sizeof(struct msghdr));
151         msg.msg_iov = &vec;
152         msg.msg_iovlen = 1;
153         msg.msg_control = cbuf;
154         msg.msg_controllen = sizeof(cbuf);
155         cmsg = CMSG_FIRSTHDR(&msg);
156         cmsg->cmsg_level = SOL_TLS;
157         /* test sending non-record types. */
158         cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
159         cmsg->cmsg_len = CMSG_LEN(cmsg_len);
160         *CMSG_DATA(cmsg) = record_type;
161         msg.msg_controllen = cmsg->cmsg_len;
162
163         return sendmsg(fd, &msg, flags);
164 }
165
166 static int tls_recv_cmsg(struct __test_metadata *_metadata,
167                          int fd, unsigned char record_type,
168                          void *data, size_t len, int flags)
169 {
170         char cbuf[CMSG_SPACE(sizeof(char))];
171         struct cmsghdr *cmsg;
172         unsigned char ctype;
173         struct msghdr msg;
174         struct iovec vec;
175         int n;
176
177         vec.iov_base = data;
178         vec.iov_len = len;
179         memset(&msg, 0, sizeof(struct msghdr));
180         msg.msg_iov = &vec;
181         msg.msg_iovlen = 1;
182         msg.msg_control = cbuf;
183         msg.msg_controllen = sizeof(cbuf);
184
185         n = recvmsg(fd, &msg, flags);
186
187         cmsg = CMSG_FIRSTHDR(&msg);
188         EXPECT_NE(cmsg, NULL);
189         EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
190         EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
191         ctype = *((unsigned char *)CMSG_DATA(cmsg));
192         EXPECT_EQ(ctype, record_type);
193
194         return n;
195 }
196
197 FIXTURE(tls_basic)
198 {
199         int fd, cfd;
200         bool notls;
201 };
202
203 FIXTURE_SETUP(tls_basic)
204 {
205         ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
206 }
207
208 FIXTURE_TEARDOWN(tls_basic)
209 {
210         close(self->fd);
211         close(self->cfd);
212 }
213
214 /* Send some data through with ULP but no keys */
215 TEST_F(tls_basic, base_base)
216 {
217         char const *test_str = "test_read";
218         int send_len = 10;
219         char buf[10];
220
221         ASSERT_EQ(strlen(test_str) + 1, send_len);
222
223         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
224         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
225         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
226 };
227
228 FIXTURE(tls)
229 {
230         int fd, cfd;
231         bool notls;
232 };
233
234 FIXTURE_VARIANT(tls)
235 {
236         uint16_t tls_version;
237         uint16_t cipher_type;
238 };
239
240 FIXTURE_VARIANT_ADD(tls, 12_aes_gcm)
241 {
242         .tls_version = TLS_1_2_VERSION,
243         .cipher_type = TLS_CIPHER_AES_GCM_128,
244 };
245
246 FIXTURE_VARIANT_ADD(tls, 13_aes_gcm)
247 {
248         .tls_version = TLS_1_3_VERSION,
249         .cipher_type = TLS_CIPHER_AES_GCM_128,
250 };
251
252 FIXTURE_VARIANT_ADD(tls, 12_chacha)
253 {
254         .tls_version = TLS_1_2_VERSION,
255         .cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
256 };
257
258 FIXTURE_VARIANT_ADD(tls, 13_chacha)
259 {
260         .tls_version = TLS_1_3_VERSION,
261         .cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
262 };
263
264 FIXTURE_VARIANT_ADD(tls, 13_sm4_gcm)
265 {
266         .tls_version = TLS_1_3_VERSION,
267         .cipher_type = TLS_CIPHER_SM4_GCM,
268 };
269
270 FIXTURE_VARIANT_ADD(tls, 13_sm4_ccm)
271 {
272         .tls_version = TLS_1_3_VERSION,
273         .cipher_type = TLS_CIPHER_SM4_CCM,
274 };
275
276 FIXTURE_VARIANT_ADD(tls, 12_aes_ccm)
277 {
278         .tls_version = TLS_1_2_VERSION,
279         .cipher_type = TLS_CIPHER_AES_CCM_128,
280 };
281
282 FIXTURE_VARIANT_ADD(tls, 13_aes_ccm)
283 {
284         .tls_version = TLS_1_3_VERSION,
285         .cipher_type = TLS_CIPHER_AES_CCM_128,
286 };
287
288 FIXTURE_VARIANT_ADD(tls, 12_aes_gcm_256)
289 {
290         .tls_version = TLS_1_2_VERSION,
291         .cipher_type = TLS_CIPHER_AES_GCM_256,
292 };
293
294 FIXTURE_VARIANT_ADD(tls, 13_aes_gcm_256)
295 {
296         .tls_version = TLS_1_3_VERSION,
297         .cipher_type = TLS_CIPHER_AES_GCM_256,
298 };
299
300 FIXTURE_SETUP(tls)
301 {
302         struct tls_crypto_info_keys tls12;
303         int ret;
304
305         tls_crypto_info_init(variant->tls_version, variant->cipher_type,
306                              &tls12);
307
308         ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
309
310         if (self->notls)
311                 return;
312
313         ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
314         ASSERT_EQ(ret, 0);
315
316         ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len);
317         ASSERT_EQ(ret, 0);
318 }
319
320 FIXTURE_TEARDOWN(tls)
321 {
322         close(self->fd);
323         close(self->cfd);
324 }
325
326 TEST_F(tls, sendfile)
327 {
328         int filefd = open("/proc/self/exe", O_RDONLY);
329         struct stat st;
330
331         EXPECT_GE(filefd, 0);
332         fstat(filefd, &st);
333         EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
334 }
335
336 TEST_F(tls, send_then_sendfile)
337 {
338         int filefd = open("/proc/self/exe", O_RDONLY);
339         char const *test_str = "test_send";
340         int to_send = strlen(test_str) + 1;
341         char recv_buf[10];
342         struct stat st;
343         char *buf;
344
345         EXPECT_GE(filefd, 0);
346         fstat(filefd, &st);
347         buf = (char *)malloc(st.st_size);
348
349         EXPECT_EQ(send(self->fd, test_str, to_send, 0), to_send);
350         EXPECT_EQ(recv(self->cfd, recv_buf, to_send, MSG_WAITALL), to_send);
351         EXPECT_EQ(memcmp(test_str, recv_buf, to_send), 0);
352
353         EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
354         EXPECT_EQ(recv(self->cfd, buf, st.st_size, MSG_WAITALL), st.st_size);
355 }
356
357 static void chunked_sendfile(struct __test_metadata *_metadata,
358                              struct _test_data_tls *self,
359                              uint16_t chunk_size,
360                              uint16_t extra_payload_size)
361 {
362         char buf[TLS_PAYLOAD_MAX_LEN];
363         uint16_t test_payload_size;
364         int size = 0;
365         int ret;
366         char filename[] = "/tmp/mytemp.XXXXXX";
367         int fd = mkstemp(filename);
368         off_t offset = 0;
369
370         unlink(filename);
371         ASSERT_GE(fd, 0);
372         EXPECT_GE(chunk_size, 1);
373         test_payload_size = chunk_size + extra_payload_size;
374         ASSERT_GE(TLS_PAYLOAD_MAX_LEN, test_payload_size);
375         memset(buf, 1, test_payload_size);
376         size = write(fd, buf, test_payload_size);
377         EXPECT_EQ(size, test_payload_size);
378         fsync(fd);
379
380         while (size > 0) {
381                 ret = sendfile(self->fd, fd, &offset, chunk_size);
382                 EXPECT_GE(ret, 0);
383                 size -= ret;
384         }
385
386         EXPECT_EQ(recv(self->cfd, buf, test_payload_size, MSG_WAITALL),
387                   test_payload_size);
388
389         close(fd);
390 }
391
392 TEST_F(tls, multi_chunk_sendfile)
393 {
394         chunked_sendfile(_metadata, self, 4096, 4096);
395         chunked_sendfile(_metadata, self, 4096, 0);
396         chunked_sendfile(_metadata, self, 4096, 1);
397         chunked_sendfile(_metadata, self, 4096, 2048);
398         chunked_sendfile(_metadata, self, 8192, 2048);
399         chunked_sendfile(_metadata, self, 4096, 8192);
400         chunked_sendfile(_metadata, self, 8192, 4096);
401         chunked_sendfile(_metadata, self, 12288, 1024);
402         chunked_sendfile(_metadata, self, 12288, 2000);
403         chunked_sendfile(_metadata, self, 15360, 100);
404         chunked_sendfile(_metadata, self, 15360, 300);
405         chunked_sendfile(_metadata, self, 1, 4096);
406         chunked_sendfile(_metadata, self, 2048, 4096);
407         chunked_sendfile(_metadata, self, 2048, 8192);
408         chunked_sendfile(_metadata, self, 4096, 8192);
409         chunked_sendfile(_metadata, self, 1024, 12288);
410         chunked_sendfile(_metadata, self, 2000, 12288);
411         chunked_sendfile(_metadata, self, 100, 15360);
412         chunked_sendfile(_metadata, self, 300, 15360);
413 }
414
415 TEST_F(tls, recv_max)
416 {
417         unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
418         char recv_mem[TLS_PAYLOAD_MAX_LEN];
419         char buf[TLS_PAYLOAD_MAX_LEN];
420
421         memrnd(buf, sizeof(buf));
422
423         EXPECT_GE(send(self->fd, buf, send_len, 0), 0);
424         EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
425         EXPECT_EQ(memcmp(buf, recv_mem, send_len), 0);
426 }
427
428 TEST_F(tls, recv_small)
429 {
430         char const *test_str = "test_read";
431         int send_len = 10;
432         char buf[10];
433
434         send_len = strlen(test_str) + 1;
435         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
436         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
437         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
438 }
439
440 TEST_F(tls, msg_more)
441 {
442         char const *test_str = "test_read";
443         int send_len = 10;
444         char buf[10 * 2];
445
446         EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
447         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
448         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
449         EXPECT_EQ(recv(self->cfd, buf, send_len * 2, MSG_WAITALL),
450                   send_len * 2);
451         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
452 }
453
454 TEST_F(tls, msg_more_unsent)
455 {
456         char const *test_str = "test_read";
457         int send_len = 10;
458         char buf[10];
459
460         EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
461         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
462 }
463
464 TEST_F(tls, sendmsg_single)
465 {
466         struct msghdr msg;
467
468         char const *test_str = "test_sendmsg";
469         size_t send_len = 13;
470         struct iovec vec;
471         char buf[13];
472
473         vec.iov_base = (char *)test_str;
474         vec.iov_len = send_len;
475         memset(&msg, 0, sizeof(struct msghdr));
476         msg.msg_iov = &vec;
477         msg.msg_iovlen = 1;
478         EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
479         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
480         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
481 }
482
483 #define MAX_FRAGS       64
484 #define SEND_LEN        13
485 TEST_F(tls, sendmsg_fragmented)
486 {
487         char const *test_str = "test_sendmsg";
488         char buf[SEND_LEN * MAX_FRAGS];
489         struct iovec vec[MAX_FRAGS];
490         struct msghdr msg;
491         int i, frags;
492
493         for (frags = 1; frags <= MAX_FRAGS; frags++) {
494                 for (i = 0; i < frags; i++) {
495                         vec[i].iov_base = (char *)test_str;
496                         vec[i].iov_len = SEND_LEN;
497                 }
498
499                 memset(&msg, 0, sizeof(struct msghdr));
500                 msg.msg_iov = vec;
501                 msg.msg_iovlen = frags;
502
503                 EXPECT_EQ(sendmsg(self->fd, &msg, 0), SEND_LEN * frags);
504                 EXPECT_EQ(recv(self->cfd, buf, SEND_LEN * frags, MSG_WAITALL),
505                           SEND_LEN * frags);
506
507                 for (i = 0; i < frags; i++)
508                         EXPECT_EQ(memcmp(buf + SEND_LEN * i,
509                                          test_str, SEND_LEN), 0);
510         }
511 }
512 #undef MAX_FRAGS
513 #undef SEND_LEN
514
515 TEST_F(tls, sendmsg_large)
516 {
517         void *mem = malloc(16384);
518         size_t send_len = 16384;
519         size_t sends = 128;
520         struct msghdr msg;
521         size_t recvs = 0;
522         size_t sent = 0;
523
524         memset(&msg, 0, sizeof(struct msghdr));
525         while (sent++ < sends) {
526                 struct iovec vec = { (void *)mem, send_len };
527
528                 msg.msg_iov = &vec;
529                 msg.msg_iovlen = 1;
530                 EXPECT_EQ(sendmsg(self->cfd, &msg, 0), send_len);
531         }
532
533         while (recvs++ < sends) {
534                 EXPECT_NE(recv(self->fd, mem, send_len, 0), -1);
535         }
536
537         free(mem);
538 }
539
540 TEST_F(tls, sendmsg_multiple)
541 {
542         char const *test_str = "test_sendmsg_multiple";
543         struct iovec vec[5];
544         char *test_strs[5];
545         struct msghdr msg;
546         int total_len = 0;
547         int len_cmp = 0;
548         int iov_len = 5;
549         char *buf;
550         int i;
551
552         memset(&msg, 0, sizeof(struct msghdr));
553         for (i = 0; i < iov_len; i++) {
554                 test_strs[i] = (char *)malloc(strlen(test_str) + 1);
555                 snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
556                 vec[i].iov_base = (void *)test_strs[i];
557                 vec[i].iov_len = strlen(test_strs[i]) + 1;
558                 total_len += vec[i].iov_len;
559         }
560         msg.msg_iov = vec;
561         msg.msg_iovlen = iov_len;
562
563         EXPECT_EQ(sendmsg(self->cfd, &msg, 0), total_len);
564         buf = malloc(total_len);
565         EXPECT_NE(recv(self->fd, buf, total_len, 0), -1);
566         for (i = 0; i < iov_len; i++) {
567                 EXPECT_EQ(memcmp(test_strs[i], buf + len_cmp,
568                                  strlen(test_strs[i])),
569                           0);
570                 len_cmp += strlen(buf + len_cmp) + 1;
571         }
572         for (i = 0; i < iov_len; i++)
573                 free(test_strs[i]);
574         free(buf);
575 }
576
577 TEST_F(tls, sendmsg_multiple_stress)
578 {
579         char const *test_str = "abcdefghijklmno";
580         struct iovec vec[1024];
581         char *test_strs[1024];
582         int iov_len = 1024;
583         int total_len = 0;
584         char buf[1 << 14];
585         struct msghdr msg;
586         int len_cmp = 0;
587         int i;
588
589         memset(&msg, 0, sizeof(struct msghdr));
590         for (i = 0; i < iov_len; i++) {
591                 test_strs[i] = (char *)malloc(strlen(test_str) + 1);
592                 snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
593                 vec[i].iov_base = (void *)test_strs[i];
594                 vec[i].iov_len = strlen(test_strs[i]) + 1;
595                 total_len += vec[i].iov_len;
596         }
597         msg.msg_iov = vec;
598         msg.msg_iovlen = iov_len;
599
600         EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
601         EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
602
603         for (i = 0; i < iov_len; i++)
604                 len_cmp += strlen(buf + len_cmp) + 1;
605
606         for (i = 0; i < iov_len; i++)
607                 free(test_strs[i]);
608 }
609
610 TEST_F(tls, splice_from_pipe)
611 {
612         int send_len = TLS_PAYLOAD_MAX_LEN;
613         char mem_send[TLS_PAYLOAD_MAX_LEN];
614         char mem_recv[TLS_PAYLOAD_MAX_LEN];
615         int p[2];
616
617         ASSERT_GE(pipe(p), 0);
618         EXPECT_GE(write(p[1], mem_send, send_len), 0);
619         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), 0);
620         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
621         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
622 }
623
624 TEST_F(tls, splice_from_pipe2)
625 {
626         int send_len = 16000;
627         char mem_send[16000];
628         char mem_recv[16000];
629         int p2[2];
630         int p[2];
631
632         ASSERT_GE(pipe(p), 0);
633         ASSERT_GE(pipe(p2), 0);
634         EXPECT_GE(write(p[1], mem_send, 8000), 0);
635         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, 8000, 0), 0);
636         EXPECT_GE(write(p2[1], mem_send + 8000, 8000), 0);
637         EXPECT_GE(splice(p2[0], NULL, self->fd, NULL, 8000, 0), 0);
638         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
639         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
640 }
641
642 TEST_F(tls, send_and_splice)
643 {
644         int send_len = TLS_PAYLOAD_MAX_LEN;
645         char mem_send[TLS_PAYLOAD_MAX_LEN];
646         char mem_recv[TLS_PAYLOAD_MAX_LEN];
647         char const *test_str = "test_read";
648         int send_len2 = 10;
649         char buf[10];
650         int p[2];
651
652         ASSERT_GE(pipe(p), 0);
653         EXPECT_EQ(send(self->fd, test_str, send_len2, 0), send_len2);
654         EXPECT_EQ(recv(self->cfd, buf, send_len2, MSG_WAITALL), send_len2);
655         EXPECT_EQ(memcmp(test_str, buf, send_len2), 0);
656
657         EXPECT_GE(write(p[1], mem_send, send_len), send_len);
658         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), send_len);
659
660         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
661         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
662 }
663
664 TEST_F(tls, splice_to_pipe)
665 {
666         int send_len = TLS_PAYLOAD_MAX_LEN;
667         char mem_send[TLS_PAYLOAD_MAX_LEN];
668         char mem_recv[TLS_PAYLOAD_MAX_LEN];
669         int p[2];
670
671         ASSERT_GE(pipe(p), 0);
672         EXPECT_GE(send(self->fd, mem_send, send_len, 0), 0);
673         EXPECT_GE(splice(self->cfd, NULL, p[1], NULL, send_len, 0), 0);
674         EXPECT_GE(read(p[0], mem_recv, send_len), 0);
675         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
676 }
677
678 TEST_F(tls, splice_cmsg_to_pipe)
679 {
680         char *test_str = "test_read";
681         char record_type = 100;
682         int send_len = 10;
683         char buf[10];
684         int p[2];
685
686         if (self->notls)
687                 SKIP(return, "no TLS support");
688
689         ASSERT_GE(pipe(p), 0);
690         EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
691         EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
692         EXPECT_EQ(errno, EINVAL);
693         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
694         EXPECT_EQ(errno, EIO);
695         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
696                                 buf, sizeof(buf), MSG_WAITALL),
697                   send_len);
698         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
699 }
700
701 TEST_F(tls, splice_dec_cmsg_to_pipe)
702 {
703         char *test_str = "test_read";
704         char record_type = 100;
705         int send_len = 10;
706         char buf[10];
707         int p[2];
708
709         if (self->notls)
710                 SKIP(return, "no TLS support");
711
712         ASSERT_GE(pipe(p), 0);
713         EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
714         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
715         EXPECT_EQ(errno, EIO);
716         EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
717         EXPECT_EQ(errno, EINVAL);
718         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
719                                 buf, sizeof(buf), MSG_WAITALL),
720                   send_len);
721         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
722 }
723
724 TEST_F(tls, recv_and_splice)
725 {
726         int send_len = TLS_PAYLOAD_MAX_LEN;
727         char mem_send[TLS_PAYLOAD_MAX_LEN];
728         char mem_recv[TLS_PAYLOAD_MAX_LEN];
729         int half = send_len / 2;
730         int p[2];
731
732         ASSERT_GE(pipe(p), 0);
733         EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
734         /* Recv hald of the record, splice the other half */
735         EXPECT_EQ(recv(self->cfd, mem_recv, half, MSG_WAITALL), half);
736         EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, half, SPLICE_F_NONBLOCK),
737                   half);
738         EXPECT_EQ(read(p[0], &mem_recv[half], half), half);
739         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
740 }
741
742 TEST_F(tls, peek_and_splice)
743 {
744         int send_len = TLS_PAYLOAD_MAX_LEN;
745         char mem_send[TLS_PAYLOAD_MAX_LEN];
746         char mem_recv[TLS_PAYLOAD_MAX_LEN];
747         int chunk = TLS_PAYLOAD_MAX_LEN / 4;
748         int n, i, p[2];
749
750         memrnd(mem_send, sizeof(mem_send));
751
752         ASSERT_GE(pipe(p), 0);
753         for (i = 0; i < 4; i++)
754                 EXPECT_EQ(send(self->fd, &mem_send[chunk * i], chunk, 0),
755                           chunk);
756
757         EXPECT_EQ(recv(self->cfd, mem_recv, chunk * 5 / 2,
758                        MSG_WAITALL | MSG_PEEK),
759                   chunk * 5 / 2);
760         EXPECT_EQ(memcmp(mem_send, mem_recv, chunk * 5 / 2), 0);
761
762         n = 0;
763         while (n < send_len) {
764                 i = splice(self->cfd, NULL, p[1], NULL, send_len - n, 0);
765                 EXPECT_GT(i, 0);
766                 n += i;
767         }
768         EXPECT_EQ(n, send_len);
769         EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
770         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
771 }
772
773 TEST_F(tls, recvmsg_single)
774 {
775         char const *test_str = "test_recvmsg_single";
776         int send_len = strlen(test_str) + 1;
777         char buf[20];
778         struct msghdr hdr;
779         struct iovec vec;
780
781         memset(&hdr, 0, sizeof(hdr));
782         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
783         vec.iov_base = (char *)buf;
784         vec.iov_len = send_len;
785         hdr.msg_iovlen = 1;
786         hdr.msg_iov = &vec;
787         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
788         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
789 }
790
791 TEST_F(tls, recvmsg_single_max)
792 {
793         int send_len = TLS_PAYLOAD_MAX_LEN;
794         char send_mem[TLS_PAYLOAD_MAX_LEN];
795         char recv_mem[TLS_PAYLOAD_MAX_LEN];
796         struct iovec vec;
797         struct msghdr hdr;
798
799         memrnd(send_mem, sizeof(send_mem));
800
801         EXPECT_EQ(send(self->fd, send_mem, send_len, 0), send_len);
802         vec.iov_base = (char *)recv_mem;
803         vec.iov_len = TLS_PAYLOAD_MAX_LEN;
804
805         hdr.msg_iovlen = 1;
806         hdr.msg_iov = &vec;
807         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
808         EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
809 }
810
811 TEST_F(tls, recvmsg_multiple)
812 {
813         unsigned int msg_iovlen = 1024;
814         struct iovec vec[1024];
815         char *iov_base[1024];
816         unsigned int iov_len = 16;
817         int send_len = 1 << 14;
818         char buf[1 << 14];
819         struct msghdr hdr;
820         int i;
821
822         memrnd(buf, sizeof(buf));
823
824         EXPECT_EQ(send(self->fd, buf, send_len, 0), send_len);
825         for (i = 0; i < msg_iovlen; i++) {
826                 iov_base[i] = (char *)malloc(iov_len);
827                 vec[i].iov_base = iov_base[i];
828                 vec[i].iov_len = iov_len;
829         }
830
831         hdr.msg_iovlen = msg_iovlen;
832         hdr.msg_iov = vec;
833         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
834
835         for (i = 0; i < msg_iovlen; i++)
836                 free(iov_base[i]);
837 }
838
839 TEST_F(tls, single_send_multiple_recv)
840 {
841         unsigned int total_len = TLS_PAYLOAD_MAX_LEN * 2;
842         unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
843         char send_mem[TLS_PAYLOAD_MAX_LEN * 2];
844         char recv_mem[TLS_PAYLOAD_MAX_LEN * 2];
845
846         memrnd(send_mem, sizeof(send_mem));
847
848         EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
849         memset(recv_mem, 0, total_len);
850
851         EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
852         EXPECT_NE(recv(self->cfd, recv_mem + send_len, send_len, 0), -1);
853         EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
854 }
855
856 TEST_F(tls, multiple_send_single_recv)
857 {
858         unsigned int total_len = 2 * 10;
859         unsigned int send_len = 10;
860         char recv_mem[2 * 10];
861         char send_mem[10];
862
863         EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
864         EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
865         memset(recv_mem, 0, total_len);
866         EXPECT_EQ(recv(self->cfd, recv_mem, total_len, MSG_WAITALL), total_len);
867
868         EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
869         EXPECT_EQ(memcmp(send_mem, recv_mem + send_len, send_len), 0);
870 }
871
872 TEST_F(tls, single_send_multiple_recv_non_align)
873 {
874         const unsigned int total_len = 15;
875         const unsigned int recv_len = 10;
876         char recv_mem[recv_len * 2];
877         char send_mem[total_len];
878
879         EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
880         memset(recv_mem, 0, total_len);
881
882         EXPECT_EQ(recv(self->cfd, recv_mem, recv_len, 0), recv_len);
883         EXPECT_EQ(recv(self->cfd, recv_mem + recv_len, recv_len, 0), 5);
884         EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
885 }
886
887 TEST_F(tls, recv_partial)
888 {
889         char const *test_str = "test_read_partial";
890         char const *test_str_first = "test_read";
891         char const *test_str_second = "_partial";
892         int send_len = strlen(test_str) + 1;
893         char recv_mem[18];
894
895         memset(recv_mem, 0, sizeof(recv_mem));
896         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
897         EXPECT_NE(recv(self->cfd, recv_mem, strlen(test_str_first),
898                        MSG_WAITALL), -1);
899         EXPECT_EQ(memcmp(test_str_first, recv_mem, strlen(test_str_first)), 0);
900         memset(recv_mem, 0, sizeof(recv_mem));
901         EXPECT_NE(recv(self->cfd, recv_mem, strlen(test_str_second),
902                        MSG_WAITALL), -1);
903         EXPECT_EQ(memcmp(test_str_second, recv_mem, strlen(test_str_second)),
904                   0);
905 }
906
907 TEST_F(tls, recv_nonblock)
908 {
909         char buf[4096];
910         bool err;
911
912         EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
913         err = (errno == EAGAIN || errno == EWOULDBLOCK);
914         EXPECT_EQ(err, true);
915 }
916
917 TEST_F(tls, recv_peek)
918 {
919         char const *test_str = "test_read_peek";
920         int send_len = strlen(test_str) + 1;
921         char buf[15];
922
923         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
924         EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
925         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
926         memset(buf, 0, sizeof(buf));
927         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
928         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
929 }
930
931 TEST_F(tls, recv_peek_multiple)
932 {
933         char const *test_str = "test_read_peek";
934         int send_len = strlen(test_str) + 1;
935         unsigned int num_peeks = 100;
936         char buf[15];
937         int i;
938
939         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
940         for (i = 0; i < num_peeks; i++) {
941                 EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
942                 EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
943                 memset(buf, 0, sizeof(buf));
944         }
945         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
946         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
947 }
948
949 TEST_F(tls, recv_peek_multiple_records)
950 {
951         char const *test_str = "test_read_peek_mult_recs";
952         char const *test_str_first = "test_read_peek";
953         char const *test_str_second = "_mult_recs";
954         int len;
955         char buf[64];
956
957         len = strlen(test_str_first);
958         EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
959
960         len = strlen(test_str_second) + 1;
961         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
962
963         len = strlen(test_str_first);
964         memset(buf, 0, len);
965         EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
966
967         /* MSG_PEEK can only peek into the current record. */
968         len = strlen(test_str_first);
969         EXPECT_EQ(memcmp(test_str_first, buf, len), 0);
970
971         len = strlen(test_str) + 1;
972         memset(buf, 0, len);
973         EXPECT_EQ(recv(self->cfd, buf, len, MSG_WAITALL), len);
974
975         /* Non-MSG_PEEK will advance strparser (and therefore record)
976          * however.
977          */
978         len = strlen(test_str) + 1;
979         EXPECT_EQ(memcmp(test_str, buf, len), 0);
980
981         /* MSG_MORE will hold current record open, so later MSG_PEEK
982          * will see everything.
983          */
984         len = strlen(test_str_first);
985         EXPECT_EQ(send(self->fd, test_str_first, len, MSG_MORE), len);
986
987         len = strlen(test_str_second) + 1;
988         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
989
990         len = strlen(test_str) + 1;
991         memset(buf, 0, len);
992         EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
993
994         len = strlen(test_str) + 1;
995         EXPECT_EQ(memcmp(test_str, buf, len), 0);
996 }
997
998 TEST_F(tls, recv_peek_large_buf_mult_recs)
999 {
1000         char const *test_str = "test_read_peek_mult_recs";
1001         char const *test_str_first = "test_read_peek";
1002         char const *test_str_second = "_mult_recs";
1003         int len;
1004         char buf[64];
1005
1006         len = strlen(test_str_first);
1007         EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
1008
1009         len = strlen(test_str_second) + 1;
1010         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1011
1012         len = strlen(test_str) + 1;
1013         memset(buf, 0, len);
1014         EXPECT_NE((len = recv(self->cfd, buf, len,
1015                               MSG_PEEK | MSG_WAITALL)), -1);
1016         len = strlen(test_str) + 1;
1017         EXPECT_EQ(memcmp(test_str, buf, len), 0);
1018 }
1019
1020 TEST_F(tls, recv_lowat)
1021 {
1022         char send_mem[10] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
1023         char recv_mem[20];
1024         int lowat = 8;
1025
1026         EXPECT_EQ(send(self->fd, send_mem, 10, 0), 10);
1027         EXPECT_EQ(send(self->fd, send_mem, 5, 0), 5);
1028
1029         memset(recv_mem, 0, 20);
1030         EXPECT_EQ(setsockopt(self->cfd, SOL_SOCKET, SO_RCVLOWAT,
1031                              &lowat, sizeof(lowat)), 0);
1032         EXPECT_EQ(recv(self->cfd, recv_mem, 1, MSG_WAITALL), 1);
1033         EXPECT_EQ(recv(self->cfd, recv_mem + 1, 6, MSG_WAITALL), 6);
1034         EXPECT_EQ(recv(self->cfd, recv_mem + 7, 10, 0), 8);
1035
1036         EXPECT_EQ(memcmp(send_mem, recv_mem, 10), 0);
1037         EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0);
1038 }
1039
1040 TEST_F(tls, bidir)
1041 {
1042         char const *test_str = "test_read";
1043         int send_len = 10;
1044         char buf[10];
1045         int ret;
1046
1047         if (!self->notls) {
1048                 struct tls_crypto_info_keys tls12;
1049
1050                 tls_crypto_info_init(variant->tls_version, variant->cipher_type,
1051                                      &tls12);
1052
1053                 ret = setsockopt(self->fd, SOL_TLS, TLS_RX, &tls12,
1054                                  tls12.len);
1055                 ASSERT_EQ(ret, 0);
1056
1057                 ret = setsockopt(self->cfd, SOL_TLS, TLS_TX, &tls12,
1058                                  tls12.len);
1059                 ASSERT_EQ(ret, 0);
1060         }
1061
1062         ASSERT_EQ(strlen(test_str) + 1, send_len);
1063
1064         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1065         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1066         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1067
1068         memset(buf, 0, sizeof(buf));
1069
1070         EXPECT_EQ(send(self->cfd, test_str, send_len, 0), send_len);
1071         EXPECT_NE(recv(self->fd, buf, send_len, 0), -1);
1072         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1073 };
1074
1075 TEST_F(tls, pollin)
1076 {
1077         char const *test_str = "test_poll";
1078         struct pollfd fd = { 0, 0, 0 };
1079         char buf[10];
1080         int send_len = 10;
1081
1082         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1083         fd.fd = self->cfd;
1084         fd.events = POLLIN;
1085
1086         EXPECT_EQ(poll(&fd, 1, 20), 1);
1087         EXPECT_EQ(fd.revents & POLLIN, 1);
1088         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
1089         /* Test timing out */
1090         EXPECT_EQ(poll(&fd, 1, 20), 0);
1091 }
1092
1093 TEST_F(tls, poll_wait)
1094 {
1095         char const *test_str = "test_poll_wait";
1096         int send_len = strlen(test_str) + 1;
1097         struct pollfd fd = { 0, 0, 0 };
1098         char recv_mem[15];
1099
1100         fd.fd = self->cfd;
1101         fd.events = POLLIN;
1102         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1103         /* Set timeout to inf. secs */
1104         EXPECT_EQ(poll(&fd, 1, -1), 1);
1105         EXPECT_EQ(fd.revents & POLLIN, 1);
1106         EXPECT_EQ(recv(self->cfd, recv_mem, send_len, MSG_WAITALL), send_len);
1107 }
1108
1109 TEST_F(tls, poll_wait_split)
1110 {
1111         struct pollfd fd = { 0, 0, 0 };
1112         char send_mem[20] = {};
1113         char recv_mem[15];
1114
1115         fd.fd = self->cfd;
1116         fd.events = POLLIN;
1117         /* Send 20 bytes */
1118         EXPECT_EQ(send(self->fd, send_mem, sizeof(send_mem), 0),
1119                   sizeof(send_mem));
1120         /* Poll with inf. timeout */
1121         EXPECT_EQ(poll(&fd, 1, -1), 1);
1122         EXPECT_EQ(fd.revents & POLLIN, 1);
1123         EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), MSG_WAITALL),
1124                   sizeof(recv_mem));
1125
1126         /* Now the remaining 5 bytes of record data are in TLS ULP */
1127         fd.fd = self->cfd;
1128         fd.events = POLLIN;
1129         EXPECT_EQ(poll(&fd, 1, -1), 1);
1130         EXPECT_EQ(fd.revents & POLLIN, 1);
1131         EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0),
1132                   sizeof(send_mem) - sizeof(recv_mem));
1133 }
1134
1135 TEST_F(tls, blocking)
1136 {
1137         size_t data = 100000;
1138         int res = fork();
1139
1140         EXPECT_NE(res, -1);
1141
1142         if (res) {
1143                 /* parent */
1144                 size_t left = data;
1145                 char buf[16384];
1146                 int status;
1147                 int pid2;
1148
1149                 while (left) {
1150                         int res = send(self->fd, buf,
1151                                        left > 16384 ? 16384 : left, 0);
1152
1153                         EXPECT_GE(res, 0);
1154                         left -= res;
1155                 }
1156
1157                 pid2 = wait(&status);
1158                 EXPECT_EQ(status, 0);
1159                 EXPECT_EQ(res, pid2);
1160         } else {
1161                 /* child */
1162                 size_t left = data;
1163                 char buf[16384];
1164
1165                 while (left) {
1166                         int res = recv(self->cfd, buf,
1167                                        left > 16384 ? 16384 : left, 0);
1168
1169                         EXPECT_GE(res, 0);
1170                         left -= res;
1171                 }
1172         }
1173 }
1174
1175 TEST_F(tls, nonblocking)
1176 {
1177         size_t data = 100000;
1178         int sendbuf = 100;
1179         int flags;
1180         int res;
1181
1182         flags = fcntl(self->fd, F_GETFL, 0);
1183         fcntl(self->fd, F_SETFL, flags | O_NONBLOCK);
1184         fcntl(self->cfd, F_SETFL, flags | O_NONBLOCK);
1185
1186         /* Ensure nonblocking behavior by imposing a small send
1187          * buffer.
1188          */
1189         EXPECT_EQ(setsockopt(self->fd, SOL_SOCKET, SO_SNDBUF,
1190                              &sendbuf, sizeof(sendbuf)), 0);
1191
1192         res = fork();
1193         EXPECT_NE(res, -1);
1194
1195         if (res) {
1196                 /* parent */
1197                 bool eagain = false;
1198                 size_t left = data;
1199                 char buf[16384];
1200                 int status;
1201                 int pid2;
1202
1203                 while (left) {
1204                         int res = send(self->fd, buf,
1205                                        left > 16384 ? 16384 : left, 0);
1206
1207                         if (res == -1 && errno == EAGAIN) {
1208                                 eagain = true;
1209                                 usleep(10000);
1210                                 continue;
1211                         }
1212                         EXPECT_GE(res, 0);
1213                         left -= res;
1214                 }
1215
1216                 EXPECT_TRUE(eagain);
1217                 pid2 = wait(&status);
1218
1219                 EXPECT_EQ(status, 0);
1220                 EXPECT_EQ(res, pid2);
1221         } else {
1222                 /* child */
1223                 bool eagain = false;
1224                 size_t left = data;
1225                 char buf[16384];
1226
1227                 while (left) {
1228                         int res = recv(self->cfd, buf,
1229                                        left > 16384 ? 16384 : left, 0);
1230
1231                         if (res == -1 && errno == EAGAIN) {
1232                                 eagain = true;
1233                                 usleep(10000);
1234                                 continue;
1235                         }
1236                         EXPECT_GE(res, 0);
1237                         left -= res;
1238                 }
1239                 EXPECT_TRUE(eagain);
1240         }
1241 }
1242
1243 static void
1244 test_mutliproc(struct __test_metadata *_metadata, struct _test_data_tls *self,
1245                bool sendpg, unsigned int n_readers, unsigned int n_writers)
1246 {
1247         const unsigned int n_children = n_readers + n_writers;
1248         const size_t data = 6 * 1000 * 1000;
1249         const size_t file_sz = data / 100;
1250         size_t read_bias, write_bias;
1251         int i, fd, child_id;
1252         char buf[file_sz];
1253         pid_t pid;
1254
1255         /* Only allow multiples for simplicity */
1256         ASSERT_EQ(!(n_readers % n_writers) || !(n_writers % n_readers), true);
1257         read_bias = n_writers / n_readers ?: 1;
1258         write_bias = n_readers / n_writers ?: 1;
1259
1260         /* prep a file to send */
1261         fd = open("/tmp/", O_TMPFILE | O_RDWR, 0600);
1262         ASSERT_GE(fd, 0);
1263
1264         memset(buf, 0xac, file_sz);
1265         ASSERT_EQ(write(fd, buf, file_sz), file_sz);
1266
1267         /* spawn children */
1268         for (child_id = 0; child_id < n_children; child_id++) {
1269                 pid = fork();
1270                 ASSERT_NE(pid, -1);
1271                 if (!pid)
1272                         break;
1273         }
1274
1275         /* parent waits for all children */
1276         if (pid) {
1277                 for (i = 0; i < n_children; i++) {
1278                         int status;
1279
1280                         wait(&status);
1281                         EXPECT_EQ(status, 0);
1282                 }
1283
1284                 return;
1285         }
1286
1287         /* Split threads for reading and writing */
1288         if (child_id < n_readers) {
1289                 size_t left = data * read_bias;
1290                 char rb[8001];
1291
1292                 while (left) {
1293                         int res;
1294
1295                         res = recv(self->cfd, rb,
1296                                    left > sizeof(rb) ? sizeof(rb) : left, 0);
1297
1298                         EXPECT_GE(res, 0);
1299                         left -= res;
1300                 }
1301         } else {
1302                 size_t left = data * write_bias;
1303
1304                 while (left) {
1305                         int res;
1306
1307                         ASSERT_EQ(lseek(fd, 0, SEEK_SET), 0);
1308                         if (sendpg)
1309                                 res = sendfile(self->fd, fd, NULL,
1310                                                left > file_sz ? file_sz : left);
1311                         else
1312                                 res = send(self->fd, buf,
1313                                            left > file_sz ? file_sz : left, 0);
1314
1315                         EXPECT_GE(res, 0);
1316                         left -= res;
1317                 }
1318         }
1319 }
1320
1321 TEST_F(tls, mutliproc_even)
1322 {
1323         test_mutliproc(_metadata, self, false, 6, 6);
1324 }
1325
1326 TEST_F(tls, mutliproc_readers)
1327 {
1328         test_mutliproc(_metadata, self, false, 4, 12);
1329 }
1330
1331 TEST_F(tls, mutliproc_writers)
1332 {
1333         test_mutliproc(_metadata, self, false, 10, 2);
1334 }
1335
1336 TEST_F(tls, mutliproc_sendpage_even)
1337 {
1338         test_mutliproc(_metadata, self, true, 6, 6);
1339 }
1340
1341 TEST_F(tls, mutliproc_sendpage_readers)
1342 {
1343         test_mutliproc(_metadata, self, true, 4, 12);
1344 }
1345
1346 TEST_F(tls, mutliproc_sendpage_writers)
1347 {
1348         test_mutliproc(_metadata, self, true, 10, 2);
1349 }
1350
1351 TEST_F(tls, control_msg)
1352 {
1353         char *test_str = "test_read";
1354         char record_type = 100;
1355         int send_len = 10;
1356         char buf[10];
1357
1358         if (self->notls)
1359                 SKIP(return, "no TLS support");
1360
1361         EXPECT_EQ(tls_send_cmsg(self->fd, record_type, test_str, send_len, 0),
1362                   send_len);
1363         /* Should fail because we didn't provide a control message */
1364         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
1365
1366         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1367                                 buf, sizeof(buf), MSG_WAITALL | MSG_PEEK),
1368                   send_len);
1369         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1370
1371         /* Recv the message again without MSG_PEEK */
1372         memset(buf, 0, sizeof(buf));
1373
1374         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1375                                 buf, sizeof(buf), MSG_WAITALL),
1376                   send_len);
1377         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1378 }
1379
1380 TEST_F(tls, shutdown)
1381 {
1382         char const *test_str = "test_read";
1383         int send_len = 10;
1384         char buf[10];
1385
1386         ASSERT_EQ(strlen(test_str) + 1, send_len);
1387
1388         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1389         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1390         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1391
1392         shutdown(self->fd, SHUT_RDWR);
1393         shutdown(self->cfd, SHUT_RDWR);
1394 }
1395
1396 TEST_F(tls, shutdown_unsent)
1397 {
1398         char const *test_str = "test_read";
1399         int send_len = 10;
1400
1401         EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
1402
1403         shutdown(self->fd, SHUT_RDWR);
1404         shutdown(self->cfd, SHUT_RDWR);
1405 }
1406
1407 TEST_F(tls, shutdown_reuse)
1408 {
1409         struct sockaddr_in addr;
1410         int ret;
1411
1412         shutdown(self->fd, SHUT_RDWR);
1413         shutdown(self->cfd, SHUT_RDWR);
1414         close(self->cfd);
1415
1416         addr.sin_family = AF_INET;
1417         addr.sin_addr.s_addr = htonl(INADDR_ANY);
1418         addr.sin_port = 0;
1419
1420         ret = bind(self->fd, &addr, sizeof(addr));
1421         EXPECT_EQ(ret, 0);
1422         ret = listen(self->fd, 10);
1423         EXPECT_EQ(ret, -1);
1424         EXPECT_EQ(errno, EINVAL);
1425
1426         ret = connect(self->fd, &addr, sizeof(addr));
1427         EXPECT_EQ(ret, -1);
1428         EXPECT_EQ(errno, EISCONN);
1429 }
1430
1431 FIXTURE(tls_err)
1432 {
1433         int fd, cfd;
1434         int fd2, cfd2;
1435         bool notls;
1436 };
1437
1438 FIXTURE_VARIANT(tls_err)
1439 {
1440         uint16_t tls_version;
1441 };
1442
1443 FIXTURE_VARIANT_ADD(tls_err, 12_aes_gcm)
1444 {
1445         .tls_version = TLS_1_2_VERSION,
1446 };
1447
1448 FIXTURE_VARIANT_ADD(tls_err, 13_aes_gcm)
1449 {
1450         .tls_version = TLS_1_3_VERSION,
1451 };
1452
1453 FIXTURE_SETUP(tls_err)
1454 {
1455         struct tls_crypto_info_keys tls12;
1456         int ret;
1457
1458         tls_crypto_info_init(variant->tls_version, TLS_CIPHER_AES_GCM_128,
1459                              &tls12);
1460
1461         ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
1462         ulp_sock_pair(_metadata, &self->fd2, &self->cfd2, &self->notls);
1463         if (self->notls)
1464                 return;
1465
1466         ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
1467         ASSERT_EQ(ret, 0);
1468
1469         ret = setsockopt(self->cfd2, SOL_TLS, TLS_RX, &tls12, tls12.len);
1470         ASSERT_EQ(ret, 0);
1471 }
1472
1473 FIXTURE_TEARDOWN(tls_err)
1474 {
1475         close(self->fd);
1476         close(self->cfd);
1477         close(self->fd2);
1478         close(self->cfd2);
1479 }
1480
1481 TEST_F(tls_err, bad_rec)
1482 {
1483         char buf[64];
1484
1485         if (self->notls)
1486                 SKIP(return, "no TLS support");
1487
1488         memset(buf, 0x55, sizeof(buf));
1489         EXPECT_EQ(send(self->fd2, buf, sizeof(buf), 0), sizeof(buf));
1490         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1491         EXPECT_EQ(errno, EMSGSIZE);
1492         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), MSG_DONTWAIT), -1);
1493         EXPECT_EQ(errno, EAGAIN);
1494 }
1495
1496 TEST_F(tls_err, bad_auth)
1497 {
1498         char buf[128];
1499         int n;
1500
1501         if (self->notls)
1502                 SKIP(return, "no TLS support");
1503
1504         memrnd(buf, sizeof(buf) / 2);
1505         EXPECT_EQ(send(self->fd, buf, sizeof(buf) / 2, 0), sizeof(buf) / 2);
1506         n = recv(self->cfd, buf, sizeof(buf), 0);
1507         EXPECT_GT(n, sizeof(buf) / 2);
1508
1509         buf[n - 1]++;
1510
1511         EXPECT_EQ(send(self->fd2, buf, n, 0), n);
1512         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1513         EXPECT_EQ(errno, EBADMSG);
1514         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1515         EXPECT_EQ(errno, EBADMSG);
1516 }
1517
1518 TEST_F(tls_err, bad_in_large_read)
1519 {
1520         char txt[3][64];
1521         char cip[3][128];
1522         char buf[3 * 128];
1523         int i, n;
1524
1525         if (self->notls)
1526                 SKIP(return, "no TLS support");
1527
1528         /* Put 3 records in the sockets */
1529         for (i = 0; i < 3; i++) {
1530                 memrnd(txt[i], sizeof(txt[i]));
1531                 EXPECT_EQ(send(self->fd, txt[i], sizeof(txt[i]), 0),
1532                           sizeof(txt[i]));
1533                 n = recv(self->cfd, cip[i], sizeof(cip[i]), 0);
1534                 EXPECT_GT(n, sizeof(txt[i]));
1535                 /* Break the third message */
1536                 if (i == 2)
1537                         cip[2][n - 1]++;
1538                 EXPECT_EQ(send(self->fd2, cip[i], n, 0), n);
1539         }
1540
1541         /* We should be able to receive the first two messages */
1542         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt[0]) * 2);
1543         EXPECT_EQ(memcmp(buf, txt[0], sizeof(txt[0])), 0);
1544         EXPECT_EQ(memcmp(buf + sizeof(txt[0]), txt[1], sizeof(txt[1])), 0);
1545         /* Third mesasge is bad */
1546         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1547         EXPECT_EQ(errno, EBADMSG);
1548         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1549         EXPECT_EQ(errno, EBADMSG);
1550 }
1551
1552 TEST_F(tls_err, bad_cmsg)
1553 {
1554         char *test_str = "test_read";
1555         int send_len = 10;
1556         char cip[128];
1557         char buf[128];
1558         char txt[64];
1559         int n;
1560
1561         if (self->notls)
1562                 SKIP(return, "no TLS support");
1563
1564         /* Queue up one data record */
1565         memrnd(txt, sizeof(txt));
1566         EXPECT_EQ(send(self->fd, txt, sizeof(txt), 0), sizeof(txt));
1567         n = recv(self->cfd, cip, sizeof(cip), 0);
1568         EXPECT_GT(n, sizeof(txt));
1569         EXPECT_EQ(send(self->fd2, cip, n, 0), n);
1570
1571         EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
1572         n = recv(self->cfd, cip, sizeof(cip), 0);
1573         cip[n - 1]++; /* Break it */
1574         EXPECT_GT(n, send_len);
1575         EXPECT_EQ(send(self->fd2, cip, n, 0), n);
1576
1577         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt));
1578         EXPECT_EQ(memcmp(buf, txt, sizeof(txt)), 0);
1579         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1580         EXPECT_EQ(errno, EBADMSG);
1581         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1582         EXPECT_EQ(errno, EBADMSG);
1583 }
1584
1585 TEST(non_established) {
1586         struct tls12_crypto_info_aes_gcm_256 tls12;
1587         struct sockaddr_in addr;
1588         int sfd, ret, fd;
1589         socklen_t len;
1590
1591         len = sizeof(addr);
1592
1593         memset(&tls12, 0, sizeof(tls12));
1594         tls12.info.version = TLS_1_2_VERSION;
1595         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
1596
1597         addr.sin_family = AF_INET;
1598         addr.sin_addr.s_addr = htonl(INADDR_ANY);
1599         addr.sin_port = 0;
1600
1601         fd = socket(AF_INET, SOCK_STREAM, 0);
1602         sfd = socket(AF_INET, SOCK_STREAM, 0);
1603
1604         ret = bind(sfd, &addr, sizeof(addr));
1605         ASSERT_EQ(ret, 0);
1606         ret = listen(sfd, 10);
1607         ASSERT_EQ(ret, 0);
1608
1609         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1610         EXPECT_EQ(ret, -1);
1611         /* TLS ULP not supported */
1612         if (errno == ENOENT)
1613                 return;
1614         EXPECT_EQ(errno, ENOTCONN);
1615
1616         ret = setsockopt(sfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1617         EXPECT_EQ(ret, -1);
1618         EXPECT_EQ(errno, ENOTCONN);
1619
1620         ret = getsockname(sfd, &addr, &len);
1621         ASSERT_EQ(ret, 0);
1622
1623         ret = connect(fd, &addr, sizeof(addr));
1624         ASSERT_EQ(ret, 0);
1625
1626         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1627         ASSERT_EQ(ret, 0);
1628
1629         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1630         EXPECT_EQ(ret, -1);
1631         EXPECT_EQ(errno, EEXIST);
1632
1633         close(fd);
1634         close(sfd);
1635 }
1636
1637 TEST(keysizes) {
1638         struct tls12_crypto_info_aes_gcm_256 tls12;
1639         int ret, fd, cfd;
1640         bool notls;
1641
1642         memset(&tls12, 0, sizeof(tls12));
1643         tls12.info.version = TLS_1_2_VERSION;
1644         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
1645
1646         ulp_sock_pair(_metadata, &fd, &cfd, &notls);
1647
1648         if (!notls) {
1649                 ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
1650                                  sizeof(tls12));
1651                 EXPECT_EQ(ret, 0);
1652
1653                 ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
1654                                  sizeof(tls12));
1655                 EXPECT_EQ(ret, 0);
1656         }
1657
1658         close(fd);
1659         close(cfd);
1660 }
1661
1662 TEST(tls_v6ops) {
1663         struct tls_crypto_info_keys tls12;
1664         struct sockaddr_in6 addr, addr2;
1665         int sfd, ret, fd;
1666         socklen_t len, len2;
1667
1668         tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_128, &tls12);
1669
1670         addr.sin6_family = AF_INET6;
1671         addr.sin6_addr = in6addr_any;
1672         addr.sin6_port = 0;
1673
1674         fd = socket(AF_INET6, SOCK_STREAM, 0);
1675         sfd = socket(AF_INET6, SOCK_STREAM, 0);
1676
1677         ret = bind(sfd, &addr, sizeof(addr));
1678         ASSERT_EQ(ret, 0);
1679         ret = listen(sfd, 10);
1680         ASSERT_EQ(ret, 0);
1681
1682         len = sizeof(addr);
1683         ret = getsockname(sfd, &addr, &len);
1684         ASSERT_EQ(ret, 0);
1685
1686         ret = connect(fd, &addr, sizeof(addr));
1687         ASSERT_EQ(ret, 0);
1688
1689         len = sizeof(addr);
1690         ret = getsockname(fd, &addr, &len);
1691         ASSERT_EQ(ret, 0);
1692
1693         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1694         if (ret) {
1695                 ASSERT_EQ(errno, ENOENT);
1696                 SKIP(return, "no TLS support");
1697         }
1698         ASSERT_EQ(ret, 0);
1699
1700         ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
1701         ASSERT_EQ(ret, 0);
1702
1703         ret = setsockopt(fd, SOL_TLS, TLS_RX, &tls12, tls12.len);
1704         ASSERT_EQ(ret, 0);
1705
1706         len2 = sizeof(addr2);
1707         ret = getsockname(fd, &addr2, &len2);
1708         ASSERT_EQ(ret, 0);
1709
1710         EXPECT_EQ(len2, len);
1711         EXPECT_EQ(memcmp(&addr, &addr2, len), 0);
1712
1713         close(fd);
1714         close(sfd);
1715 }
1716
1717 TEST_HARNESS_MAIN