tls: rx: return the decrypted skb via darg
[linux-2.6-microblaze.git] / net / tls / tls_sw.c
1 /*
2  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4  * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
5  * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
6  * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
7  * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io
8  *
9  * This software is available to you under a choice of one of two
10  * licenses.  You may choose to be licensed under the terms of the GNU
11  * General Public License (GPL) Version 2, available from the file
12  * COPYING in the main directory of this source tree, or the
13  * OpenIB.org BSD license below:
14  *
15  *     Redistribution and use in source and binary forms, with or
16  *     without modification, are permitted provided that the following
17  *     conditions are met:
18  *
19  *      - Redistributions of source code must retain the above
20  *        copyright notice, this list of conditions and the following
21  *        disclaimer.
22  *
23  *      - Redistributions in binary form must reproduce the above
24  *        copyright notice, this list of conditions and the following
25  *        disclaimer in the documentation and/or other materials
26  *        provided with the distribution.
27  *
28  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
29  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
30  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
31  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
32  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
33  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
34  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
35  * SOFTWARE.
36  */
37
38 #include <linux/bug.h>
39 #include <linux/sched/signal.h>
40 #include <linux/module.h>
41 #include <linux/splice.h>
42 #include <crypto/aead.h>
43
44 #include <net/strparser.h>
45 #include <net/tls.h>
46
47 #include "tls.h"
48
49 struct tls_decrypt_arg {
50         struct_group(inargs,
51         bool zc;
52         bool async;
53         u8 tail;
54         );
55
56         struct sk_buff *skb;
57 };
58
59 struct tls_decrypt_ctx {
60         u8 iv[MAX_IV_SIZE];
61         u8 aad[TLS_MAX_AAD_SIZE];
62         u8 tail;
63         struct scatterlist sg[];
64 };
65
66 noinline void tls_err_abort(struct sock *sk, int err)
67 {
68         WARN_ON_ONCE(err >= 0);
69         /* sk->sk_err should contain a positive error code. */
70         sk->sk_err = -err;
71         sk_error_report(sk);
72 }
73
74 static int __skb_nsg(struct sk_buff *skb, int offset, int len,
75                      unsigned int recursion_level)
76 {
77         int start = skb_headlen(skb);
78         int i, chunk = start - offset;
79         struct sk_buff *frag_iter;
80         int elt = 0;
81
82         if (unlikely(recursion_level >= 24))
83                 return -EMSGSIZE;
84
85         if (chunk > 0) {
86                 if (chunk > len)
87                         chunk = len;
88                 elt++;
89                 len -= chunk;
90                 if (len == 0)
91                         return elt;
92                 offset += chunk;
93         }
94
95         for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
96                 int end;
97
98                 WARN_ON(start > offset + len);
99
100                 end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
101                 chunk = end - offset;
102                 if (chunk > 0) {
103                         if (chunk > len)
104                                 chunk = len;
105                         elt++;
106                         len -= chunk;
107                         if (len == 0)
108                                 return elt;
109                         offset += chunk;
110                 }
111                 start = end;
112         }
113
114         if (unlikely(skb_has_frag_list(skb))) {
115                 skb_walk_frags(skb, frag_iter) {
116                         int end, ret;
117
118                         WARN_ON(start > offset + len);
119
120                         end = start + frag_iter->len;
121                         chunk = end - offset;
122                         if (chunk > 0) {
123                                 if (chunk > len)
124                                         chunk = len;
125                                 ret = __skb_nsg(frag_iter, offset - start, chunk,
126                                                 recursion_level + 1);
127                                 if (unlikely(ret < 0))
128                                         return ret;
129                                 elt += ret;
130                                 len -= chunk;
131                                 if (len == 0)
132                                         return elt;
133                                 offset += chunk;
134                         }
135                         start = end;
136                 }
137         }
138         BUG_ON(len);
139         return elt;
140 }
141
142 /* Return the number of scatterlist elements required to completely map the
143  * skb, or -EMSGSIZE if the recursion depth is exceeded.
144  */
145 static int skb_nsg(struct sk_buff *skb, int offset, int len)
146 {
147         return __skb_nsg(skb, offset, len, 0);
148 }
149
150 static int tls_padding_length(struct tls_prot_info *prot, struct sk_buff *skb,
151                               struct tls_decrypt_arg *darg)
152 {
153         struct strp_msg *rxm = strp_msg(skb);
154         struct tls_msg *tlm = tls_msg(skb);
155         int sub = 0;
156
157         /* Determine zero-padding length */
158         if (prot->version == TLS_1_3_VERSION) {
159                 int offset = rxm->full_len - TLS_TAG_SIZE - 1;
160                 char content_type = darg->zc ? darg->tail : 0;
161                 int err;
162
163                 while (content_type == 0) {
164                         if (offset < prot->prepend_size)
165                                 return -EBADMSG;
166                         err = skb_copy_bits(skb, rxm->offset + offset,
167                                             &content_type, 1);
168                         if (err)
169                                 return err;
170                         if (content_type)
171                                 break;
172                         sub++;
173                         offset--;
174                 }
175                 tlm->control = content_type;
176         }
177         return sub;
178 }
179
180 static void tls_decrypt_done(struct crypto_async_request *req, int err)
181 {
182         struct aead_request *aead_req = (struct aead_request *)req;
183         struct scatterlist *sgout = aead_req->dst;
184         struct scatterlist *sgin = aead_req->src;
185         struct tls_sw_context_rx *ctx;
186         struct tls_context *tls_ctx;
187         struct tls_prot_info *prot;
188         struct scatterlist *sg;
189         struct sk_buff *skb;
190         unsigned int pages;
191
192         skb = (struct sk_buff *)req->data;
193         tls_ctx = tls_get_ctx(skb->sk);
194         ctx = tls_sw_ctx_rx(tls_ctx);
195         prot = &tls_ctx->prot_info;
196
197         /* Propagate if there was an err */
198         if (err) {
199                 if (err == -EBADMSG)
200                         TLS_INC_STATS(sock_net(skb->sk),
201                                       LINUX_MIB_TLSDECRYPTERROR);
202                 ctx->async_wait.err = err;
203                 tls_err_abort(skb->sk, err);
204         } else {
205                 struct strp_msg *rxm = strp_msg(skb);
206
207                 /* No TLS 1.3 support with async crypto */
208                 WARN_ON(prot->tail_size);
209
210                 rxm->offset += prot->prepend_size;
211                 rxm->full_len -= prot->overhead_size;
212         }
213
214         /* After using skb->sk to propagate sk through crypto async callback
215          * we need to NULL it again.
216          */
217         skb->sk = NULL;
218
219
220         /* Free the destination pages if skb was not decrypted inplace */
221         if (sgout != sgin) {
222                 /* Skip the first S/G entry as it points to AAD */
223                 for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
224                         if (!sg)
225                                 break;
226                         put_page(sg_page(sg));
227                 }
228         }
229
230         kfree(aead_req);
231
232         spin_lock_bh(&ctx->decrypt_compl_lock);
233         if (!atomic_dec_return(&ctx->decrypt_pending))
234                 complete(&ctx->async_wait.completion);
235         spin_unlock_bh(&ctx->decrypt_compl_lock);
236 }
237
238 static int tls_do_decryption(struct sock *sk,
239                              struct sk_buff *skb,
240                              struct scatterlist *sgin,
241                              struct scatterlist *sgout,
242                              char *iv_recv,
243                              size_t data_len,
244                              struct aead_request *aead_req,
245                              struct tls_decrypt_arg *darg)
246 {
247         struct tls_context *tls_ctx = tls_get_ctx(sk);
248         struct tls_prot_info *prot = &tls_ctx->prot_info;
249         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
250         int ret;
251
252         aead_request_set_tfm(aead_req, ctx->aead_recv);
253         aead_request_set_ad(aead_req, prot->aad_size);
254         aead_request_set_crypt(aead_req, sgin, sgout,
255                                data_len + prot->tag_size,
256                                (u8 *)iv_recv);
257
258         if (darg->async) {
259                 /* Using skb->sk to push sk through to crypto async callback
260                  * handler. This allows propagating errors up to the socket
261                  * if needed. It _must_ be cleared in the async handler
262                  * before consume_skb is called. We _know_ skb->sk is NULL
263                  * because it is a clone from strparser.
264                  */
265                 skb->sk = sk;
266                 aead_request_set_callback(aead_req,
267                                           CRYPTO_TFM_REQ_MAY_BACKLOG,
268                                           tls_decrypt_done, skb);
269                 atomic_inc(&ctx->decrypt_pending);
270         } else {
271                 aead_request_set_callback(aead_req,
272                                           CRYPTO_TFM_REQ_MAY_BACKLOG,
273                                           crypto_req_done, &ctx->async_wait);
274         }
275
276         ret = crypto_aead_decrypt(aead_req);
277         if (ret == -EINPROGRESS) {
278                 if (darg->async)
279                         return 0;
280
281                 ret = crypto_wait_req(ret, &ctx->async_wait);
282         }
283         darg->async = false;
284
285         return ret;
286 }
287
288 static void tls_trim_both_msgs(struct sock *sk, int target_size)
289 {
290         struct tls_context *tls_ctx = tls_get_ctx(sk);
291         struct tls_prot_info *prot = &tls_ctx->prot_info;
292         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
293         struct tls_rec *rec = ctx->open_rec;
294
295         sk_msg_trim(sk, &rec->msg_plaintext, target_size);
296         if (target_size > 0)
297                 target_size += prot->overhead_size;
298         sk_msg_trim(sk, &rec->msg_encrypted, target_size);
299 }
300
301 static int tls_alloc_encrypted_msg(struct sock *sk, int len)
302 {
303         struct tls_context *tls_ctx = tls_get_ctx(sk);
304         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
305         struct tls_rec *rec = ctx->open_rec;
306         struct sk_msg *msg_en = &rec->msg_encrypted;
307
308         return sk_msg_alloc(sk, msg_en, len, 0);
309 }
310
311 static int tls_clone_plaintext_msg(struct sock *sk, int required)
312 {
313         struct tls_context *tls_ctx = tls_get_ctx(sk);
314         struct tls_prot_info *prot = &tls_ctx->prot_info;
315         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
316         struct tls_rec *rec = ctx->open_rec;
317         struct sk_msg *msg_pl = &rec->msg_plaintext;
318         struct sk_msg *msg_en = &rec->msg_encrypted;
319         int skip, len;
320
321         /* We add page references worth len bytes from encrypted sg
322          * at the end of plaintext sg. It is guaranteed that msg_en
323          * has enough required room (ensured by caller).
324          */
325         len = required - msg_pl->sg.size;
326
327         /* Skip initial bytes in msg_en's data to be able to use
328          * same offset of both plain and encrypted data.
329          */
330         skip = prot->prepend_size + msg_pl->sg.size;
331
332         return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
333 }
334
335 static struct tls_rec *tls_get_rec(struct sock *sk)
336 {
337         struct tls_context *tls_ctx = tls_get_ctx(sk);
338         struct tls_prot_info *prot = &tls_ctx->prot_info;
339         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
340         struct sk_msg *msg_pl, *msg_en;
341         struct tls_rec *rec;
342         int mem_size;
343
344         mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
345
346         rec = kzalloc(mem_size, sk->sk_allocation);
347         if (!rec)
348                 return NULL;
349
350         msg_pl = &rec->msg_plaintext;
351         msg_en = &rec->msg_encrypted;
352
353         sk_msg_init(msg_pl);
354         sk_msg_init(msg_en);
355
356         sg_init_table(rec->sg_aead_in, 2);
357         sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size);
358         sg_unmark_end(&rec->sg_aead_in[1]);
359
360         sg_init_table(rec->sg_aead_out, 2);
361         sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size);
362         sg_unmark_end(&rec->sg_aead_out[1]);
363
364         return rec;
365 }
366
367 static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
368 {
369         sk_msg_free(sk, &rec->msg_encrypted);
370         sk_msg_free(sk, &rec->msg_plaintext);
371         kfree(rec);
372 }
373
374 static void tls_free_open_rec(struct sock *sk)
375 {
376         struct tls_context *tls_ctx = tls_get_ctx(sk);
377         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
378         struct tls_rec *rec = ctx->open_rec;
379
380         if (rec) {
381                 tls_free_rec(sk, rec);
382                 ctx->open_rec = NULL;
383         }
384 }
385
386 int tls_tx_records(struct sock *sk, int flags)
387 {
388         struct tls_context *tls_ctx = tls_get_ctx(sk);
389         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
390         struct tls_rec *rec, *tmp;
391         struct sk_msg *msg_en;
392         int tx_flags, rc = 0;
393
394         if (tls_is_partially_sent_record(tls_ctx)) {
395                 rec = list_first_entry(&ctx->tx_list,
396                                        struct tls_rec, list);
397
398                 if (flags == -1)
399                         tx_flags = rec->tx_flags;
400                 else
401                         tx_flags = flags;
402
403                 rc = tls_push_partial_record(sk, tls_ctx, tx_flags);
404                 if (rc)
405                         goto tx_err;
406
407                 /* Full record has been transmitted.
408                  * Remove the head of tx_list
409                  */
410                 list_del(&rec->list);
411                 sk_msg_free(sk, &rec->msg_plaintext);
412                 kfree(rec);
413         }
414
415         /* Tx all ready records */
416         list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
417                 if (READ_ONCE(rec->tx_ready)) {
418                         if (flags == -1)
419                                 tx_flags = rec->tx_flags;
420                         else
421                                 tx_flags = flags;
422
423                         msg_en = &rec->msg_encrypted;
424                         rc = tls_push_sg(sk, tls_ctx,
425                                          &msg_en->sg.data[msg_en->sg.curr],
426                                          0, tx_flags);
427                         if (rc)
428                                 goto tx_err;
429
430                         list_del(&rec->list);
431                         sk_msg_free(sk, &rec->msg_plaintext);
432                         kfree(rec);
433                 } else {
434                         break;
435                 }
436         }
437
438 tx_err:
439         if (rc < 0 && rc != -EAGAIN)
440                 tls_err_abort(sk, -EBADMSG);
441
442         return rc;
443 }
444
445 static void tls_encrypt_done(struct crypto_async_request *req, int err)
446 {
447         struct aead_request *aead_req = (struct aead_request *)req;
448         struct sock *sk = req->data;
449         struct tls_context *tls_ctx = tls_get_ctx(sk);
450         struct tls_prot_info *prot = &tls_ctx->prot_info;
451         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
452         struct scatterlist *sge;
453         struct sk_msg *msg_en;
454         struct tls_rec *rec;
455         bool ready = false;
456         int pending;
457
458         rec = container_of(aead_req, struct tls_rec, aead_req);
459         msg_en = &rec->msg_encrypted;
460
461         sge = sk_msg_elem(msg_en, msg_en->sg.curr);
462         sge->offset -= prot->prepend_size;
463         sge->length += prot->prepend_size;
464
465         /* Check if error is previously set on socket */
466         if (err || sk->sk_err) {
467                 rec = NULL;
468
469                 /* If err is already set on socket, return the same code */
470                 if (sk->sk_err) {
471                         ctx->async_wait.err = -sk->sk_err;
472                 } else {
473                         ctx->async_wait.err = err;
474                         tls_err_abort(sk, err);
475                 }
476         }
477
478         if (rec) {
479                 struct tls_rec *first_rec;
480
481                 /* Mark the record as ready for transmission */
482                 smp_store_mb(rec->tx_ready, true);
483
484                 /* If received record is at head of tx_list, schedule tx */
485                 first_rec = list_first_entry(&ctx->tx_list,
486                                              struct tls_rec, list);
487                 if (rec == first_rec)
488                         ready = true;
489         }
490
491         spin_lock_bh(&ctx->encrypt_compl_lock);
492         pending = atomic_dec_return(&ctx->encrypt_pending);
493
494         if (!pending && ctx->async_notify)
495                 complete(&ctx->async_wait.completion);
496         spin_unlock_bh(&ctx->encrypt_compl_lock);
497
498         if (!ready)
499                 return;
500
501         /* Schedule the transmission */
502         if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
503                 schedule_delayed_work(&ctx->tx_work.work, 1);
504 }
505
506 static int tls_do_encryption(struct sock *sk,
507                              struct tls_context *tls_ctx,
508                              struct tls_sw_context_tx *ctx,
509                              struct aead_request *aead_req,
510                              size_t data_len, u32 start)
511 {
512         struct tls_prot_info *prot = &tls_ctx->prot_info;
513         struct tls_rec *rec = ctx->open_rec;
514         struct sk_msg *msg_en = &rec->msg_encrypted;
515         struct scatterlist *sge = sk_msg_elem(msg_en, start);
516         int rc, iv_offset = 0;
517
518         /* For CCM based ciphers, first byte of IV is a constant */
519         switch (prot->cipher_type) {
520         case TLS_CIPHER_AES_CCM_128:
521                 rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE;
522                 iv_offset = 1;
523                 break;
524         case TLS_CIPHER_SM4_CCM:
525                 rec->iv_data[0] = TLS_SM4_CCM_IV_B0_BYTE;
526                 iv_offset = 1;
527                 break;
528         }
529
530         memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
531                prot->iv_size + prot->salt_size);
532
533         tls_xor_iv_with_seq(prot, rec->iv_data + iv_offset,
534                             tls_ctx->tx.rec_seq);
535
536         sge->offset += prot->prepend_size;
537         sge->length -= prot->prepend_size;
538
539         msg_en->sg.curr = start;
540
541         aead_request_set_tfm(aead_req, ctx->aead_send);
542         aead_request_set_ad(aead_req, prot->aad_size);
543         aead_request_set_crypt(aead_req, rec->sg_aead_in,
544                                rec->sg_aead_out,
545                                data_len, rec->iv_data);
546
547         aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
548                                   tls_encrypt_done, sk);
549
550         /* Add the record in tx_list */
551         list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
552         atomic_inc(&ctx->encrypt_pending);
553
554         rc = crypto_aead_encrypt(aead_req);
555         if (!rc || rc != -EINPROGRESS) {
556                 atomic_dec(&ctx->encrypt_pending);
557                 sge->offset -= prot->prepend_size;
558                 sge->length += prot->prepend_size;
559         }
560
561         if (!rc) {
562                 WRITE_ONCE(rec->tx_ready, true);
563         } else if (rc != -EINPROGRESS) {
564                 list_del(&rec->list);
565                 return rc;
566         }
567
568         /* Unhook the record from context if encryption is not failure */
569         ctx->open_rec = NULL;
570         tls_advance_record_sn(sk, prot, &tls_ctx->tx);
571         return rc;
572 }
573
574 static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
575                                  struct tls_rec **to, struct sk_msg *msg_opl,
576                                  struct sk_msg *msg_oen, u32 split_point,
577                                  u32 tx_overhead_size, u32 *orig_end)
578 {
579         u32 i, j, bytes = 0, apply = msg_opl->apply_bytes;
580         struct scatterlist *sge, *osge, *nsge;
581         u32 orig_size = msg_opl->sg.size;
582         struct scatterlist tmp = { };
583         struct sk_msg *msg_npl;
584         struct tls_rec *new;
585         int ret;
586
587         new = tls_get_rec(sk);
588         if (!new)
589                 return -ENOMEM;
590         ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
591                            tx_overhead_size, 0);
592         if (ret < 0) {
593                 tls_free_rec(sk, new);
594                 return ret;
595         }
596
597         *orig_end = msg_opl->sg.end;
598         i = msg_opl->sg.start;
599         sge = sk_msg_elem(msg_opl, i);
600         while (apply && sge->length) {
601                 if (sge->length > apply) {
602                         u32 len = sge->length - apply;
603
604                         get_page(sg_page(sge));
605                         sg_set_page(&tmp, sg_page(sge), len,
606                                     sge->offset + apply);
607                         sge->length = apply;
608                         bytes += apply;
609                         apply = 0;
610                 } else {
611                         apply -= sge->length;
612                         bytes += sge->length;
613                 }
614
615                 sk_msg_iter_var_next(i);
616                 if (i == msg_opl->sg.end)
617                         break;
618                 sge = sk_msg_elem(msg_opl, i);
619         }
620
621         msg_opl->sg.end = i;
622         msg_opl->sg.curr = i;
623         msg_opl->sg.copybreak = 0;
624         msg_opl->apply_bytes = 0;
625         msg_opl->sg.size = bytes;
626
627         msg_npl = &new->msg_plaintext;
628         msg_npl->apply_bytes = apply;
629         msg_npl->sg.size = orig_size - bytes;
630
631         j = msg_npl->sg.start;
632         nsge = sk_msg_elem(msg_npl, j);
633         if (tmp.length) {
634                 memcpy(nsge, &tmp, sizeof(*nsge));
635                 sk_msg_iter_var_next(j);
636                 nsge = sk_msg_elem(msg_npl, j);
637         }
638
639         osge = sk_msg_elem(msg_opl, i);
640         while (osge->length) {
641                 memcpy(nsge, osge, sizeof(*nsge));
642                 sg_unmark_end(nsge);
643                 sk_msg_iter_var_next(i);
644                 sk_msg_iter_var_next(j);
645                 if (i == *orig_end)
646                         break;
647                 osge = sk_msg_elem(msg_opl, i);
648                 nsge = sk_msg_elem(msg_npl, j);
649         }
650
651         msg_npl->sg.end = j;
652         msg_npl->sg.curr = j;
653         msg_npl->sg.copybreak = 0;
654
655         *to = new;
656         return 0;
657 }
658
659 static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
660                                   struct tls_rec *from, u32 orig_end)
661 {
662         struct sk_msg *msg_npl = &from->msg_plaintext;
663         struct sk_msg *msg_opl = &to->msg_plaintext;
664         struct scatterlist *osge, *nsge;
665         u32 i, j;
666
667         i = msg_opl->sg.end;
668         sk_msg_iter_var_prev(i);
669         j = msg_npl->sg.start;
670
671         osge = sk_msg_elem(msg_opl, i);
672         nsge = sk_msg_elem(msg_npl, j);
673
674         if (sg_page(osge) == sg_page(nsge) &&
675             osge->offset + osge->length == nsge->offset) {
676                 osge->length += nsge->length;
677                 put_page(sg_page(nsge));
678         }
679
680         msg_opl->sg.end = orig_end;
681         msg_opl->sg.curr = orig_end;
682         msg_opl->sg.copybreak = 0;
683         msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size;
684         msg_opl->sg.size += msg_npl->sg.size;
685
686         sk_msg_free(sk, &to->msg_encrypted);
687         sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted);
688
689         kfree(from);
690 }
691
692 static int tls_push_record(struct sock *sk, int flags,
693                            unsigned char record_type)
694 {
695         struct tls_context *tls_ctx = tls_get_ctx(sk);
696         struct tls_prot_info *prot = &tls_ctx->prot_info;
697         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
698         struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
699         u32 i, split_point, orig_end;
700         struct sk_msg *msg_pl, *msg_en;
701         struct aead_request *req;
702         bool split;
703         int rc;
704
705         if (!rec)
706                 return 0;
707
708         msg_pl = &rec->msg_plaintext;
709         msg_en = &rec->msg_encrypted;
710
711         split_point = msg_pl->apply_bytes;
712         split = split_point && split_point < msg_pl->sg.size;
713         if (unlikely((!split &&
714                       msg_pl->sg.size +
715                       prot->overhead_size > msg_en->sg.size) ||
716                      (split &&
717                       split_point +
718                       prot->overhead_size > msg_en->sg.size))) {
719                 split = true;
720                 split_point = msg_en->sg.size;
721         }
722         if (split) {
723                 rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
724                                            split_point, prot->overhead_size,
725                                            &orig_end);
726                 if (rc < 0)
727                         return rc;
728                 /* This can happen if above tls_split_open_record allocates
729                  * a single large encryption buffer instead of two smaller
730                  * ones. In this case adjust pointers and continue without
731                  * split.
732                  */
733                 if (!msg_pl->sg.size) {
734                         tls_merge_open_record(sk, rec, tmp, orig_end);
735                         msg_pl = &rec->msg_plaintext;
736                         msg_en = &rec->msg_encrypted;
737                         split = false;
738                 }
739                 sk_msg_trim(sk, msg_en, msg_pl->sg.size +
740                             prot->overhead_size);
741         }
742
743         rec->tx_flags = flags;
744         req = &rec->aead_req;
745
746         i = msg_pl->sg.end;
747         sk_msg_iter_var_prev(i);
748
749         rec->content_type = record_type;
750         if (prot->version == TLS_1_3_VERSION) {
751                 /* Add content type to end of message.  No padding added */
752                 sg_set_buf(&rec->sg_content_type, &rec->content_type, 1);
753                 sg_mark_end(&rec->sg_content_type);
754                 sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1,
755                          &rec->sg_content_type);
756         } else {
757                 sg_mark_end(sk_msg_elem(msg_pl, i));
758         }
759
760         if (msg_pl->sg.end < msg_pl->sg.start) {
761                 sg_chain(&msg_pl->sg.data[msg_pl->sg.start],
762                          MAX_SKB_FRAGS - msg_pl->sg.start + 1,
763                          msg_pl->sg.data);
764         }
765
766         i = msg_pl->sg.start;
767         sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]);
768
769         i = msg_en->sg.end;
770         sk_msg_iter_var_prev(i);
771         sg_mark_end(sk_msg_elem(msg_en, i));
772
773         i = msg_en->sg.start;
774         sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]);
775
776         tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size,
777                      tls_ctx->tx.rec_seq, record_type, prot);
778
779         tls_fill_prepend(tls_ctx,
780                          page_address(sg_page(&msg_en->sg.data[i])) +
781                          msg_en->sg.data[i].offset,
782                          msg_pl->sg.size + prot->tail_size,
783                          record_type);
784
785         tls_ctx->pending_open_record_frags = false;
786
787         rc = tls_do_encryption(sk, tls_ctx, ctx, req,
788                                msg_pl->sg.size + prot->tail_size, i);
789         if (rc < 0) {
790                 if (rc != -EINPROGRESS) {
791                         tls_err_abort(sk, -EBADMSG);
792                         if (split) {
793                                 tls_ctx->pending_open_record_frags = true;
794                                 tls_merge_open_record(sk, rec, tmp, orig_end);
795                         }
796                 }
797                 ctx->async_capable = 1;
798                 return rc;
799         } else if (split) {
800                 msg_pl = &tmp->msg_plaintext;
801                 msg_en = &tmp->msg_encrypted;
802                 sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size);
803                 tls_ctx->pending_open_record_frags = true;
804                 ctx->open_rec = tmp;
805         }
806
807         return tls_tx_records(sk, flags);
808 }
809
810 static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
811                                bool full_record, u8 record_type,
812                                ssize_t *copied, int flags)
813 {
814         struct tls_context *tls_ctx = tls_get_ctx(sk);
815         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
816         struct sk_msg msg_redir = { };
817         struct sk_psock *psock;
818         struct sock *sk_redir;
819         struct tls_rec *rec;
820         bool enospc, policy;
821         int err = 0, send;
822         u32 delta = 0;
823
824         policy = !(flags & MSG_SENDPAGE_NOPOLICY);
825         psock = sk_psock_get(sk);
826         if (!psock || !policy) {
827                 err = tls_push_record(sk, flags, record_type);
828                 if (err && sk->sk_err == EBADMSG) {
829                         *copied -= sk_msg_free(sk, msg);
830                         tls_free_open_rec(sk);
831                         err = -sk->sk_err;
832                 }
833                 if (psock)
834                         sk_psock_put(sk, psock);
835                 return err;
836         }
837 more_data:
838         enospc = sk_msg_full(msg);
839         if (psock->eval == __SK_NONE) {
840                 delta = msg->sg.size;
841                 psock->eval = sk_psock_msg_verdict(sk, psock, msg);
842                 delta -= msg->sg.size;
843         }
844         if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
845             !enospc && !full_record) {
846                 err = -ENOSPC;
847                 goto out_err;
848         }
849         msg->cork_bytes = 0;
850         send = msg->sg.size;
851         if (msg->apply_bytes && msg->apply_bytes < send)
852                 send = msg->apply_bytes;
853
854         switch (psock->eval) {
855         case __SK_PASS:
856                 err = tls_push_record(sk, flags, record_type);
857                 if (err && sk->sk_err == EBADMSG) {
858                         *copied -= sk_msg_free(sk, msg);
859                         tls_free_open_rec(sk);
860                         err = -sk->sk_err;
861                         goto out_err;
862                 }
863                 break;
864         case __SK_REDIRECT:
865                 sk_redir = psock->sk_redir;
866                 memcpy(&msg_redir, msg, sizeof(*msg));
867                 if (msg->apply_bytes < send)
868                         msg->apply_bytes = 0;
869                 else
870                         msg->apply_bytes -= send;
871                 sk_msg_return_zero(sk, msg, send);
872                 msg->sg.size -= send;
873                 release_sock(sk);
874                 err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags);
875                 lock_sock(sk);
876                 if (err < 0) {
877                         *copied -= sk_msg_free_nocharge(sk, &msg_redir);
878                         msg->sg.size = 0;
879                 }
880                 if (msg->sg.size == 0)
881                         tls_free_open_rec(sk);
882                 break;
883         case __SK_DROP:
884         default:
885                 sk_msg_free_partial(sk, msg, send);
886                 if (msg->apply_bytes < send)
887                         msg->apply_bytes = 0;
888                 else
889                         msg->apply_bytes -= send;
890                 if (msg->sg.size == 0)
891                         tls_free_open_rec(sk);
892                 *copied -= (send + delta);
893                 err = -EACCES;
894         }
895
896         if (likely(!err)) {
897                 bool reset_eval = !ctx->open_rec;
898
899                 rec = ctx->open_rec;
900                 if (rec) {
901                         msg = &rec->msg_plaintext;
902                         if (!msg->apply_bytes)
903                                 reset_eval = true;
904                 }
905                 if (reset_eval) {
906                         psock->eval = __SK_NONE;
907                         if (psock->sk_redir) {
908                                 sock_put(psock->sk_redir);
909                                 psock->sk_redir = NULL;
910                         }
911                 }
912                 if (rec)
913                         goto more_data;
914         }
915  out_err:
916         sk_psock_put(sk, psock);
917         return err;
918 }
919
920 static int tls_sw_push_pending_record(struct sock *sk, int flags)
921 {
922         struct tls_context *tls_ctx = tls_get_ctx(sk);
923         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
924         struct tls_rec *rec = ctx->open_rec;
925         struct sk_msg *msg_pl;
926         size_t copied;
927
928         if (!rec)
929                 return 0;
930
931         msg_pl = &rec->msg_plaintext;
932         copied = msg_pl->sg.size;
933         if (!copied)
934                 return 0;
935
936         return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
937                                    &copied, flags);
938 }
939
940 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
941 {
942         long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
943         struct tls_context *tls_ctx = tls_get_ctx(sk);
944         struct tls_prot_info *prot = &tls_ctx->prot_info;
945         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
946         bool async_capable = ctx->async_capable;
947         unsigned char record_type = TLS_RECORD_TYPE_DATA;
948         bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
949         bool eor = !(msg->msg_flags & MSG_MORE);
950         size_t try_to_copy;
951         ssize_t copied = 0;
952         struct sk_msg *msg_pl, *msg_en;
953         struct tls_rec *rec;
954         int required_size;
955         int num_async = 0;
956         bool full_record;
957         int record_room;
958         int num_zc = 0;
959         int orig_size;
960         int ret = 0;
961         int pending;
962
963         if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
964                                MSG_CMSG_COMPAT))
965                 return -EOPNOTSUPP;
966
967         mutex_lock(&tls_ctx->tx_lock);
968         lock_sock(sk);
969
970         if (unlikely(msg->msg_controllen)) {
971                 ret = tls_process_cmsg(sk, msg, &record_type);
972                 if (ret) {
973                         if (ret == -EINPROGRESS)
974                                 num_async++;
975                         else if (ret != -EAGAIN)
976                                 goto send_end;
977                 }
978         }
979
980         while (msg_data_left(msg)) {
981                 if (sk->sk_err) {
982                         ret = -sk->sk_err;
983                         goto send_end;
984                 }
985
986                 if (ctx->open_rec)
987                         rec = ctx->open_rec;
988                 else
989                         rec = ctx->open_rec = tls_get_rec(sk);
990                 if (!rec) {
991                         ret = -ENOMEM;
992                         goto send_end;
993                 }
994
995                 msg_pl = &rec->msg_plaintext;
996                 msg_en = &rec->msg_encrypted;
997
998                 orig_size = msg_pl->sg.size;
999                 full_record = false;
1000                 try_to_copy = msg_data_left(msg);
1001                 record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
1002                 if (try_to_copy >= record_room) {
1003                         try_to_copy = record_room;
1004                         full_record = true;
1005                 }
1006
1007                 required_size = msg_pl->sg.size + try_to_copy +
1008                                 prot->overhead_size;
1009
1010                 if (!sk_stream_memory_free(sk))
1011                         goto wait_for_sndbuf;
1012
1013 alloc_encrypted:
1014                 ret = tls_alloc_encrypted_msg(sk, required_size);
1015                 if (ret) {
1016                         if (ret != -ENOSPC)
1017                                 goto wait_for_memory;
1018
1019                         /* Adjust try_to_copy according to the amount that was
1020                          * actually allocated. The difference is due
1021                          * to max sg elements limit
1022                          */
1023                         try_to_copy -= required_size - msg_en->sg.size;
1024                         full_record = true;
1025                 }
1026
1027                 if (!is_kvec && (full_record || eor) && !async_capable) {
1028                         u32 first = msg_pl->sg.end;
1029
1030                         ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter,
1031                                                         msg_pl, try_to_copy);
1032                         if (ret)
1033                                 goto fallback_to_reg_send;
1034
1035                         num_zc++;
1036                         copied += try_to_copy;
1037
1038                         sk_msg_sg_copy_set(msg_pl, first);
1039                         ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1040                                                   record_type, &copied,
1041                                                   msg->msg_flags);
1042                         if (ret) {
1043                                 if (ret == -EINPROGRESS)
1044                                         num_async++;
1045                                 else if (ret == -ENOMEM)
1046                                         goto wait_for_memory;
1047                                 else if (ctx->open_rec && ret == -ENOSPC)
1048                                         goto rollback_iter;
1049                                 else if (ret != -EAGAIN)
1050                                         goto send_end;
1051                         }
1052                         continue;
1053 rollback_iter:
1054                         copied -= try_to_copy;
1055                         sk_msg_sg_copy_clear(msg_pl, first);
1056                         iov_iter_revert(&msg->msg_iter,
1057                                         msg_pl->sg.size - orig_size);
1058 fallback_to_reg_send:
1059                         sk_msg_trim(sk, msg_pl, orig_size);
1060                 }
1061
1062                 required_size = msg_pl->sg.size + try_to_copy;
1063
1064                 ret = tls_clone_plaintext_msg(sk, required_size);
1065                 if (ret) {
1066                         if (ret != -ENOSPC)
1067                                 goto send_end;
1068
1069                         /* Adjust try_to_copy according to the amount that was
1070                          * actually allocated. The difference is due
1071                          * to max sg elements limit
1072                          */
1073                         try_to_copy -= required_size - msg_pl->sg.size;
1074                         full_record = true;
1075                         sk_msg_trim(sk, msg_en,
1076                                     msg_pl->sg.size + prot->overhead_size);
1077                 }
1078
1079                 if (try_to_copy) {
1080                         ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter,
1081                                                        msg_pl, try_to_copy);
1082                         if (ret < 0)
1083                                 goto trim_sgl;
1084                 }
1085
1086                 /* Open records defined only if successfully copied, otherwise
1087                  * we would trim the sg but not reset the open record frags.
1088                  */
1089                 tls_ctx->pending_open_record_frags = true;
1090                 copied += try_to_copy;
1091                 if (full_record || eor) {
1092                         ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1093                                                   record_type, &copied,
1094                                                   msg->msg_flags);
1095                         if (ret) {
1096                                 if (ret == -EINPROGRESS)
1097                                         num_async++;
1098                                 else if (ret == -ENOMEM)
1099                                         goto wait_for_memory;
1100                                 else if (ret != -EAGAIN) {
1101                                         if (ret == -ENOSPC)
1102                                                 ret = 0;
1103                                         goto send_end;
1104                                 }
1105                         }
1106                 }
1107
1108                 continue;
1109
1110 wait_for_sndbuf:
1111                 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1112 wait_for_memory:
1113                 ret = sk_stream_wait_memory(sk, &timeo);
1114                 if (ret) {
1115 trim_sgl:
1116                         if (ctx->open_rec)
1117                                 tls_trim_both_msgs(sk, orig_size);
1118                         goto send_end;
1119                 }
1120
1121                 if (ctx->open_rec && msg_en->sg.size < required_size)
1122                         goto alloc_encrypted;
1123         }
1124
1125         if (!num_async) {
1126                 goto send_end;
1127         } else if (num_zc) {
1128                 /* Wait for pending encryptions to get completed */
1129                 spin_lock_bh(&ctx->encrypt_compl_lock);
1130                 ctx->async_notify = true;
1131
1132                 pending = atomic_read(&ctx->encrypt_pending);
1133                 spin_unlock_bh(&ctx->encrypt_compl_lock);
1134                 if (pending)
1135                         crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1136                 else
1137                         reinit_completion(&ctx->async_wait.completion);
1138
1139                 /* There can be no concurrent accesses, since we have no
1140                  * pending encrypt operations
1141                  */
1142                 WRITE_ONCE(ctx->async_notify, false);
1143
1144                 if (ctx->async_wait.err) {
1145                         ret = ctx->async_wait.err;
1146                         copied = 0;
1147                 }
1148         }
1149
1150         /* Transmit if any encryptions have completed */
1151         if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1152                 cancel_delayed_work(&ctx->tx_work.work);
1153                 tls_tx_records(sk, msg->msg_flags);
1154         }
1155
1156 send_end:
1157         ret = sk_stream_error(sk, msg->msg_flags, ret);
1158
1159         release_sock(sk);
1160         mutex_unlock(&tls_ctx->tx_lock);
1161         return copied > 0 ? copied : ret;
1162 }
1163
1164 static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
1165                               int offset, size_t size, int flags)
1166 {
1167         long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
1168         struct tls_context *tls_ctx = tls_get_ctx(sk);
1169         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1170         struct tls_prot_info *prot = &tls_ctx->prot_info;
1171         unsigned char record_type = TLS_RECORD_TYPE_DATA;
1172         struct sk_msg *msg_pl;
1173         struct tls_rec *rec;
1174         int num_async = 0;
1175         ssize_t copied = 0;
1176         bool full_record;
1177         int record_room;
1178         int ret = 0;
1179         bool eor;
1180
1181         eor = !(flags & MSG_SENDPAGE_NOTLAST);
1182         sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
1183
1184         /* Call the sk_stream functions to manage the sndbuf mem. */
1185         while (size > 0) {
1186                 size_t copy, required_size;
1187
1188                 if (sk->sk_err) {
1189                         ret = -sk->sk_err;
1190                         goto sendpage_end;
1191                 }
1192
1193                 if (ctx->open_rec)
1194                         rec = ctx->open_rec;
1195                 else
1196                         rec = ctx->open_rec = tls_get_rec(sk);
1197                 if (!rec) {
1198                         ret = -ENOMEM;
1199                         goto sendpage_end;
1200                 }
1201
1202                 msg_pl = &rec->msg_plaintext;
1203
1204                 full_record = false;
1205                 record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
1206                 copy = size;
1207                 if (copy >= record_room) {
1208                         copy = record_room;
1209                         full_record = true;
1210                 }
1211
1212                 required_size = msg_pl->sg.size + copy + prot->overhead_size;
1213
1214                 if (!sk_stream_memory_free(sk))
1215                         goto wait_for_sndbuf;
1216 alloc_payload:
1217                 ret = tls_alloc_encrypted_msg(sk, required_size);
1218                 if (ret) {
1219                         if (ret != -ENOSPC)
1220                                 goto wait_for_memory;
1221
1222                         /* Adjust copy according to the amount that was
1223                          * actually allocated. The difference is due
1224                          * to max sg elements limit
1225                          */
1226                         copy -= required_size - msg_pl->sg.size;
1227                         full_record = true;
1228                 }
1229
1230                 sk_msg_page_add(msg_pl, page, copy, offset);
1231                 sk_mem_charge(sk, copy);
1232
1233                 offset += copy;
1234                 size -= copy;
1235                 copied += copy;
1236
1237                 tls_ctx->pending_open_record_frags = true;
1238                 if (full_record || eor || sk_msg_full(msg_pl)) {
1239                         ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1240                                                   record_type, &copied, flags);
1241                         if (ret) {
1242                                 if (ret == -EINPROGRESS)
1243                                         num_async++;
1244                                 else if (ret == -ENOMEM)
1245                                         goto wait_for_memory;
1246                                 else if (ret != -EAGAIN) {
1247                                         if (ret == -ENOSPC)
1248                                                 ret = 0;
1249                                         goto sendpage_end;
1250                                 }
1251                         }
1252                 }
1253                 continue;
1254 wait_for_sndbuf:
1255                 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1256 wait_for_memory:
1257                 ret = sk_stream_wait_memory(sk, &timeo);
1258                 if (ret) {
1259                         if (ctx->open_rec)
1260                                 tls_trim_both_msgs(sk, msg_pl->sg.size);
1261                         goto sendpage_end;
1262                 }
1263
1264                 if (ctx->open_rec)
1265                         goto alloc_payload;
1266         }
1267
1268         if (num_async) {
1269                 /* Transmit if any encryptions have completed */
1270                 if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1271                         cancel_delayed_work(&ctx->tx_work.work);
1272                         tls_tx_records(sk, flags);
1273                 }
1274         }
1275 sendpage_end:
1276         ret = sk_stream_error(sk, flags, ret);
1277         return copied > 0 ? copied : ret;
1278 }
1279
1280 int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
1281                            int offset, size_t size, int flags)
1282 {
1283         if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
1284                       MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY |
1285                       MSG_NO_SHARED_FRAGS))
1286                 return -EOPNOTSUPP;
1287
1288         return tls_sw_do_sendpage(sk, page, offset, size, flags);
1289 }
1290
1291 int tls_sw_sendpage(struct sock *sk, struct page *page,
1292                     int offset, size_t size, int flags)
1293 {
1294         struct tls_context *tls_ctx = tls_get_ctx(sk);
1295         int ret;
1296
1297         if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
1298                       MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
1299                 return -EOPNOTSUPP;
1300
1301         mutex_lock(&tls_ctx->tx_lock);
1302         lock_sock(sk);
1303         ret = tls_sw_do_sendpage(sk, page, offset, size, flags);
1304         release_sock(sk);
1305         mutex_unlock(&tls_ctx->tx_lock);
1306         return ret;
1307 }
1308
1309 static int
1310 tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
1311                 long timeo)
1312 {
1313         struct tls_context *tls_ctx = tls_get_ctx(sk);
1314         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1315         DEFINE_WAIT_FUNC(wait, woken_wake_function);
1316
1317         while (!ctx->recv_pkt) {
1318                 if (!sk_psock_queue_empty(psock))
1319                         return 0;
1320
1321                 if (sk->sk_err)
1322                         return sock_error(sk);
1323
1324                 if (!skb_queue_empty(&sk->sk_receive_queue)) {
1325                         __strp_unpause(&ctx->strp);
1326                         if (ctx->recv_pkt)
1327                                 break;
1328                 }
1329
1330                 if (sk->sk_shutdown & RCV_SHUTDOWN)
1331                         return 0;
1332
1333                 if (sock_flag(sk, SOCK_DONE))
1334                         return 0;
1335
1336                 if (nonblock || !timeo)
1337                         return -EAGAIN;
1338
1339                 add_wait_queue(sk_sleep(sk), &wait);
1340                 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1341                 sk_wait_event(sk, &timeo,
1342                               ctx->recv_pkt || !sk_psock_queue_empty(psock),
1343                               &wait);
1344                 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1345                 remove_wait_queue(sk_sleep(sk), &wait);
1346
1347                 /* Handle signals */
1348                 if (signal_pending(current))
1349                         return sock_intr_errno(timeo);
1350         }
1351
1352         return 1;
1353 }
1354
1355 static int tls_setup_from_iter(struct iov_iter *from,
1356                                int length, int *pages_used,
1357                                struct scatterlist *to,
1358                                int to_max_pages)
1359 {
1360         int rc = 0, i = 0, num_elem = *pages_used, maxpages;
1361         struct page *pages[MAX_SKB_FRAGS];
1362         unsigned int size = 0;
1363         ssize_t copied, use;
1364         size_t offset;
1365
1366         while (length > 0) {
1367                 i = 0;
1368                 maxpages = to_max_pages - num_elem;
1369                 if (maxpages == 0) {
1370                         rc = -EFAULT;
1371                         goto out;
1372                 }
1373                 copied = iov_iter_get_pages(from, pages,
1374                                             length,
1375                                             maxpages, &offset);
1376                 if (copied <= 0) {
1377                         rc = -EFAULT;
1378                         goto out;
1379                 }
1380
1381                 iov_iter_advance(from, copied);
1382
1383                 length -= copied;
1384                 size += copied;
1385                 while (copied) {
1386                         use = min_t(int, copied, PAGE_SIZE - offset);
1387
1388                         sg_set_page(&to[num_elem],
1389                                     pages[i], use, offset);
1390                         sg_unmark_end(&to[num_elem]);
1391                         /* We do not uncharge memory from this API */
1392
1393                         offset = 0;
1394                         copied -= use;
1395
1396                         i++;
1397                         num_elem++;
1398                 }
1399         }
1400         /* Mark the end in the last sg entry if newly added */
1401         if (num_elem > *pages_used)
1402                 sg_mark_end(&to[num_elem - 1]);
1403 out:
1404         if (rc)
1405                 iov_iter_revert(from, size);
1406         *pages_used = num_elem;
1407
1408         return rc;
1409 }
1410
1411 /* Decrypt handlers
1412  *
1413  * tls_decrypt_sg() and tls_decrypt_device() are decrypt handlers.
1414  * They must transform the darg in/out argument are as follows:
1415  *       |          Input            |         Output
1416  * -------------------------------------------------------------------
1417  *    zc | Zero-copy decrypt allowed | Zero-copy performed
1418  * async | Async decrypt allowed     | Async crypto used / in progress
1419  *   skb |            *              | Output skb
1420  */
1421
1422 /* This function decrypts the input skb into either out_iov or in out_sg
1423  * or in skb buffers itself. The input parameter 'darg->zc' indicates if
1424  * zero-copy mode needs to be tried or not. With zero-copy mode, either
1425  * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
1426  * NULL, then the decryption happens inside skb buffers itself, i.e.
1427  * zero-copy gets disabled and 'darg->zc' is updated.
1428  */
1429 static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
1430                           struct scatterlist *out_sg,
1431                           struct tls_decrypt_arg *darg)
1432 {
1433         struct tls_context *tls_ctx = tls_get_ctx(sk);
1434         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1435         struct tls_prot_info *prot = &tls_ctx->prot_info;
1436         int n_sgin, n_sgout, aead_size, err, pages = 0;
1437         struct sk_buff *skb = tls_strp_msg(ctx);
1438         struct strp_msg *rxm = strp_msg(skb);
1439         struct tls_msg *tlm = tls_msg(skb);
1440         struct aead_request *aead_req;
1441         struct sk_buff *unused;
1442         struct scatterlist *sgin = NULL;
1443         struct scatterlist *sgout = NULL;
1444         const int data_len = rxm->full_len - prot->overhead_size;
1445         int tail_pages = !!prot->tail_size;
1446         struct tls_decrypt_ctx *dctx;
1447         int iv_offset = 0;
1448         u8 *mem;
1449
1450         if (darg->zc && (out_iov || out_sg)) {
1451                 if (out_iov)
1452                         n_sgout = 1 + tail_pages +
1453                                 iov_iter_npages_cap(out_iov, INT_MAX, data_len);
1454                 else
1455                         n_sgout = sg_nents(out_sg);
1456                 n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
1457                                  rxm->full_len - prot->prepend_size);
1458         } else {
1459                 n_sgout = 0;
1460                 darg->zc = false;
1461                 n_sgin = skb_cow_data(skb, 0, &unused);
1462         }
1463
1464         if (n_sgin < 1)
1465                 return -EBADMSG;
1466
1467         /* Increment to accommodate AAD */
1468         n_sgin = n_sgin + 1;
1469
1470         /* Allocate a single block of memory which contains
1471          *   aead_req || tls_decrypt_ctx.
1472          * Both structs are variable length.
1473          */
1474         aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
1475         mem = kmalloc(aead_size + struct_size(dctx, sg, n_sgin + n_sgout),
1476                       sk->sk_allocation);
1477         if (!mem)
1478                 return -ENOMEM;
1479
1480         /* Segment the allocated memory */
1481         aead_req = (struct aead_request *)mem;
1482         dctx = (struct tls_decrypt_ctx *)(mem + aead_size);
1483         sgin = &dctx->sg[0];
1484         sgout = &dctx->sg[n_sgin];
1485
1486         /* For CCM based ciphers, first byte of nonce+iv is a constant */
1487         switch (prot->cipher_type) {
1488         case TLS_CIPHER_AES_CCM_128:
1489                 dctx->iv[0] = TLS_AES_CCM_IV_B0_BYTE;
1490                 iv_offset = 1;
1491                 break;
1492         case TLS_CIPHER_SM4_CCM:
1493                 dctx->iv[0] = TLS_SM4_CCM_IV_B0_BYTE;
1494                 iv_offset = 1;
1495                 break;
1496         }
1497
1498         /* Prepare IV */
1499         if (prot->version == TLS_1_3_VERSION ||
1500             prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
1501                 memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv,
1502                        prot->iv_size + prot->salt_size);
1503         } else {
1504                 err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
1505                                     &dctx->iv[iv_offset] + prot->salt_size,
1506                                     prot->iv_size);
1507                 if (err < 0)
1508                         goto exit_free;
1509                 memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv, prot->salt_size);
1510         }
1511         tls_xor_iv_with_seq(prot, &dctx->iv[iv_offset], tls_ctx->rx.rec_seq);
1512
1513         /* Prepare AAD */
1514         tls_make_aad(dctx->aad, rxm->full_len - prot->overhead_size +
1515                      prot->tail_size,
1516                      tls_ctx->rx.rec_seq, tlm->control, prot);
1517
1518         /* Prepare sgin */
1519         sg_init_table(sgin, n_sgin);
1520         sg_set_buf(&sgin[0], dctx->aad, prot->aad_size);
1521         err = skb_to_sgvec(skb, &sgin[1],
1522                            rxm->offset + prot->prepend_size,
1523                            rxm->full_len - prot->prepend_size);
1524         if (err < 0)
1525                 goto exit_free;
1526
1527         if (n_sgout) {
1528                 if (out_iov) {
1529                         sg_init_table(sgout, n_sgout);
1530                         sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
1531
1532                         err = tls_setup_from_iter(out_iov, data_len,
1533                                                   &pages, &sgout[1],
1534                                                   (n_sgout - 1 - tail_pages));
1535                         if (err < 0)
1536                                 goto fallback_to_reg_recv;
1537
1538                         if (prot->tail_size) {
1539                                 sg_unmark_end(&sgout[pages]);
1540                                 sg_set_buf(&sgout[pages + 1], &dctx->tail,
1541                                            prot->tail_size);
1542                                 sg_mark_end(&sgout[pages + 1]);
1543                         }
1544                 } else if (out_sg) {
1545                         memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
1546                 } else {
1547                         goto fallback_to_reg_recv;
1548                 }
1549         } else {
1550 fallback_to_reg_recv:
1551                 sgout = sgin;
1552                 pages = 0;
1553                 darg->zc = false;
1554         }
1555
1556         /* Prepare and submit AEAD request */
1557         err = tls_do_decryption(sk, skb, sgin, sgout, dctx->iv,
1558                                 data_len + prot->tail_size, aead_req, darg);
1559         if (err)
1560                 goto exit_free_pages;
1561
1562         darg->skb = tls_strp_msg(ctx);
1563         if (darg->async)
1564                 return 0;
1565
1566         if (prot->tail_size)
1567                 darg->tail = dctx->tail;
1568
1569 exit_free_pages:
1570         /* Release the pages in case iov was mapped to pages */
1571         for (; pages > 0; pages--)
1572                 put_page(sg_page(&sgout[pages]));
1573 exit_free:
1574         kfree(mem);
1575         return err;
1576 }
1577
1578 static int
1579 tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx,
1580                    struct tls_decrypt_arg *darg)
1581 {
1582         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1583         int err;
1584
1585         if (tls_ctx->rx_conf != TLS_HW)
1586                 return 0;
1587
1588         err = tls_device_decrypted(sk, tls_ctx);
1589         if (err <= 0)
1590                 return err;
1591
1592         darg->zc = false;
1593         darg->async = false;
1594         darg->skb = tls_strp_msg(ctx);
1595         ctx->recv_pkt = NULL;
1596         return 1;
1597 }
1598
1599 static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest,
1600                              struct tls_decrypt_arg *darg)
1601 {
1602         struct tls_context *tls_ctx = tls_get_ctx(sk);
1603         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1604         struct tls_prot_info *prot = &tls_ctx->prot_info;
1605         struct strp_msg *rxm;
1606         int pad, err;
1607
1608         err = tls_decrypt_device(sk, tls_ctx, darg);
1609         if (err < 0)
1610                 return err;
1611         if (err)
1612                 goto decrypt_done;
1613
1614         err = tls_decrypt_sg(sk, dest, NULL, darg);
1615         if (err < 0) {
1616                 if (err == -EBADMSG)
1617                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
1618                 return err;
1619         }
1620         if (darg->async) {
1621                 if (darg->skb == ctx->recv_pkt)
1622                         ctx->recv_pkt = NULL;
1623                 goto decrypt_next;
1624         }
1625         /* If opportunistic TLS 1.3 ZC failed retry without ZC */
1626         if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION &&
1627                      darg->tail != TLS_RECORD_TYPE_DATA)) {
1628                 darg->zc = false;
1629                 if (!darg->tail)
1630                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXNOPADVIOL);
1631                 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTRETRY);
1632                 return tls_rx_one_record(sk, dest, darg);
1633         }
1634
1635         if (darg->skb == ctx->recv_pkt)
1636                 ctx->recv_pkt = NULL;
1637
1638 decrypt_done:
1639         pad = tls_padding_length(prot, darg->skb, darg);
1640         if (pad < 0) {
1641                 consume_skb(darg->skb);
1642                 return pad;
1643         }
1644
1645         rxm = strp_msg(darg->skb);
1646         rxm->full_len -= pad;
1647         rxm->offset += prot->prepend_size;
1648         rxm->full_len -= prot->overhead_size;
1649 decrypt_next:
1650         tls_advance_record_sn(sk, prot, &tls_ctx->rx);
1651
1652         return 0;
1653 }
1654
1655 int decrypt_skb(struct sock *sk, struct scatterlist *sgout)
1656 {
1657         struct tls_decrypt_arg darg = { .zc = true, };
1658
1659         return tls_decrypt_sg(sk, NULL, sgout, &darg);
1660 }
1661
1662 static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
1663                                    u8 *control)
1664 {
1665         int err;
1666
1667         if (!*control) {
1668                 *control = tlm->control;
1669                 if (!*control)
1670                         return -EBADMSG;
1671
1672                 err = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
1673                                sizeof(*control), control);
1674                 if (*control != TLS_RECORD_TYPE_DATA) {
1675                         if (err || msg->msg_flags & MSG_CTRUNC)
1676                                 return -EIO;
1677                 }
1678         } else if (*control != tlm->control) {
1679                 return 0;
1680         }
1681
1682         return 1;
1683 }
1684
1685 static void tls_rx_rec_done(struct tls_sw_context_rx *ctx)
1686 {
1687         consume_skb(ctx->recv_pkt);
1688         ctx->recv_pkt = NULL;
1689         __strp_unpause(&ctx->strp);
1690 }
1691
1692 /* This function traverses the rx_list in tls receive context to copies the
1693  * decrypted records into the buffer provided by caller zero copy is not
1694  * true. Further, the records are removed from the rx_list if it is not a peek
1695  * case and the record has been consumed completely.
1696  */
1697 static int process_rx_list(struct tls_sw_context_rx *ctx,
1698                            struct msghdr *msg,
1699                            u8 *control,
1700                            size_t skip,
1701                            size_t len,
1702                            bool zc,
1703                            bool is_peek)
1704 {
1705         struct sk_buff *skb = skb_peek(&ctx->rx_list);
1706         struct tls_msg *tlm;
1707         ssize_t copied = 0;
1708         int err;
1709
1710         while (skip && skb) {
1711                 struct strp_msg *rxm = strp_msg(skb);
1712                 tlm = tls_msg(skb);
1713
1714                 err = tls_record_content_type(msg, tlm, control);
1715                 if (err <= 0)
1716                         goto out;
1717
1718                 if (skip < rxm->full_len)
1719                         break;
1720
1721                 skip = skip - rxm->full_len;
1722                 skb = skb_peek_next(skb, &ctx->rx_list);
1723         }
1724
1725         while (len && skb) {
1726                 struct sk_buff *next_skb;
1727                 struct strp_msg *rxm = strp_msg(skb);
1728                 int chunk = min_t(unsigned int, rxm->full_len - skip, len);
1729
1730                 tlm = tls_msg(skb);
1731
1732                 err = tls_record_content_type(msg, tlm, control);
1733                 if (err <= 0)
1734                         goto out;
1735
1736                 if (!zc || (rxm->full_len - skip) > len) {
1737                         err = skb_copy_datagram_msg(skb, rxm->offset + skip,
1738                                                     msg, chunk);
1739                         if (err < 0)
1740                                 goto out;
1741                 }
1742
1743                 len = len - chunk;
1744                 copied = copied + chunk;
1745
1746                 /* Consume the data from record if it is non-peek case*/
1747                 if (!is_peek) {
1748                         rxm->offset = rxm->offset + chunk;
1749                         rxm->full_len = rxm->full_len - chunk;
1750
1751                         /* Return if there is unconsumed data in the record */
1752                         if (rxm->full_len - skip)
1753                                 break;
1754                 }
1755
1756                 /* The remaining skip-bytes must lie in 1st record in rx_list.
1757                  * So from the 2nd record, 'skip' should be 0.
1758                  */
1759                 skip = 0;
1760
1761                 if (msg)
1762                         msg->msg_flags |= MSG_EOR;
1763
1764                 next_skb = skb_peek_next(skb, &ctx->rx_list);
1765
1766                 if (!is_peek) {
1767                         __skb_unlink(skb, &ctx->rx_list);
1768                         consume_skb(skb);
1769                 }
1770
1771                 skb = next_skb;
1772         }
1773         err = 0;
1774
1775 out:
1776         return copied ? : err;
1777 }
1778
1779 static void
1780 tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot,
1781                        size_t len_left, size_t decrypted, ssize_t done,
1782                        size_t *flushed_at)
1783 {
1784         size_t max_rec;
1785
1786         if (len_left <= decrypted)
1787                 return;
1788
1789         max_rec = prot->overhead_size - prot->tail_size + TLS_MAX_PAYLOAD_SIZE;
1790         if (done - *flushed_at < SZ_128K && tcp_inq(sk) > max_rec)
1791                 return;
1792
1793         *flushed_at = done;
1794         sk_flush_backlog(sk);
1795 }
1796
1797 static long tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx,
1798                                bool nonblock)
1799 {
1800         long timeo;
1801
1802         lock_sock(sk);
1803
1804         timeo = sock_rcvtimeo(sk, nonblock);
1805
1806         while (unlikely(ctx->reader_present)) {
1807                 DEFINE_WAIT_FUNC(wait, woken_wake_function);
1808
1809                 ctx->reader_contended = 1;
1810
1811                 add_wait_queue(&ctx->wq, &wait);
1812                 sk_wait_event(sk, &timeo,
1813                               !READ_ONCE(ctx->reader_present), &wait);
1814                 remove_wait_queue(&ctx->wq, &wait);
1815
1816                 if (!timeo)
1817                         return -EAGAIN;
1818                 if (signal_pending(current))
1819                         return sock_intr_errno(timeo);
1820         }
1821
1822         WRITE_ONCE(ctx->reader_present, 1);
1823
1824         return timeo;
1825 }
1826
1827 static void tls_rx_reader_unlock(struct sock *sk, struct tls_sw_context_rx *ctx)
1828 {
1829         if (unlikely(ctx->reader_contended)) {
1830                 if (wq_has_sleeper(&ctx->wq))
1831                         wake_up(&ctx->wq);
1832                 else
1833                         ctx->reader_contended = 0;
1834
1835                 WARN_ON_ONCE(!ctx->reader_present);
1836         }
1837
1838         WRITE_ONCE(ctx->reader_present, 0);
1839         release_sock(sk);
1840 }
1841
1842 int tls_sw_recvmsg(struct sock *sk,
1843                    struct msghdr *msg,
1844                    size_t len,
1845                    int flags,
1846                    int *addr_len)
1847 {
1848         struct tls_context *tls_ctx = tls_get_ctx(sk);
1849         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1850         struct tls_prot_info *prot = &tls_ctx->prot_info;
1851         struct sk_psock *psock;
1852         unsigned char control = 0;
1853         ssize_t decrypted = 0;
1854         size_t flushed_at = 0;
1855         struct strp_msg *rxm;
1856         struct tls_msg *tlm;
1857         struct sk_buff *skb;
1858         ssize_t copied = 0;
1859         bool async = false;
1860         int target, err = 0;
1861         long timeo;
1862         bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
1863         bool is_peek = flags & MSG_PEEK;
1864         bool bpf_strp_enabled;
1865         bool zc_capable;
1866
1867         if (unlikely(flags & MSG_ERRQUEUE))
1868                 return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
1869
1870         psock = sk_psock_get(sk);
1871         timeo = tls_rx_reader_lock(sk, ctx, flags & MSG_DONTWAIT);
1872         if (timeo < 0)
1873                 return timeo;
1874         bpf_strp_enabled = sk_psock_strp_enabled(psock);
1875
1876         /* If crypto failed the connection is broken */
1877         err = ctx->async_wait.err;
1878         if (err)
1879                 goto end;
1880
1881         /* Process pending decrypted records. It must be non-zero-copy */
1882         err = process_rx_list(ctx, msg, &control, 0, len, false, is_peek);
1883         if (err < 0)
1884                 goto end;
1885
1886         copied = err;
1887         if (len <= copied)
1888                 goto end;
1889
1890         target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
1891         len = len - copied;
1892
1893         zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek &&
1894                 ctx->zc_capable;
1895         decrypted = 0;
1896         while (len && (decrypted + copied < target || ctx->recv_pkt)) {
1897                 struct tls_decrypt_arg darg;
1898                 int to_decrypt, chunk;
1899
1900                 err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, timeo);
1901                 if (err <= 0) {
1902                         if (psock) {
1903                                 chunk = sk_msg_recvmsg(sk, psock, msg, len,
1904                                                        flags);
1905                                 if (chunk > 0) {
1906                                         decrypted += chunk;
1907                                         len -= chunk;
1908                                         continue;
1909                                 }
1910                         }
1911                         goto recv_end;
1912                 }
1913
1914                 memset(&darg.inargs, 0, sizeof(darg.inargs));
1915
1916                 rxm = strp_msg(ctx->recv_pkt);
1917                 tlm = tls_msg(ctx->recv_pkt);
1918
1919                 to_decrypt = rxm->full_len - prot->overhead_size;
1920
1921                 if (zc_capable && to_decrypt <= len &&
1922                     tlm->control == TLS_RECORD_TYPE_DATA)
1923                         darg.zc = true;
1924
1925                 /* Do not use async mode if record is non-data */
1926                 if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
1927                         darg.async = ctx->async_capable;
1928                 else
1929                         darg.async = false;
1930
1931                 err = tls_rx_one_record(sk, &msg->msg_iter, &darg);
1932                 if (err < 0) {
1933                         tls_err_abort(sk, -EBADMSG);
1934                         goto recv_end;
1935                 }
1936
1937                 skb = darg.skb;
1938                 rxm = strp_msg(skb);
1939                 tlm = tls_msg(skb);
1940
1941                 async |= darg.async;
1942
1943                 /* If the type of records being processed is not known yet,
1944                  * set it to record type just dequeued. If it is already known,
1945                  * but does not match the record type just dequeued, go to end.
1946                  * We always get record type here since for tls1.2, record type
1947                  * is known just after record is dequeued from stream parser.
1948                  * For tls1.3, we disable async.
1949                  */
1950                 err = tls_record_content_type(msg, tlm, &control);
1951                 if (err <= 0) {
1952                         tls_rx_rec_done(ctx);
1953 put_on_rx_list_err:
1954                         __skb_queue_tail(&ctx->rx_list, skb);
1955                         goto recv_end;
1956                 }
1957
1958                 /* periodically flush backlog, and feed strparser */
1959                 tls_read_flush_backlog(sk, prot, len, to_decrypt,
1960                                        decrypted + copied, &flushed_at);
1961
1962                 /* TLS 1.3 may have updated the length by more than overhead */
1963                 chunk = rxm->full_len;
1964                 tls_rx_rec_done(ctx);
1965
1966                 if (async) {
1967                         /* TLS 1.2-only, to_decrypt must be text length */
1968                         chunk = min_t(int, to_decrypt, len);
1969 put_on_rx_list:
1970                         decrypted += chunk;
1971                         len -= chunk;
1972                         __skb_queue_tail(&ctx->rx_list, skb);
1973                         continue;
1974                 }
1975
1976                 if (!darg.zc) {
1977                         bool partially_consumed = chunk > len;
1978
1979                         if (bpf_strp_enabled) {
1980                                 err = sk_psock_tls_strp_read(psock, skb);
1981                                 if (err != __SK_PASS) {
1982                                         rxm->offset = rxm->offset + rxm->full_len;
1983                                         rxm->full_len = 0;
1984                                         if (err == __SK_DROP)
1985                                                 consume_skb(skb);
1986                                         continue;
1987                                 }
1988                         }
1989
1990                         if (partially_consumed)
1991                                 chunk = len;
1992
1993                         err = skb_copy_datagram_msg(skb, rxm->offset,
1994                                                     msg, chunk);
1995                         if (err < 0)
1996                                 goto put_on_rx_list_err;
1997
1998                         if (is_peek)
1999                                 goto put_on_rx_list;
2000
2001                         if (partially_consumed) {
2002                                 rxm->offset += chunk;
2003                                 rxm->full_len -= chunk;
2004                                 goto put_on_rx_list;
2005                         }
2006                 }
2007
2008                 decrypted += chunk;
2009                 len -= chunk;
2010
2011                 consume_skb(skb);
2012
2013                 /* Return full control message to userspace before trying
2014                  * to parse another message type
2015                  */
2016                 msg->msg_flags |= MSG_EOR;
2017                 if (control != TLS_RECORD_TYPE_DATA)
2018                         break;
2019         }
2020
2021 recv_end:
2022         if (async) {
2023                 int ret, pending;
2024
2025                 /* Wait for all previously submitted records to be decrypted */
2026                 spin_lock_bh(&ctx->decrypt_compl_lock);
2027                 reinit_completion(&ctx->async_wait.completion);
2028                 pending = atomic_read(&ctx->decrypt_pending);
2029                 spin_unlock_bh(&ctx->decrypt_compl_lock);
2030                 if (pending) {
2031                         ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
2032                         if (ret) {
2033                                 if (err >= 0 || err == -EINPROGRESS)
2034                                         err = ret;
2035                                 decrypted = 0;
2036                                 goto end;
2037                         }
2038                 }
2039
2040                 /* Drain records from the rx_list & copy if required */
2041                 if (is_peek || is_kvec)
2042                         err = process_rx_list(ctx, msg, &control, copied,
2043                                               decrypted, false, is_peek);
2044                 else
2045                         err = process_rx_list(ctx, msg, &control, 0,
2046                                               decrypted, true, is_peek);
2047                 decrypted = max(err, 0);
2048         }
2049
2050         copied += decrypted;
2051
2052 end:
2053         tls_rx_reader_unlock(sk, ctx);
2054         if (psock)
2055                 sk_psock_put(sk, psock);
2056         return copied ? : err;
2057 }
2058
2059 ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
2060                            struct pipe_inode_info *pipe,
2061                            size_t len, unsigned int flags)
2062 {
2063         struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
2064         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2065         struct strp_msg *rxm = NULL;
2066         struct sock *sk = sock->sk;
2067         struct tls_msg *tlm;
2068         struct sk_buff *skb;
2069         ssize_t copied = 0;
2070         int err = 0;
2071         long timeo;
2072         int chunk;
2073
2074         timeo = tls_rx_reader_lock(sk, ctx, flags & SPLICE_F_NONBLOCK);
2075         if (timeo < 0)
2076                 return timeo;
2077
2078         if (!skb_queue_empty(&ctx->rx_list)) {
2079                 skb = __skb_dequeue(&ctx->rx_list);
2080         } else {
2081                 struct tls_decrypt_arg darg;
2082
2083                 err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK,
2084                                       timeo);
2085                 if (err <= 0)
2086                         goto splice_read_end;
2087
2088                 memset(&darg.inargs, 0, sizeof(darg.inargs));
2089
2090                 err = tls_rx_one_record(sk, NULL, &darg);
2091                 if (err < 0) {
2092                         tls_err_abort(sk, -EBADMSG);
2093                         goto splice_read_end;
2094                 }
2095
2096                 tls_rx_rec_done(ctx);
2097                 skb = darg.skb;
2098         }
2099
2100         rxm = strp_msg(skb);
2101         tlm = tls_msg(skb);
2102
2103         /* splice does not support reading control messages */
2104         if (tlm->control != TLS_RECORD_TYPE_DATA) {
2105                 err = -EINVAL;
2106                 goto splice_requeue;
2107         }
2108
2109         chunk = min_t(unsigned int, rxm->full_len, len);
2110         copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
2111         if (copied < 0)
2112                 goto splice_requeue;
2113
2114         if (chunk < rxm->full_len) {
2115                 rxm->offset += len;
2116                 rxm->full_len -= len;
2117                 goto splice_requeue;
2118         }
2119
2120         consume_skb(skb);
2121
2122 splice_read_end:
2123         tls_rx_reader_unlock(sk, ctx);
2124         return copied ? : err;
2125
2126 splice_requeue:
2127         __skb_queue_head(&ctx->rx_list, skb);
2128         goto splice_read_end;
2129 }
2130
2131 bool tls_sw_sock_is_readable(struct sock *sk)
2132 {
2133         struct tls_context *tls_ctx = tls_get_ctx(sk);
2134         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2135         bool ingress_empty = true;
2136         struct sk_psock *psock;
2137
2138         rcu_read_lock();
2139         psock = sk_psock(sk);
2140         if (psock)
2141                 ingress_empty = list_empty(&psock->ingress_msg);
2142         rcu_read_unlock();
2143
2144         return !ingress_empty || ctx->recv_pkt ||
2145                 !skb_queue_empty(&ctx->rx_list);
2146 }
2147
2148 static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
2149 {
2150         struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
2151         struct tls_prot_info *prot = &tls_ctx->prot_info;
2152         char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
2153         struct strp_msg *rxm = strp_msg(skb);
2154         struct tls_msg *tlm = tls_msg(skb);
2155         size_t cipher_overhead;
2156         size_t data_len = 0;
2157         int ret;
2158
2159         /* Verify that we have a full TLS header, or wait for more data */
2160         if (rxm->offset + prot->prepend_size > skb->len)
2161                 return 0;
2162
2163         /* Sanity-check size of on-stack buffer. */
2164         if (WARN_ON(prot->prepend_size > sizeof(header))) {
2165                 ret = -EINVAL;
2166                 goto read_failure;
2167         }
2168
2169         /* Linearize header to local buffer */
2170         ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size);
2171         if (ret < 0)
2172                 goto read_failure;
2173
2174         tlm->control = header[0];
2175
2176         data_len = ((header[4] & 0xFF) | (header[3] << 8));
2177
2178         cipher_overhead = prot->tag_size;
2179         if (prot->version != TLS_1_3_VERSION &&
2180             prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305)
2181                 cipher_overhead += prot->iv_size;
2182
2183         if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead +
2184             prot->tail_size) {
2185                 ret = -EMSGSIZE;
2186                 goto read_failure;
2187         }
2188         if (data_len < cipher_overhead) {
2189                 ret = -EBADMSG;
2190                 goto read_failure;
2191         }
2192
2193         /* Note that both TLS1.3 and TLS1.2 use TLS_1_2 version here */
2194         if (header[1] != TLS_1_2_VERSION_MINOR ||
2195             header[2] != TLS_1_2_VERSION_MAJOR) {
2196                 ret = -EINVAL;
2197                 goto read_failure;
2198         }
2199
2200         tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE,
2201                                      TCP_SKB_CB(skb)->seq + rxm->offset);
2202         return data_len + TLS_HEADER_SIZE;
2203
2204 read_failure:
2205         tls_err_abort(strp->sk, ret);
2206
2207         return ret;
2208 }
2209
2210 static void tls_queue(struct strparser *strp, struct sk_buff *skb)
2211 {
2212         struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
2213         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2214
2215         ctx->recv_pkt = skb;
2216         strp_pause(strp);
2217
2218         ctx->saved_data_ready(strp->sk);
2219 }
2220
2221 static void tls_data_ready(struct sock *sk)
2222 {
2223         struct tls_context *tls_ctx = tls_get_ctx(sk);
2224         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2225         struct sk_psock *psock;
2226
2227         strp_data_ready(&ctx->strp);
2228
2229         psock = sk_psock_get(sk);
2230         if (psock) {
2231                 if (!list_empty(&psock->ingress_msg))
2232                         ctx->saved_data_ready(sk);
2233                 sk_psock_put(sk, psock);
2234         }
2235 }
2236
2237 void tls_sw_cancel_work_tx(struct tls_context *tls_ctx)
2238 {
2239         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2240
2241         set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask);
2242         set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask);
2243         cancel_delayed_work_sync(&ctx->tx_work.work);
2244 }
2245
2246 void tls_sw_release_resources_tx(struct sock *sk)
2247 {
2248         struct tls_context *tls_ctx = tls_get_ctx(sk);
2249         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2250         struct tls_rec *rec, *tmp;
2251         int pending;
2252
2253         /* Wait for any pending async encryptions to complete */
2254         spin_lock_bh(&ctx->encrypt_compl_lock);
2255         ctx->async_notify = true;
2256         pending = atomic_read(&ctx->encrypt_pending);
2257         spin_unlock_bh(&ctx->encrypt_compl_lock);
2258
2259         if (pending)
2260                 crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
2261
2262         tls_tx_records(sk, -1);
2263
2264         /* Free up un-sent records in tx_list. First, free
2265          * the partially sent record if any at head of tx_list.
2266          */
2267         if (tls_ctx->partially_sent_record) {
2268                 tls_free_partial_record(sk, tls_ctx);
2269                 rec = list_first_entry(&ctx->tx_list,
2270                                        struct tls_rec, list);
2271                 list_del(&rec->list);
2272                 sk_msg_free(sk, &rec->msg_plaintext);
2273                 kfree(rec);
2274         }
2275
2276         list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
2277                 list_del(&rec->list);
2278                 sk_msg_free(sk, &rec->msg_encrypted);
2279                 sk_msg_free(sk, &rec->msg_plaintext);
2280                 kfree(rec);
2281         }
2282
2283         crypto_free_aead(ctx->aead_send);
2284         tls_free_open_rec(sk);
2285 }
2286
2287 void tls_sw_free_ctx_tx(struct tls_context *tls_ctx)
2288 {
2289         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2290
2291         kfree(ctx);
2292 }
2293
2294 void tls_sw_release_resources_rx(struct sock *sk)
2295 {
2296         struct tls_context *tls_ctx = tls_get_ctx(sk);
2297         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2298
2299         kfree(tls_ctx->rx.rec_seq);
2300         kfree(tls_ctx->rx.iv);
2301
2302         if (ctx->aead_recv) {
2303                 kfree_skb(ctx->recv_pkt);
2304                 ctx->recv_pkt = NULL;
2305                 __skb_queue_purge(&ctx->rx_list);
2306                 crypto_free_aead(ctx->aead_recv);
2307                 strp_stop(&ctx->strp);
2308                 /* If tls_sw_strparser_arm() was not called (cleanup paths)
2309                  * we still want to strp_stop(), but sk->sk_data_ready was
2310                  * never swapped.
2311                  */
2312                 if (ctx->saved_data_ready) {
2313                         write_lock_bh(&sk->sk_callback_lock);
2314                         sk->sk_data_ready = ctx->saved_data_ready;
2315                         write_unlock_bh(&sk->sk_callback_lock);
2316                 }
2317         }
2318 }
2319
2320 void tls_sw_strparser_done(struct tls_context *tls_ctx)
2321 {
2322         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2323
2324         strp_done(&ctx->strp);
2325 }
2326
2327 void tls_sw_free_ctx_rx(struct tls_context *tls_ctx)
2328 {
2329         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2330
2331         kfree(ctx);
2332 }
2333
2334 void tls_sw_free_resources_rx(struct sock *sk)
2335 {
2336         struct tls_context *tls_ctx = tls_get_ctx(sk);
2337
2338         tls_sw_release_resources_rx(sk);
2339         tls_sw_free_ctx_rx(tls_ctx);
2340 }
2341
2342 /* The work handler to transmitt the encrypted records in tx_list */
2343 static void tx_work_handler(struct work_struct *work)
2344 {
2345         struct delayed_work *delayed_work = to_delayed_work(work);
2346         struct tx_work *tx_work = container_of(delayed_work,
2347                                                struct tx_work, work);
2348         struct sock *sk = tx_work->sk;
2349         struct tls_context *tls_ctx = tls_get_ctx(sk);
2350         struct tls_sw_context_tx *ctx;
2351
2352         if (unlikely(!tls_ctx))
2353                 return;
2354
2355         ctx = tls_sw_ctx_tx(tls_ctx);
2356         if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask))
2357                 return;
2358
2359         if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
2360                 return;
2361         mutex_lock(&tls_ctx->tx_lock);
2362         lock_sock(sk);
2363         tls_tx_records(sk, -1);
2364         release_sock(sk);
2365         mutex_unlock(&tls_ctx->tx_lock);
2366 }
2367
2368 static bool tls_is_tx_ready(struct tls_sw_context_tx *ctx)
2369 {
2370         struct tls_rec *rec;
2371
2372         rec = list_first_entry(&ctx->tx_list, struct tls_rec, list);
2373         if (!rec)
2374                 return false;
2375
2376         return READ_ONCE(rec->tx_ready);
2377 }
2378
2379 void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
2380 {
2381         struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
2382
2383         /* Schedule the transmission if tx list is ready */
2384         if (tls_is_tx_ready(tx_ctx) &&
2385             !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
2386                 schedule_delayed_work(&tx_ctx->tx_work.work, 0);
2387 }
2388
2389 void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
2390 {
2391         struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
2392
2393         write_lock_bh(&sk->sk_callback_lock);
2394         rx_ctx->saved_data_ready = sk->sk_data_ready;
2395         sk->sk_data_ready = tls_data_ready;
2396         write_unlock_bh(&sk->sk_callback_lock);
2397
2398         strp_check_rcv(&rx_ctx->strp);
2399 }
2400
2401 void tls_update_rx_zc_capable(struct tls_context *tls_ctx)
2402 {
2403         struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
2404
2405         rx_ctx->zc_capable = tls_ctx->rx_no_pad ||
2406                 tls_ctx->prot_info.version != TLS_1_3_VERSION;
2407 }
2408
2409 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
2410 {
2411         struct tls_context *tls_ctx = tls_get_ctx(sk);
2412         struct tls_prot_info *prot = &tls_ctx->prot_info;
2413         struct tls_crypto_info *crypto_info;
2414         struct tls_sw_context_tx *sw_ctx_tx = NULL;
2415         struct tls_sw_context_rx *sw_ctx_rx = NULL;
2416         struct cipher_context *cctx;
2417         struct crypto_aead **aead;
2418         struct strp_callbacks cb;
2419         u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
2420         struct crypto_tfm *tfm;
2421         char *iv, *rec_seq, *key, *salt, *cipher_name;
2422         size_t keysize;
2423         int rc = 0;
2424
2425         if (!ctx) {
2426                 rc = -EINVAL;
2427                 goto out;
2428         }
2429
2430         if (tx) {
2431                 if (!ctx->priv_ctx_tx) {
2432                         sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
2433                         if (!sw_ctx_tx) {
2434                                 rc = -ENOMEM;
2435                                 goto out;
2436                         }
2437                         ctx->priv_ctx_tx = sw_ctx_tx;
2438                 } else {
2439                         sw_ctx_tx =
2440                                 (struct tls_sw_context_tx *)ctx->priv_ctx_tx;
2441                 }
2442         } else {
2443                 if (!ctx->priv_ctx_rx) {
2444                         sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
2445                         if (!sw_ctx_rx) {
2446                                 rc = -ENOMEM;
2447                                 goto out;
2448                         }
2449                         ctx->priv_ctx_rx = sw_ctx_rx;
2450                 } else {
2451                         sw_ctx_rx =
2452                                 (struct tls_sw_context_rx *)ctx->priv_ctx_rx;
2453                 }
2454         }
2455
2456         if (tx) {
2457                 crypto_init_wait(&sw_ctx_tx->async_wait);
2458                 spin_lock_init(&sw_ctx_tx->encrypt_compl_lock);
2459                 crypto_info = &ctx->crypto_send.info;
2460                 cctx = &ctx->tx;
2461                 aead = &sw_ctx_tx->aead_send;
2462                 INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
2463                 INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
2464                 sw_ctx_tx->tx_work.sk = sk;
2465         } else {
2466                 crypto_init_wait(&sw_ctx_rx->async_wait);
2467                 spin_lock_init(&sw_ctx_rx->decrypt_compl_lock);
2468                 init_waitqueue_head(&sw_ctx_rx->wq);
2469                 crypto_info = &ctx->crypto_recv.info;
2470                 cctx = &ctx->rx;
2471                 skb_queue_head_init(&sw_ctx_rx->rx_list);
2472                 aead = &sw_ctx_rx->aead_recv;
2473         }
2474
2475         switch (crypto_info->cipher_type) {
2476         case TLS_CIPHER_AES_GCM_128: {
2477                 struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
2478
2479                 gcm_128_info = (void *)crypto_info;
2480                 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
2481                 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
2482                 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
2483                 iv = gcm_128_info->iv;
2484                 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
2485                 rec_seq = gcm_128_info->rec_seq;
2486                 keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE;
2487                 key = gcm_128_info->key;
2488                 salt = gcm_128_info->salt;
2489                 salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
2490                 cipher_name = "gcm(aes)";
2491                 break;
2492         }
2493         case TLS_CIPHER_AES_GCM_256: {
2494                 struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;
2495
2496                 gcm_256_info = (void *)crypto_info;
2497                 nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
2498                 tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE;
2499                 iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
2500                 iv = gcm_256_info->iv;
2501                 rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE;
2502                 rec_seq = gcm_256_info->rec_seq;
2503                 keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE;
2504                 key = gcm_256_info->key;
2505                 salt = gcm_256_info->salt;
2506                 salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE;
2507                 cipher_name = "gcm(aes)";
2508                 break;
2509         }
2510         case TLS_CIPHER_AES_CCM_128: {
2511                 struct tls12_crypto_info_aes_ccm_128 *ccm_128_info;
2512
2513                 ccm_128_info = (void *)crypto_info;
2514                 nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
2515                 tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE;
2516                 iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
2517                 iv = ccm_128_info->iv;
2518                 rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE;
2519                 rec_seq = ccm_128_info->rec_seq;
2520                 keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE;
2521                 key = ccm_128_info->key;
2522                 salt = ccm_128_info->salt;
2523                 salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE;
2524                 cipher_name = "ccm(aes)";
2525                 break;
2526         }
2527         case TLS_CIPHER_CHACHA20_POLY1305: {
2528                 struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305_info;
2529
2530                 chacha20_poly1305_info = (void *)crypto_info;
2531                 nonce_size = 0;
2532                 tag_size = TLS_CIPHER_CHACHA20_POLY1305_TAG_SIZE;
2533                 iv_size = TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE;
2534                 iv = chacha20_poly1305_info->iv;
2535                 rec_seq_size = TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE;
2536                 rec_seq = chacha20_poly1305_info->rec_seq;
2537                 keysize = TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE;
2538                 key = chacha20_poly1305_info->key;
2539                 salt = chacha20_poly1305_info->salt;
2540                 salt_size = TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE;
2541                 cipher_name = "rfc7539(chacha20,poly1305)";
2542                 break;
2543         }
2544         case TLS_CIPHER_SM4_GCM: {
2545                 struct tls12_crypto_info_sm4_gcm *sm4_gcm_info;
2546
2547                 sm4_gcm_info = (void *)crypto_info;
2548                 nonce_size = TLS_CIPHER_SM4_GCM_IV_SIZE;
2549                 tag_size = TLS_CIPHER_SM4_GCM_TAG_SIZE;
2550                 iv_size = TLS_CIPHER_SM4_GCM_IV_SIZE;
2551                 iv = sm4_gcm_info->iv;
2552                 rec_seq_size = TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE;
2553                 rec_seq = sm4_gcm_info->rec_seq;
2554                 keysize = TLS_CIPHER_SM4_GCM_KEY_SIZE;
2555                 key = sm4_gcm_info->key;
2556                 salt = sm4_gcm_info->salt;
2557                 salt_size = TLS_CIPHER_SM4_GCM_SALT_SIZE;
2558                 cipher_name = "gcm(sm4)";
2559                 break;
2560         }
2561         case TLS_CIPHER_SM4_CCM: {
2562                 struct tls12_crypto_info_sm4_ccm *sm4_ccm_info;
2563
2564                 sm4_ccm_info = (void *)crypto_info;
2565                 nonce_size = TLS_CIPHER_SM4_CCM_IV_SIZE;
2566                 tag_size = TLS_CIPHER_SM4_CCM_TAG_SIZE;
2567                 iv_size = TLS_CIPHER_SM4_CCM_IV_SIZE;
2568                 iv = sm4_ccm_info->iv;
2569                 rec_seq_size = TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE;
2570                 rec_seq = sm4_ccm_info->rec_seq;
2571                 keysize = TLS_CIPHER_SM4_CCM_KEY_SIZE;
2572                 key = sm4_ccm_info->key;
2573                 salt = sm4_ccm_info->salt;
2574                 salt_size = TLS_CIPHER_SM4_CCM_SALT_SIZE;
2575                 cipher_name = "ccm(sm4)";
2576                 break;
2577         }
2578         default:
2579                 rc = -EINVAL;
2580                 goto free_priv;
2581         }
2582
2583         if (crypto_info->version == TLS_1_3_VERSION) {
2584                 nonce_size = 0;
2585                 prot->aad_size = TLS_HEADER_SIZE;
2586                 prot->tail_size = 1;
2587         } else {
2588                 prot->aad_size = TLS_AAD_SPACE_SIZE;
2589                 prot->tail_size = 0;
2590         }
2591
2592         /* Sanity-check the sizes for stack allocations. */
2593         if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
2594             rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE ||
2595             prot->aad_size > TLS_MAX_AAD_SIZE) {
2596                 rc = -EINVAL;
2597                 goto free_priv;
2598         }
2599
2600         prot->version = crypto_info->version;
2601         prot->cipher_type = crypto_info->cipher_type;
2602         prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
2603         prot->tag_size = tag_size;
2604         prot->overhead_size = prot->prepend_size +
2605                               prot->tag_size + prot->tail_size;
2606         prot->iv_size = iv_size;
2607         prot->salt_size = salt_size;
2608         cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL);
2609         if (!cctx->iv) {
2610                 rc = -ENOMEM;
2611                 goto free_priv;
2612         }
2613         /* Note: 128 & 256 bit salt are the same size */
2614         prot->rec_seq_size = rec_seq_size;
2615         memcpy(cctx->iv, salt, salt_size);
2616         memcpy(cctx->iv + salt_size, iv, iv_size);
2617         cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
2618         if (!cctx->rec_seq) {
2619                 rc = -ENOMEM;
2620                 goto free_iv;
2621         }
2622
2623         if (!*aead) {
2624                 *aead = crypto_alloc_aead(cipher_name, 0, 0);
2625                 if (IS_ERR(*aead)) {
2626                         rc = PTR_ERR(*aead);
2627                         *aead = NULL;
2628                         goto free_rec_seq;
2629                 }
2630         }
2631
2632         ctx->push_pending_record = tls_sw_push_pending_record;
2633
2634         rc = crypto_aead_setkey(*aead, key, keysize);
2635
2636         if (rc)
2637                 goto free_aead;
2638
2639         rc = crypto_aead_setauthsize(*aead, prot->tag_size);
2640         if (rc)
2641                 goto free_aead;
2642
2643         if (sw_ctx_rx) {
2644                 tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
2645
2646                 tls_update_rx_zc_capable(ctx);
2647                 sw_ctx_rx->async_capable =
2648                         crypto_info->version != TLS_1_3_VERSION &&
2649                         !!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC);
2650
2651                 /* Set up strparser */
2652                 memset(&cb, 0, sizeof(cb));
2653                 cb.rcv_msg = tls_queue;
2654                 cb.parse_msg = tls_read_size;
2655
2656                 strp_init(&sw_ctx_rx->strp, sk, &cb);
2657         }
2658
2659         goto out;
2660
2661 free_aead:
2662         crypto_free_aead(*aead);
2663         *aead = NULL;
2664 free_rec_seq:
2665         kfree(cctx->rec_seq);
2666         cctx->rec_seq = NULL;
2667 free_iv:
2668         kfree(cctx->iv);
2669         cctx->iv = NULL;
2670 free_priv:
2671         if (tx) {
2672                 kfree(ctx->priv_ctx_tx);
2673                 ctx->priv_ctx_tx = NULL;
2674         } else {
2675                 kfree(ctx->priv_ctx_rx);
2676                 ctx->priv_ctx_rx = NULL;
2677         }
2678 out:
2679         return rc;
2680 }