Merge branch 'next-general' of git://git.kernel.org/pub/scm/linux/kernel/git/jmorris...
[linux-2.6-microblaze.git] / net / tls / tls_sw.c
index 057a558..4dc766b 100644 (file)
  * SOFTWARE.
  */
 
+#include <linux/sched/signal.h>
 #include <linux/module.h>
 #include <crypto/aead.h>
 
+#include <net/strparser.h>
 #include <net/tls.h>
 
+static int tls_do_decryption(struct sock *sk,
+                            struct scatterlist *sgin,
+                            struct scatterlist *sgout,
+                            char *iv_recv,
+                            size_t data_len,
+                            struct sk_buff *skb,
+                            gfp_t flags)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct strp_msg *rxm = strp_msg(skb);
+       struct aead_request *aead_req;
+
+       int ret;
+       unsigned int req_size = sizeof(struct aead_request) +
+               crypto_aead_reqsize(ctx->aead_recv);
+
+       aead_req = kzalloc(req_size, flags);
+       if (!aead_req)
+               return -ENOMEM;
+
+       aead_request_set_tfm(aead_req, ctx->aead_recv);
+       aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
+       aead_request_set_crypt(aead_req, sgin, sgout,
+                              data_len + tls_ctx->rx.tag_size,
+                              (u8 *)iv_recv);
+       aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
+                                 crypto_req_done, &ctx->async_wait);
+
+       ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
+
+       if (ret < 0)
+               goto out;
+
+       rxm->offset += tls_ctx->rx.prepend_size;
+       rxm->full_len -= tls_ctx->rx.overhead_size;
+       tls_advance_record_sn(sk, &tls_ctx->rx);
+
+       ctx->decrypted = true;
+
+       ctx->saved_data_ready(sk);
+
+out:
+       kfree(aead_req);
+       return ret;
+}
+
 static void trim_sg(struct sock *sk, struct scatterlist *sg,
                    int *sg_num_elem, unsigned int *sg_size, int target_size)
 {
@@ -79,7 +128,7 @@ static void trim_both_sgl(struct sock *sk, int target_size)
                target_size);
 
        if (target_size > 0)
-               target_size += tls_ctx->overhead_size;
+               target_size += tls_ctx->tx.overhead_size;
 
        trim_sg(sk, ctx->sg_encrypted_data,
                &ctx->sg_encrypted_num_elem,
@@ -152,21 +201,21 @@ static int tls_do_encryption(struct tls_context *tls_ctx,
        if (!aead_req)
                return -ENOMEM;
 
-       ctx->sg_encrypted_data[0].offset += tls_ctx->prepend_size;
-       ctx->sg_encrypted_data[0].length -= tls_ctx->prepend_size;
+       ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size;
+       ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size;
 
        aead_request_set_tfm(aead_req, ctx->aead_send);
        aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
        aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out,
-                              data_len, tls_ctx->iv);
+                              data_len, tls_ctx->tx.iv);
 
        aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
                                  crypto_req_done, &ctx->async_wait);
 
        rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait);
 
-       ctx->sg_encrypted_data[0].offset -= tls_ctx->prepend_size;
-       ctx->sg_encrypted_data[0].length += tls_ctx->prepend_size;
+       ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
+       ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
 
        kfree(aead_req);
        return rc;
@@ -183,7 +232,7 @@ static int tls_push_record(struct sock *sk, int flags,
        sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1);
 
        tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size,
-                    tls_ctx->rec_seq, tls_ctx->rec_seq_size,
+                    tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
                     record_type);
 
        tls_fill_prepend(tls_ctx,
@@ -214,9 +263,9 @@ static int tls_push_record(struct sock *sk, int flags,
        /* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */
        rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags);
        if (rc < 0 && rc != -EAGAIN)
-               tls_err_abort(sk);
+               tls_err_abort(sk, EBADMSG);
 
-       tls_advance_record_sn(sk, tls_ctx);
+       tls_advance_record_sn(sk, &tls_ctx->tx);
        return rc;
 }
 
@@ -226,23 +275,24 @@ static int tls_sw_push_pending_record(struct sock *sk, int flags)
 }
 
 static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
