Merge tag 'linux-watchdog-5.9-rc1' of git://www.linux-watchdog.org/linux-watchdog
[linux-2.6-microblaze.git] / net / xfrm / espintcp.c
1 // SPDX-License-Identifier: GPL-2.0
2 #include <net/tcp.h>
3 #include <net/strparser.h>
4 #include <net/xfrm.h>
5 #include <net/esp.h>
6 #include <net/espintcp.h>
7 #include <linux/skmsg.h>
8 #include <net/inet_common.h>
9 #if IS_ENABLED(CONFIG_IPV6)
10 #include <net/ipv6_stubs.h>
11 #endif
12
13 static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb,
14                           struct sock *sk)
15 {
16         if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf ||
17             !sk_rmem_schedule(sk, skb, skb->truesize)) {
18                 XFRM_INC_STATS(sock_net(sk), LINUX_MIB_XFRMINERROR);
19                 kfree_skb(skb);
20                 return;
21         }
22
23         skb_set_owner_r(skb, sk);
24
25         memset(skb->cb, 0, sizeof(skb->cb));
26         skb_queue_tail(&ctx->ike_queue, skb);
27         ctx->saved_data_ready(sk);
28 }
29
30 static void handle_esp(struct sk_buff *skb, struct sock *sk)
31 {
32         skb_reset_transport_header(skb);
33         memset(skb->cb, 0, sizeof(skb->cb));
34
35         rcu_read_lock();
36         skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif);
37         local_bh_disable();
38 #if IS_ENABLED(CONFIG_IPV6)
39         if (sk->sk_family == AF_INET6)
40                 ipv6_stub->xfrm6_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
41         else
42 #endif
43                 xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
44         local_bh_enable();
45         rcu_read_unlock();
46 }
47
48 static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb)
49 {
50         struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx,
51                                                 strp);
52         struct strp_msg *rxm = strp_msg(skb);
53         int len = rxm->full_len - 2;
54         u32 nonesp_marker;
55         int err;
56
57         /* keepalive packet? */
58         if (unlikely(len == 1)) {
59                 u8 data;
60
61                 err = skb_copy_bits(skb, rxm->offset + 2, &data, 1);
62                 if (err < 0) {
63                         XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
64                         kfree_skb(skb);
65                         return;
66                 }
67
68                 if (data == 0xff) {
69                         kfree_skb(skb);
70                         return;
71                 }
72         }
73
74         /* drop other short messages */
75         if (unlikely(len <= sizeof(nonesp_marker))) {
76                 XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
77                 kfree_skb(skb);
78                 return;
79         }
80
81         err = skb_copy_bits(skb, rxm->offset + 2, &nonesp_marker,
82                             sizeof(nonesp_marker));
83         if (err < 0) {
84                 XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
85                 kfree_skb(skb);
86                 return;
87         }
88
89         /* remove header, leave non-ESP marker/SPI */
90         if (!__pskb_pull(skb, rxm->offset + 2)) {
91                 XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR);
92                 kfree_skb(skb);
93                 return;
94         }
95
96         if (pskb_trim(skb, rxm->full_len - 2) != 0) {
97                 XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR);
98                 kfree_skb(skb);
99                 return;
100         }
101
102         if (nonesp_marker == 0)
103                 handle_nonesp(ctx, skb, strp->sk);
104         else
105                 handle_esp(skb, strp->sk);
106 }
107
108 static int espintcp_parse(struct strparser *strp, struct sk_buff *skb)
109 {
110         struct strp_msg *rxm = strp_msg(skb);
111         __be16 blen;
112         u16 len;
113         int err;
114
115         if (skb->len < rxm->offset + 2)
116                 return 0;
117
118         err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen));
119         if (err < 0)
120                 return err;
121
122         len = be16_to_cpu(blen);
123         if (len < 2)
124                 return -EINVAL;
125
126         return len;
127 }
128
129 static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
130                             int nonblock, int flags, int *addr_len)
131 {
132         struct espintcp_ctx *ctx = espintcp_getctx(sk);
133         struct sk_buff *skb;
134         int err = 0;
135         int copied;
136         int off = 0;
137
138         flags |= nonblock ? MSG_DONTWAIT : 0;
139
140         skb = __skb_recv_datagram(sk, &ctx->ike_queue, flags, &off, &err);
141         if (!skb) {
142                 if (err == -EAGAIN && sk->sk_shutdown & RCV_SHUTDOWN)
143                         return 0;
144                 return err;
145         }
146
147         copied = len;
148         if (copied > skb->len)
149                 copied = skb->len;
150         else if (copied < skb->len)
151                 msg->msg_flags |= MSG_TRUNC;
152
153         err = skb_copy_datagram_msg(skb, 0, msg, copied);
154         if (unlikely(err)) {
155                 kfree_skb(skb);
156                 return err;
157         }
158
159         if (flags & MSG_TRUNC)
160                 copied = skb->len;
161         kfree_skb(skb);
162         return copied;
163 }
164
165 int espintcp_queue_out(struct sock *sk, struct sk_buff *skb)
166 {
167         struct espintcp_ctx *ctx = espintcp_getctx(sk);
168
169         if (skb_queue_len(&ctx->out_queue) >= netdev_max_backlog)
170                 return -ENOBUFS;
171
172         __skb_queue_tail(&ctx->out_queue, skb);
173
174         return 0;
175 }
176 EXPORT_SYMBOL_GPL(espintcp_queue_out);
177
178 /* espintcp length field is 2B and length includes the length field's size */
179 #define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2)
180
181 static int espintcp_sendskb_locked(struct sock *sk, struct espintcp_msg *emsg,
182                                    int flags)
183 {
184         do {
185                 int ret;
186
187                 ret = skb_send_sock_locked(sk, emsg->skb,
188                                            emsg->offset, emsg->len);
189                 if (ret < 0)
190                         return ret;
191
192                 emsg->len -= ret;
193                 emsg->offset += ret;
194         } while (emsg->len > 0);
195
196         kfree_skb(emsg->skb);
197         memset(emsg, 0, sizeof(*emsg));
198
199         return 0;
200 }
201
202 static int espintcp_sendskmsg_locked(struct sock *sk,
203                                      struct espintcp_msg *emsg, int flags)
204 {
205         struct sk_msg *skmsg = &emsg->skmsg;
206         struct scatterlist *sg;
207         int done = 0;
208         int ret;
209
210         flags |= MSG_SENDPAGE_NOTLAST;
211         sg = &skmsg->sg.data[skmsg->sg.start];
212         do {
213                 size_t size = sg->length - emsg->offset;
214                 int offset = sg->offset + emsg->offset;
215                 struct page *p;
216
217                 emsg->offset = 0;
218
219                 if (sg_is_last(sg))
220                         flags &= ~MSG_SENDPAGE_NOTLAST;
221
222                 p = sg_page(sg);
223 retry:
224                 ret = do_tcp_sendpages(sk, p, offset, size, flags);
225                 if (ret < 0) {
226                         emsg->offset = offset - sg->offset;
227                         skmsg->sg.start += done;
228                         return ret;
229                 }
230
231                 if (ret != size) {
232                         offset += ret;
233                         size -= ret;
234                         goto retry;
235                 }
236
237                 done++;
238                 put_page(p);
239                 sk_mem_uncharge(sk, sg->length);
240                 sg = sg_next(sg);
241         } while (sg);
242
243         memset(emsg, 0, sizeof(*emsg));
244
245         return 0;
246 }
247
248 static int espintcp_push_msgs(struct sock *sk, int flags)
249 {
250         struct espintcp_ctx *ctx = espintcp_getctx(sk);
251         struct espintcp_msg *emsg = &ctx->partial;
252         int err;
253
254         if (!emsg->len)
255                 return 0;
256
257         if (ctx->tx_running)
258                 return -EAGAIN;
259         ctx->tx_running = 1;
260
261         if (emsg->skb)
262                 err = espintcp_sendskb_locked(sk, emsg, flags);
263         else
264                 err = espintcp_sendskmsg_locked(sk, emsg, flags);
265         if (err == -EAGAIN) {
266                 ctx->tx_running = 0;
267                 return flags & MSG_DONTWAIT ? -EAGAIN : 0;
268         }
269         if (!err)
270                 memset(emsg, 0, sizeof(*emsg));
271
272         ctx->tx_running = 0;
273
274         return err;
275 }
276
277 int espintcp_push_skb(struct sock *sk, struct sk_buff *skb)
278 {
279         struct espintcp_ctx *ctx = espintcp_getctx(sk);
280         struct espintcp_msg *emsg = &ctx->partial;
281         unsigned int len;
282         int offset;
283
284         if (sk->sk_state != TCP_ESTABLISHED) {
285                 kfree_skb(skb);
286                 return -ECONNRESET;
287         }
288
289         offset = skb_transport_offset(skb);
290         len = skb->len - offset;
291
292         espintcp_push_msgs(sk, 0);
293
294         if (emsg->len) {
295                 kfree_skb(skb);
296                 return -ENOBUFS;
297         }
298
299         skb_set_owner_w(skb, sk);
300
301         emsg->offset = offset;
302         emsg->len = len;
303         emsg->skb = skb;
304
305         espintcp_push_msgs(sk, 0);
306
307         return 0;
308 }
309 EXPORT_SYMBOL_GPL(espintcp_push_skb);
310
311 static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
312 {
313         long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
314         struct espintcp_ctx *ctx = espintcp_getctx(sk);
315         struct espintcp_msg *emsg = &ctx->partial;
316         struct iov_iter pfx_iter;
317         struct kvec pfx_iov = {};
318         size_t msglen = size + 2;
319         char buf[2] = {0};
320         int err, end;
321
322         if (msg->msg_flags & ~MSG_DONTWAIT)
323                 return -EOPNOTSUPP;
324
325         if (size > MAX_ESPINTCP_MSG)
326                 return -EMSGSIZE;
327
328         if (msg->msg_controllen)
329                 return -EOPNOTSUPP;
330
331         lock_sock(sk);
332
333         err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT);
334         if (err < 0) {
335                 if (err != -EAGAIN || !(msg->msg_flags & MSG_DONTWAIT))
336                         err = -ENOBUFS;
337                 goto unlock;
338         }
339
340         sk_msg_init(&emsg->skmsg);
341         while (1) {
342                 /* only -ENOMEM is possible since we don't coalesce */
343                 err = sk_msg_alloc(sk, &emsg->skmsg, msglen, 0);
344                 if (!err)
345                         break;
346
347                 err = sk_stream_wait_memory(sk, &timeo);
348                 if (err)
349                         goto fail;
350         }
351
352         *((__be16 *)buf) = cpu_to_be16(msglen);
353         pfx_iov.iov_base = buf;
354         pfx_iov.iov_len = sizeof(buf);
355         iov_iter_kvec(&pfx_iter, WRITE, &pfx_iov, 1, pfx_iov.iov_len);
356
357         err = sk_msg_memcopy_from_iter(sk, &pfx_iter, &emsg->skmsg,
358                                        pfx_iov.iov_len);
359         if (err < 0)
360                 goto fail;
361
362         err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, &emsg->skmsg, size);
363         if (err < 0)
364                 goto fail;
365
366         end = emsg->skmsg.sg.end;
367         emsg->len = size;
368         sk_msg_iter_var_prev(end);
369         sg_mark_end(sk_msg_elem(&emsg->skmsg, end));
370
371         tcp_rate_check_app_limited(sk);
372
373         err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT);
374         /* this message could be partially sent, keep it */
375
376         release_sock(sk);
377
378         return size;
379
380 fail:
381         sk_msg_free(sk, &emsg->skmsg);
382         memset(emsg, 0, sizeof(*emsg));
383 unlock:
384         release_sock(sk);
385         return err;
386 }
387
388 static struct proto espintcp_prot __ro_after_init;
389 static struct proto_ops espintcp_ops __ro_after_init;
390 static struct proto espintcp6_prot;
391 static struct proto_ops espintcp6_ops;
392 static DEFINE_MUTEX(tcpv6_prot_mutex);
393
394 static void espintcp_data_ready(struct sock *sk)
395 {
396         struct espintcp_ctx *ctx = espintcp_getctx(sk);
397
398         strp_data_ready(&ctx->strp);
399 }
400
401 static void espintcp_tx_work(struct work_struct *work)
402 {
403         struct espintcp_ctx *ctx = container_of(work,
404                                                 struct espintcp_ctx, work);
405         struct sock *sk = ctx->strp.sk;
406
407         lock_sock(sk);
408         if (!ctx->tx_running)
409                 espintcp_push_msgs(sk, 0);
410         release_sock(sk);
411 }
412
413 static void espintcp_write_space(struct sock *sk)
414 {
415         struct espintcp_ctx *ctx = espintcp_getctx(sk);
416
417         schedule_work(&ctx->work);
418         ctx->saved_write_space(sk);
419 }
420
421 static void espintcp_destruct(struct sock *sk)
422 {
423         struct espintcp_ctx *ctx = espintcp_getctx(sk);
424
425         ctx->saved_destruct(sk);
426         kfree(ctx);
427 }
428
429 bool tcp_is_ulp_esp(struct sock *sk)
430 {
431         return sk->sk_prot == &espintcp_prot || sk->sk_prot == &espintcp6_prot;
432 }
433 EXPORT_SYMBOL_GPL(tcp_is_ulp_esp);
434
435 static void build_protos(struct proto *espintcp_prot,
436                          struct proto_ops *espintcp_ops,
437                          const struct proto *orig_prot,
438                          const struct proto_ops *orig_ops);
439 static int espintcp_init_sk(struct sock *sk)
440 {
441         struct inet_connection_sock *icsk = inet_csk(sk);
442         struct strp_callbacks cb = {
443                 .rcv_msg = espintcp_rcv,
444                 .parse_msg = espintcp_parse,
445         };
446         struct espintcp_ctx *ctx;
447         int err;
448
449         /* sockmap is not compatible with espintcp */
450         if (sk->sk_user_data)
451                 return -EBUSY;
452
453         ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
454         if (!ctx)
455                 return -ENOMEM;
456
457         err = strp_init(&ctx->strp, sk, &cb);
458         if (err)
459                 goto free;
460
461         __sk_dst_reset(sk);
462
463         strp_check_rcv(&ctx->strp);
464         skb_queue_head_init(&ctx->ike_queue);
465         skb_queue_head_init(&ctx->out_queue);
466
467         if (sk->sk_family == AF_INET) {
468                 sk->sk_prot = &espintcp_prot;
469                 sk->sk_socket->ops = &espintcp_ops;
470         } else {
471                 mutex_lock(&tcpv6_prot_mutex);
472                 if (!espintcp6_prot.recvmsg)
473                         build_protos(&espintcp6_prot, &espintcp6_ops, sk->sk_prot, sk->sk_socket->ops);
474                 mutex_unlock(&tcpv6_prot_mutex);
475
476                 sk->sk_prot = &espintcp6_prot;
477                 sk->sk_socket->ops = &espintcp6_ops;
478         }
479         ctx->saved_data_ready = sk->sk_data_ready;
480         ctx->saved_write_space = sk->sk_write_space;
481         ctx->saved_destruct = sk->sk_destruct;
482         sk->sk_data_ready = espintcp_data_ready;
483         sk->sk_write_space = espintcp_write_space;
484         sk->sk_destruct = espintcp_destruct;
485         rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
486         INIT_WORK(&ctx->work, espintcp_tx_work);
487
488         /* avoid using task_frag */
489         sk->sk_allocation = GFP_ATOMIC;
490
491         return 0;
492
493 free:
494         kfree(ctx);
495         return err;
496 }
497
498 static void espintcp_release(struct sock *sk)
499 {
500         struct espintcp_ctx *ctx = espintcp_getctx(sk);
501         struct sk_buff_head queue;
502         struct sk_buff *skb;
503
504         __skb_queue_head_init(&queue);
505         skb_queue_splice_init(&ctx->out_queue, &queue);
506
507         while ((skb = __skb_dequeue(&queue)))
508                 espintcp_push_skb(sk, skb);
509
510         tcp_release_cb(sk);
511 }
512
513 static void espintcp_close(struct sock *sk, long timeout)
514 {
515         struct espintcp_ctx *ctx = espintcp_getctx(sk);
516         struct espintcp_msg *emsg = &ctx->partial;
517
518         strp_stop(&ctx->strp);
519
520         sk->sk_prot = &tcp_prot;
521         barrier();
522
523         cancel_work_sync(&ctx->work);
524         strp_done(&ctx->strp);
525
526         skb_queue_purge(&ctx->out_queue);
527         skb_queue_purge(&ctx->ike_queue);
528
529         if (emsg->len) {
530                 if (emsg->skb)
531                         kfree_skb(emsg->skb);
532                 else
533                         sk_msg_free(sk, &emsg->skmsg);
534         }
535
536         tcp_close(sk, timeout);
537 }
538
539 static __poll_t espintcp_poll(struct file *file, struct socket *sock,
540                               poll_table *wait)
541 {
542         __poll_t mask = datagram_poll(file, sock, wait);
543         struct sock *sk = sock->sk;
544         struct espintcp_ctx *ctx = espintcp_getctx(sk);
545
546         if (!skb_queue_empty(&ctx->ike_queue))
547                 mask |= EPOLLIN | EPOLLRDNORM;
548
549         return mask;
550 }
551
552 static void build_protos(struct proto *espintcp_prot,
553                          struct proto_ops *espintcp_ops,
554                          const struct proto *orig_prot,
555                          const struct proto_ops *orig_ops)
556 {
557         memcpy(espintcp_prot, orig_prot, sizeof(struct proto));
558         memcpy(espintcp_ops, orig_ops, sizeof(struct proto_ops));
559         espintcp_prot->sendmsg = espintcp_sendmsg;
560         espintcp_prot->recvmsg = espintcp_recvmsg;
561         espintcp_prot->close = espintcp_close;
562         espintcp_prot->release_cb = espintcp_release;
563         espintcp_ops->poll = espintcp_poll;
564 }
565
566 static struct tcp_ulp_ops espintcp_ulp __read_mostly = {
567         .name = "espintcp",
568         .owner = THIS_MODULE,
569         .init = espintcp_init_sk,
570 };
571
572 void __init espintcp_init(void)
573 {
574         build_protos(&espintcp_prot, &espintcp_ops, &tcp_prot, &inet_stream_ops);
575
576         tcp_register_ulp(&espintcp_ulp);
577 }