tls: rx: return the decrypted skb via darg
authorJakub Kicinski <kuba@kernel.org>
Fri, 15 Jul 2022 05:22:31 +0000 (22:22 -0700)
committerDavid S. Miller <davem@davemloft.net>
Mon, 18 Jul 2022 10:24:11 +0000 (11:24 +0100)
Instead of using ctx->recv_pkt after decryption read the skb
from darg.skb. This moves the decision of what the "output skb"
is to the decrypt handlers. For now after decrypt handler returns
successfully ctx->recv_pkt is simply moved to darg.skb, but it
will change soon.

Note that tls_decrypt_sg() cannot clear the ctx->recv_pkt
because it gets called to re-encrypt (i.e. by the device offload).
So we need an awkward temporary if() in tls_rx_one_record().

Signed-off-by: Jakub Kicinski <kuba@kernel.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
net/tls/tls_sw.c

index 6205ad1..6a98754 100644 (file)
 #include "tls.h"
 
 struct tls_decrypt_arg {
+       struct_group(inargs,
        bool zc;
        bool async;
        u8 tail;
+       );
+
+       struct sk_buff *skb;
 };
 
 struct tls_decrypt_ctx {
@@ -1412,6 +1416,7 @@ out:
  * -------------------------------------------------------------------
  *    zc | Zero-copy decrypt allowed | Zero-copy performed
  * async | Async decrypt allowed     | Async crypto used / in progress
+ *   skb |            *              | Output skb
  */
 
 /* This function decrypts the input skb into either out_iov or in out_sg
@@ -1551,12 +1556,17 @@ fallback_to_reg_recv:
        /* Prepare and submit AEAD request */
        err = tls_do_decryption(sk, skb, sgin, sgout, dctx->iv,
                                data_len + prot->tail_size, aead_req, darg);
+       if (err)
+               goto exit_free_pages;
+
+       darg->skb = tls_strp_msg(ctx);
        if (darg->async)
                return 0;
 
        if (prot->tail_size)
                darg->tail = dctx->tail;
 
+exit_free_pages:
        /* Release the pages in case iov was mapped to pages */
        for (; pages > 0; pages--)
                put_page(sg_page(&sgout[pages]));
@@ -1569,6 +1579,7 @@ static int
 tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx,
                   struct tls_decrypt_arg *darg)
 {
+       struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
        int err;
 
        if (tls_ctx->rx_conf != TLS_HW)
@@ -1580,6 +1591,8 @@ tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx,
 
        darg->zc = false;
        darg->async = false;
+       darg->skb = tls_strp_msg(ctx);
+       ctx->recv_pkt = NULL;
        return 1;
 }
 
@@ -1604,8 +1617,11 @@ static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest,
                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
                return err;
        }
-       if (darg->async)
+       if (darg->async) {
+               if (darg->skb == ctx->recv_pkt)
+                       ctx->recv_pkt = NULL;
                goto decrypt_next;
+       }
        /* If opportunistic TLS 1.3 ZC failed retry without ZC */
        if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION &&
                     darg->tail != TLS_RECORD_TYPE_DATA)) {
@@ -1616,12 +1632,17 @@ static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest,
                return tls_rx_one_record(sk, dest, darg);
        }
 
+       if (darg->skb == ctx->recv_pkt)
+               ctx->recv_pkt = NULL;
+
 decrypt_done:
-       pad = tls_padding_length(prot, ctx->recv_pkt, darg);
-       if (pad < 0)
+       pad = tls_padding_length(prot, darg->skb, darg);
+       if (pad < 0) {
+               consume_skb(darg->skb);
                return pad;
+       }
 
-       rxm = strp_msg(ctx->recv_pkt);
+       rxm = strp_msg(darg->skb);
        rxm->full_len -= pad;
        rxm->offset += prot->prepend_size;
        rxm->full_len -= prot->overhead_size;
@@ -1663,6 +1684,7 @@ static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
 
 static void tls_rx_rec_done(struct tls_sw_context_rx *ctx)
 {
+       consume_skb(ctx->recv_pkt);
        ctx->recv_pkt = NULL;
        __strp_unpause(&ctx->strp);
 }
@@ -1872,7 +1894,7 @@ int tls_sw_recvmsg(struct sock *sk,
                ctx->zc_capable;
        decrypted = 0;
        while (len && (decrypted + copied < target || ctx->recv_pkt)) {
-               struct tls_decrypt_arg darg = {};
+               struct tls_decrypt_arg darg;
                int to_decrypt, chunk;
 
                err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, timeo);
@@ -1889,9 +1911,10 @@ int tls_sw_recvmsg(struct sock *sk,
                        goto recv_end;
                }
 
-               skb = ctx->recv_pkt;
-               rxm = strp_msg(skb);
-               tlm = tls_msg(skb);
+               memset(&darg.inargs, 0, sizeof(darg.inargs));
+
+               rxm = strp_msg(ctx->recv_pkt);
+               tlm = tls_msg(ctx->recv_pkt);
 
                to_decrypt = rxm->full_len - prot->overhead_size;
 
@@ -1911,6 +1934,10 @@ int tls_sw_recvmsg(struct sock *sk,
                        goto recv_end;
                }
 
+               skb = darg.skb;
+               rxm = strp_msg(skb);
+               tlm = tls_msg(skb);
+
                async |= darg.async;
 
                /* If the type of records being processed is not known yet,
@@ -2051,21 +2078,23 @@ ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
        if (!skb_queue_empty(&ctx->rx_list)) {
                skb = __skb_dequeue(&ctx->rx_list);
        } else {
-               struct tls_decrypt_arg darg = {};
+               struct tls_decrypt_arg darg;
 
                err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK,
                                      timeo);
                if (err <= 0)
                        goto splice_read_end;
 
+               memset(&darg.inargs, 0, sizeof(darg.inargs));
+
                err = tls_rx_one_record(sk, NULL, &darg);
                if (err < 0) {
                        tls_err_abort(sk, -EBADMSG);
                        goto splice_read_end;
                }
 
-               skb = ctx->recv_pkt;
                tls_rx_rec_done(ctx);
+               skb = darg.skb;
        }
 
        rxm = strp_msg(skb);