2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
5 * This software is available to you under a choice of one of two
6 * licenses. You may choose to be licensed under the terms of the GNU
7 * General Public License (GPL) Version 2, available from the file
8 * COPYING in the main directory of this source tree, or the
9 * OpenIB.org BSD license below:
11 * Redistribution and use in source and binary forms, with or
12 * without modification, are permitted provided that the following
15 * - Redistributions of source code must retain the above
16 * copyright notice, this list of conditions and the following
19 * - Redistributions in binary form must reproduce the above
20 * copyright notice, this list of conditions and the following
21 * disclaimer in the documentation and/or other materials
22 * provided with the distribution.
24 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
25 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
26 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
27 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
28 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
29 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
30 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
34 #include <linux/module.h>
37 #include <net/inet_common.h>
38 #include <linux/highmem.h>
39 #include <linux/netdevice.h>
40 #include <linux/sched/signal.h>
41 #include <linux/inetdevice.h>
42 #include <linux/inet_diag.h>
46 #include <net/tls_toe.h>
50 MODULE_AUTHOR("Mellanox Technologies");
51 MODULE_DESCRIPTION("Transport Layer Security Support");
52 MODULE_LICENSE("Dual BSD/GPL");
53 MODULE_ALIAS_TCP_ULP("tls");
61 static const struct proto *saved_tcpv6_prot;
62 static DEFINE_MUTEX(tcpv6_prot_mutex);
63 static const struct proto *saved_tcpv4_prot;
64 static DEFINE_MUTEX(tcpv4_prot_mutex);
65 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
66 static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
67 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
68 const struct proto *base);
70 void update_sk_prot(struct sock *sk, struct tls_context *ctx)
72 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
74 WRITE_ONCE(sk->sk_prot,
75 &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
76 WRITE_ONCE(sk->sk_socket->ops,
77 &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]);
80 int wait_on_pending_writer(struct sock *sk, long *timeo)
83 DEFINE_WAIT_FUNC(wait, woken_wake_function);
85 add_wait_queue(sk_sleep(sk), &wait);
92 if (signal_pending(current)) {
93 rc = sock_intr_errno(*timeo);
97 if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait))
100 remove_wait_queue(sk_sleep(sk), &wait);
104 int tls_push_sg(struct sock *sk,
105 struct tls_context *ctx,
106 struct scatterlist *sg,
110 int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST;
114 int offset = first_offset;
116 size = sg->length - offset;
117 offset += sg->offset;
119 ctx->in_tcp_sendpages = true;
122 sendpage_flags = flags;
124 /* is sending application-limited? */
125 tcp_rate_check_app_limited(sk);
128 ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags);
137 offset -= sg->offset;
138 ctx->partially_sent_offset = offset;
139 ctx->partially_sent_record = (void *)sg;
140 ctx->in_tcp_sendpages = false;
145 sk_mem_uncharge(sk, sg->length);
154 ctx->in_tcp_sendpages = false;
159 static int tls_handle_open_record(struct sock *sk, int flags)
161 struct tls_context *ctx = tls_get_ctx(sk);
163 if (tls_is_pending_open_record(ctx))
164 return ctx->push_pending_record(sk, flags);
169 int tls_process_cmsg(struct sock *sk, struct msghdr *msg,
170 unsigned char *record_type)
172 struct cmsghdr *cmsg;
175 for_each_cmsghdr(cmsg, msg) {
176 if (!CMSG_OK(msg, cmsg))
178 if (cmsg->cmsg_level != SOL_TLS)
181 switch (cmsg->cmsg_type) {
182 case TLS_SET_RECORD_TYPE:
183 if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type)))
186 if (msg->msg_flags & MSG_MORE)
189 rc = tls_handle_open_record(sk, msg->msg_flags);
193 *record_type = *(unsigned char *)CMSG_DATA(cmsg);
204 int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
207 struct scatterlist *sg;
210 sg = ctx->partially_sent_record;
211 offset = ctx->partially_sent_offset;
213 ctx->partially_sent_record = NULL;
214 return tls_push_sg(sk, ctx, sg, offset, flags);
217 void tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
219 struct scatterlist *sg;
221 for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) {
222 put_page(sg_page(sg));
223 sk_mem_uncharge(sk, sg->length);
225 ctx->partially_sent_record = NULL;
228 static void tls_write_space(struct sock *sk)
230 struct tls_context *ctx = tls_get_ctx(sk);
232 /* If in_tcp_sendpages call lower protocol write space handler
233 * to ensure we wake up any waiting operations there. For example
234 * if do_tcp_sendpages where to call sk_wait_event.
236 if (ctx->in_tcp_sendpages) {
237 ctx->sk_write_space(sk);
241 #ifdef CONFIG_TLS_DEVICE
242 if (ctx->tx_conf == TLS_HW)
243 tls_device_write_space(sk, ctx);
246 tls_sw_write_space(sk, ctx);
248 ctx->sk_write_space(sk);
252 * tls_ctx_free() - free TLS ULP context
253 * @sk: socket to with @ctx is attached
254 * @ctx: TLS context structure
256 * Free TLS context. If @sk is %NULL caller guarantees that the socket
257 * to which @ctx was attached has no outstanding references.
259 void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
264 memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
265 memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
266 mutex_destroy(&ctx->tx_lock);
274 static void tls_sk_proto_cleanup(struct sock *sk,
275 struct tls_context *ctx, long timeo)
277 if (unlikely(sk->sk_write_pending) &&
278 !wait_on_pending_writer(sk, &timeo))
279 tls_handle_open_record(sk, 0);
281 /* We need these for tls_sw_fallback handling of other packets */
282 if (ctx->tx_conf == TLS_SW) {
283 kfree(ctx->tx.rec_seq);
285 tls_sw_release_resources_tx(sk);
286 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
287 } else if (ctx->tx_conf == TLS_HW) {
288 tls_device_free_resources_tx(sk);
289 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
292 if (ctx->rx_conf == TLS_SW) {
293 tls_sw_release_resources_rx(sk);
294 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
295 } else if (ctx->rx_conf == TLS_HW) {
296 tls_device_offload_cleanup_rx(sk);
297 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
301 static void tls_sk_proto_close(struct sock *sk, long timeout)
303 struct inet_connection_sock *icsk = inet_csk(sk);
304 struct tls_context *ctx = tls_get_ctx(sk);
305 long timeo = sock_sndtimeo(sk, 0);
308 if (ctx->tx_conf == TLS_SW)
309 tls_sw_cancel_work_tx(ctx);
312 free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
314 if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
315 tls_sk_proto_cleanup(sk, ctx, timeo);
317 write_lock_bh(&sk->sk_callback_lock);
319 rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
320 WRITE_ONCE(sk->sk_prot, ctx->sk_proto);
321 if (sk->sk_write_space == tls_write_space)
322 sk->sk_write_space = ctx->sk_write_space;
323 write_unlock_bh(&sk->sk_callback_lock);
325 if (ctx->tx_conf == TLS_SW)
326 tls_sw_free_ctx_tx(ctx);
327 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
328 tls_sw_strparser_done(ctx);
329 if (ctx->rx_conf == TLS_SW)
330 tls_sw_free_ctx_rx(ctx);
331 ctx->sk_proto->close(sk, timeout);
334 tls_ctx_free(sk, ctx);
337 static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval,
338 int __user *optlen, int tx)
341 struct tls_context *ctx = tls_get_ctx(sk);
342 struct tls_crypto_info *crypto_info;
343 struct cipher_context *cctx;
346 if (get_user(len, optlen))
349 if (!optval || (len < sizeof(*crypto_info))) {
359 /* get user crypto info */
361 crypto_info = &ctx->crypto_send.info;
364 crypto_info = &ctx->crypto_recv.info;
368 if (!TLS_CRYPTO_INFO_READY(crypto_info)) {
373 if (len == sizeof(*crypto_info)) {
374 if (copy_to_user(optval, crypto_info, sizeof(*crypto_info)))
379 switch (crypto_info->cipher_type) {
380 case TLS_CIPHER_AES_GCM_128: {
381 struct tls12_crypto_info_aes_gcm_128 *
382 crypto_info_aes_gcm_128 =
383 container_of(crypto_info,
384 struct tls12_crypto_info_aes_gcm_128,
387 if (len != sizeof(*crypto_info_aes_gcm_128)) {
392 memcpy(crypto_info_aes_gcm_128->iv,
393 cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
394 TLS_CIPHER_AES_GCM_128_IV_SIZE);
395 memcpy(crypto_info_aes_gcm_128->rec_seq, cctx->rec_seq,
396 TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
398 if (copy_to_user(optval,
399 crypto_info_aes_gcm_128,
400 sizeof(*crypto_info_aes_gcm_128)))
404 case TLS_CIPHER_AES_GCM_256: {
405 struct tls12_crypto_info_aes_gcm_256 *
406 crypto_info_aes_gcm_256 =
407 container_of(crypto_info,
408 struct tls12_crypto_info_aes_gcm_256,
411 if (len != sizeof(*crypto_info_aes_gcm_256)) {
416 memcpy(crypto_info_aes_gcm_256->iv,
417 cctx->iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
418 TLS_CIPHER_AES_GCM_256_IV_SIZE);
419 memcpy(crypto_info_aes_gcm_256->rec_seq, cctx->rec_seq,
420 TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
422 if (copy_to_user(optval,
423 crypto_info_aes_gcm_256,
424 sizeof(*crypto_info_aes_gcm_256)))
428 case TLS_CIPHER_AES_CCM_128: {
429 struct tls12_crypto_info_aes_ccm_128 *aes_ccm_128 =
430 container_of(crypto_info,
431 struct tls12_crypto_info_aes_ccm_128, info);
433 if (len != sizeof(*aes_ccm_128)) {
438 memcpy(aes_ccm_128->iv,
439 cctx->iv + TLS_CIPHER_AES_CCM_128_SALT_SIZE,
440 TLS_CIPHER_AES_CCM_128_IV_SIZE);
441 memcpy(aes_ccm_128->rec_seq, cctx->rec_seq,
442 TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE);
444 if (copy_to_user(optval, aes_ccm_128, sizeof(*aes_ccm_128)))
448 case TLS_CIPHER_CHACHA20_POLY1305: {
449 struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305 =
450 container_of(crypto_info,
451 struct tls12_crypto_info_chacha20_poly1305,
454 if (len != sizeof(*chacha20_poly1305)) {
459 memcpy(chacha20_poly1305->iv,
460 cctx->iv + TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE,
461 TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE);
462 memcpy(chacha20_poly1305->rec_seq, cctx->rec_seq,
463 TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE);
465 if (copy_to_user(optval, chacha20_poly1305,
466 sizeof(*chacha20_poly1305)))
470 case TLS_CIPHER_SM4_GCM: {
471 struct tls12_crypto_info_sm4_gcm *sm4_gcm_info =
472 container_of(crypto_info,
473 struct tls12_crypto_info_sm4_gcm, info);
475 if (len != sizeof(*sm4_gcm_info)) {
480 memcpy(sm4_gcm_info->iv,
481 cctx->iv + TLS_CIPHER_SM4_GCM_SALT_SIZE,
482 TLS_CIPHER_SM4_GCM_IV_SIZE);
483 memcpy(sm4_gcm_info->rec_seq, cctx->rec_seq,
484 TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE);
486 if (copy_to_user(optval, sm4_gcm_info, sizeof(*sm4_gcm_info)))
490 case TLS_CIPHER_SM4_CCM: {
491 struct tls12_crypto_info_sm4_ccm *sm4_ccm_info =
492 container_of(crypto_info,
493 struct tls12_crypto_info_sm4_ccm, info);
495 if (len != sizeof(*sm4_ccm_info)) {
500 memcpy(sm4_ccm_info->iv,
501 cctx->iv + TLS_CIPHER_SM4_CCM_SALT_SIZE,
502 TLS_CIPHER_SM4_CCM_IV_SIZE);
503 memcpy(sm4_ccm_info->rec_seq, cctx->rec_seq,
504 TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE);
506 if (copy_to_user(optval, sm4_ccm_info, sizeof(*sm4_ccm_info)))
518 static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval,
521 struct tls_context *ctx = tls_get_ctx(sk);
525 if (get_user(len, optlen))
528 if (len != sizeof(value))
531 value = ctx->zerocopy_sendfile;
532 if (copy_to_user(optval, &value, sizeof(value)))
538 static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval,
541 struct tls_context *ctx = tls_get_ctx(sk);
544 if (ctx->prot_info.version != TLS_1_3_VERSION)
547 if (get_user(len, optlen))
549 if (len < sizeof(value))
554 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
555 value = ctx->rx_no_pad;
560 if (put_user(sizeof(value), optlen))
562 if (copy_to_user(optval, &value, sizeof(value)))
568 static int do_tls_getsockopt(struct sock *sk, int optname,
569 char __user *optval, int __user *optlen)
576 rc = do_tls_getsockopt_conf(sk, optval, optlen,
579 case TLS_TX_ZEROCOPY_RO:
580 rc = do_tls_getsockopt_tx_zc(sk, optval, optlen);
582 case TLS_RX_EXPECT_NO_PAD:
583 rc = do_tls_getsockopt_no_pad(sk, optval, optlen);
592 static int tls_getsockopt(struct sock *sk, int level, int optname,
593 char __user *optval, int __user *optlen)
595 struct tls_context *ctx = tls_get_ctx(sk);
597 if (level != SOL_TLS)
598 return ctx->sk_proto->getsockopt(sk, level,
599 optname, optval, optlen);
601 return do_tls_getsockopt(sk, optname, optval, optlen);
604 static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
605 unsigned int optlen, int tx)
607 struct tls_crypto_info *crypto_info;
608 struct tls_crypto_info *alt_crypto_info;
609 struct tls_context *ctx = tls_get_ctx(sk);
614 if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info)))
618 crypto_info = &ctx->crypto_send.info;
619 alt_crypto_info = &ctx->crypto_recv.info;
621 crypto_info = &ctx->crypto_recv.info;
622 alt_crypto_info = &ctx->crypto_send.info;
625 /* Currently we don't support set crypto info more than one time */
626 if (TLS_CRYPTO_INFO_READY(crypto_info))
629 rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info));
632 goto err_crypto_info;
636 if (crypto_info->version != TLS_1_2_VERSION &&
637 crypto_info->version != TLS_1_3_VERSION) {
639 goto err_crypto_info;
642 /* Ensure that TLS version and ciphers are same in both directions */
643 if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) {
644 if (alt_crypto_info->version != crypto_info->version ||
645 alt_crypto_info->cipher_type != crypto_info->cipher_type) {
647 goto err_crypto_info;
651 switch (crypto_info->cipher_type) {
652 case TLS_CIPHER_AES_GCM_128:
653 optsize = sizeof(struct tls12_crypto_info_aes_gcm_128);
655 case TLS_CIPHER_AES_GCM_256: {
656 optsize = sizeof(struct tls12_crypto_info_aes_gcm_256);
659 case TLS_CIPHER_AES_CCM_128:
660 optsize = sizeof(struct tls12_crypto_info_aes_ccm_128);
662 case TLS_CIPHER_CHACHA20_POLY1305:
663 optsize = sizeof(struct tls12_crypto_info_chacha20_poly1305);
665 case TLS_CIPHER_SM4_GCM:
666 optsize = sizeof(struct tls12_crypto_info_sm4_gcm);
668 case TLS_CIPHER_SM4_CCM:
669 optsize = sizeof(struct tls12_crypto_info_sm4_ccm);
673 goto err_crypto_info;
676 if (optlen != optsize) {
678 goto err_crypto_info;
681 rc = copy_from_sockptr_offset(crypto_info + 1, optval,
682 sizeof(*crypto_info),
683 optlen - sizeof(*crypto_info));
686 goto err_crypto_info;
690 rc = tls_set_device_offload(sk, ctx);
693 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE);
694 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
696 rc = tls_set_sw_offload(sk, ctx, 1);
698 goto err_crypto_info;
699 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW);
700 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
704 rc = tls_set_device_offload_rx(sk, ctx);
707 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE);
708 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
710 rc = tls_set_sw_offload(sk, ctx, 0);
712 goto err_crypto_info;
713 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW);
714 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
717 tls_sw_strparser_arm(sk, ctx);
724 update_sk_prot(sk, ctx);
726 ctx->sk_write_space = sk->sk_write_space;
727 sk->sk_write_space = tls_write_space;
732 memzero_explicit(crypto_info, sizeof(union tls_crypto_context));
736 static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval,
739 struct tls_context *ctx = tls_get_ctx(sk);
742 if (sockptr_is_null(optval) || optlen != sizeof(value))
745 if (copy_from_sockptr(&value, optval, sizeof(value)))
751 ctx->zerocopy_sendfile = value;
756 static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval,
759 struct tls_context *ctx = tls_get_ctx(sk);
763 if (ctx->prot_info.version != TLS_1_3_VERSION ||
764 sockptr_is_null(optval) || optlen < sizeof(val))
767 rc = copy_from_sockptr(&val, optval, sizeof(val));
772 rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val));
774 return rc == 0 ? -EINVAL : rc;
778 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) {
779 ctx->rx_no_pad = val;
780 tls_update_rx_zc_capable(ctx);
788 static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
797 rc = do_tls_setsockopt_conf(sk, optval, optlen,
801 case TLS_TX_ZEROCOPY_RO:
803 rc = do_tls_setsockopt_tx_zc(sk, optval, optlen);
806 case TLS_RX_EXPECT_NO_PAD:
807 rc = do_tls_setsockopt_no_pad(sk, optval, optlen);
816 static int tls_setsockopt(struct sock *sk, int level, int optname,
817 sockptr_t optval, unsigned int optlen)
819 struct tls_context *ctx = tls_get_ctx(sk);
821 if (level != SOL_TLS)
822 return ctx->sk_proto->setsockopt(sk, level, optname, optval,
825 return do_tls_setsockopt(sk, optname, optval, optlen);
828 struct tls_context *tls_ctx_create(struct sock *sk)
830 struct inet_connection_sock *icsk = inet_csk(sk);
831 struct tls_context *ctx;
833 ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);
837 mutex_init(&ctx->tx_lock);
838 rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
839 ctx->sk_proto = READ_ONCE(sk->sk_prot);
844 static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
845 const struct proto_ops *base)
847 ops[TLS_BASE][TLS_BASE] = *base;
849 ops[TLS_SW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
850 ops[TLS_SW ][TLS_BASE].sendpage_locked = tls_sw_sendpage_locked;
852 ops[TLS_BASE][TLS_SW ] = ops[TLS_BASE][TLS_BASE];
853 ops[TLS_BASE][TLS_SW ].splice_read = tls_sw_splice_read;
855 ops[TLS_SW ][TLS_SW ] = ops[TLS_SW ][TLS_BASE];
856 ops[TLS_SW ][TLS_SW ].splice_read = tls_sw_splice_read;
858 #ifdef CONFIG_TLS_DEVICE
859 ops[TLS_HW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
860 ops[TLS_HW ][TLS_BASE].sendpage_locked = NULL;
862 ops[TLS_HW ][TLS_SW ] = ops[TLS_BASE][TLS_SW ];
863 ops[TLS_HW ][TLS_SW ].sendpage_locked = NULL;
865 ops[TLS_BASE][TLS_HW ] = ops[TLS_BASE][TLS_SW ];
867 ops[TLS_SW ][TLS_HW ] = ops[TLS_SW ][TLS_SW ];
869 ops[TLS_HW ][TLS_HW ] = ops[TLS_HW ][TLS_SW ];
870 ops[TLS_HW ][TLS_HW ].sendpage_locked = NULL;
872 #ifdef CONFIG_TLS_TOE
873 ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
877 static void tls_build_proto(struct sock *sk)
879 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
880 struct proto *prot = READ_ONCE(sk->sk_prot);
882 /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
883 if (ip_ver == TLSV6 &&
884 unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) {
885 mutex_lock(&tcpv6_prot_mutex);
886 if (likely(prot != saved_tcpv6_prot)) {
887 build_protos(tls_prots[TLSV6], prot);
888 build_proto_ops(tls_proto_ops[TLSV6],
890 smp_store_release(&saved_tcpv6_prot, prot);
892 mutex_unlock(&tcpv6_prot_mutex);
895 if (ip_ver == TLSV4 &&
896 unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) {
897 mutex_lock(&tcpv4_prot_mutex);
898 if (likely(prot != saved_tcpv4_prot)) {
899 build_protos(tls_prots[TLSV4], prot);
900 build_proto_ops(tls_proto_ops[TLSV4],
902 smp_store_release(&saved_tcpv4_prot, prot);
904 mutex_unlock(&tcpv4_prot_mutex);
908 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
909 const struct proto *base)
911 prot[TLS_BASE][TLS_BASE] = *base;
912 prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt;
913 prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt;
914 prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close;
916 prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
917 prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg;
918 prot[TLS_SW][TLS_BASE].sendpage = tls_sw_sendpage;
920 prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
921 prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg;
922 prot[TLS_BASE][TLS_SW].sock_is_readable = tls_sw_sock_is_readable;
923 prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close;
925 prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
926 prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg;
927 prot[TLS_SW][TLS_SW].sock_is_readable = tls_sw_sock_is_readable;
928 prot[TLS_SW][TLS_SW].close = tls_sk_proto_close;
930 #ifdef CONFIG_TLS_DEVICE
931 prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
932 prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg;
933 prot[TLS_HW][TLS_BASE].sendpage = tls_device_sendpage;
935 prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
936 prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg;
937 prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage;
939 prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
941 prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
943 prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
945 #ifdef CONFIG_TLS_TOE
946 prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
947 prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_toe_hash;
948 prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_toe_unhash;
952 static int tls_init(struct sock *sk)
954 struct tls_context *ctx;
959 #ifdef CONFIG_TLS_TOE
960 if (tls_toe_bypass(sk))
964 /* The TLS ulp is currently supported only for TCP sockets
965 * in ESTABLISHED state.
966 * Supporting sockets in LISTEN state will require us
967 * to modify the accept implementation to clone rather then
968 * share the ulp context.
970 if (sk->sk_state != TCP_ESTABLISHED)
973 /* allocate tls context */
974 write_lock_bh(&sk->sk_callback_lock);
975 ctx = tls_ctx_create(sk);
981 ctx->tx_conf = TLS_BASE;
982 ctx->rx_conf = TLS_BASE;
983 update_sk_prot(sk, ctx);
985 write_unlock_bh(&sk->sk_callback_lock);
989 static void tls_update(struct sock *sk, struct proto *p,
990 void (*write_space)(struct sock *sk))
992 struct tls_context *ctx;
994 WARN_ON_ONCE(sk->sk_prot == p);
996 ctx = tls_get_ctx(sk);
998 ctx->sk_write_space = write_space;
1001 /* Pairs with lockless read in sk_clone_lock(). */
1002 WRITE_ONCE(sk->sk_prot, p);
1003 sk->sk_write_space = write_space;
1007 static u16 tls_user_config(struct tls_context *ctx, bool tx)
1009 u16 config = tx ? ctx->tx_conf : ctx->rx_conf;
1013 return TLS_CONF_BASE;
1019 return TLS_CONF_HW_RECORD;
1024 static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
1026 u16 version, cipher_type;
1027 struct tls_context *ctx;
1028 struct nlattr *start;
1031 start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS);
1036 ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data);
1041 version = ctx->prot_info.version;
1043 err = nla_put_u16(skb, TLS_INFO_VERSION, version);
1047 cipher_type = ctx->prot_info.cipher_type;
1049 err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type);
1053 err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true));
1057 err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false));
1061 if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) {
1062 err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX);
1066 if (ctx->rx_no_pad) {
1067 err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD);
1073 nla_nest_end(skb, start);
1078 nla_nest_cancel(skb, start);
1082 static size_t tls_get_info_size(const struct sock *sk)
1086 size += nla_total_size(0) + /* INET_ULP_INFO_TLS */
1087 nla_total_size(sizeof(u16)) + /* TLS_INFO_VERSION */
1088 nla_total_size(sizeof(u16)) + /* TLS_INFO_CIPHER */
1089 nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */
1090 nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */
1091 nla_total_size(0) + /* TLS_INFO_ZC_RO_TX */
1092 nla_total_size(0) + /* TLS_INFO_RX_NO_PAD */
1098 static int __net_init tls_init_net(struct net *net)
1102 net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib);
1103 if (!net->mib.tls_statistics)
1106 err = tls_proc_init(net);
1108 goto err_free_stats;
1112 free_percpu(net->mib.tls_statistics);
1116 static void __net_exit tls_exit_net(struct net *net)
1119 free_percpu(net->mib.tls_statistics);
1122 static struct pernet_operations tls_proc_ops = {
1123 .init = tls_init_net,
1124 .exit = tls_exit_net,
1127 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
1129 .owner = THIS_MODULE,
1131 .update = tls_update,
1132 .get_info = tls_get_info,
1133 .get_info_size = tls_get_info_size,
1136 static int __init tls_register(void)
1140 err = register_pernet_subsys(&tls_proc_ops);
1144 err = tls_device_init();
1146 unregister_pernet_subsys(&tls_proc_ops);
1150 tcp_register_ulp(&tcp_tls_ulp_ops);
1155 static void __exit tls_unregister(void)
1157 tcp_unregister_ulp(&tcp_tls_ulp_ops);
1158 tls_device_cleanup();
1159 unregister_pernet_subsys(&tls_proc_ops);
1162 module_init(tls_register);
1163 module_exit(tls_unregister);