Merge git://git.kernel.org/pub/scm/linux/kernel/git/netdev/net
[linux-2.6-microblaze.git] / net / tls / tls_main.c
1 /*
2  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4  *
5  * This software is available to you under a choice of one of two
6  * licenses.  You may choose to be licensed under the terms of the GNU
7  * General Public License (GPL) Version 2, available from the file
8  * COPYING in the main directory of this source tree, or the
9  * OpenIB.org BSD license below:
10  *
11  *     Redistribution and use in source and binary forms, with or
12  *     without modification, are permitted provided that the following
13  *     conditions are met:
14  *
15  *      - Redistributions of source code must retain the above
16  *        copyright notice, this list of conditions and the following
17  *        disclaimer.
18  *
19  *      - Redistributions in binary form must reproduce the above
20  *        copyright notice, this list of conditions and the following
21  *        disclaimer in the documentation and/or other materials
22  *        provided with the distribution.
23  *
24  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
25  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
26  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
27  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
28  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
29  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
30  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31  * SOFTWARE.
32  */
33
34 #include <linux/module.h>
35
36 #include <net/tcp.h>
37 #include <net/inet_common.h>
38 #include <linux/highmem.h>
39 #include <linux/netdevice.h>
40 #include <linux/sched/signal.h>
41 #include <linux/inetdevice.h>
42 #include <linux/inet_diag.h>
43
44 #include <net/snmp.h>
45 #include <net/tls.h>
46 #include <net/tls_toe.h>
47
48 #include "tls.h"
49
50 MODULE_AUTHOR("Mellanox Technologies");
51 MODULE_DESCRIPTION("Transport Layer Security Support");
52 MODULE_LICENSE("Dual BSD/GPL");
53 MODULE_ALIAS_TCP_ULP("tls");
54
55 enum {
56         TLSV4,
57         TLSV6,
58         TLS_NUM_PROTS,
59 };
60
61 static const struct proto *saved_tcpv6_prot;
62 static DEFINE_MUTEX(tcpv6_prot_mutex);
63 static const struct proto *saved_tcpv4_prot;
64 static DEFINE_MUTEX(tcpv4_prot_mutex);
65 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
66 static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
67 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
68                          const struct proto *base);
69
70 void update_sk_prot(struct sock *sk, struct tls_context *ctx)
71 {
72         int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
73
74         WRITE_ONCE(sk->sk_prot,
75                    &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
76         WRITE_ONCE(sk->sk_socket->ops,
77                    &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]);
78 }
79
80 int wait_on_pending_writer(struct sock *sk, long *timeo)
81 {
82         int rc = 0;
83         DEFINE_WAIT_FUNC(wait, woken_wake_function);
84
85         add_wait_queue(sk_sleep(sk), &wait);
86         while (1) {
87                 if (!*timeo) {
88                         rc = -EAGAIN;
89                         break;
90                 }
91
92                 if (signal_pending(current)) {
93                         rc = sock_intr_errno(*timeo);
94                         break;
95                 }
96
97                 if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait))
98                         break;
99         }
100         remove_wait_queue(sk_sleep(sk), &wait);
101         return rc;
102 }
103
104 int tls_push_sg(struct sock *sk,
105                 struct tls_context *ctx,
106                 struct scatterlist *sg,
107                 u16 first_offset,
108                 int flags)
109 {
110         int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST;
111         int ret = 0;
112         struct page *p;
113         size_t size;
114         int offset = first_offset;
115
116         size = sg->length - offset;
117         offset += sg->offset;
118
119         ctx->in_tcp_sendpages = true;
120         while (1) {
121                 if (sg_is_last(sg))
122                         sendpage_flags = flags;
123
124                 /* is sending application-limited? */
125                 tcp_rate_check_app_limited(sk);
126                 p = sg_page(sg);
127 retry:
128                 ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags);
129
130                 if (ret != size) {
131                         if (ret > 0) {
132                                 offset += ret;
133                                 size -= ret;
134                                 goto retry;
135                         }
136
137                         offset -= sg->offset;
138                         ctx->partially_sent_offset = offset;
139                         ctx->partially_sent_record = (void *)sg;
140                         ctx->in_tcp_sendpages = false;
141                         return ret;
142                 }
143
144                 put_page(p);
145                 sk_mem_uncharge(sk, sg->length);
146                 sg = sg_next(sg);
147                 if (!sg)
148                         break;
149
150                 offset = sg->offset;
151                 size = sg->length;
152         }
153
154         ctx->in_tcp_sendpages = false;
155
156         return 0;
157 }
158
159 static int tls_handle_open_record(struct sock *sk, int flags)
160 {
161         struct tls_context *ctx = tls_get_ctx(sk);
162
163         if (tls_is_pending_open_record(ctx))
164                 return ctx->push_pending_record(sk, flags);
165
166         return 0;
167 }
168
169 int tls_process_cmsg(struct sock *sk, struct msghdr *msg,
170                      unsigned char *record_type)
171 {
172         struct cmsghdr *cmsg;
173         int rc = -EINVAL;
174
175         for_each_cmsghdr(cmsg, msg) {
176                 if (!CMSG_OK(msg, cmsg))
177                         return -EINVAL;
178                 if (cmsg->cmsg_level != SOL_TLS)
179                         continue;
180
181                 switch (cmsg->cmsg_type) {
182                 case TLS_SET_RECORD_TYPE:
183                         if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type)))
184                                 return -EINVAL;
185
186                         if (msg->msg_flags & MSG_MORE)
187                                 return -EINVAL;
188
189                         rc = tls_handle_open_record(sk, msg->msg_flags);
190                         if (rc)
191                                 return rc;
192
193                         *record_type = *(unsigned char *)CMSG_DATA(cmsg);
194                         rc = 0;
195                         break;
196                 default:
197                         return -EINVAL;
198                 }
199         }
200
201         return rc;
202 }
203
204 int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
205                             int flags)
206 {
207         struct scatterlist *sg;
208         u16 offset;
209
210         sg = ctx->partially_sent_record;
211         offset = ctx->partially_sent_offset;
212
213         ctx->partially_sent_record = NULL;
214         return tls_push_sg(sk, ctx, sg, offset, flags);
215 }
216
217 void tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
218 {
219         struct scatterlist *sg;
220
221         for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) {
222                 put_page(sg_page(sg));
223                 sk_mem_uncharge(sk, sg->length);
224         }
225         ctx->partially_sent_record = NULL;
226 }
227
228 static void tls_write_space(struct sock *sk)
229 {
230         struct tls_context *ctx = tls_get_ctx(sk);
231
232         /* If in_tcp_sendpages call lower protocol write space handler
233          * to ensure we wake up any waiting operations there. For example
234          * if do_tcp_sendpages where to call sk_wait_event.
235          */
236         if (ctx->in_tcp_sendpages) {
237                 ctx->sk_write_space(sk);
238                 return;
239         }
240
241 #ifdef CONFIG_TLS_DEVICE
242         if (ctx->tx_conf == TLS_HW)
243                 tls_device_write_space(sk, ctx);
244         else
245 #endif
246                 tls_sw_write_space(sk, ctx);
247
248         ctx->sk_write_space(sk);
249 }
250
251 /**
252  * tls_ctx_free() - free TLS ULP context
253  * @sk:  socket to with @ctx is attached
254  * @ctx: TLS context structure
255  *
256  * Free TLS context. If @sk is %NULL caller guarantees that the socket
257  * to which @ctx was attached has no outstanding references.
258  */
259 void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
260 {
261         if (!ctx)
262                 return;
263
264         memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
265         memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
266         mutex_destroy(&ctx->tx_lock);
267
268         if (sk)
269                 kfree_rcu(ctx, rcu);
270         else
271                 kfree(ctx);
272 }
273
274 static void tls_sk_proto_cleanup(struct sock *sk,
275                                  struct tls_context *ctx, long timeo)
276 {
277         if (unlikely(sk->sk_write_pending) &&
278             !wait_on_pending_writer(sk, &timeo))
279                 tls_handle_open_record(sk, 0);
280
281         /* We need these for tls_sw_fallback handling of other packets */
282         if (ctx->tx_conf == TLS_SW) {
283                 kfree(ctx->tx.rec_seq);
284                 kfree(ctx->tx.iv);
285                 tls_sw_release_resources_tx(sk);
286                 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
287         } else if (ctx->tx_conf == TLS_HW) {
288                 tls_device_free_resources_tx(sk);
289                 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
290         }
291
292         if (ctx->rx_conf == TLS_SW) {
293                 tls_sw_release_resources_rx(sk);
294                 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
295         } else if (ctx->rx_conf == TLS_HW) {
296                 tls_device_offload_cleanup_rx(sk);
297                 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
298         }
299 }
300
301 static void tls_sk_proto_close(struct sock *sk, long timeout)
302 {
303         struct inet_connection_sock *icsk = inet_csk(sk);
304         struct tls_context *ctx = tls_get_ctx(sk);
305         long timeo = sock_sndtimeo(sk, 0);
306         bool free_ctx;
307
308         if (ctx->tx_conf == TLS_SW)
309                 tls_sw_cancel_work_tx(ctx);
310
311         lock_sock(sk);
312         free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
313
314         if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
315                 tls_sk_proto_cleanup(sk, ctx, timeo);
316
317         write_lock_bh(&sk->sk_callback_lock);
318         if (free_ctx)
319                 rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
320         WRITE_ONCE(sk->sk_prot, ctx->sk_proto);
321         if (sk->sk_write_space == tls_write_space)
322                 sk->sk_write_space = ctx->sk_write_space;
323         write_unlock_bh(&sk->sk_callback_lock);
324         release_sock(sk);
325         if (ctx->tx_conf == TLS_SW)
326                 tls_sw_free_ctx_tx(ctx);
327         if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
328                 tls_sw_strparser_done(ctx);
329         if (ctx->rx_conf == TLS_SW)
330                 tls_sw_free_ctx_rx(ctx);
331         ctx->sk_proto->close(sk, timeout);
332
333         if (free_ctx)
334                 tls_ctx_free(sk, ctx);
335 }
336
337 static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval,
338                                   int __user *optlen, int tx)
339 {
340         int rc = 0;
341         struct tls_context *ctx = tls_get_ctx(sk);
342         struct tls_crypto_info *crypto_info;
343         struct cipher_context *cctx;
344         int len;
345
346         if (get_user(len, optlen))
347                 return -EFAULT;
348
349         if (!optval || (len < sizeof(*crypto_info))) {
350                 rc = -EINVAL;
351                 goto out;
352         }
353
354         if (!ctx) {
355                 rc = -EBUSY;
356                 goto out;
357         }
358
359         /* get user crypto info */
360         if (tx) {
361                 crypto_info = &ctx->crypto_send.info;
362                 cctx = &ctx->tx;
363         } else {
364                 crypto_info = &ctx->crypto_recv.info;
365                 cctx = &ctx->rx;
366         }
367
368         if (!TLS_CRYPTO_INFO_READY(crypto_info)) {
369                 rc = -EBUSY;
370                 goto out;
371         }
372
373         if (len == sizeof(*crypto_info)) {
374                 if (copy_to_user(optval, crypto_info, sizeof(*crypto_info)))
375                         rc = -EFAULT;
376                 goto out;
377         }
378
379         switch (crypto_info->cipher_type) {
380         case TLS_CIPHER_AES_GCM_128: {
381                 struct tls12_crypto_info_aes_gcm_128 *
382                   crypto_info_aes_gcm_128 =
383                   container_of(crypto_info,
384                                struct tls12_crypto_info_aes_gcm_128,
385                                info);
386
387                 if (len != sizeof(*crypto_info_aes_gcm_128)) {
388                         rc = -EINVAL;
389                         goto out;
390                 }
391                 lock_sock(sk);
392                 memcpy(crypto_info_aes_gcm_128->iv,
393                        cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
394                        TLS_CIPHER_AES_GCM_128_IV_SIZE);
395                 memcpy(crypto_info_aes_gcm_128->rec_seq, cctx->rec_seq,
396                        TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
397                 release_sock(sk);
398                 if (copy_to_user(optval,
399                                  crypto_info_aes_gcm_128,
400                                  sizeof(*crypto_info_aes_gcm_128)))
401                         rc = -EFAULT;
402                 break;
403         }
404         case TLS_CIPHER_AES_GCM_256: {
405                 struct tls12_crypto_info_aes_gcm_256 *
406                   crypto_info_aes_gcm_256 =
407                   container_of(crypto_info,
408                                struct tls12_crypto_info_aes_gcm_256,
409                                info);
410
411                 if (len != sizeof(*crypto_info_aes_gcm_256)) {
412                         rc = -EINVAL;
413                         goto out;
414                 }
415                 lock_sock(sk);
416                 memcpy(crypto_info_aes_gcm_256->iv,
417                        cctx->iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
418                        TLS_CIPHER_AES_GCM_256_IV_SIZE);
419                 memcpy(crypto_info_aes_gcm_256->rec_seq, cctx->rec_seq,
420                        TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
421                 release_sock(sk);
422                 if (copy_to_user(optval,
423                                  crypto_info_aes_gcm_256,
424                                  sizeof(*crypto_info_aes_gcm_256)))
425                         rc = -EFAULT;
426                 break;
427         }
428         case TLS_CIPHER_AES_CCM_128: {
429                 struct tls12_crypto_info_aes_ccm_128 *aes_ccm_128 =
430                         container_of(crypto_info,
431                                 struct tls12_crypto_info_aes_ccm_128, info);
432
433                 if (len != sizeof(*aes_ccm_128)) {
434                         rc = -EINVAL;
435                         goto out;
436                 }
437                 lock_sock(sk);
438                 memcpy(aes_ccm_128->iv,
439                        cctx->iv + TLS_CIPHER_AES_CCM_128_SALT_SIZE,
440                        TLS_CIPHER_AES_CCM_128_IV_SIZE);
441                 memcpy(aes_ccm_128->rec_seq, cctx->rec_seq,
442                        TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE);
443                 release_sock(sk);
444                 if (copy_to_user(optval, aes_ccm_128, sizeof(*aes_ccm_128)))
445                         rc = -EFAULT;
446                 break;
447         }
448         case TLS_CIPHER_CHACHA20_POLY1305: {
449                 struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305 =
450                         container_of(crypto_info,
451                                 struct tls12_crypto_info_chacha20_poly1305,
452                                 info);
453
454                 if (len != sizeof(*chacha20_poly1305)) {
455                         rc = -EINVAL;
456                         goto out;
457                 }
458                 lock_sock(sk);
459                 memcpy(chacha20_poly1305->iv,
460                        cctx->iv + TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE,
461                        TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE);
462                 memcpy(chacha20_poly1305->rec_seq, cctx->rec_seq,
463                        TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE);
464                 release_sock(sk);
465                 if (copy_to_user(optval, chacha20_poly1305,
466                                 sizeof(*chacha20_poly1305)))
467                         rc = -EFAULT;
468                 break;
469         }
470         case TLS_CIPHER_SM4_GCM: {
471                 struct tls12_crypto_info_sm4_gcm *sm4_gcm_info =
472                         container_of(crypto_info,
473                                 struct tls12_crypto_info_sm4_gcm, info);
474
475                 if (len != sizeof(*sm4_gcm_info)) {
476                         rc = -EINVAL;
477                         goto out;
478                 }
479                 lock_sock(sk);
480                 memcpy(sm4_gcm_info->iv,
481                        cctx->iv + TLS_CIPHER_SM4_GCM_SALT_SIZE,
482                        TLS_CIPHER_SM4_GCM_IV_SIZE);
483                 memcpy(sm4_gcm_info->rec_seq, cctx->rec_seq,
484                        TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE);
485                 release_sock(sk);
486                 if (copy_to_user(optval, sm4_gcm_info, sizeof(*sm4_gcm_info)))
487                         rc = -EFAULT;
488                 break;
489         }
490         case TLS_CIPHER_SM4_CCM: {
491                 struct tls12_crypto_info_sm4_ccm *sm4_ccm_info =
492                         container_of(crypto_info,
493                                 struct tls12_crypto_info_sm4_ccm, info);
494
495                 if (len != sizeof(*sm4_ccm_info)) {
496                         rc = -EINVAL;
497                         goto out;
498                 }
499                 lock_sock(sk);
500                 memcpy(sm4_ccm_info->iv,
501                        cctx->iv + TLS_CIPHER_SM4_CCM_SALT_SIZE,
502                        TLS_CIPHER_SM4_CCM_IV_SIZE);
503                 memcpy(sm4_ccm_info->rec_seq, cctx->rec_seq,
504                        TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE);
505                 release_sock(sk);
506                 if (copy_to_user(optval, sm4_ccm_info, sizeof(*sm4_ccm_info)))
507                         rc = -EFAULT;
508                 break;
509         }
510         default:
511                 rc = -EINVAL;
512         }
513
514 out:
515         return rc;
516 }
517
518 static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval,
519                                    int __user *optlen)
520 {
521         struct tls_context *ctx = tls_get_ctx(sk);
522         unsigned int value;
523         int len;
524
525         if (get_user(len, optlen))
526                 return -EFAULT;
527
528         if (len != sizeof(value))
529                 return -EINVAL;
530
531         value = ctx->zerocopy_sendfile;
532         if (copy_to_user(optval, &value, sizeof(value)))
533                 return -EFAULT;
534
535         return 0;
536 }
537
538 static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval,
539                                     int __user *optlen)
540 {
541         struct tls_context *ctx = tls_get_ctx(sk);
542         int value, len;
543
544         if (ctx->prot_info.version != TLS_1_3_VERSION)
545                 return -EINVAL;
546
547         if (get_user(len, optlen))
548                 return -EFAULT;
549         if (len < sizeof(value))
550                 return -EINVAL;
551
552         lock_sock(sk);
553         value = -EINVAL;
554         if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
555                 value = ctx->rx_no_pad;
556         release_sock(sk);
557         if (value < 0)
558                 return value;
559
560         if (put_user(sizeof(value), optlen))
561                 return -EFAULT;
562         if (copy_to_user(optval, &value, sizeof(value)))
563                 return -EFAULT;
564
565         return 0;
566 }
567
568 static int do_tls_getsockopt(struct sock *sk, int optname,
569                              char __user *optval, int __user *optlen)
570 {
571         int rc = 0;
572
573         switch (optname) {
574         case TLS_TX:
575         case TLS_RX:
576                 rc = do_tls_getsockopt_conf(sk, optval, optlen,
577                                             optname == TLS_TX);
578                 break;
579         case TLS_TX_ZEROCOPY_RO:
580                 rc = do_tls_getsockopt_tx_zc(sk, optval, optlen);
581                 break;
582         case TLS_RX_EXPECT_NO_PAD:
583                 rc = do_tls_getsockopt_no_pad(sk, optval, optlen);
584                 break;
585         default:
586                 rc = -ENOPROTOOPT;
587                 break;
588         }
589         return rc;
590 }
591
592 static int tls_getsockopt(struct sock *sk, int level, int optname,
593                           char __user *optval, int __user *optlen)
594 {
595         struct tls_context *ctx = tls_get_ctx(sk);
596
597         if (level != SOL_TLS)
598                 return ctx->sk_proto->getsockopt(sk, level,
599                                                  optname, optval, optlen);
600
601         return do_tls_getsockopt(sk, optname, optval, optlen);
602 }
603
604 static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
605                                   unsigned int optlen, int tx)
606 {
607         struct tls_crypto_info *crypto_info;
608         struct tls_crypto_info *alt_crypto_info;
609         struct tls_context *ctx = tls_get_ctx(sk);
610         size_t optsize;
611         int rc = 0;
612         int conf;
613
614         if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info)))
615                 return -EINVAL;
616
617         if (tx) {
618                 crypto_info = &ctx->crypto_send.info;
619                 alt_crypto_info = &ctx->crypto_recv.info;
620         } else {
621                 crypto_info = &ctx->crypto_recv.info;
622                 alt_crypto_info = &ctx->crypto_send.info;
623         }
624
625         /* Currently we don't support set crypto info more than one time */
626         if (TLS_CRYPTO_INFO_READY(crypto_info))
627                 return -EBUSY;
628
629         rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info));
630         if (rc) {
631                 rc = -EFAULT;
632                 goto err_crypto_info;
633         }
634
635         /* check version */
636         if (crypto_info->version != TLS_1_2_VERSION &&
637             crypto_info->version != TLS_1_3_VERSION) {
638                 rc = -EINVAL;
639                 goto err_crypto_info;
640         }
641
642         /* Ensure that TLS version and ciphers are same in both directions */
643         if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) {
644                 if (alt_crypto_info->version != crypto_info->version ||
645                     alt_crypto_info->cipher_type != crypto_info->cipher_type) {
646                         rc = -EINVAL;
647                         goto err_crypto_info;
648                 }
649         }
650
651         switch (crypto_info->cipher_type) {
652         case TLS_CIPHER_AES_GCM_128:
653                 optsize = sizeof(struct tls12_crypto_info_aes_gcm_128);
654                 break;
655         case TLS_CIPHER_AES_GCM_256: {
656                 optsize = sizeof(struct tls12_crypto_info_aes_gcm_256);
657                 break;
658         }
659         case TLS_CIPHER_AES_CCM_128:
660                 optsize = sizeof(struct tls12_crypto_info_aes_ccm_128);
661                 break;
662         case TLS_CIPHER_CHACHA20_POLY1305:
663                 optsize = sizeof(struct tls12_crypto_info_chacha20_poly1305);
664                 break;
665         case TLS_CIPHER_SM4_GCM:
666                 optsize = sizeof(struct tls12_crypto_info_sm4_gcm);
667                 break;
668         case TLS_CIPHER_SM4_CCM:
669                 optsize = sizeof(struct tls12_crypto_info_sm4_ccm);
670                 break;
671         default:
672                 rc = -EINVAL;
673                 goto err_crypto_info;
674         }
675
676         if (optlen != optsize) {
677                 rc = -EINVAL;
678                 goto err_crypto_info;
679         }
680
681         rc = copy_from_sockptr_offset(crypto_info + 1, optval,
682                                       sizeof(*crypto_info),
683                                       optlen - sizeof(*crypto_info));
684         if (rc) {
685                 rc = -EFAULT;
686                 goto err_crypto_info;
687         }
688
689         if (tx) {
690                 rc = tls_set_device_offload(sk, ctx);
691                 conf = TLS_HW;
692                 if (!rc) {
693                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE);
694                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
695                 } else {
696                         rc = tls_set_sw_offload(sk, ctx, 1);
697                         if (rc)
698                                 goto err_crypto_info;
699                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW);
700                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
701                         conf = TLS_SW;
702                 }
703         } else {
704                 rc = tls_set_device_offload_rx(sk, ctx);
705                 conf = TLS_HW;
706                 if (!rc) {
707                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE);
708                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
709                 } else {
710                         rc = tls_set_sw_offload(sk, ctx, 0);
711                         if (rc)
712                                 goto err_crypto_info;
713                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW);
714                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
715                         conf = TLS_SW;
716                 }
717                 tls_sw_strparser_arm(sk, ctx);
718         }
719
720         if (tx)
721                 ctx->tx_conf = conf;
722         else
723                 ctx->rx_conf = conf;
724         update_sk_prot(sk, ctx);
725         if (tx) {
726                 ctx->sk_write_space = sk->sk_write_space;
727                 sk->sk_write_space = tls_write_space;
728         }
729         return 0;
730
731 err_crypto_info:
732         memzero_explicit(crypto_info, sizeof(union tls_crypto_context));
733         return rc;
734 }
735
736 static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval,
737                                    unsigned int optlen)
738 {
739         struct tls_context *ctx = tls_get_ctx(sk);
740         unsigned int value;
741
742         if (sockptr_is_null(optval) || optlen != sizeof(value))
743                 return -EINVAL;
744
745         if (copy_from_sockptr(&value, optval, sizeof(value)))
746                 return -EFAULT;
747
748         if (value > 1)
749                 return -EINVAL;
750
751         ctx->zerocopy_sendfile = value;
752
753         return 0;
754 }
755
756 static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval,
757                                     unsigned int optlen)
758 {
759         struct tls_context *ctx = tls_get_ctx(sk);
760         u32 val;
761         int rc;
762
763         if (ctx->prot_info.version != TLS_1_3_VERSION ||
764             sockptr_is_null(optval) || optlen < sizeof(val))
765                 return -EINVAL;
766
767         rc = copy_from_sockptr(&val, optval, sizeof(val));
768         if (rc)
769                 return -EFAULT;
770         if (val > 1)
771                 return -EINVAL;
772         rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val));
773         if (rc < 1)
774                 return rc == 0 ? -EINVAL : rc;
775
776         lock_sock(sk);
777         rc = -EINVAL;
778         if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) {
779                 ctx->rx_no_pad = val;
780                 tls_update_rx_zc_capable(ctx);
781                 rc = 0;
782         }
783         release_sock(sk);
784
785         return rc;
786 }
787
788 static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
789                              unsigned int optlen)
790 {
791         int rc = 0;
792
793         switch (optname) {
794         case TLS_TX:
795         case TLS_RX:
796                 lock_sock(sk);
797                 rc = do_tls_setsockopt_conf(sk, optval, optlen,
798                                             optname == TLS_TX);
799                 release_sock(sk);
800                 break;
801         case TLS_TX_ZEROCOPY_RO:
802                 lock_sock(sk);
803                 rc = do_tls_setsockopt_tx_zc(sk, optval, optlen);
804                 release_sock(sk);
805                 break;
806         case TLS_RX_EXPECT_NO_PAD:
807                 rc = do_tls_setsockopt_no_pad(sk, optval, optlen);
808                 break;
809         default:
810                 rc = -ENOPROTOOPT;
811                 break;
812         }
813         return rc;
814 }
815
816 static int tls_setsockopt(struct sock *sk, int level, int optname,
817                           sockptr_t optval, unsigned int optlen)
818 {
819         struct tls_context *ctx = tls_get_ctx(sk);
820
821         if (level != SOL_TLS)
822                 return ctx->sk_proto->setsockopt(sk, level, optname, optval,
823                                                  optlen);
824
825         return do_tls_setsockopt(sk, optname, optval, optlen);
826 }
827
828 struct tls_context *tls_ctx_create(struct sock *sk)
829 {
830         struct inet_connection_sock *icsk = inet_csk(sk);
831         struct tls_context *ctx;
832
833         ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);
834         if (!ctx)
835                 return NULL;
836
837         mutex_init(&ctx->tx_lock);
838         rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
839         ctx->sk_proto = READ_ONCE(sk->sk_prot);
840         ctx->sk = sk;
841         return ctx;
842 }
843
844 static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
845                             const struct proto_ops *base)
846 {
847         ops[TLS_BASE][TLS_BASE] = *base;
848
849         ops[TLS_SW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
850         ops[TLS_SW  ][TLS_BASE].sendpage_locked = tls_sw_sendpage_locked;
851
852         ops[TLS_BASE][TLS_SW  ] = ops[TLS_BASE][TLS_BASE];
853         ops[TLS_BASE][TLS_SW  ].splice_read     = tls_sw_splice_read;
854
855         ops[TLS_SW  ][TLS_SW  ] = ops[TLS_SW  ][TLS_BASE];
856         ops[TLS_SW  ][TLS_SW  ].splice_read     = tls_sw_splice_read;
857
858 #ifdef CONFIG_TLS_DEVICE
859         ops[TLS_HW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
860         ops[TLS_HW  ][TLS_BASE].sendpage_locked = NULL;
861
862         ops[TLS_HW  ][TLS_SW  ] = ops[TLS_BASE][TLS_SW  ];
863         ops[TLS_HW  ][TLS_SW  ].sendpage_locked = NULL;
864
865         ops[TLS_BASE][TLS_HW  ] = ops[TLS_BASE][TLS_SW  ];
866
867         ops[TLS_SW  ][TLS_HW  ] = ops[TLS_SW  ][TLS_SW  ];
868
869         ops[TLS_HW  ][TLS_HW  ] = ops[TLS_HW  ][TLS_SW  ];
870         ops[TLS_HW  ][TLS_HW  ].sendpage_locked = NULL;
871 #endif
872 #ifdef CONFIG_TLS_TOE
873         ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
874 #endif
875 }
876
877 static void tls_build_proto(struct sock *sk)
878 {
879         int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
880         struct proto *prot = READ_ONCE(sk->sk_prot);
881
882         /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
883         if (ip_ver == TLSV6 &&
884             unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) {
885                 mutex_lock(&tcpv6_prot_mutex);
886                 if (likely(prot != saved_tcpv6_prot)) {
887                         build_protos(tls_prots[TLSV6], prot);
888                         build_proto_ops(tls_proto_ops[TLSV6],
889                                         sk->sk_socket->ops);
890                         smp_store_release(&saved_tcpv6_prot, prot);
891                 }
892                 mutex_unlock(&tcpv6_prot_mutex);
893         }
894
895         if (ip_ver == TLSV4 &&
896             unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) {
897                 mutex_lock(&tcpv4_prot_mutex);
898                 if (likely(prot != saved_tcpv4_prot)) {
899                         build_protos(tls_prots[TLSV4], prot);
900                         build_proto_ops(tls_proto_ops[TLSV4],
901                                         sk->sk_socket->ops);
902                         smp_store_release(&saved_tcpv4_prot, prot);
903                 }
904                 mutex_unlock(&tcpv4_prot_mutex);
905         }
906 }
907
908 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
909                          const struct proto *base)
910 {
911         prot[TLS_BASE][TLS_BASE] = *base;
912         prot[TLS_BASE][TLS_BASE].setsockopt     = tls_setsockopt;
913         prot[TLS_BASE][TLS_BASE].getsockopt     = tls_getsockopt;
914         prot[TLS_BASE][TLS_BASE].close          = tls_sk_proto_close;
915
916         prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
917         prot[TLS_SW][TLS_BASE].sendmsg          = tls_sw_sendmsg;
918         prot[TLS_SW][TLS_BASE].sendpage         = tls_sw_sendpage;
919
920         prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
921         prot[TLS_BASE][TLS_SW].recvmsg            = tls_sw_recvmsg;
922         prot[TLS_BASE][TLS_SW].sock_is_readable   = tls_sw_sock_is_readable;
923         prot[TLS_BASE][TLS_SW].close              = tls_sk_proto_close;
924
925         prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
926         prot[TLS_SW][TLS_SW].recvmsg            = tls_sw_recvmsg;
927         prot[TLS_SW][TLS_SW].sock_is_readable   = tls_sw_sock_is_readable;
928         prot[TLS_SW][TLS_SW].close              = tls_sk_proto_close;
929
930 #ifdef CONFIG_TLS_DEVICE
931         prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
932         prot[TLS_HW][TLS_BASE].sendmsg          = tls_device_sendmsg;
933         prot[TLS_HW][TLS_BASE].sendpage         = tls_device_sendpage;
934
935         prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
936         prot[TLS_HW][TLS_SW].sendmsg            = tls_device_sendmsg;
937         prot[TLS_HW][TLS_SW].sendpage           = tls_device_sendpage;
938
939         prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
940
941         prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
942
943         prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
944 #endif
945 #ifdef CONFIG_TLS_TOE
946         prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
947         prot[TLS_HW_RECORD][TLS_HW_RECORD].hash         = tls_toe_hash;
948         prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash       = tls_toe_unhash;
949 #endif
950 }
951
952 static int tls_init(struct sock *sk)
953 {
954         struct tls_context *ctx;
955         int rc = 0;
956
957         tls_build_proto(sk);
958
959 #ifdef CONFIG_TLS_TOE
960         if (tls_toe_bypass(sk))
961                 return 0;
962 #endif
963
964         /* The TLS ulp is currently supported only for TCP sockets
965          * in ESTABLISHED state.
966          * Supporting sockets in LISTEN state will require us
967          * to modify the accept implementation to clone rather then
968          * share the ulp context.
969          */
970         if (sk->sk_state != TCP_ESTABLISHED)
971                 return -ENOTCONN;
972
973         /* allocate tls context */
974         write_lock_bh(&sk->sk_callback_lock);
975         ctx = tls_ctx_create(sk);
976         if (!ctx) {
977                 rc = -ENOMEM;
978                 goto out;
979         }
980
981         ctx->tx_conf = TLS_BASE;
982         ctx->rx_conf = TLS_BASE;
983         update_sk_prot(sk, ctx);
984 out:
985         write_unlock_bh(&sk->sk_callback_lock);
986         return rc;
987 }
988
989 static void tls_update(struct sock *sk, struct proto *p,
990                        void (*write_space)(struct sock *sk))
991 {
992         struct tls_context *ctx;
993
994         WARN_ON_ONCE(sk->sk_prot == p);
995
996         ctx = tls_get_ctx(sk);
997         if (likely(ctx)) {
998                 ctx->sk_write_space = write_space;
999                 ctx->sk_proto = p;
1000         } else {
1001                 /* Pairs with lockless read in sk_clone_lock(). */
1002                 WRITE_ONCE(sk->sk_prot, p);
1003                 sk->sk_write_space = write_space;
1004         }
1005 }
1006
1007 static u16 tls_user_config(struct tls_context *ctx, bool tx)
1008 {
1009         u16 config = tx ? ctx->tx_conf : ctx->rx_conf;
1010
1011         switch (config) {
1012         case TLS_BASE:
1013                 return TLS_CONF_BASE;
1014         case TLS_SW:
1015                 return TLS_CONF_SW;
1016         case TLS_HW:
1017                 return TLS_CONF_HW;
1018         case TLS_HW_RECORD:
1019                 return TLS_CONF_HW_RECORD;
1020         }
1021         return 0;
1022 }
1023
1024 static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
1025 {
1026         u16 version, cipher_type;
1027         struct tls_context *ctx;
1028         struct nlattr *start;
1029         int err;
1030
1031         start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS);
1032         if (!start)
1033                 return -EMSGSIZE;
1034
1035         rcu_read_lock();
1036         ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data);
1037         if (!ctx) {
1038                 err = 0;
1039                 goto nla_failure;
1040         }
1041         version = ctx->prot_info.version;
1042         if (version) {
1043                 err = nla_put_u16(skb, TLS_INFO_VERSION, version);
1044                 if (err)
1045                         goto nla_failure;
1046         }
1047         cipher_type = ctx->prot_info.cipher_type;
1048         if (cipher_type) {
1049                 err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type);
1050                 if (err)
1051                         goto nla_failure;
1052         }
1053         err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true));
1054         if (err)
1055                 goto nla_failure;
1056
1057         err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false));
1058         if (err)
1059                 goto nla_failure;
1060
1061         if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) {
1062                 err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX);
1063                 if (err)
1064                         goto nla_failure;
1065         }
1066         if (ctx->rx_no_pad) {
1067                 err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD);
1068                 if (err)
1069                         goto nla_failure;
1070         }
1071
1072         rcu_read_unlock();
1073         nla_nest_end(skb, start);
1074         return 0;
1075
1076 nla_failure:
1077         rcu_read_unlock();
1078         nla_nest_cancel(skb, start);
1079         return err;
1080 }
1081
1082 static size_t tls_get_info_size(const struct sock *sk)
1083 {
1084         size_t size = 0;
1085
1086         size += nla_total_size(0) +             /* INET_ULP_INFO_TLS */
1087                 nla_total_size(sizeof(u16)) +   /* TLS_INFO_VERSION */
1088                 nla_total_size(sizeof(u16)) +   /* TLS_INFO_CIPHER */
1089                 nla_total_size(sizeof(u16)) +   /* TLS_INFO_RXCONF */
1090                 nla_total_size(sizeof(u16)) +   /* TLS_INFO_TXCONF */
1091                 nla_total_size(0) +             /* TLS_INFO_ZC_RO_TX */
1092                 nla_total_size(0) +             /* TLS_INFO_RX_NO_PAD */
1093                 0;
1094
1095         return size;
1096 }
1097
1098 static int __net_init tls_init_net(struct net *net)
1099 {
1100         int err;
1101
1102         net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib);
1103         if (!net->mib.tls_statistics)
1104                 return -ENOMEM;
1105
1106         err = tls_proc_init(net);
1107         if (err)
1108                 goto err_free_stats;
1109
1110         return 0;
1111 err_free_stats:
1112         free_percpu(net->mib.tls_statistics);
1113         return err;
1114 }
1115
1116 static void __net_exit tls_exit_net(struct net *net)
1117 {
1118         tls_proc_fini(net);
1119         free_percpu(net->mib.tls_statistics);
1120 }
1121
1122 static struct pernet_operations tls_proc_ops = {
1123         .init = tls_init_net,
1124         .exit = tls_exit_net,
1125 };
1126
1127 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
1128         .name                   = "tls",
1129         .owner                  = THIS_MODULE,
1130         .init                   = tls_init,
1131         .update                 = tls_update,
1132         .get_info               = tls_get_info,
1133         .get_info_size          = tls_get_info_size,
1134 };
1135
1136 static int __init tls_register(void)
1137 {
1138         int err;
1139
1140         err = register_pernet_subsys(&tls_proc_ops);
1141         if (err)
1142                 return err;
1143
1144         err = tls_device_init();
1145         if (err) {
1146                 unregister_pernet_subsys(&tls_proc_ops);
1147                 return err;
1148         }
1149
1150         tcp_register_ulp(&tcp_tls_ulp_ops);
1151
1152         return 0;
1153 }
1154
1155 static void __exit tls_unregister(void)
1156 {
1157         tcp_unregister_ulp(&tcp_tls_ulp_ops);
1158         tls_device_cleanup();
1159         unregister_pernet_subsys(&tls_proc_ops);
1160 }
1161
1162 module_init(tls_register);
1163 module_exit(tls_unregister);