-                             int length)
+                             int length, int *pages_used,
+                             unsigned int *size_used,
+                             struct scatterlist *to, int to_max_pages,
+                             bool charge)
 {
-       struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
        struct page *pages[MAX_SKB_FRAGS];
 
        size_t offset;
        ssize_t copied, use;
        int i = 0;
-       unsigned int size = ctx->sg_plaintext_size;
-       int num_elem = ctx->sg_plaintext_num_elem;
+       unsigned int size = *size_used;
+       int num_elem = *pages_used;
        int rc = 0;
        int maxpages;
 
        while (length > 0) {
                i = 0;
-               maxpages = ARRAY_SIZE(ctx->sg_plaintext_data) - num_elem;
+               maxpages = to_max_pages - num_elem;
                if (maxpages == 0) {
                        rc = -EFAULT;
                        goto out;
@@ -262,10 +312,11 @@ static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
                while (copied) {
                        use = min_t(int, copied, PAGE_SIZE - offset);
 
-                       sg_set_page(&ctx->sg_plaintext_data[num_elem],
+                       sg_set_page(&to[num_elem],
                                    pages[i], use, offset);
-                       sg_unmark_end(&ctx->sg_plaintext_data[num_elem]);
-                       sk_mem_charge(sk, use);
+                       sg_unmark_end(&to[num_elem]);
+                       if (charge)
+                               sk_mem_charge(sk, use);
 
                        offset = 0;
                        copied -= use;
@@ -276,8 +327,9 @@ static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
        }
 
 out:
-       ctx->sg_plaintext_size = size;
-       ctx->sg_plaintext_num_elem = num_elem;
+       *size_used = size;
+       *pages_used = num_elem;
+
        return rc;
 }
 
@@ -354,7 +406,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
                }
 
                required_size = ctx->sg_plaintext_size + try_to_copy +
-                               tls_ctx->overhead_size;
+                               tls_ctx->tx.overhead_size;
 
                if (!sk_stream_memory_free(sk))
                        goto wait_for_sndbuf;
@@ -374,7 +426,11 @@ alloc_encrypted:
 
                if (full_record || eor) {
                        ret = zerocopy_from_iter(sk, &msg->msg_iter,
-                                                try_to_copy);
+                               try_to_copy, &ctx->sg_plaintext_num_elem,
+                               &ctx->sg_plaintext_size,
+                               ctx->sg_plaintext_data,
+                               ARRAY_SIZE(ctx->sg_plaintext_data),
+                               true);
                        if (ret)
                                goto fallback_to_reg_send;
 
@@ -413,7 +469,7 @@ alloc_plaintext:
                                &ctx->sg_encrypted_num_elem,
                                &ctx->sg_encrypted_size,
                                ctx->sg_plaintext_size +
-                               tls_ctx->overhead_size);
+                               tls_ctx->tx.overhead_size);
                }
 
                ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
@@ -505,7 +561,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
                        full_record = true;
                }
                required_size = ctx->sg_plaintext_size + copy +
-                             tls_ctx->overhead_size;
+                             tls_ctx->tx.overhead_size;
 
                if (!sk_stream_memory_free(sk))
                        goto wait_for_sndbuf;
@@ -574,13 +630,404 @@ sendpage_end:
        return ret;
 }
 
