net/tls: fix the IV leaks
[linux-2.6-microblaze.git] / net / tls / tls_device.c
1 /* Copyright (c) 2018, Mellanox Technologies All rights reserved.
2  *
3  * This software is available to you under a choice of one of two
4  * licenses.  You may choose to be licensed under the terms of the GNU
5  * General Public License (GPL) Version 2, available from the file
6  * COPYING in the main directory of this source tree, or the
7  * OpenIB.org BSD license below:
8  *
9  *     Redistribution and use in source and binary forms, with or
10  *     without modification, are permitted provided that the following
11  *     conditions are met:
12  *
13  *      - Redistributions of source code must retain the above
14  *        copyright notice, this list of conditions and the following
15  *        disclaimer.
16  *
17  *      - Redistributions in binary form must reproduce the above
18  *        copyright notice, this list of conditions and the following
19  *        disclaimer in the documentation and/or other materials
20  *        provided with the distribution.
21  *
22  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
23  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
24  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
25  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
26  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
27  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
28  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29  * SOFTWARE.
30  */
31
32 #include <crypto/aead.h>
33 #include <linux/highmem.h>
34 #include <linux/module.h>
35 #include <linux/netdevice.h>
36 #include <net/dst.h>
37 #include <net/inet_connection_sock.h>
38 #include <net/tcp.h>
39 #include <net/tls.h>
40
41 /* device_offload_lock is used to synchronize tls_dev_add
42  * against NETDEV_DOWN notifications.
43  */
44 static DECLARE_RWSEM(device_offload_lock);
45
46 static void tls_device_gc_task(struct work_struct *work);
47
48 static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task);
49 static LIST_HEAD(tls_device_gc_list);
50 static LIST_HEAD(tls_device_list);
51 static DEFINE_SPINLOCK(tls_device_lock);
52
53 static void tls_device_free_ctx(struct tls_context *ctx)
54 {
55         if (ctx->tx_conf == TLS_HW) {
56                 kfree(tls_offload_ctx_tx(ctx));
57                 kfree(ctx->tx.rec_seq);
58                 kfree(ctx->tx.iv);
59         }
60
61         if (ctx->rx_conf == TLS_HW)
62                 kfree(tls_offload_ctx_rx(ctx));
63
64         kfree(ctx);
65 }
66
67 static void tls_device_gc_task(struct work_struct *work)
68 {
69         struct tls_context *ctx, *tmp;
70         unsigned long flags;
71         LIST_HEAD(gc_list);
72
73         spin_lock_irqsave(&tls_device_lock, flags);
74         list_splice_init(&tls_device_gc_list, &gc_list);
75         spin_unlock_irqrestore(&tls_device_lock, flags);
76
77         list_for_each_entry_safe(ctx, tmp, &gc_list, list) {
78                 struct net_device *netdev = ctx->netdev;
79
80                 if (netdev && ctx->tx_conf == TLS_HW) {
81                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
82                                                         TLS_OFFLOAD_CTX_DIR_TX);
83                         dev_put(netdev);
84                         ctx->netdev = NULL;
85                 }
86
87                 list_del(&ctx->list);
88                 tls_device_free_ctx(ctx);
89         }
90 }
91
92 static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
93                               struct net_device *netdev)
94 {
95         if (sk->sk_destruct != tls_device_sk_destruct) {
96                 refcount_set(&ctx->refcount, 1);
97                 dev_hold(netdev);
98                 ctx->netdev = netdev;
99                 spin_lock_irq(&tls_device_lock);
100                 list_add_tail(&ctx->list, &tls_device_list);
101                 spin_unlock_irq(&tls_device_lock);
102
103                 ctx->sk_destruct = sk->sk_destruct;
104                 sk->sk_destruct = tls_device_sk_destruct;
105         }
106 }
107
108 static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
109 {
110         unsigned long flags;
111
112         spin_lock_irqsave(&tls_device_lock, flags);
113         list_move_tail(&ctx->list, &tls_device_gc_list);
114
115         /* schedule_work inside the spinlock
116          * to make sure tls_device_down waits for that work.
117          */
118         schedule_work(&tls_device_gc_work);
119
120         spin_unlock_irqrestore(&tls_device_lock, flags);
121 }
122
123 /* We assume that the socket is already connected */
124 static struct net_device *get_netdev_for_sock(struct sock *sk)
125 {
126         struct dst_entry *dst = sk_dst_get(sk);
127         struct net_device *netdev = NULL;
128
129         if (likely(dst)) {
130                 netdev = dst->dev;
131                 dev_hold(netdev);
132         }
133
134         dst_release(dst);
135
136         return netdev;
137 }
138
139 static void destroy_record(struct tls_record_info *record)
140 {
141         int nr_frags = record->num_frags;
142         skb_frag_t *frag;
143
144         while (nr_frags-- > 0) {
145                 frag = &record->frags[nr_frags];
146                 __skb_frag_unref(frag);
147         }
148         kfree(record);
149 }
150
151 static void delete_all_records(struct tls_offload_context_tx *offload_ctx)
152 {
153         struct tls_record_info *info, *temp;
154
155         list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) {
156                 list_del(&info->list);
157                 destroy_record(info);
158         }
159
160         offload_ctx->retransmit_hint = NULL;
161 }
162
163 static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
164 {
165         struct tls_context *tls_ctx = tls_get_ctx(sk);
166         struct tls_record_info *info, *temp;
167         struct tls_offload_context_tx *ctx;
168         u64 deleted_records = 0;
169         unsigned long flags;
170
171         if (!tls_ctx)
172                 return;
173
174         ctx = tls_offload_ctx_tx(tls_ctx);
175
176         spin_lock_irqsave(&ctx->lock, flags);
177         info = ctx->retransmit_hint;
178         if (info && !before(acked_seq, info->end_seq)) {
179                 ctx->retransmit_hint = NULL;
180                 list_del(&info->list);
181                 destroy_record(info);
182                 deleted_records++;
183         }
184
185         list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
186                 if (before(acked_seq, info->end_seq))
187                         break;
188                 list_del(&info->list);
189
190                 destroy_record(info);
191                 deleted_records++;
192         }
193
194         ctx->unacked_record_sn += deleted_records;
195         spin_unlock_irqrestore(&ctx->lock, flags);
196 }
197
198 /* At this point, there should be no references on this
199  * socket and no in-flight SKBs associated with this
200  * socket, so it is safe to free all the resources.
201  */
202 void tls_device_sk_destruct(struct sock *sk)
203 {
204         struct tls_context *tls_ctx = tls_get_ctx(sk);
205         struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
206
207         tls_ctx->sk_destruct(sk);
208
209         if (tls_ctx->tx_conf == TLS_HW) {
210                 if (ctx->open_record)
211                         destroy_record(ctx->open_record);
212                 delete_all_records(ctx);
213                 crypto_free_aead(ctx->aead_send);
214                 clean_acked_data_disable(inet_csk(sk));
215         }
216
217         if (refcount_dec_and_test(&tls_ctx->refcount))
218                 tls_device_queue_ctx_destruction(tls_ctx);
219 }
220 EXPORT_SYMBOL(tls_device_sk_destruct);
221
222 static void tls_append_frag(struct tls_record_info *record,
223                             struct page_frag *pfrag,
224                             int size)
225 {
226         skb_frag_t *frag;
227
228         frag = &record->frags[record->num_frags - 1];
229         if (frag->page.p == pfrag->page &&
230             frag->page_offset + frag->size == pfrag->offset) {
231                 frag->size += size;
232         } else {
233                 ++frag;
234                 frag->page.p = pfrag->page;
235                 frag->page_offset = pfrag->offset;
236                 frag->size = size;
237                 ++record->num_frags;
238                 get_page(pfrag->page);
239         }
240
241         pfrag->offset += size;
242         record->len += size;
243 }
244
245 static int tls_push_record(struct sock *sk,
246                            struct tls_context *ctx,
247                            struct tls_offload_context_tx *offload_ctx,
248                            struct tls_record_info *record,
249                            struct page_frag *pfrag,
250                            int flags,
251                            unsigned char record_type)
252 {
253         struct tls_prot_info *prot = &ctx->prot_info;
254         struct tcp_sock *tp = tcp_sk(sk);
255         struct page_frag dummy_tag_frag;
256         skb_frag_t *frag;
257         int i;
258
259         /* fill prepend */
260         frag = &record->frags[0];
261         tls_fill_prepend(ctx,
262                          skb_frag_address(frag),
263                          record->len - prot->prepend_size,
264                          record_type,
265                          ctx->crypto_send.info.version);
266
267         /* HW doesn't care about the data in the tag, because it fills it. */
268         dummy_tag_frag.page = skb_frag_page(frag);
269         dummy_tag_frag.offset = 0;
270
271         tls_append_frag(record, &dummy_tag_frag, prot->tag_size);
272         record->end_seq = tp->write_seq + record->len;
273         spin_lock_irq(&offload_ctx->lock);
274         list_add_tail(&record->list, &offload_ctx->records_list);
275         spin_unlock_irq(&offload_ctx->lock);
276         offload_ctx->open_record = NULL;
277         tls_advance_record_sn(sk, &ctx->tx, ctx->crypto_send.info.version);
278
279         for (i = 0; i < record->num_frags; i++) {
280                 frag = &record->frags[i];
281                 sg_unmark_end(&offload_ctx->sg_tx_data[i]);
282                 sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag),
283                             frag->size, frag->page_offset);
284                 sk_mem_charge(sk, frag->size);
285                 get_page(skb_frag_page(frag));
286         }
287         sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]);
288
289         /* all ready, send */
290         return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
291 }
292
293 static int tls_create_new_record(struct tls_offload_context_tx *offload_ctx,
294                                  struct page_frag *pfrag,
295                                  size_t prepend_size)
296 {
297         struct tls_record_info *record;
298         skb_frag_t *frag;
299
300         record = kmalloc(sizeof(*record), GFP_KERNEL);
301         if (!record)
302                 return -ENOMEM;
303
304         frag = &record->frags[0];
305         __skb_frag_set_page(frag, pfrag->page);
306         frag->page_offset = pfrag->offset;
307         skb_frag_size_set(frag, prepend_size);
308
309         get_page(pfrag->page);
310         pfrag->offset += prepend_size;
311
312         record->num_frags = 1;
313         record->len = prepend_size;
314         offload_ctx->open_record = record;
315         return 0;
316 }
317
318 static int tls_do_allocation(struct sock *sk,
319                              struct tls_offload_context_tx *offload_ctx,
320                              struct page_frag *pfrag,
321                              size_t prepend_size)
322 {
323         int ret;
324
325         if (!offload_ctx->open_record) {
326                 if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
327                                                    sk->sk_allocation))) {
328                         sk->sk_prot->enter_memory_pressure(sk);
329                         sk_stream_moderate_sndbuf(sk);
330                         return -ENOMEM;
331                 }
332
333                 ret = tls_create_new_record(offload_ctx, pfrag, prepend_size);
334                 if (ret)
335                         return ret;
336
337                 if (pfrag->size > pfrag->offset)
338                         return 0;
339         }
340
341         if (!sk_page_frag_refill(sk, pfrag))
342                 return -ENOMEM;
343
344         return 0;
345 }
346
347 static int tls_push_data(struct sock *sk,
348                          struct iov_iter *msg_iter,
349                          size_t size, int flags,
350                          unsigned char record_type)
351 {
352         struct tls_context *tls_ctx = tls_get_ctx(sk);
353         struct tls_prot_info *prot = &tls_ctx->prot_info;
354         struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
355         int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
356         int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE);
357         struct tls_record_info *record = ctx->open_record;
358         struct page_frag *pfrag;
359         size_t orig_size = size;
360         u32 max_open_record_len;
361         int copy, rc = 0;
362         bool done = false;
363         long timeo;
364
365         if (flags &
366             ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST))
367                 return -ENOTSUPP;
368
369         if (sk->sk_err)
370                 return -sk->sk_err;
371
372         timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
373         if (tls_is_partially_sent_record(tls_ctx)) {
374                 rc = tls_push_partial_record(sk, tls_ctx, flags);
375                 if (rc < 0)
376                         return rc;
377         }
378
379         pfrag = sk_page_frag(sk);
380
381         /* TLS_HEADER_SIZE is not counted as part of the TLS record, and
382          * we need to leave room for an authentication tag.
383          */
384         max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
385                               prot->prepend_size;
386         do {
387                 rc = tls_do_allocation(sk, ctx, pfrag,
388                                        prot->prepend_size);
389                 if (rc) {
390                         rc = sk_stream_wait_memory(sk, &timeo);
391                         if (!rc)
392                                 continue;
393
394                         record = ctx->open_record;
395                         if (!record)
396                                 break;
397 handle_error:
398                         if (record_type != TLS_RECORD_TYPE_DATA) {
399                                 /* avoid sending partial
400                                  * record with type !=
401                                  * application_data
402                                  */
403                                 size = orig_size;
404                                 destroy_record(record);
405                                 ctx->open_record = NULL;
406                         } else if (record->len > prot->prepend_size) {
407                                 goto last_record;
408                         }
409
410                         break;
411                 }
412
413                 record = ctx->open_record;
414                 copy = min_t(size_t, size, (pfrag->size - pfrag->offset));
415                 copy = min_t(size_t, copy, (max_open_record_len - record->len));
416
417                 if (copy_from_iter_nocache(page_address(pfrag->page) +
418                                                pfrag->offset,
419                                            copy, msg_iter) != copy) {
420                         rc = -EFAULT;
421                         goto handle_error;
422                 }
423                 tls_append_frag(record, pfrag, copy);
424
425                 size -= copy;
426                 if (!size) {
427 last_record:
428                         tls_push_record_flags = flags;
429                         if (more) {
430                                 tls_ctx->pending_open_record_frags =
431                                                 !!record->num_frags;
432                                 break;
433                         }
434
435                         done = true;
436                 }
437
438                 if (done || record->len >= max_open_record_len ||
439                     (record->num_frags >= MAX_SKB_FRAGS - 1)) {
440                         rc = tls_push_record(sk,
441                                              tls_ctx,
442                                              ctx,
443                                              record,
444                                              pfrag,
445                                              tls_push_record_flags,
446                                              record_type);
447                         if (rc < 0)
448                                 break;
449                 }
450         } while (!done);
451
452         if (orig_size - size > 0)
453                 rc = orig_size - size;
454
455         return rc;
456 }
457
458 int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
459 {
460         unsigned char record_type = TLS_RECORD_TYPE_DATA;
461         int rc;
462
463         lock_sock(sk);
464
465         if (unlikely(msg->msg_controllen)) {
466                 rc = tls_proccess_cmsg(sk, msg, &record_type);
467                 if (rc)
468                         goto out;
469         }
470
471         rc = tls_push_data(sk, &msg->msg_iter, size,
472                            msg->msg_flags, record_type);
473
474 out:
475         release_sock(sk);
476         return rc;
477 }
478
479 int tls_device_sendpage(struct sock *sk, struct page *page,
480                         int offset, size_t size, int flags)
481 {
482         struct iov_iter msg_iter;
483         char *kaddr = kmap(page);
484         struct kvec iov;
485         int rc;
486
487         if (flags & MSG_SENDPAGE_NOTLAST)
488                 flags |= MSG_MORE;
489
490         lock_sock(sk);
491
492         if (flags & MSG_OOB) {
493                 rc = -ENOTSUPP;
494                 goto out;
495         }
496
497         iov.iov_base = kaddr + offset;
498         iov.iov_len = size;
499         iov_iter_kvec(&msg_iter, WRITE, &iov, 1, size);
500         rc = tls_push_data(sk, &msg_iter, size,
501                            flags, TLS_RECORD_TYPE_DATA);
502         kunmap(page);
503
504 out:
505         release_sock(sk);
506         return rc;
507 }
508
509 struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context,
510                                        u32 seq, u64 *p_record_sn)
511 {
512         u64 record_sn = context->hint_record_sn;
513         struct tls_record_info *info;
514
515         info = context->retransmit_hint;
516         if (!info ||
517             before(seq, info->end_seq - info->len)) {
518                 /* if retransmit_hint is irrelevant start
519                  * from the beggining of the list
520                  */
521                 info = list_first_entry(&context->records_list,
522                                         struct tls_record_info, list);
523                 record_sn = context->unacked_record_sn;
524         }
525
526         list_for_each_entry_from(info, &context->records_list, list) {
527                 if (before(seq, info->end_seq)) {
528                         if (!context->retransmit_hint ||
529                             after(info->end_seq,
530                                   context->retransmit_hint->end_seq)) {
531                                 context->hint_record_sn = record_sn;
532                                 context->retransmit_hint = info;
533                         }
534                         *p_record_sn = record_sn;
535                         return info;
536                 }
537                 record_sn++;
538         }
539
540         return NULL;
541 }
542 EXPORT_SYMBOL(tls_get_record);
543
544 static int tls_device_push_pending_record(struct sock *sk, int flags)
545 {
546         struct iov_iter msg_iter;
547
548         iov_iter_kvec(&msg_iter, WRITE, NULL, 0, 0);
549         return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
550 }
551
552 void tls_device_write_space(struct sock *sk, struct tls_context *ctx)
553 {
554         int rc = 0;
555
556         if (!sk->sk_write_pending && tls_is_partially_sent_record(ctx)) {
557                 gfp_t sk_allocation = sk->sk_allocation;
558
559                 sk->sk_allocation = GFP_ATOMIC;
560                 rc = tls_push_partial_record(sk, ctx,
561                                              MSG_DONTWAIT | MSG_NOSIGNAL);
562                 sk->sk_allocation = sk_allocation;
563         }
564 }
565
566 void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn)
567 {
568         struct tls_context *tls_ctx = tls_get_ctx(sk);
569         struct net_device *netdev = tls_ctx->netdev;
570         struct tls_offload_context_rx *rx_ctx;
571         u32 is_req_pending;
572         s64 resync_req;
573         u32 req_seq;
574
575         if (tls_ctx->rx_conf != TLS_HW)
576                 return;
577
578         rx_ctx = tls_offload_ctx_rx(tls_ctx);
579         resync_req = atomic64_read(&rx_ctx->resync_req);
580         req_seq = ntohl(resync_req >> 32) - ((u32)TLS_HEADER_SIZE - 1);
581         is_req_pending = resync_req;
582
583         if (unlikely(is_req_pending) && req_seq == seq &&
584             atomic64_try_cmpxchg(&rx_ctx->resync_req, &resync_req, 0))
585                 netdev->tlsdev_ops->tls_dev_resync_rx(netdev, sk,
586                                                       seq + TLS_HEADER_SIZE - 1,
587                                                       rcd_sn);
588 }
589
590 static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb)
591 {
592         struct strp_msg *rxm = strp_msg(skb);
593         int err = 0, offset = rxm->offset, copy, nsg;
594         struct sk_buff *skb_iter, *unused;
595         struct scatterlist sg[1];
596         char *orig_buf, *buf;
597
598         orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE +
599                            TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation);
600         if (!orig_buf)
601                 return -ENOMEM;
602         buf = orig_buf;
603
604         nsg = skb_cow_data(skb, 0, &unused);
605         if (unlikely(nsg < 0)) {
606                 err = nsg;
607                 goto free_buf;
608         }
609
610         sg_init_table(sg, 1);
611         sg_set_buf(&sg[0], buf,
612                    rxm->full_len + TLS_HEADER_SIZE +
613                    TLS_CIPHER_AES_GCM_128_IV_SIZE);
614         skb_copy_bits(skb, offset, buf,
615                       TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE);
616
617         /* We are interested only in the decrypted data not the auth */
618         err = decrypt_skb(sk, skb, sg);
619         if (err != -EBADMSG)
620                 goto free_buf;
621         else
622                 err = 0;
623
624         copy = min_t(int, skb_pagelen(skb) - offset,
625                      rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE);
626
627         if (skb->decrypted)
628                 skb_store_bits(skb, offset, buf, copy);
629
630         offset += copy;
631         buf += copy;
632
633         skb_walk_frags(skb, skb_iter) {
634                 copy = min_t(int, skb_iter->len,
635                              rxm->full_len - offset + rxm->offset -
636                              TLS_CIPHER_AES_GCM_128_TAG_SIZE);
637
638                 if (skb_iter->decrypted)
639                         skb_store_bits(skb_iter, offset, buf, copy);
640
641                 offset += copy;
642                 buf += copy;
643         }
644
645 free_buf:
646         kfree(orig_buf);
647         return err;
648 }
649
650 int tls_device_decrypted(struct sock *sk, struct sk_buff *skb)
651 {
652         struct tls_context *tls_ctx = tls_get_ctx(sk);
653         struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx);
654         int is_decrypted = skb->decrypted;
655         int is_encrypted = !is_decrypted;
656         struct sk_buff *skb_iter;
657
658         /* Skip if it is already decrypted */
659         if (ctx->sw.decrypted)
660                 return 0;
661
662         /* Check if all the data is decrypted already */
663         skb_walk_frags(skb, skb_iter) {
664                 is_decrypted &= skb_iter->decrypted;
665                 is_encrypted &= !skb_iter->decrypted;
666         }
667
668         ctx->sw.decrypted |= is_decrypted;
669
670         /* Return immedeatly if the record is either entirely plaintext or
671          * entirely ciphertext. Otherwise handle reencrypt partially decrypted
672          * record.
673          */
674         return (is_encrypted || is_decrypted) ? 0 :
675                 tls_device_reencrypt(sk, skb);
676 }
677
678 int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
679 {
680         u16 nonce_size, tag_size, iv_size, rec_seq_size;
681         struct tls_context *tls_ctx = tls_get_ctx(sk);
682         struct tls_prot_info *prot = &tls_ctx->prot_info;
683         struct tls_record_info *start_marker_record;
684         struct tls_offload_context_tx *offload_ctx;
685         struct tls_crypto_info *crypto_info;
686         struct net_device *netdev;
687         char *iv, *rec_seq;
688         struct sk_buff *skb;
689         int rc = -EINVAL;
690         __be64 rcd_sn;
691
692         if (!ctx)
693                 goto out;
694
695         if (ctx->priv_ctx_tx) {
696                 rc = -EEXIST;
697                 goto out;
698         }
699
700         start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
701         if (!start_marker_record) {
702                 rc = -ENOMEM;
703                 goto out;
704         }
705
706         offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL);
707         if (!offload_ctx) {
708                 rc = -ENOMEM;
709                 goto free_marker_record;
710         }
711
712         crypto_info = &ctx->crypto_send.info;
713         switch (crypto_info->cipher_type) {
714         case TLS_CIPHER_AES_GCM_128:
715                 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
716                 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
717                 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
718                 iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
719                 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
720                 rec_seq =
721                  ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
722                 break;
723         default:
724                 rc = -EINVAL;
725                 goto free_offload_ctx;
726         }
727
728         prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
729         prot->tag_size = tag_size;
730         prot->overhead_size = prot->prepend_size + prot->tag_size;
731         prot->iv_size = iv_size;
732         ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
733                              GFP_KERNEL);
734         if (!ctx->tx.iv) {
735                 rc = -ENOMEM;
736                 goto free_offload_ctx;
737         }
738
739         memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
740
741         prot->rec_seq_size = rec_seq_size;
742         ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
743         if (!ctx->tx.rec_seq) {
744                 rc = -ENOMEM;
745                 goto free_iv;
746         }
747
748         rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
749         if (rc)
750                 goto free_rec_seq;
751
752         /* start at rec_seq - 1 to account for the start marker record */
753         memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn));
754         offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1;
755
756         start_marker_record->end_seq = tcp_sk(sk)->write_seq;
757         start_marker_record->len = 0;
758         start_marker_record->num_frags = 0;
759
760         INIT_LIST_HEAD(&offload_ctx->records_list);
761         list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
762         spin_lock_init(&offload_ctx->lock);
763         sg_init_table(offload_ctx->sg_tx_data,
764                       ARRAY_SIZE(offload_ctx->sg_tx_data));
765
766         clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked);
767         ctx->push_pending_record = tls_device_push_pending_record;
768
769         /* TLS offload is greatly simplified if we don't send
770          * SKBs where only part of the payload needs to be encrypted.
771          * So mark the last skb in the write queue as end of record.
772          */
773         skb = tcp_write_queue_tail(sk);
774         if (skb)
775                 TCP_SKB_CB(skb)->eor = 1;
776
777         /* We support starting offload on multiple sockets
778          * concurrently, so we only need a read lock here.
779          * This lock must precede get_netdev_for_sock to prevent races between
780          * NETDEV_DOWN and setsockopt.
781          */
782         down_read(&device_offload_lock);
783         netdev = get_netdev_for_sock(sk);
784         if (!netdev) {
785                 pr_err_ratelimited("%s: netdev not found\n", __func__);
786                 rc = -EINVAL;
787                 goto release_lock;
788         }
789
790         if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
791                 rc = -ENOTSUPP;
792                 goto release_netdev;
793         }
794
795         /* Avoid offloading if the device is down
796          * We don't want to offload new flows after
797          * the NETDEV_DOWN event
798          */
799         if (!(netdev->flags & IFF_UP)) {
800                 rc = -EINVAL;
801                 goto release_netdev;
802         }
803
804         ctx->priv_ctx_tx = offload_ctx;
805         rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX,
806                                              &ctx->crypto_send.info,
807                                              tcp_sk(sk)->write_seq);
808         if (rc)
809                 goto release_netdev;
810
811         tls_device_attach(ctx, sk, netdev);
812
813         /* following this assignment tls_is_sk_tx_device_offloaded
814          * will return true and the context might be accessed
815          * by the netdev's xmit function.
816          */
817         smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb);
818         dev_put(netdev);
819         up_read(&device_offload_lock);
820         goto out;
821
822 release_netdev:
823         dev_put(netdev);
824 release_lock:
825         up_read(&device_offload_lock);
826         clean_acked_data_disable(inet_csk(sk));
827         crypto_free_aead(offload_ctx->aead_send);
828 free_rec_seq:
829         kfree(ctx->tx.rec_seq);
830 free_iv:
831         kfree(ctx->tx.iv);
832 free_offload_ctx:
833         kfree(offload_ctx);
834         ctx->priv_ctx_tx = NULL;
835 free_marker_record:
836         kfree(start_marker_record);
837 out:
838         return rc;
839 }
840
841 int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
842 {
843         struct tls_offload_context_rx *context;
844         struct net_device *netdev;
845         int rc = 0;
846
847         /* We support starting offload on multiple sockets
848          * concurrently, so we only need a read lock here.
849          * This lock must precede get_netdev_for_sock to prevent races between
850          * NETDEV_DOWN and setsockopt.
851          */
852         down_read(&device_offload_lock);
853         netdev = get_netdev_for_sock(sk);
854         if (!netdev) {
855                 pr_err_ratelimited("%s: netdev not found\n", __func__);
856                 rc = -EINVAL;
857                 goto release_lock;
858         }
859
860         if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
861                 pr_err_ratelimited("%s: netdev %s with no TLS offload\n",
862                                    __func__, netdev->name);
863                 rc = -ENOTSUPP;
864                 goto release_netdev;
865         }
866
867         /* Avoid offloading if the device is down
868          * We don't want to offload new flows after
869          * the NETDEV_DOWN event
870          */
871         if (!(netdev->flags & IFF_UP)) {
872                 rc = -EINVAL;
873                 goto release_netdev;
874         }
875
876         context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL);
877         if (!context) {
878                 rc = -ENOMEM;
879                 goto release_netdev;
880         }
881
882         ctx->priv_ctx_rx = context;
883         rc = tls_set_sw_offload(sk, ctx, 0);
884         if (rc)
885                 goto release_ctx;
886
887         rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX,
888                                              &ctx->crypto_recv.info,
889                                              tcp_sk(sk)->copied_seq);
890         if (rc) {
891                 pr_err_ratelimited("%s: The netdev has refused to offload this socket\n",
892                                    __func__);
893                 goto free_sw_resources;
894         }
895
896         tls_device_attach(ctx, sk, netdev);
897         goto release_netdev;
898
899 free_sw_resources:
900         tls_sw_free_resources_rx(sk);
901 release_ctx:
902         ctx->priv_ctx_rx = NULL;
903 release_netdev:
904         dev_put(netdev);
905 release_lock:
906         up_read(&device_offload_lock);
907         return rc;
908 }
909
910 void tls_device_offload_cleanup_rx(struct sock *sk)
911 {
912         struct tls_context *tls_ctx = tls_get_ctx(sk);
913         struct net_device *netdev;
914
915         down_read(&device_offload_lock);
916         netdev = tls_ctx->netdev;
917         if (!netdev)
918                 goto out;
919
920         if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
921                 pr_err_ratelimited("%s: device is missing NETIF_F_HW_TLS_RX cap\n",
922                                    __func__);
923                 goto out;
924         }
925
926         netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx,
927                                         TLS_OFFLOAD_CTX_DIR_RX);
928
929         if (tls_ctx->tx_conf != TLS_HW) {
930                 dev_put(netdev);
931                 tls_ctx->netdev = NULL;
932         }
933 out:
934         up_read(&device_offload_lock);
935         kfree(tls_ctx->rx.rec_seq);
936         kfree(tls_ctx->rx.iv);
937         tls_sw_release_resources_rx(sk);
938 }
939
940 static int tls_device_down(struct net_device *netdev)
941 {
942         struct tls_context *ctx, *tmp;
943         unsigned long flags;
944         LIST_HEAD(list);
945
946         /* Request a write lock to block new offload attempts */
947         down_write(&device_offload_lock);
948
949         spin_lock_irqsave(&tls_device_lock, flags);
950         list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
951                 if (ctx->netdev != netdev ||
952                     !refcount_inc_not_zero(&ctx->refcount))
953                         continue;
954
955                 list_move(&ctx->list, &list);
956         }
957         spin_unlock_irqrestore(&tls_device_lock, flags);
958
959         list_for_each_entry_safe(ctx, tmp, &list, list) {
960                 if (ctx->tx_conf == TLS_HW)
961                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
962                                                         TLS_OFFLOAD_CTX_DIR_TX);
963                 if (ctx->rx_conf == TLS_HW)
964                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
965                                                         TLS_OFFLOAD_CTX_DIR_RX);
966                 ctx->netdev = NULL;
967                 dev_put(netdev);
968                 list_del_init(&ctx->list);
969
970                 if (refcount_dec_and_test(&ctx->refcount))
971                         tls_device_free_ctx(ctx);
972         }
973
974         up_write(&device_offload_lock);
975
976         flush_work(&tls_device_gc_work);
977
978         return NOTIFY_DONE;
979 }
980
981 static int tls_dev_event(struct notifier_block *this, unsigned long event,
982                          void *ptr)
983 {
984         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
985
986         if (!(dev->features & (NETIF_F_HW_TLS_RX | NETIF_F_HW_TLS_TX)))
987                 return NOTIFY_DONE;
988
989         switch (event) {
990         case NETDEV_REGISTER:
991         case NETDEV_FEAT_CHANGE:
992                 if ((dev->features & NETIF_F_HW_TLS_RX) &&
993                     !dev->tlsdev_ops->tls_dev_resync_rx)
994                         return NOTIFY_BAD;
995
996                 if  (dev->tlsdev_ops &&
997                      dev->tlsdev_ops->tls_dev_add &&
998                      dev->tlsdev_ops->tls_dev_del)
999                         return NOTIFY_DONE;
1000                 else
1001                         return NOTIFY_BAD;
1002         case NETDEV_DOWN:
1003                 return tls_device_down(dev);
1004         }
1005         return NOTIFY_DONE;
1006 }
1007
1008 static struct notifier_block tls_dev_notifier = {
1009         .notifier_call  = tls_dev_event,
1010 };
1011
1012 void __init tls_device_init(void)
1013 {
1014         register_netdevice_notifier(&tls_dev_notifier);
1015 }
1016
1017 void __exit tls_device_cleanup(void)
1018 {
1019         unregister_netdevice_notifier(&tls_dev_notifier);
1020         flush_work(&tls_device_gc_work);
1021 }