tls: rx: don't report text length from the bowels of decrypt
authorJakub Kicinski <kuba@kernel.org>
Fri, 8 Apr 2022 18:31:25 +0000 (11:31 -0700)
committerDavid S. Miller <davem@davemloft.net>
Sun, 10 Apr 2022 16:32:11 +0000 (17:32 +0100)
We plumb pointer to chunk all the way to the decryption method.
It's set to the length of the text when decrypt_skb_update()
returns.

I think the code is written this way because original TLS
implementation passed &chunk to zerocopy_from_iter() and this
was carried forward as the code gotten more complex, without
any refactoring.

The fix for peek() introduced a new variable - to_decrypt
which for all practical purposes is what chunk is going to
get set to. Spare ourselves the pointer passing, use to_decrypt.

Use this opportunity to clean things up a little further.

Note that chunk / to_decrypt was mostly needed for the async
path, since the sync path would access rxm->full_len (decryption
transforms full_len from record size to text size). Use the
right source of truth more explicitly.

We have three cases:
 - async - it's TLS 1.2 only, so chunk == to_decrypt, but we
           need the min() because to_decrypt is a whole record
   and we don't want to underflow len. Note that we can't
   handle partial record by falling back to sync as it
   would introduce reordering against records in flight.
 - zc - again, TLS 1.2 only for now, so chunk == to_decrypt,
        we don't do zc if len < to_decrypt, no need to check again.
 - normal - it already handles chunk > len, we can factor out the
            assignment to rxm->full_len and share it with zc.

Signed-off-by: Jakub Kicinski <kuba@kernel.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
net/tls/tls_sw.c

index 86f77f8..c321c5f 100644 (file)
@@ -1412,7 +1412,7 @@ out:
 static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
                            struct iov_iter *out_iov,
                            struct scatterlist *out_sg,
-                           int *chunk, bool *zc, bool async)
+                           bool *zc, bool async)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@@ -1526,7 +1526,6 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
                                                  (n_sgout - 1));
                        if (err < 0)
                                goto fallback_to_reg_recv;
-                       *chunk = data_len;
                } else if (out_sg) {
                        memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
                } else {
@@ -1536,7 +1535,6 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 fallback_to_reg_recv:
                sgout = sgin;
                pages = 0;
-               *chunk = data_len;
                *zc = false;
        }
 
@@ -1555,8 +1553,7 @@ fallback_to_reg_recv:
 }
 
 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
-                             struct iov_iter *dest, int *chunk, bool *zc,
-                             bool async)
+                             struct iov_iter *dest, bool *zc, bool async)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_prot_info *prot = &tls_ctx->prot_info;
@@ -1580,7 +1577,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
                }
        }
 
-       err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async);
+       err = decrypt_internal(sk, skb, dest, NULL, zc, async);
        if (err < 0) {
                if (err == -EINPROGRESS)
                        tls_advance_record_sn(sk, prot, &tls_ctx->rx);
@@ -1607,9 +1604,8 @@ int decrypt_skb(struct sock *sk, struct sk_buff *skb,
                struct scatterlist *sgout)
 {
        bool zc = true;
-       int chunk;
 
-       return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false);
+       return decrypt_internal(sk, skb, NULL, sgout, &zc, false);
 }
 
 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
@@ -1799,9 +1795,8 @@ int tls_sw_recvmsg(struct sock *sk,
        num_async = 0;
        while (len && (decrypted + copied < target || ctx->recv_pkt)) {
                bool retain_skb = false;
+               int to_decrypt, chunk;
                bool zc = false;
-               int to_decrypt;
-               int chunk = 0;
                bool async_capable;
                bool async = false;
 
@@ -1838,7 +1833,7 @@ int tls_sw_recvmsg(struct sock *sk,
                        async_capable = false;
 
                err = decrypt_skb_update(sk, skb, &msg->msg_iter,
-                                        &chunk, &zc, async_capable);
+                                        &zc, async_capable);
                if (err < 0 && err != -EINPROGRESS) {
                        tls_err_abort(sk, -EBADMSG);
                        goto recv_end;
@@ -1876,8 +1871,13 @@ int tls_sw_recvmsg(struct sock *sk,
                        }
                }
 
-               if (async)
+               if (async) {
+                       /* TLS 1.2-only, to_decrypt must be text length */
+                       chunk = min_t(int, to_decrypt, len);
                        goto pick_next_record;
+               }
+               /* TLS 1.3 may have updated the length by more than overhead */
+               chunk = rxm->full_len;
 
                if (!zc) {
                        if (bpf_strp_enabled) {
@@ -1893,11 +1893,9 @@ int tls_sw_recvmsg(struct sock *sk,
                                }
                        }
 
-                       if (rxm->full_len > len) {
+                       if (chunk > len) {
                                retain_skb = true;
                                chunk = len;
-                       } else {
-                               chunk = rxm->full_len;
                        }
 
                        err = skb_copy_datagram_msg(skb, rxm->offset,
@@ -1912,9 +1910,6 @@ int tls_sw_recvmsg(struct sock *sk,
                }
 
 pick_next_record:
-               if (chunk > len)
-                       chunk = len;
-
                decrypted += chunk;
                len -= chunk;
 
@@ -2016,7 +2011,7 @@ ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
                if (!skb)
                        goto splice_read_end;
 
-               err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
+               err = decrypt_skb_update(sk, skb, NULL, &zc, false);
                if (err < 0) {
                        tls_err_abort(sk, -EBADMSG);
                        goto splice_read_end;