-void tls_sw_free_tx_resources(struct sock *sk)
+static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
+                                    long timeo, int *err)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct sk_buff *skb;
+       DEFINE_WAIT_FUNC(wait, woken_wake_function);
+
+       while (!(skb = ctx->recv_pkt)) {
+               if (sk->sk_err) {
+                       *err = sock_error(sk);
+                       return NULL;
+               }
+
+               if (sock_flag(sk, SOCK_DONE))
+                       return NULL;
+
+               if ((flags & MSG_DONTWAIT) || !timeo) {
+                       *err = -EAGAIN;
+                       return NULL;
+               }
+
+               add_wait_queue(sk_sleep(sk), &wait);
+               sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
+               sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
+               sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
+               remove_wait_queue(sk_sleep(sk), &wait);
+
+               /* Handle signals */
+               if (signal_pending(current)) {
+                       *err = sock_intr_errno(timeo);
+                       return NULL;
+               }
+       }
+
+       return skb;
+}
+
+static int decrypt_skb(struct sock *sk, struct sk_buff *skb,
+                      struct scatterlist *sgout)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + tls_ctx->rx.iv_size];
+       struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2];
+       struct scatterlist *sgin = &sgin_arr[0];
+       struct strp_msg *rxm = strp_msg(skb);
+       int ret, nsg = ARRAY_SIZE(sgin_arr);
+       char aad_recv[TLS_AAD_SPACE_SIZE];
+       struct sk_buff *unused;
+
+       ret = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
+                           iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
+                           tls_ctx->rx.iv_size);
+       if (ret < 0)
+               return ret;
+
+       memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+       if (!sgout) {
+               nsg = skb_cow_data(skb, 0, &unused) + 1;
+               sgin = kmalloc_array(nsg, sizeof(*sgin), sk->sk_allocation);
+               if (!sgout)
+                       sgout = sgin;
+       }
+
+       sg_init_table(sgin, nsg);
+       sg_set_buf(&sgin[0], aad_recv, sizeof(aad_recv));
+
+       nsg = skb_to_sgvec(skb, &sgin[1],
+                          rxm->offset + tls_ctx->rx.prepend_size,
+                          rxm->full_len - tls_ctx->rx.prepend_size);
+
+       tls_make_aad(aad_recv,
+                    rxm->full_len - tls_ctx->rx.overhead_size,
+                    tls_ctx->rx.rec_seq,
+                    tls_ctx->rx.rec_seq_size,
+                    ctx->control);
+
+       ret = tls_do_decryption(sk, sgin, sgout, iv,
+                               rxm->full_len - tls_ctx->rx.overhead_size,
+                               skb, sk->sk_allocation);
+
+       if (sgin != &sgin_arr[0])
+               kfree(sgin);
+
+       return ret;
+}
+
+static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
+                              unsigned int len)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct strp_msg *rxm = strp_msg(skb);
+
+       if (len < rxm->full_len) {
+               rxm->offset += len;
+               rxm->full_len -= len;
+
+               return false;
+       }
+
+       /* Finished with message */
+       ctx->recv_pkt = NULL;
+       kfree_skb(skb);
+       strp_unpause(&ctx->strp);
+
+       return true;
+}
+
+int tls_sw_recvmsg(struct sock *sk,
+                  struct msghdr *msg,
+                  size_t len,
+                  int nonblock,
+                  int flags,
+                  int *addr_len)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       unsigned char control;
+       struct strp_msg *rxm;
+       struct sk_buff *skb;
+       ssize_t copied = 0;
+       bool cmsg = false;
+       int err = 0;
+       long timeo;
+
+       flags |= nonblock;
+
+       if (unlikely(flags & MSG_ERRQUEUE))
+               return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
+
+       lock_sock(sk);
+
+       timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+       do {
+               bool zc = false;
+               int chunk = 0;
+
+               skb = tls_wait_data(sk, flags, timeo, &err);
+               if (!skb)
+                       goto recv_end;
+
+               rxm = strp_msg(skb);
+               if (!cmsg) {
+                       int cerr;
+
+                       cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
+                                       sizeof(ctx->control), &ctx->control);
+                       cmsg = true;
+                       control = ctx->control;
+                       if (ctx->control != TLS_RECORD_TYPE_DATA) {
+                               if (cerr || msg->msg_flags & MSG_CTRUNC) {
+                                       err = -EIO;
+                                       goto recv_end;
+                               }
+                       }
+               } else if (control != ctx->control) {
+                       goto recv_end;
+               }
+
+               if (!ctx->decrypted) {
+                       int page_count;
+                       int to_copy;
+
+                       page_count = iov_iter_npages(&msg->msg_iter,
+                                                    MAX_SKB_FRAGS);
+                       to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
+                       if (to_copy <= len && page_count < MAX_SKB_FRAGS &&
+                           likely(!(flags & MSG_PEEK)))  {
+                               struct scatterlist sgin[MAX_SKB_FRAGS + 1];
+                               char unused[21];
+                               int pages = 0;
+
+                               zc = true;
+                               sg_init_table(sgin, MAX_SKB_FRAGS + 1);
+                               sg_set_buf(&sgin[0], unused, 13);
+
+                               err = zerocopy_from_iter(sk, &msg->msg_iter,
+                                                        to_copy, &pages,
+                                                        &chunk, &sgin[1],
+                                                        MAX_SKB_FRAGS, false);
+                               if (err < 0)
+                                       goto fallback_to_reg_recv;
+
+                               err = decrypt_skb(sk, skb, sgin);
+                               for (; pages > 0; pages--)
+                                       put_page(sg_page(&sgin[pages]));
+                               if (err < 0) {
+                                       tls_err_abort(sk, EBADMSG);
+                                       goto recv_end;
+                               }
+                       } else {
+fallback_to_reg_recv:
+                               err = decrypt_skb(sk, skb, NULL);
+                               if (err < 0) {
+                                       tls_err_abort(sk, EBADMSG);
+                                       goto recv_end;
+                               }
+                       }
+                       ctx->decrypted = true;
+               }
+
+               if (!zc) {
+                       chunk = min_t(unsigned int, rxm->full_len, len);
+                       err = skb_copy_datagram_msg(skb, rxm->offset, msg,
+                                                   chunk);
+                       if (err < 0)
+                               goto recv_end;
+               }
+
+               copied += chunk;
+               len -= chunk;
+               if (likely(!(flags & MSG_PEEK))) {
+                       u8 control = ctx->control;
+
+                       if (tls_sw_advance_skb(sk, skb, chunk)) {
+                               /* Return full control message to
+                                * userspace before trying to parse
+                                * another message type
+                                */
+                               msg->msg_flags |= MSG_EOR;
+                               if (control != TLS_RECORD_TYPE_DATA)
+                                       goto recv_end;
+                       }
+               }
+       } while (len);
+
+recv_end:
+       release_sock(sk);
+       return copied ? : err;
+}
+
+ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
+                          struct pipe_inode_info *pipe,
+                          size_t len, unsigned int flags)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
+       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct strp_msg *rxm = NULL;
+       struct sock *sk = sock->sk;
+       struct sk_buff *skb;
+       ssize_t copied = 0;
+       int err = 0;
+       long timeo;
+       int chunk;
+
+       lock_sock(sk);
+
+       timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+
+       skb = tls_wait_data(sk, flags, timeo, &err);
+       if (!skb)
+               goto splice_read_end;
+
+       /* splice does not support reading control messages */
+       if (ctx->control != TLS_RECORD_TYPE_DATA) {
+               err = -ENOTSUPP;
+               goto splice_read_end;
+       }
+
+       if (!ctx->decrypted) {
+               err = decrypt_skb(sk, skb, NULL);
+
+               if (err < 0) {
+                       tls_err_abort(sk, EBADMSG);
+                       goto splice_read_end;
+               }
+               ctx->decrypted = true;
+       }
+       rxm = strp_msg(skb);
+
+       chunk = min_t(unsigned int, rxm->full_len, len);
+       copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
+       if (copied < 0)
+               goto splice_read_end;
+
+       if (likely(!(flags & MSG_PEEK)))
+               tls_sw_advance_skb(sk, skb, copied);
+
+splice_read_end:
+       release_sock(sk);
+       return copied ? : err;
+}
+
+unsigned int tls_sw_poll(struct file *file, struct socket *sock,
+                        struct poll_table_struct *wait)
+{
+       unsigned int ret;
+       struct sock *sk = sock->sk;
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+
+       /* Grab POLLOUT and POLLHUP from the underlying socket */
+       ret = ctx->sk_poll(file, sock, wait);
+
+       /* Clear POLLIN bits, and set based on recv_pkt */
+       ret &= ~(POLLIN | POLLRDNORM);
+       if (ctx->recv_pkt)
+               ret |= POLLIN | POLLRDNORM;
+
+       return ret;
+}
+
+static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
+       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       char header[tls_ctx->rx.prepend_size];
+       struct strp_msg *rxm = strp_msg(skb);
+       size_t cipher_overhead;
+       size_t data_len = 0;
+       int ret;
+
+       /* Verify that we have a full TLS header, or wait for more data */
+       if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
+               return 0;
+
+       /* Linearize header to local buffer */
+       ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
+
+       if (ret < 0)
+               goto read_failure;
+
+       ctx->control = header[0];
+
+       data_len = ((header[4] & 0xFF) | (header[3] << 8));
+
+       cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;
+
+       if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
+               ret = -EMSGSIZE;
+               goto read_failure;
+       }
+       if (data_len < cipher_overhead) {
+               ret = -EBADMSG;
+               goto read_failure;
+       }
+
+       if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.version) ||
+           header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.version)) {
+               ret = -EINVAL;
+               goto read_failure;
+       }
+
+       return data_len + TLS_HEADER_SIZE;
+
+read_failure:
+       tls_err_abort(strp->sk, ret);
+
+       return ret;
+}
+
+static void tls_queue(struct strparser *strp, struct sk_buff *skb)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
+       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+       struct strp_msg *rxm;
+
+       rxm = strp_msg(skb);
+
+       ctx->decrypted = false;
+
+       ctx->recv_pkt = skb;
+       strp_pause(strp);
+
+       strp->sk->sk_state_change(strp->sk);
+}
+
+static void tls_data_ready(struct sock *sk)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+
+       strp_data_ready(&ctx->strp);
+}
+
+void tls_sw_free_resources(struct sock *sk)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 
        if (ctx->aead_send)
                crypto_free_aead(ctx->aead_send);
