tls: rx: wrap decrypt params in a struct
authorJakub Kicinski <kuba@kernel.org>
Fri, 8 Jul 2022 01:03:11 +0000 (18:03 -0700)
committerJakub Kicinski <kuba@kernel.org>
Sat, 9 Jul 2022 01:38:45 +0000 (18:38 -0700)
The max size of iv + aad + tail is 22B. That's smaller
than a single sg entry (32B). Don't bother with the
memory packing, just create a struct which holds the
max size of those members.

Signed-off-by: Jakub Kicinski <kuba@kernel.org>
net/tls/tls_sw.c

index 377c0f6..5965649 100644 (file)
@@ -50,6 +50,13 @@ struct tls_decrypt_arg {
        u8 tail;
 };
 
+struct tls_decrypt_ctx {
+       u8 iv[MAX_IV_SIZE];
+       u8 aad[TLS_MAX_AAD_SIZE];
+       u8 tail;
+       struct scatterlist sg[];
+};
+
 noinline void tls_err_abort(struct sock *sk, int err)
 {
        WARN_ON_ONCE(err >= 0);
@@ -1414,17 +1421,18 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
        struct tls_prot_info *prot = &tls_ctx->prot_info;
+       int n_sgin, n_sgout, aead_size, err, pages = 0;
        struct strp_msg *rxm = strp_msg(skb);
        struct tls_msg *tlm = tls_msg(skb);
-       int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
-       u8 *aad, *iv, *tail, *mem = NULL;
        struct aead_request *aead_req;
        struct sk_buff *unused;
        struct scatterlist *sgin = NULL;
        struct scatterlist *sgout = NULL;
        const int data_len = rxm->full_len - prot->overhead_size;
        int tail_pages = !!prot->tail_size;
+       struct tls_decrypt_ctx *dctx;
        int iv_offset = 0;
+       u8 *mem;
 
        if (darg->zc && (out_iov || out_sg)) {
                if (out_iov)
@@ -1446,38 +1454,30 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
        /* Increment to accommodate AAD */
        n_sgin = n_sgin + 1;
 
-       nsg = n_sgin + n_sgout;
-
-       aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
-       mem_size = aead_size + (nsg * sizeof(struct scatterlist));
-       mem_size = mem_size + TLS_MAX_AAD_SIZE;
-       mem_size = mem_size + MAX_IV_SIZE;
-       mem_size = mem_size + prot->tail_size;
-
        /* Allocate a single block of memory which contains
-        * aead_req || sgin[] || sgout[] || aad || iv || tail.
-        * This order achieves correct alignment for aead_req, sgin, sgout.
+        *   aead_req || tls_decrypt_ctx.
+        * Both structs are variable length.
         */
-       mem = kmalloc(mem_size, sk->sk_allocation);
+       aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
+       mem = kmalloc(aead_size + struct_size(dctx, sg, n_sgin + n_sgout),
+                     sk->sk_allocation);
        if (!mem)
                return -ENOMEM;
 
        /* Segment the allocated memory */
        aead_req = (struct aead_request *)mem;
-       sgin = (struct scatterlist *)(mem + aead_size);
-       sgout = sgin + n_sgin;
-       aad = (u8 *)(sgout + n_sgout);
-       iv = aad + TLS_MAX_AAD_SIZE;
-       tail = iv + MAX_IV_SIZE;
+       dctx = (struct tls_decrypt_ctx *)(mem + aead_size);
+       sgin = &dctx->sg[0];
+       sgout = &dctx->sg[n_sgin];
 
        /* For CCM based ciphers, first byte of nonce+iv is a constant */
        switch (prot->cipher_type) {
        case TLS_CIPHER_AES_CCM_128:
-               iv[0] = TLS_AES_CCM_IV_B0_BYTE;
+               dctx->iv[0] = TLS_AES_CCM_IV_B0_BYTE;
                iv_offset = 1;
                break;
        case TLS_CIPHER_SM4_CCM:
-               iv[0] = TLS_SM4_CCM_IV_B0_BYTE;
+               dctx->iv[0] = TLS_SM4_CCM_IV_B0_BYTE;
                iv_offset = 1;
                break;
        }
@@ -1485,28 +1485,28 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
        /* Prepare IV */
        if (prot->version == TLS_1_3_VERSION ||
            prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
-               memcpy(iv + iv_offset, tls_ctx->rx.iv,
+               memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv,
                       prot->iv_size + prot->salt_size);
        } else {
                err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
-                                   iv + iv_offset + prot->salt_size,
+                                   &dctx->iv[iv_offset] + prot->salt_size,
                                    prot->iv_size);
                if (err < 0) {
                        kfree(mem);
                        return err;
                }
-               memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size);
+               memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv, prot->salt_size);
        }
-       xor_iv_with_seq(prot, iv + iv_offset, tls_ctx->rx.rec_seq);
+       xor_iv_with_seq(prot, &dctx->iv[iv_offset], tls_ctx->rx.rec_seq);
 
        /* Prepare AAD */
-       tls_make_aad(aad, rxm->full_len - prot->overhead_size +
+       tls_make_aad(dctx->aad, rxm->full_len - prot->overhead_size +
                     prot->tail_size,
                     tls_ctx->rx.rec_seq, tlm->control, prot);
 
        /* Prepare sgin */
        sg_init_table(sgin, n_sgin);
-       sg_set_buf(&sgin[0], aad, prot->aad_size);
+       sg_set_buf(&sgin[0], dctx->aad, prot->aad_size);
        err = skb_to_sgvec(skb, &sgin[1],
                           rxm->offset + prot->prepend_size,
                           rxm->full_len - prot->prepend_size);
@@ -1518,7 +1518,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
        if (n_sgout) {
                if (out_iov) {
                        sg_init_table(sgout, n_sgout);
-                       sg_set_buf(&sgout[0], aad, prot->aad_size);
+                       sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
 
                        err = tls_setup_from_iter(out_iov, data_len,
                                                  &pages, &sgout[1],
@@ -1528,7 +1528,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 
                        if (prot->tail_size) {
                                sg_unmark_end(&sgout[pages]);
-                               sg_set_buf(&sgout[pages + 1], tail,
+                               sg_set_buf(&sgout[pages + 1], &dctx->tail,
                                           prot->tail_size);
                                sg_mark_end(&sgout[pages + 1]);
                        }
@@ -1545,13 +1545,13 @@ fallback_to_reg_recv:
        }
 
        /* Prepare and submit AEAD request */
-       err = tls_do_decryption(sk, skb, sgin, sgout, iv,
+       err = tls_do_decryption(sk, skb, sgin, sgout, dctx->iv,
                                data_len + prot->tail_size, aead_req, darg);
        if (darg->async)
                return 0;
 
        if (prot->tail_size)
-               darg->tail = *tail;
+               darg->tail = dctx->tail;
 
        /* Release the pages in case iov was mapped to pages */
        for (; pages > 0; pages--)