tls: rx: factor SW handling out of tls_rx_one_record()
authorJakub Kicinski <kuba@kernel.org>
Fri, 22 Jul 2022 23:50:28 +0000 (16:50 -0700)
committerJakub Kicinski <kuba@kernel.org>
Tue, 26 Jul 2022 21:38:50 +0000 (14:38 -0700)
After recent changes the SW side of tls_rx_one_record() can
be nicely encapsulated in its own function. Move the pad handling
as well. This will be useful for ->zc handling in tls_decrypt_device().

Signed-off-by: Jakub Kicinski <kuba@kernel.org>
net/tls/tls_sw.c

index cb99fc1..eed52f8 100644 (file)
@@ -1409,7 +1409,7 @@ tls_alloc_clrtxt_skb(struct sock *sk, struct sk_buff *skb,
 
 /* Decrypt handlers
  *
- * tls_decrypt_sg() and tls_decrypt_device() are decrypt handlers.
+ * tls_decrypt_sw() and tls_decrypt_device() are decrypt handlers.
  * They must transform the darg in/out argument are as follows:
  *       |          Input            |         Output
  * -------------------------------------------------------------------
@@ -1589,49 +1589,22 @@ exit_free_skb:
 }
 
 static int
-tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx,
-                  struct tls_decrypt_arg *darg)
-{
-       struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
-       int err;
-
-       if (tls_ctx->rx_conf != TLS_HW)
-               return 0;
-
-       err = tls_device_decrypted(sk, tls_ctx);
-       if (err <= 0)
-               return err;
-
-       darg->zc = false;
-       darg->async = false;
-       darg->skb = tls_strp_msg(ctx);
-       ctx->recv_pkt = NULL;
-       return 1;
-}
-
-static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest,
-                            struct tls_decrypt_arg *darg)
+tls_decrypt_sw(struct sock *sk, struct tls_context *tls_ctx,
+              struct msghdr *msg, struct tls_decrypt_arg *darg)
 {
-       struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
        struct tls_prot_info *prot = &tls_ctx->prot_info;
        struct strp_msg *rxm;
        int pad, err;
 
-       err = tls_decrypt_device(sk, tls_ctx, darg);
-       if (err < 0)
-               return err;
-       if (err)
-               goto decrypt_done;
-
-       err = tls_decrypt_sg(sk, dest, NULL, darg);
+       err = tls_decrypt_sg(sk, &msg->msg_iter, NULL, darg);
        if (err < 0) {
                if (err == -EBADMSG)
                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
                return err;
        }
-       if (darg->async)
-               goto decrypt_done;
+       /* keep going even for ->async, the code below is TLS 1.3 */
+
        /* If opportunistic TLS 1.3 ZC failed retry without ZC */
        if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION &&
                     darg->tail != TLS_RECORD_TYPE_DATA)) {
@@ -1639,10 +1612,9 @@ static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest,
                if (!darg->tail)
                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXNOPADVIOL);
                TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTRETRY);
-               return tls_rx_one_record(sk, dest, darg);
+               return tls_decrypt_sw(sk, tls_ctx, msg, darg);
        }
 
-decrypt_done:
        if (darg->skb == ctx->recv_pkt)
                ctx->recv_pkt = NULL;
 
@@ -1654,6 +1626,55 @@ decrypt_done:
 
        rxm = strp_msg(darg->skb);
        rxm->full_len -= pad;
+
+       return 0;
+}
+
+static int
+tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx,
+                  struct tls_decrypt_arg *darg)
+{
+       struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+       struct tls_prot_info *prot = &tls_ctx->prot_info;
+       struct strp_msg *rxm;
+       int pad, err;
+
+       if (tls_ctx->rx_conf != TLS_HW)
+               return 0;
+
+       err = tls_device_decrypted(sk, tls_ctx);
+       if (err <= 0)
+               return err;
+
+       pad = tls_padding_length(prot, tls_strp_msg(ctx), darg);
+       if (pad < 0)
+               return pad;
+
+       darg->zc = false;
+       darg->async = false;
+       darg->skb = tls_strp_msg(ctx);
+       ctx->recv_pkt = NULL;
+
+       rxm = strp_msg(darg->skb);
+       rxm->full_len -= pad;
+       return 1;
+}
+
+static int tls_rx_one_record(struct sock *sk, struct msghdr *msg,
+                            struct tls_decrypt_arg *darg)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_prot_info *prot = &tls_ctx->prot_info;
+       struct strp_msg *rxm;
+       int err;
+
+       err = tls_decrypt_device(sk, tls_ctx, darg);
+       if (!err)
+               err = tls_decrypt_sw(sk, tls_ctx, msg, darg);
+       if (err < 0)
+               return err;
+
+       rxm = strp_msg(darg->skb);
        rxm->offset += prot->prepend_size;
        rxm->full_len -= prot->overhead_size;
        tls_advance_record_sn(sk, prot, &tls_ctx->rx);
@@ -1943,7 +1964,7 @@ int tls_sw_recvmsg(struct sock *sk,
                else
                        darg.async = false;
 
-               err = tls_rx_one_record(sk, &msg->msg_iter, &darg);
+               err = tls_rx_one_record(sk, msg, &darg);
                if (err < 0) {
                        tls_err_abort(sk, -EBADMSG);
                        goto recv_end;