+       if (ctx->aead_recv) {
+               if (ctx->recv_pkt) {
+                       kfree_skb(ctx->recv_pkt);
+                       ctx->recv_pkt = NULL;
+               }
+               crypto_free_aead(ctx->aead_recv);
+               strp_stop(&ctx->strp);
+               write_lock_bh(&sk->sk_callback_lock);
+               sk->sk_data_ready = ctx->saved_data_ready;
+               write_unlock_bh(&sk->sk_callback_lock);
+               release_sock(sk);
+               strp_done(&ctx->strp);
+               lock_sock(sk);
+       }
 
        tls_free_both_sg(sk);
 
@@ -588,12 +1035,15 @@ void tls_sw_free_tx_resources(struct sock *sk)
        kfree(tls_ctx);
 }
 
-int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
+int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 {
        char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE];
        struct tls_crypto_info *crypto_info;
        struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
        struct tls_sw_context *sw_ctx;
+       struct cipher_context *cctx;
+       struct crypto_aead **aead;
+       struct strp_callbacks cb;
        u16 nonce_size, tag_size, iv_size, rec_seq_size;
        char *iv, *rec_seq;
        int rc = 0;
@@ -603,22 +1053,29 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
                goto out;
        }
 
-       if (ctx->priv_ctx) {
-               rc = -EEXIST;
-               goto out;
-       }
-
-       sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL);
-       if (!sw_ctx) {
-               rc = -ENOMEM;
-               goto out;
+       if (!ctx->priv_ctx) {
+               sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL);
+               if (!sw_ctx) {
+                       rc = -ENOMEM;
+                       goto out;
+               }
+               crypto_init_wait(&sw_ctx->async_wait);
+       } else {
+               sw_ctx = ctx->priv_ctx;
        }
 
