tls: rx: decrypt into a fresh skb
authorJakub Kicinski <kuba@kernel.org>
Fri, 15 Jul 2022 05:22:35 +0000 (22:22 -0700)
committerDavid S. Miller <davem@davemloft.net>
Mon, 18 Jul 2022 10:24:11 +0000 (11:24 +0100)
We currently CoW Rx skbs whenever we can't decrypt to a user
space buffer. The skbs can be enormous (64kB) and CoW does
a linear alloc which has a strong chance of failing under
memory pressure. Or even without, skb_cow_data() assumes
GFP_ATOMIC.

Allocate a new frag'd skb and decrypt into it. We finally
take advantage of the decrypted skb getting returned via
darg.

Signed-off-by: Jakub Kicinski <kuba@kernel.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
net/tls/tls.h
net/tls/tls_sw.c

index c818dc6..3740740 100644 (file)
@@ -39,6 +39,9 @@
 #include <linux/skmsg.h>
 #include <net/tls.h>
 
+#define TLS_PAGE_ORDER (min_t(unsigned int, PAGE_ALLOC_COSTLY_ORDER,   \
+                              TLS_MAX_PAYLOAD_SIZE >> PAGE_SHIFT))
+
 #define __TLS_INC_STATS(net, field)                            \
        __SNMP_INC_STATS((net)->mib.tls_statistics, field)
 #define TLS_INC_STATS(net, field)                              \
index 1c9a070..859ea02 100644 (file)
@@ -1383,6 +1383,29 @@ out:
        return rc;
 }
 
+static struct sk_buff *
+tls_alloc_clrtxt_skb(struct sock *sk, struct sk_buff *skb,
+                    unsigned int full_len)
+{
+       struct strp_msg *clr_rxm;
+       struct sk_buff *clr_skb;
+       int err;
+
+       clr_skb = alloc_skb_with_frags(0, full_len, TLS_PAGE_ORDER,
+                                      &err, sk->sk_allocation);
+       if (!clr_skb)
+               return NULL;
+
+       skb_copy_header(clr_skb, skb);
+       clr_skb->len = full_len;
+       clr_skb->data_len = full_len;
+
+       clr_rxm = strp_msg(clr_skb);
+       clr_rxm->offset = 0;
+
+       return clr_skb;
+}
+
 /* Decrypt handlers
  *
  * tls_decrypt_sg() and tls_decrypt_device() are decrypt handlers.
@@ -1410,34 +1433,40 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
        struct tls_prot_info *prot = &tls_ctx->prot_info;
        int n_sgin, n_sgout, aead_size, err, pages = 0;
        struct sk_buff *skb = tls_strp_msg(ctx);
-       struct strp_msg *rxm = strp_msg(skb);
-       struct tls_msg *tlm = tls_msg(skb);
+       const struct strp_msg *rxm = strp_msg(skb);
+       const struct tls_msg *tlm = tls_msg(skb);
        struct aead_request *aead_req;
-       struct sk_buff *unused;
        struct scatterlist *sgin = NULL;
        struct scatterlist *sgout = NULL;
        const int data_len = rxm->full_len - prot->overhead_size;
        int tail_pages = !!prot->tail_size;
        struct tls_decrypt_ctx *dctx;
+       struct sk_buff *clear_skb;
        int iv_offset = 0;
        u8 *mem;
 
+       n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
+                        rxm->full_len - prot->prepend_size);
+       if (n_sgin < 1)
+               return n_sgin ?: -EBADMSG;
+
        if (darg->zc && (out_iov || out_sg)) {
+               clear_skb = NULL;
+
                if (out_iov)
                        n_sgout = 1 + tail_pages +
                                iov_iter_npages_cap(out_iov, INT_MAX, data_len);
                else
                        n_sgout = sg_nents(out_sg);
-               n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
-                                rxm->full_len - prot->prepend_size);
        } else {
-               n_sgout = 0;
                darg->zc = false;
-               n_sgin = skb_cow_data(skb, 0, &unused);
-       }
 
-       if (n_sgin < 1)
-               return -EBADMSG;
+               clear_skb = tls_alloc_clrtxt_skb(sk, skb, rxm->full_len);
+               if (!clear_skb)
+                       return -ENOMEM;
+
+               n_sgout = 1 + skb_shinfo(clear_skb)->nr_frags;
+       }
 
        /* Increment to accommodate AAD */
        n_sgin = n_sgin + 1;
@@ -1449,8 +1478,10 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
        aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
        mem = kmalloc(aead_size + struct_size(dctx, sg, n_sgin + n_sgout),
                      sk->sk_allocation);
-       if (!mem)
-               return -ENOMEM;
+       if (!mem) {
+               err = -ENOMEM;
+               goto exit_free_skb;
+       }
 
        /* Segment the allocated memory */
        aead_req = (struct aead_request *)mem;
@@ -1499,33 +1530,31 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
        if (err < 0)
                goto exit_free;
 
-       if (n_sgout) {
-               if (out_iov) {
-                       sg_init_table(sgout, n_sgout);
-                       sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
+       if (clear_skb) {
+               sg_init_table(sgout, n_sgout);
+               sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
 
-                       err = tls_setup_from_iter(out_iov, data_len,
-                                                 &pages, &sgout[1],
-                                                 (n_sgout - 1 - tail_pages));
-                       if (err < 0)
-                               goto fallback_to_reg_recv;
+               err = skb_to_sgvec(clear_skb, &sgout[1], prot->prepend_size,
+                                  data_len + prot->tail_size);
+               if (err < 0)
+                       goto exit_free;
+       } else if (out_iov) {
+               sg_init_table(sgout, n_sgout);
+               sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
 
-                       if (prot->tail_size) {
-                               sg_unmark_end(&sgout[pages]);
-                               sg_set_buf(&sgout[pages + 1], &dctx->tail,
-                                          prot->tail_size);
-                               sg_mark_end(&sgout[pages + 1]);
-                       }
-               } else if (out_sg) {
-                       memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
-               } else {
-                       goto fallback_to_reg_recv;
+               err = tls_setup_from_iter(out_iov, data_len, &pages, &sgout[1],
+                                         (n_sgout - 1 - tail_pages));
+               if (err < 0)
+                       goto exit_free_pages;
+
+               if (prot->tail_size) {
+                       sg_unmark_end(&sgout[pages]);
+                       sg_set_buf(&sgout[pages + 1], &dctx->tail,
+                                  prot->tail_size);
+                       sg_mark_end(&sgout[pages + 1]);
                }
-       } else {
-fallback_to_reg_recv:
-               sgout = sgin;
-               pages = 0;
-               darg->zc = false;
+       } else if (out_sg) {
+               memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
        }
 
        /* Prepare and submit AEAD request */
@@ -1534,7 +1563,8 @@ fallback_to_reg_recv:
        if (err)
                goto exit_free_pages;
 
-       darg->skb = tls_strp_msg(ctx);
+       darg->skb = clear_skb ?: tls_strp_msg(ctx);
+       clear_skb = NULL;
 
        if (unlikely(darg->async)) {
                err = tls_strp_msg_hold(sk, skb, &ctx->async_hold);
@@ -1552,6 +1582,8 @@ exit_free_pages:
                put_page(sg_page(&sgout[pages]));
 exit_free:
        kfree(mem);
+exit_free_skb:
+       consume_skb(clear_skb);
        return err;
 }