tls: rx: add sockopt for enabling optimistic decrypt with TLS 1.3
authorJakub Kicinski <kuba@kernel.org>
Tue, 5 Jul 2022 23:59:24 +0000 (16:59 -0700)
committerDavid S. Miller <davem@davemloft.net>
Wed, 6 Jul 2022 11:56:35 +0000 (12:56 +0100)
Since optimisitic decrypt may add extra load in case of retries
require socket owner to explicitly opt-in.

Signed-off-by: Jakub Kicinski <kuba@kernel.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
Documentation/networking/tls.rst
include/linux/sockptr.h
include/net/tls.h
include/uapi/linux/snmp.h
include/uapi/linux/tls.h
net/tls/tls_main.c
net/tls/tls_proc.c
net/tls/tls_sw.c

index be8e10c..7a66438 100644 (file)
@@ -239,6 +239,19 @@ for the original TCP transmission and TCP retransmissions. To the receiver
 this will look like TLS records had been tampered with and will result
 in record authentication failures.
 
+TLS_RX_EXPECT_NO_PAD
+~~~~~~~~~~~~~~~~~~~~
+
+TLS 1.3 only. Expect the sender to not pad records. This allows the data
+to be decrypted directly into user space buffers with TLS 1.3.
+
+This optimization is safe to enable only if the remote end is trusted,
+otherwise it is an attack vector to doubling the TLS processing cost.
+
+If the record decrypted turns out to had been padded or is not a data
+record it will be decrypted again into a kernel buffer without zero copy.
+Such events are counted in the ``TlsDecryptRetry`` statistic.
+
 Statistics
 ==========
 
@@ -264,3 +277,8 @@ TLS implementation exposes the following per-namespace statistics
 
 - ``TlsDeviceRxResync`` -
   number of RX resyncs sent to NICs handling cryptography
+
+- ``TlsDecryptRetry`` -
+  number of RX records which had to be re-decrypted due to
+  ``TLS_RX_EXPECT_NO_PAD`` mis-prediction. Note that this counter will
+  also increment for non-data records.
index ea19341..d45902f 100644 (file)
@@ -102,4 +102,12 @@ static inline long strncpy_from_sockptr(char *dst, sockptr_t src, size_t count)
        return strncpy_from_user(dst, src.user, count);
 }
 
+static inline int check_zeroed_sockptr(sockptr_t src, size_t offset,
+                                      size_t size)
+{
+       if (!sockptr_is_kernel(src))
+               return check_zeroed_user(src.user + offset, size);
+       return memchr_inv(src.kernel + offset, 0, size) == NULL;
+}
+
 #endif /* _LINUX_SOCKPTR_H */
index 8017f17..4fc16ca 100644 (file)
@@ -149,6 +149,7 @@ struct tls_sw_context_rx {
 
        struct sk_buff *recv_pkt;
        u8 async_capable:1;
+       u8 zc_capable:1;
        atomic_t decrypt_pending;
        /* protect crypto_wait with decrypt_pending*/
        spinlock_t decrypt_compl_lock;
@@ -239,6 +240,7 @@ struct tls_context {
        u8 tx_conf:3;
        u8 rx_conf:3;
        u8 zerocopy_sendfile:1;
+       u8 rx_no_pad:1;
 
        int (*push_pending_record)(struct sock *sk, int flags);
        void (*sk_write_space)(struct sock *sk);
@@ -358,6 +360,7 @@ int tls_sk_attach(struct sock *sk, int optname, char __user *optval,
 void tls_err_abort(struct sock *sk, int err);
 
 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx);
+void tls_update_rx_zc_capable(struct tls_context *tls_ctx);
 void tls_sw_strparser_arm(struct sock *sk, struct tls_context *ctx);
 void tls_sw_strparser_done(struct tls_context *tls_ctx);
 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
index 904909d..1c9152a 100644 (file)
@@ -344,6 +344,7 @@ enum
        LINUX_MIB_TLSRXDEVICE,                  /* TlsRxDevice */
        LINUX_MIB_TLSDECRYPTERROR,              /* TlsDecryptError */
        LINUX_MIB_TLSRXDEVICERESYNC,            /* TlsRxDeviceResync */
+       LINUX_MIN_TLSDECRYPTRETRY,              /* TlsDecryptRetry */
        __LINUX_MIB_TLSMAX
 };
 
index bb8f808..f1157d8 100644 (file)
@@ -40,6 +40,7 @@
 #define TLS_TX                 1       /* Set transmit parameters */
 #define TLS_RX                 2       /* Set receive parameters */
 #define TLS_TX_ZEROCOPY_RO     3       /* TX zerocopy (only sendfile now) */
+#define TLS_RX_EXPECT_NO_PAD   4       /* Attempt opportunistic zero-copy */
 
 /* Supported versions */
 #define TLS_VERSION_MINOR(ver) ((ver) & 0xFF)
@@ -162,6 +163,7 @@ enum {
        TLS_INFO_TXCONF,
        TLS_INFO_RXCONF,
        TLS_INFO_ZC_RO_TX,
+       TLS_INFO_RX_NO_PAD,
        __TLS_INFO_MAX,
 };
 #define TLS_INFO_MAX (__TLS_INFO_MAX - 1)
index 2ffede4..1b3efc9 100644 (file)
@@ -533,6 +533,37 @@ static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval,
        return 0;
 }
 
