tls: Fix write space handling
[linux-2.6-microblaze.git] / net / tls / tls_device.c
1 /* Copyright (c) 2018, Mellanox Technologies All rights reserved.
2  *
3  * This software is available to you under a choice of one of two
4  * licenses.  You may choose to be licensed under the terms of the GNU
5  * General Public License (GPL) Version 2, available from the file
6  * COPYING in the main directory of this source tree, or the
7  * OpenIB.org BSD license below:
8  *
9  *     Redistribution and use in source and binary forms, with or
10  *     without modification, are permitted provided that the following
11  *     conditions are met:
12  *
13  *      - Redistributions of source code must retain the above
14  *        copyright notice, this list of conditions and the following
15  *        disclaimer.
16  *
17  *      - Redistributions in binary form must reproduce the above
18  *        copyright notice, this list of conditions and the following
19  *        disclaimer in the documentation and/or other materials
20  *        provided with the distribution.
21  *
22  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
23  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
24  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
25  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
26  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
27  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
28  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29  * SOFTWARE.
30  */
31
32 #include <crypto/aead.h>
33 #include <linux/highmem.h>
34 #include <linux/module.h>
35 #include <linux/netdevice.h>
36 #include <net/dst.h>
37 #include <net/inet_connection_sock.h>
38 #include <net/tcp.h>
39 #include <net/tls.h>
40
41 /* device_offload_lock is used to synchronize tls_dev_add
42  * against NETDEV_DOWN notifications.
43  */
44 static DECLARE_RWSEM(device_offload_lock);
45
46 static void tls_device_gc_task(struct work_struct *work);
47
48 static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task);
49 static LIST_HEAD(tls_device_gc_list);
50 static LIST_HEAD(tls_device_list);
51 static DEFINE_SPINLOCK(tls_device_lock);
52
53 static void tls_device_free_ctx(struct tls_context *ctx)
54 {
55         if (ctx->tx_conf == TLS_HW)
56                 kfree(tls_offload_ctx_tx(ctx));
57
58         if (ctx->rx_conf == TLS_HW)
59                 kfree(tls_offload_ctx_rx(ctx));
60
61         kfree(ctx);
62 }
63
64 static void tls_device_gc_task(struct work_struct *work)
65 {
66         struct tls_context *ctx, *tmp;
67         unsigned long flags;
68         LIST_HEAD(gc_list);
69
70         spin_lock_irqsave(&tls_device_lock, flags);
71         list_splice_init(&tls_device_gc_list, &gc_list);
72         spin_unlock_irqrestore(&tls_device_lock, flags);
73
74         list_for_each_entry_safe(ctx, tmp, &gc_list, list) {
75                 struct net_device *netdev = ctx->netdev;
76
77                 if (netdev && ctx->tx_conf == TLS_HW) {
78                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
79                                                         TLS_OFFLOAD_CTX_DIR_TX);
80                         dev_put(netdev);
81                         ctx->netdev = NULL;
82                 }
83
84                 list_del(&ctx->list);
85                 tls_device_free_ctx(ctx);
86         }
87 }
88
89 static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
90                               struct net_device *netdev)
91 {
92         if (sk->sk_destruct != tls_device_sk_destruct) {
93                 refcount_set(&ctx->refcount, 1);
94                 dev_hold(netdev);
95                 ctx->netdev = netdev;
96                 spin_lock_irq(&tls_device_lock);
97                 list_add_tail(&ctx->list, &tls_device_list);
98                 spin_unlock_irq(&tls_device_lock);
99
100                 ctx->sk_destruct = sk->sk_destruct;
101                 sk->sk_destruct = tls_device_sk_destruct;
102         }
103 }
104
105 static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
106 {
107         unsigned long flags;
108
109         spin_lock_irqsave(&tls_device_lock, flags);
110         list_move_tail(&ctx->list, &tls_device_gc_list);
111
112         /* schedule_work inside the spinlock
113          * to make sure tls_device_down waits for that work.
114          */
115         schedule_work(&tls_device_gc_work);
116
117         spin_unlock_irqrestore(&tls_device_lock, flags);
118 }
119
120 /* We assume that the socket is already connected */
121 static struct net_device *get_netdev_for_sock(struct sock *sk)
122 {
123         struct dst_entry *dst = sk_dst_get(sk);
124         struct net_device *netdev = NULL;
125
126         if (likely(dst)) {
127                 netdev = dst->dev;
128                 dev_hold(netdev);
129         }
130
131         dst_release(dst);
132
133         return netdev;
134 }
135
136 static void destroy_record(struct tls_record_info *record)
137 {
138         int nr_frags = record->num_frags;
139         skb_frag_t *frag;
140
141         while (nr_frags-- > 0) {
142                 frag = &record->frags[nr_frags];
143                 __skb_frag_unref(frag);
144         }
145         kfree(record);
146 }
147
148 static void delete_all_records(struct tls_offload_context_tx *offload_ctx)
149 {
150         struct tls_record_info *info, *temp;
151
152         list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) {
153                 list_del(&info->list);
154                 destroy_record(info);
155         }
156
157         offload_ctx->retransmit_hint = NULL;
158 }
159
160 static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
161 {
162         struct tls_context *tls_ctx = tls_get_ctx(sk);
163         struct tls_record_info *info, *temp;
164         struct tls_offload_context_tx *ctx;
165         u64 deleted_records = 0;
166         unsigned long flags;
167
168         if (!tls_ctx)
169                 return;
170
171         ctx = tls_offload_ctx_tx(tls_ctx);
172
173         spin_lock_irqsave(&ctx->lock, flags);
174         info = ctx->retransmit_hint;
175         if (info && !before(acked_seq, info->end_seq)) {
176                 ctx->retransmit_hint = NULL;
177                 list_del(&info->list);
178                 destroy_record(info);
179                 deleted_records++;
180         }
181
182         list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
183                 if (before(acked_seq, info->end_seq))
184                         break;
185                 list_del(&info->list);
186
187                 destroy_record(info);
188                 deleted_records++;
189         }
190
191         ctx->unacked_record_sn += deleted_records;
192         spin_unlock_irqrestore(&ctx->lock, flags);
193 }
194
195 /* At this point, there should be no references on this
196  * socket and no in-flight SKBs associated with this
197  * socket, so it is safe to free all the resources.
198  */
199 void tls_device_sk_destruct(struct sock *sk)
200 {
201         struct tls_context *tls_ctx = tls_get_ctx(sk);
202         struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
203
204         tls_ctx->sk_destruct(sk);
205
206         if (tls_ctx->tx_conf == TLS_HW) {
207                 if (ctx->open_record)
208                         destroy_record(ctx->open_record);
209                 delete_all_records(ctx);
210                 crypto_free_aead(ctx->aead_send);
211                 clean_acked_data_disable(inet_csk(sk));
212         }
213
214         if (refcount_dec_and_test(&tls_ctx->refcount))
215                 tls_device_queue_ctx_destruction(tls_ctx);
216 }
217 EXPORT_SYMBOL(tls_device_sk_destruct);
218
219 static void tls_append_frag(struct tls_record_info *record,
220                             struct page_frag *pfrag,
221                             int size)
222 {
223         skb_frag_t *frag;
224
225         frag = &record->frags[record->num_frags - 1];
226         if (frag->page.p == pfrag->page &&
227             frag->page_offset + frag->size == pfrag->offset) {
228                 frag->size += size;
229         } else {
230                 ++frag;
231                 frag->page.p = pfrag->page;
232                 frag->page_offset = pfrag->offset;
233                 frag->size = size;
234                 ++record->num_frags;
235                 get_page(pfrag->page);
236         }
237
238         pfrag->offset += size;
239         record->len += size;
240 }
241
242 static int tls_push_record(struct sock *sk,
243                            struct tls_context *ctx,
244                            struct tls_offload_context_tx *offload_ctx,
245                            struct tls_record_info *record,
246                            struct page_frag *pfrag,
247                            int flags,
248                            unsigned char record_type)
249 {
250         struct tls_prot_info *prot = &ctx->prot_info;
251         struct tcp_sock *tp = tcp_sk(sk);
252         struct page_frag dummy_tag_frag;
253         skb_frag_t *frag;
254         int i;
255
256         /* fill prepend */
257         frag = &record->frags[0];
258         tls_fill_prepend(ctx,
259                          skb_frag_address(frag),
260                          record->len - prot->prepend_size,
261                          record_type,
262                          ctx->crypto_send.info.version);
263
264         /* HW doesn't care about the data in the tag, because it fills it. */
265         dummy_tag_frag.page = skb_frag_page(frag);
266         dummy_tag_frag.offset = 0;
267
268         tls_append_frag(record, &dummy_tag_frag, prot->tag_size);
269         record->end_seq = tp->write_seq + record->len;
270         spin_lock_irq(&offload_ctx->lock);
271         list_add_tail(&record->list, &offload_ctx->records_list);
272         spin_unlock_irq(&offload_ctx->lock);
273         offload_ctx->open_record = NULL;
274         tls_advance_record_sn(sk, &ctx->tx, ctx->crypto_send.info.version);
275
276         for (i = 0; i < record->num_frags; i++) {
277                 frag = &record->frags[i];
278                 sg_unmark_end(&offload_ctx->sg_tx_data[i]);
279                 sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag),
280                             frag->size, frag->page_offset);
281                 sk_mem_charge(sk, frag->size);
282                 get_page(skb_frag_page(frag));
283         }
284         sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]);
285
286         /* all ready, send */
287         return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
288 }
289
290 static int tls_create_new_record(struct tls_offload_context_tx *offload_ctx,
291                                  struct page_frag *pfrag,
292                                  size_t prepend_size)
293 {
294         struct tls_record_info *record;
295         skb_frag_t *frag;
296
297         record = kmalloc(sizeof(*record), GFP_KERNEL);
298         if (!record)
299                 return -ENOMEM;
300
301         frag = &record->frags[0];
302         __skb_frag_set_page(frag, pfrag->page);
303         frag->page_offset = pfrag->offset;
304         skb_frag_size_set(frag, prepend_size);
305
306         get_page(pfrag->page);
307         pfrag->offset += prepend_size;
308
309         record->num_frags = 1;
310         record->len = prepend_size;
311         offload_ctx->open_record = record;
312         return 0;
313 }
314
315 static int tls_do_allocation(struct sock *sk,
316                              struct tls_offload_context_tx *offload_ctx,
317                              struct page_frag *pfrag,
318                              size_t prepend_size)
319 {
320         int ret;
321
322         if (!offload_ctx->open_record) {
323                 if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
324                                                    sk->sk_allocation))) {
325                         sk->sk_prot->enter_memory_pressure(sk);
326                         sk_stream_moderate_sndbuf(sk);
327                         return -ENOMEM;
328                 }
329
330                 ret = tls_create_new_record(offload_ctx, pfrag, prepend_size);
331                 if (ret)
332                         return ret;
333
334                 if (pfrag->size > pfrag->offset)
335                         return 0;
336         }
337
338         if (!sk_page_frag_refill(sk, pfrag))
339                 return -ENOMEM;
340
341         return 0;
342 }
343
344 static int tls_push_data(struct sock *sk,
345                          struct iov_iter *msg_iter,
346                          size_t size, int flags,
347                          unsigned char record_type)
348 {
349         struct tls_context *tls_ctx = tls_get_ctx(sk);
350         struct tls_prot_info *prot = &tls_ctx->prot_info;
351         struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
352         int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
353         int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE);
354         struct tls_record_info *record = ctx->open_record;
355         struct page_frag *pfrag;
356         size_t orig_size = size;
357         u32 max_open_record_len;
358         int copy, rc = 0;
359         bool done = false;
360         long timeo;
361
362         if (flags &
363             ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST))
364                 return -ENOTSUPP;
365
366         if (sk->sk_err)
367                 return -sk->sk_err;
368
369         timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
370         if (tls_is_partially_sent_record(tls_ctx)) {
371                 rc = tls_push_partial_record(sk, tls_ctx, flags);
372                 if (rc < 0)
373                         return rc;
374         }
375
376         pfrag = sk_page_frag(sk);
377
378         /* TLS_HEADER_SIZE is not counted as part of the TLS record, and
379          * we need to leave room for an authentication tag.
380          */
381         max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
382                               prot->prepend_size;
383         do {
384                 rc = tls_do_allocation(sk, ctx, pfrag,
385                                        prot->prepend_size);
386                 if (rc) {
387                         rc = sk_stream_wait_memory(sk, &timeo);
388                         if (!rc)
389                                 continue;
390
391                         record = ctx->open_record;
392                         if (!record)
393                                 break;
394 handle_error:
395                         if (record_type != TLS_RECORD_TYPE_DATA) {
396                                 /* avoid sending partial
397                                  * record with type !=
398                                  * application_data
399                                  */
400                                 size = orig_size;
401                                 destroy_record(record);
402                                 ctx->open_record = NULL;
403                         } else if (record->len > prot->prepend_size) {
404                                 goto last_record;
405                         }
406
407                         break;
408                 }
409
410                 record = ctx->open_record;
411                 copy = min_t(size_t, size, (pfrag->size - pfrag->offset));
412                 copy = min_t(size_t, copy, (max_open_record_len - record->len));
413
414                 if (copy_from_iter_nocache(page_address(pfrag->page) +
415                                                pfrag->offset,
416                                            copy, msg_iter) != copy) {
417                         rc = -EFAULT;
418                         goto handle_error;
419                 }
420                 tls_append_frag(record, pfrag, copy);
421
422                 size -= copy;
423                 if (!size) {
424 last_record:
425                         tls_push_record_flags = flags;
426                         if (more) {
427                                 tls_ctx->pending_open_record_frags =
428                                                 !!record->num_frags;
429                                 break;
430                         }
431
432                         done = true;
433                 }
434
435                 if (done || record->len >= max_open_record_len ||
436                     (record->num_frags >= MAX_SKB_FRAGS - 1)) {
437                         rc = tls_push_record(sk,
438                                              tls_ctx,
439                                              ctx,
440                                              record,
441                                              pfrag,
442                                              tls_push_record_flags,
443                                              record_type);
444                         if (rc < 0)
445                                 break;
446                 }
447         } while (!done);
448
449         if (orig_size - size > 0)
450                 rc = orig_size - size;
451
452         return rc;
453 }
454
455 int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
456 {
457         unsigned char record_type = TLS_RECORD_TYPE_DATA;
458         int rc;
459
460         lock_sock(sk);
461
462         if (unlikely(msg->msg_controllen)) {
463                 rc = tls_proccess_cmsg(sk, msg, &record_type);
464                 if (rc)
465                         goto out;
466         }
467
468         rc = tls_push_data(sk, &msg->msg_iter, size,
469                            msg->msg_flags, record_type);
470
471 out:
472         release_sock(sk);
473         return rc;
474 }
475
476 int tls_device_sendpage(struct sock *sk, struct page *page,
477                         int offset, size_t size, int flags)
478 {
479         struct iov_iter msg_iter;
480         char *kaddr = kmap(page);
481         struct kvec iov;
482         int rc;
483
484         if (flags & MSG_SENDPAGE_NOTLAST)
485                 flags |= MSG_MORE;
486
487         lock_sock(sk);
488
489         if (flags & MSG_OOB) {
490                 rc = -ENOTSUPP;
491                 goto out;
492         }
493
494         iov.iov_base = kaddr + offset;
495         iov.iov_len = size;
496         iov_iter_kvec(&msg_iter, WRITE, &iov, 1, size);
497         rc = tls_push_data(sk, &msg_iter, size,
498                            flags, TLS_RECORD_TYPE_DATA);
499         kunmap(page);
500
501 out:
502         release_sock(sk);
503         return rc;
504 }
505
506 struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context,
507                                        u32 seq, u64 *p_record_sn)
508 {
509         u64 record_sn = context->hint_record_sn;
510         struct tls_record_info *info;
511
512         info = context->retransmit_hint;
513         if (!info ||
514             before(seq, info->end_seq - info->len)) {
515                 /* if retransmit_hint is irrelevant start
516                  * from the beggining of the list
517                  */
518                 info = list_first_entry(&context->records_list,
519                                         struct tls_record_info, list);
520                 record_sn = context->unacked_record_sn;
521         }
522
523         list_for_each_entry_from(info, &context->records_list, list) {
524                 if (before(seq, info->end_seq)) {
525                         if (!context->retransmit_hint ||
526                             after(info->end_seq,
527                                   context->retransmit_hint->end_seq)) {
528                                 context->hint_record_sn = record_sn;
529                                 context->retransmit_hint = info;
530                         }
531                         *p_record_sn = record_sn;
532                         return info;
533                 }
534                 record_sn++;
535         }
536
537         return NULL;
538 }
539 EXPORT_SYMBOL(tls_get_record);
540
541 static int tls_device_push_pending_record(struct sock *sk, int flags)
542 {
543         struct iov_iter msg_iter;
544
545         iov_iter_kvec(&msg_iter, WRITE, NULL, 0, 0);
546         return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
547 }
548
549 void tls_device_write_space(struct sock *sk, struct tls_context *ctx)
550 {
551         int rc = 0;
552
553         if (!sk->sk_write_pending && tls_is_partially_sent_record(ctx)) {
554                 gfp_t sk_allocation = sk->sk_allocation;
555
556                 sk->sk_allocation = GFP_ATOMIC;
557                 rc = tls_push_partial_record(sk, ctx,
558                                              MSG_DONTWAIT | MSG_NOSIGNAL);
559                 sk->sk_allocation = sk_allocation;
560         }
561
562         if (!rc)
563                 ctx->sk_write_space(sk);
564 }
565
566 void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn)
567 {
568         struct tls_context *tls_ctx = tls_get_ctx(sk);
569         struct net_device *netdev = tls_ctx->netdev;
570         struct tls_offload_context_rx *rx_ctx;
571         u32 is_req_pending;
572         s64 resync_req;
573         u32 req_seq;
574
575         if (tls_ctx->rx_conf != TLS_HW)
576                 return;
577
578         rx_ctx = tls_offload_ctx_rx(tls_ctx);
579         resync_req = atomic64_read(&rx_ctx->resync_req);
580         req_seq = ntohl(resync_req >> 32) - ((u32)TLS_HEADER_SIZE - 1);
581         is_req_pending = resync_req;
582
583         if (unlikely(is_req_pending) && req_seq == seq &&
584             atomic64_try_cmpxchg(&rx_ctx->resync_req, &resync_req, 0))
585                 netdev->tlsdev_ops->tls_dev_resync_rx(netdev, sk,
586                                                       seq + TLS_HEADER_SIZE - 1,
587                                                       rcd_sn);
588 }
589
590 static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb)
591 {
592         struct strp_msg *rxm = strp_msg(skb);
593         int err = 0, offset = rxm->offset, copy, nsg;
594         struct sk_buff *skb_iter, *unused;
595         struct scatterlist sg[1];
596         char *orig_buf, *buf;
597
598         orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE +
599                            TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation);
600         if (!orig_buf)
601                 return -ENOMEM;
602         buf = orig_buf;
603
604         nsg = skb_cow_data(skb, 0, &unused);
605         if (unlikely(nsg < 0)) {
606                 err = nsg;
607                 goto free_buf;
608         }
609
610         sg_init_table(sg, 1);
611         sg_set_buf(&sg[0], buf,
612                    rxm->full_len + TLS_HEADER_SIZE +
613                    TLS_CIPHER_AES_GCM_128_IV_SIZE);
614         skb_copy_bits(skb, offset, buf,
615                       TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE);
616
617         /* We are interested only in the decrypted data not the auth */
618         err = decrypt_skb(sk, skb, sg);
619         if (err != -EBADMSG)
620                 goto free_buf;
621         else
622                 err = 0;
623
624         copy = min_t(int, skb_pagelen(skb) - offset,
625                      rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE);
626
627         if (skb->decrypted)
628                 skb_store_bits(skb, offset, buf, copy);
629
630         offset += copy;
631         buf += copy;
632
633         skb_walk_frags(skb, skb_iter) {
634                 copy = min_t(int, skb_iter->len,
635                              rxm->full_len - offset + rxm->offset -
636                              TLS_CIPHER_AES_GCM_128_TAG_SIZE);
637
638                 if (skb_iter->decrypted)
639                         skb_store_bits(skb_iter, offset, buf, copy);
640
641                 offset += copy;
642                 buf += copy;
643         }
644
645 free_buf:
646         kfree(orig_buf);
647         return err;
648 }
649
650 int tls_device_decrypted(struct sock *sk, struct sk_buff *skb)
651 {
652         struct tls_context *tls_ctx = tls_get_ctx(sk);
653         struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx);
654         int is_decrypted = skb->decrypted;
655         int is_encrypted = !is_decrypted;
656         struct sk_buff *skb_iter;
657
658         /* Skip if it is already decrypted */
659         if (ctx->sw.decrypted)
660                 return 0;
661
662         /* Check if all the data is decrypted already */
663         skb_walk_frags(skb, skb_iter) {
664                 is_decrypted &= skb_iter->decrypted;
665                 is_encrypted &= !skb_iter->decrypted;
666         }
667
668         ctx->sw.decrypted |= is_decrypted;
669
670         /* Return immedeatly if the record is either entirely plaintext or
671          * entirely ciphertext. Otherwise handle reencrypt partially decrypted
672          * record.
673          */
674         return (is_encrypted || is_decrypted) ? 0 :
675                 tls_device_reencrypt(sk, skb);
676 }
677
678 int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
679 {
680         u16 nonce_size, tag_size, iv_size, rec_seq_size;
681         struct tls_context *tls_ctx = tls_get_ctx(sk);
682         struct tls_prot_info *prot = &tls_ctx->prot_info;
683         struct tls_record_info *start_marker_record;
684         struct tls_offload_context_tx *offload_ctx;
685         struct tls_crypto_info *crypto_info;
686         struct net_device *netdev;
687         char *iv, *rec_seq;
688         struct sk_buff *skb;
689         int rc = -EINVAL;
690         __be64 rcd_sn;
691
692         if (!ctx)
693                 goto out;
694
695         if (ctx->priv_ctx_tx) {
696                 rc = -EEXIST;
697                 goto out;
698         }
699
700         start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
701         if (!start_marker_record) {
702                 rc = -ENOMEM;
703                 goto out;
704         }
705
706         offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL);
707         if (!offload_ctx) {
708                 rc = -ENOMEM;
709                 goto free_marker_record;
710         }
711
712         crypto_info = &ctx->crypto_send.info;
713         switch (crypto_info->cipher_type) {
714         case TLS_CIPHER_AES_GCM_128:
715                 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
716                 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
717                 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
718                 iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
719                 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
720                 rec_seq =
721                  ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
722                 break;
723         default:
724                 rc = -EINVAL;
725                 goto free_offload_ctx;
726         }
727
728         prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
729         prot->tag_size = tag_size;
730         prot->overhead_size = prot->prepend_size + prot->tag_size;
731         prot->iv_size = iv_size;
732         ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
733                              GFP_KERNEL);
734         if (!ctx->tx.iv) {
735                 rc = -ENOMEM;
736                 goto free_offload_ctx;
737         }
738
739         memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
740
741         prot->rec_seq_size = rec_seq_size;
742         ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
743         if (!ctx->tx.rec_seq) {
744                 rc = -ENOMEM;
745                 goto free_iv;
746         }
747
748         rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
749         if (rc)
750                 goto free_rec_seq;
751
752         /* start at rec_seq - 1 to account for the start marker record */
753         memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn));
754         offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1;
755
756         start_marker_record->end_seq = tcp_sk(sk)->write_seq;
757         start_marker_record->len = 0;
758         start_marker_record->num_frags = 0;
759
760         INIT_LIST_HEAD(&offload_ctx->records_list);
761         list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
762         spin_lock_init(&offload_ctx->lock);
763         sg_init_table(offload_ctx->sg_tx_data,
764                       ARRAY_SIZE(offload_ctx->sg_tx_data));
765
766         clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked);
767         ctx->push_pending_record = tls_device_push_pending_record;
768
769         /* TLS offload is greatly simplified if we don't send
770          * SKBs where only part of the payload needs to be encrypted.
771          * So mark the last skb in the write queue as end of record.
772          */
773         skb = tcp_write_queue_tail(sk);
774         if (skb)
775                 TCP_SKB_CB(skb)->eor = 1;
776
777         /* We support starting offload on multiple sockets
778          * concurrently, so we only need a read lock here.
779          * This lock must precede get_netdev_for_sock to prevent races between
780          * NETDEV_DOWN and setsockopt.
781          */
782         down_read(&device_offload_lock);
783         netdev = get_netdev_for_sock(sk);
784         if (!netdev) {
785                 pr_err_ratelimited("%s: netdev not found\n", __func__);
786                 rc = -EINVAL;
787                 goto release_lock;
788         }
789
790         if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
791                 rc = -ENOTSUPP;
792                 goto release_netdev;
793         }
794
795         /* Avoid offloading if the device is down
796          * We don't want to offload new flows after
797          * the NETDEV_DOWN event
798          */
799         if (!(netdev->flags & IFF_UP)) {
800                 rc = -EINVAL;
801                 goto release_netdev;
802         }
803
804         ctx->priv_ctx_tx = offload_ctx;
805         rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX,
806                                              &ctx->crypto_send.info,
807                                              tcp_sk(sk)->write_seq);
808         if (rc)
809                 goto release_netdev;
810
811         tls_device_attach(ctx, sk, netdev);
812
813         /* following this assignment tls_is_sk_tx_device_offloaded
814          * will return true and the context might be accessed
815          * by the netdev's xmit function.
816          */
817         smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb);
818         dev_put(netdev);
819         up_read(&device_offload_lock);
820         goto out;
821
822 release_netdev:
823         dev_put(netdev);
824 release_lock:
825         up_read(&device_offload_lock);
826         clean_acked_data_disable(inet_csk(sk));
827         crypto_free_aead(offload_ctx->aead_send);
828 free_rec_seq:
829         kfree(ctx->tx.rec_seq);
830 free_iv:
831         kfree(ctx->tx.iv);
832 free_offload_ctx:
833         kfree(offload_ctx);
834         ctx->priv_ctx_tx = NULL;
835 free_marker_record:
836         kfree(start_marker_record);
837 out:
838         return rc;
839 }
840
841 int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
842 {
843         struct tls_offload_context_rx *context;
844         struct net_device *netdev;
845         int rc = 0;
846
847         /* We support starting offload on multiple sockets
848          * concurrently, so we only need a read lock here.
849          * This lock must precede get_netdev_for_sock to prevent races between
850          * NETDEV_DOWN and setsockopt.
851          */
852         down_read(&device_offload_lock);
853         netdev = get_netdev_for_sock(sk);
854         if (!netdev) {
855                 pr_err_ratelimited("%s: netdev not found\n", __func__);
856                 rc = -EINVAL;
857                 goto release_lock;
858         }
859
860         if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
861                 pr_err_ratelimited("%s: netdev %s with no TLS offload\n",
862                                    __func__, netdev->name);
863                 rc = -ENOTSUPP;
864                 goto release_netdev;
865         }
866
867         /* Avoid offloading if the device is down
868          * We don't want to offload new flows after
869          * the NETDEV_DOWN event
870          */
871         if (!(netdev->flags & IFF_UP)) {
872                 rc = -EINVAL;
873                 goto release_netdev;
874         }
875
876         context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL);
877         if (!context) {
878                 rc = -ENOMEM;
879                 goto release_netdev;
880         }
881
882         ctx->priv_ctx_rx = context;
883         rc = tls_set_sw_offload(sk, ctx, 0);
884         if (rc)
885                 goto release_ctx;
886
887         rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX,
888                                              &ctx->crypto_recv.info,
889                                              tcp_sk(sk)->copied_seq);
890         if (rc) {
891                 pr_err_ratelimited("%s: The netdev has refused to offload this socket\n",
892                                    __func__);
893                 goto free_sw_resources;
894         }
895
896         tls_device_attach(ctx, sk, netdev);
897         goto release_netdev;
898
899 free_sw_resources:
900         tls_sw_free_resources_rx(sk);
901 release_ctx:
902         ctx->priv_ctx_rx = NULL;
903 release_netdev:
904         dev_put(netdev);
905 release_lock:
906         up_read(&device_offload_lock);
907         return rc;
908 }
909
910 void tls_device_offload_cleanup_rx(struct sock *sk)
911 {
912         struct tls_context *tls_ctx = tls_get_ctx(sk);
913         struct net_device *netdev;
914
915         down_read(&device_offload_lock);
916         netdev = tls_ctx->netdev;
917         if (!netdev)
918                 goto out;
919
920         if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
921                 pr_err_ratelimited("%s: device is missing NETIF_F_HW_TLS_RX cap\n",
922                                    __func__);
923                 goto out;
924         }
925
926         netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx,
927                                         TLS_OFFLOAD_CTX_DIR_RX);
928
929         if (tls_ctx->tx_conf != TLS_HW) {
930                 dev_put(netdev);
931                 tls_ctx->netdev = NULL;
932         }
933 out:
934         up_read(&device_offload_lock);
935         kfree(tls_ctx->rx.rec_seq);
936         kfree(tls_ctx->rx.iv);
937         tls_sw_release_resources_rx(sk);
938 }
939
940 static int tls_device_down(struct net_device *netdev)
941 {
942         struct tls_context *ctx, *tmp;
943         unsigned long flags;
944         LIST_HEAD(list);
945
946         /* Request a write lock to block new offload attempts */
947         down_write(&device_offload_lock);
948
949         spin_lock_irqsave(&tls_device_lock, flags);
950         list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
951                 if (ctx->netdev != netdev ||
952                     !refcount_inc_not_zero(&ctx->refcount))
953                         continue;
954
955                 list_move(&ctx->list, &list);
956         }
957         spin_unlock_irqrestore(&tls_device_lock, flags);
958
959         list_for_each_entry_safe(ctx, tmp, &list, list) {
960                 if (ctx->tx_conf == TLS_HW)
961                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
962                                                         TLS_OFFLOAD_CTX_DIR_TX);
963                 if (ctx->rx_conf == TLS_HW)
964                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
965                                                         TLS_OFFLOAD_CTX_DIR_RX);
966                 ctx->netdev = NULL;
967                 dev_put(netdev);
968                 list_del_init(&ctx->list);
969
970                 if (refcount_dec_and_test(&ctx->refcount))
971                         tls_device_free_ctx(ctx);
972         }
973
974         up_write(&device_offload_lock);
975
976         flush_work(&tls_device_gc_work);
977
978         return NOTIFY_DONE;
979 }
980
981 static int tls_dev_event(struct notifier_block *this, unsigned long event,
982                          void *ptr)
983 {
984         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
985
986         if (!(dev->features & (NETIF_F_HW_TLS_RX | NETIF_F_HW_TLS_TX)))
987                 return NOTIFY_DONE;
988
989         switch (event) {
990         case NETDEV_REGISTER:
991         case NETDEV_FEAT_CHANGE:
992                 if ((dev->features & NETIF_F_HW_TLS_RX) &&
993                     !dev->tlsdev_ops->tls_dev_resync_rx)
994                         return NOTIFY_BAD;
995
996                 if  (dev->tlsdev_ops &&
997                      dev->tlsdev_ops->tls_dev_add &&
998                      dev->tlsdev_ops->tls_dev_del)
999                         return NOTIFY_DONE;
1000                 else
1001                         return NOTIFY_BAD;
1002         case NETDEV_DOWN:
1003                 return tls_device_down(dev);
1004         }
1005         return NOTIFY_DONE;
1006 }
1007
1008 static struct notifier_block tls_dev_notifier = {
1009         .notifier_call  = tls_dev_event,
1010 };
1011
1012 void __init tls_device_init(void)
1013 {
1014         register_netdevice_notifier(&tls_dev_notifier);
1015 }
1016
1017 void __exit tls_device_cleanup(void)
1018 {
1019         unregister_netdevice_notifier(&tls_dev_notifier);
1020         flush_work(&tls_device_gc_work);
1021 }