-       crypto_init_wait(&sw_ctx->async_wait);
-
        ctx->priv_ctx = (struct tls_offload_context *)sw_ctx;
 
-       crypto_info = &ctx->crypto_send;
+       if (tx) {
+               crypto_info = &ctx->crypto_send;
+               cctx = &ctx->tx;
+               aead = &sw_ctx->aead_send;
+       } else {
+               crypto_info = &ctx->crypto_recv;
+               cctx = &ctx->rx;
+               aead = &sw_ctx->aead_recv;
+       }
+
        switch (crypto_info->cipher_type) {
        case TLS_CIPHER_AES_GCM_128: {
                nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
@@ -637,46 +1094,49 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
                goto free_priv;
        }
 
-       ctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
-       ctx->tag_size = tag_size;
-       ctx->overhead_size = ctx->prepend_size + ctx->tag_size;
-       ctx->iv_size = iv_size;
-       ctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, GFP_KERNEL);
-       if (!ctx->iv) {
+       cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
+       cctx->tag_size = tag_size;
+       cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
+       cctx->iv_size = iv_size;
+       cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
+                          GFP_KERNEL);
+       if (!cctx->iv) {
                rc = -ENOMEM;
                goto free_priv;
        }
-       memcpy(ctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
-       memcpy(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
-       ctx->rec_seq_size = rec_seq_size;
-       ctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
-       if (!ctx->rec_seq) {
+       memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+       memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
+       cctx->rec_seq_size = rec_seq_size;
+       cctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
+       if (!cctx->rec_seq) {
                rc = -ENOMEM;
                goto free_iv;
        }
-       memcpy(ctx->rec_seq, rec_seq, rec_seq_size);
-
-       sg_init_table(sw_ctx->sg_encrypted_data,
-                     ARRAY_SIZE(sw_ctx->sg_encrypted_data));
-       sg_init_table(sw_ctx->sg_plaintext_data,
-                     ARRAY_SIZE(sw_ctx->sg_plaintext_data));
-
-       sg_init_table(sw_ctx->sg_aead_in, 2);
-       sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space,
-                  sizeof(sw_ctx->aad_space));
-       sg_unmark_end(&sw_ctx->sg_aead_in[1]);
-       sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data);
-       sg_init_table(sw_ctx->sg_aead_out, 2);
-       sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space,
-                  sizeof(sw_ctx->aad_space));
-       sg_unmark_end(&sw_ctx->sg_aead_out[1]);
-       sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data);
-
-       if (!sw_ctx->aead_send) {
-               sw_ctx->aead_send = crypto_alloc_aead("gcm(aes)", 0, 0);
-               if (IS_ERR(sw_ctx->aead_send)) {
-                       rc = PTR_ERR(sw_ctx->aead_send);
-                       sw_ctx->aead_send = NULL;
+       memcpy(cctx->rec_seq, rec_seq, rec_seq_size);
+
+       if (tx) {
+               sg_init_table(sw_ctx->sg_encrypted_data,
+                             ARRAY_SIZE(sw_ctx->sg_encrypted_data));
+               sg_init_table(sw_ctx->sg_plaintext_data,
+                             ARRAY_SIZE(sw_ctx->sg_plaintext_data));
+
+               sg_init_table(sw_ctx->sg_aead_in, 2);
+               sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space,
+                          sizeof(sw_ctx->aad_space));
+               sg_unmark_end(&sw_ctx->sg_aead_in[1]);
+               sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data);
+               sg_init_table(sw_ctx->sg_aead_out, 2);
+               sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space,
+                          sizeof(sw_ctx->aad_space));
+               sg_unmark_end(&sw_ctx->sg_aead_out[1]);
+               sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data);
+       }
+
+       if (!*aead) {
+               *aead = crypto_alloc_aead("gcm(aes)", 0, 0);
+               if (IS_ERR(*aead)) {
+                       rc = PTR_ERR(*aead);
+                       *aead = NULL;
                        goto free_rec_seq;
                }
        }