+static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval,
+                                   int __user *optlen)
+{
+       struct tls_context *ctx = tls_get_ctx(sk);
+       unsigned int value;
+       int err, len;
+
+       if (ctx->prot_info.version != TLS_1_3_VERSION)
+               return -EINVAL;
+
+       if (get_user(len, optlen))
+               return -EFAULT;
+       if (len < sizeof(value))
+               return -EINVAL;
+
+       lock_sock(sk);
+       err = -EINVAL;
+       if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
+               value = ctx->rx_no_pad;
+       release_sock(sk);
+       if (err)
+               return err;
+
+       if (put_user(sizeof(value), optlen))
+               return -EFAULT;
+       if (copy_to_user(optval, &value, sizeof(value)))
+               return -EFAULT;
+
+       return 0;
+}
+
 static int do_tls_getsockopt(struct sock *sk, int optname,
                             char __user *optval, int __user *optlen)
 {
@@ -547,6 +578,9 @@ static int do_tls_getsockopt(struct sock *sk, int optname,
        case TLS_TX_ZEROCOPY_RO:
                rc = do_tls_getsockopt_tx_zc(sk, optval, optlen);
                break;
+       case TLS_RX_EXPECT_NO_PAD:
+               rc = do_tls_getsockopt_no_pad(sk, optval, optlen);
+               break;
        default:
                rc = -ENOPROTOOPT;
                break;
@@ -718,6 +752,38 @@ static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval,
        return 0;
 }
 
+static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval,
+                                   unsigned int optlen)
+{
+       struct tls_context *ctx = tls_get_ctx(sk);
+       u32 val;
+       int rc;
+
+       if (ctx->prot_info.version != TLS_1_3_VERSION ||
+           sockptr_is_null(optval) || optlen < sizeof(val))
+               return -EINVAL;
+
+       rc = copy_from_sockptr(&val, optval, sizeof(val));
+       if (rc)
+               return -EFAULT;
+       if (val > 1)
+               return -EINVAL;
+       rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val));
+       if (rc < 1)
+               return rc == 0 ? -EINVAL : rc;
+
+       lock_sock(sk);
+       rc = -EINVAL;
+       if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) {
+               ctx->rx_no_pad = val;
+               tls_update_rx_zc_capable(ctx);
+               rc = 0;
+       }
+       release_sock(sk);
+
+       return rc;
+}
+
 static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
                             unsigned int optlen)
 {
@@ -736,6 +802,9 @@ static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
                rc = do_tls_setsockopt_tx_zc(sk, optval, optlen);
                release_sock(sk);
                break;
+       case TLS_RX_EXPECT_NO_PAD:
+               rc = do_tls_setsockopt_no_pad(sk, optval, optlen);
+               break;
        default:
                rc = -ENOPROTOOPT;
                break;
@@ -976,6 +1045,11 @@ static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
                if (err)
                        goto nla_failure;
        }
+       if (ctx->rx_no_pad) {
+               err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD);
+               if (err)
+                       goto nla_failure;
+       }
 
        rcu_read_unlock();
        nla_nest_end(skb, start);
@@ -997,6 +1071,7 @@ static size_t tls_get_info_size(const struct sock *sk)
                nla_total_size(sizeof(u16)) +   /* TLS_INFO_RXCONF */
                nla_total_size(sizeof(u16)) +   /* TLS_INFO_TXCONF */
                nla_total_size(0) +             /* TLS_INFO_ZC_RO_TX */
+               nla_total_size(0) +             /* TLS_INFO_RX_NO_PAD */
                0;
 
        return size;
index feeceb0..0c20000 100644 (file)
@@ -18,6 +18,7 @@ static const struct snmp_mib tls_mib_list[] = {
        SNMP_MIB_ITEM("TlsRxDevice", LINUX_MIB_TLSRXDEVICE),
        SNMP_MIB_ITEM("TlsDecryptError", LINUX_MIB_TLSDECRYPTERROR),
        SNMP_MIB_ITEM("TlsRxDeviceResync", LINUX_MIB_TLSRXDEVICERESYNC),
+       SNMP_MIB_ITEM("TlsDecryptRetry", LINUX_MIN_TLSDECRYPTRETRY),
        SNMP_MIB_SENTINEL
 };
 
index 2bac576..7592b65 100644 (file)
@@ -1601,6 +1601,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
        if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION &&
                     darg->tail != TLS_RECORD_TYPE_DATA)) {
                darg->zc = false;
+               TLS_INC_STATS(sock_net(sk), LINUX_MIN_TLSDECRYPTRETRY);
                return decrypt_skb_update(sk, skb, dest, darg);
        }
 
@@ -1787,7 +1788,7 @@ int tls_sw_recvmsg(struct sock *sk,
        timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
 
        zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek &&
-                    prot->version != TLS_1_3_VERSION;
+               ctx->zc_capable;
        decrypted = 0;
        while (len && (decrypted + copied < target || ctx->recv_pkt)) {
                struct tls_decrypt_arg darg = {};
@@ -2269,6 +2270,14 @@ void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
        strp_check_rcv(&rx_ctx->strp);
 }
 
+void tls_update_rx_zc_capable(struct tls_context *tls_ctx)
+{
+       struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
+
+       rx_ctx->zc_capable = tls_ctx->rx_no_pad ||
+               tls_ctx->prot_info.version != TLS_1_3_VERSION;
+}
+
 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
@@ -2504,12 +2513,10 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
        if (sw_ctx_rx) {
                tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
 
-               if (crypto_info->version == TLS_1_3_VERSION)
-                       sw_ctx_rx->async_capable = 0;
-               else
-                       sw_ctx_rx->async_capable =
-                               !!(tfm->__crt_alg->cra_flags &
-                                  CRYPTO_ALG_ASYNC);
+               tls_update_rx_zc_capable(ctx);
+               sw_ctx_rx->async_capable =
+                       crypto_info->version != TLS_1_3_VERSION &&
+                       !!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC);
 
                /* Set up strparser */
                memset(&cb, 0, sizeof(cb));