@@ -685,24 +1145,44 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
 
        memcpy(keyval, gcm_128_info->key, TLS_CIPHER_AES_GCM_128_KEY_SIZE);
 
-       rc = crypto_aead_setkey(sw_ctx->aead_send, keyval,
+       rc = crypto_aead_setkey(*aead, keyval,
                                TLS_CIPHER_AES_GCM_128_KEY_SIZE);
        if (rc)
                goto free_aead;
 
-       rc = crypto_aead_setauthsize(sw_ctx->aead_send, ctx->tag_size);
-       if (!rc)
-               return 0;
+       rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
+       if (rc)
+               goto free_aead;
+
+       if (!tx) {
+               /* Set up strparser */
+               memset(&cb, 0, sizeof(cb));
+               cb.rcv_msg = tls_queue;
+               cb.parse_msg = tls_read_size;
+
+               strp_init(&sw_ctx->strp, sk, &cb);
+
+               write_lock_bh(&sk->sk_callback_lock);
+               sw_ctx->saved_data_ready = sk->sk_data_ready;
+               sk->sk_data_ready = tls_data_ready;
+               write_unlock_bh(&sk->sk_callback_lock);
+
+               sw_ctx->sk_poll = sk->sk_socket->ops->poll;
+
+               strp_check_rcv(&sw_ctx->strp);
+       }
+
+       goto out;
 
 free_aead:
-       crypto_free_aead(sw_ctx->aead_send);
-       sw_ctx->aead_send = NULL;
+       crypto_free_aead(*aead);
+       *aead = NULL;
 free_rec_seq:
-       kfree(ctx->rec_seq);
-       ctx->rec_seq = NULL;
+       kfree(cctx->rec_seq);
+       cctx->rec_seq = NULL;
 free_iv:
-       kfree(ctx->iv);
-       ctx->iv = NULL;
+       kfree(ctx->tx.iv);
+       ctx->tx.iv = NULL;
 free_priv:
        kfree(ctx->priv_ctx);
        ctx->priv_ctx = NULL;