ASoC: qdsp6: Suggest more generic node names
[linux-2.6-microblaze.git] / net / xfrm / espintcp.c
1 // SPDX-License-Identifier: GPL-2.0
2 #include <net/tcp.h>
3 #include <net/strparser.h>
4 #include <net/xfrm.h>
5 #include <net/esp.h>
6 #include <net/espintcp.h>
7 #include <linux/skmsg.h>
8 #include <net/inet_common.h>
9
10 static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb,
11                           struct sock *sk)
12 {
13         if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf ||
14             !sk_rmem_schedule(sk, skb, skb->truesize)) {
15                 kfree_skb(skb);
16                 return;
17         }
18
19         skb_set_owner_r(skb, sk);
20
21         memset(skb->cb, 0, sizeof(skb->cb));
22         skb_queue_tail(&ctx->ike_queue, skb);
23         ctx->saved_data_ready(sk);
24 }
25
26 static void handle_esp(struct sk_buff *skb, struct sock *sk)
27 {
28         skb_reset_transport_header(skb);
29         memset(skb->cb, 0, sizeof(skb->cb));
30
31         rcu_read_lock();
32         skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif);
33         local_bh_disable();
34         xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
35         local_bh_enable();
36         rcu_read_unlock();
37 }
38
39 static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb)
40 {
41         struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx,
42                                                 strp);
43         struct strp_msg *rxm = strp_msg(skb);
44         u32 nonesp_marker;
45         int err;
46
47         err = skb_copy_bits(skb, rxm->offset + 2, &nonesp_marker,
48                             sizeof(nonesp_marker));
49         if (err < 0) {
50                 kfree_skb(skb);
51                 return;
52         }
53
54         /* remove header, leave non-ESP marker/SPI */
55         if (!__pskb_pull(skb, rxm->offset + 2)) {
56                 kfree_skb(skb);
57                 return;
58         }
59
60         if (pskb_trim(skb, rxm->full_len - 2) != 0) {
61                 kfree_skb(skb);
62                 return;
63         }
64
65         if (nonesp_marker == 0)
66                 handle_nonesp(ctx, skb, strp->sk);
67         else
68                 handle_esp(skb, strp->sk);
69 }
70
71 static int espintcp_parse(struct strparser *strp, struct sk_buff *skb)
72 {
73         struct strp_msg *rxm = strp_msg(skb);
74         __be16 blen;
75         u16 len;
76         int err;
77
78         if (skb->len < rxm->offset + 2)
79                 return 0;
80
81         err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen));
82         if (err < 0)
83                 return err;
84
85         len = be16_to_cpu(blen);
86         if (len < 6)
87                 return -EINVAL;
88
89         return len;
90 }
91
92 static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
93                             int nonblock, int flags, int *addr_len)
94 {
95         struct espintcp_ctx *ctx = espintcp_getctx(sk);
96         struct sk_buff *skb;
97         int err = 0;
98         int copied;
99         int off = 0;
100
101         flags |= nonblock ? MSG_DONTWAIT : 0;
102
103         skb = __skb_recv_datagram(sk, &ctx->ike_queue, flags, &off, &err);
104         if (!skb)
105                 return err;
106
107         copied = len;
108         if (copied > skb->len)
109                 copied = skb->len;
110         else if (copied < skb->len)
111                 msg->msg_flags |= MSG_TRUNC;
112
113         err = skb_copy_datagram_msg(skb, 0, msg, copied);
114         if (unlikely(err)) {
115                 kfree_skb(skb);
116                 return err;
117         }
118
119         if (flags & MSG_TRUNC)
120                 copied = skb->len;
121         kfree_skb(skb);
122         return copied;
123 }
124
125 int espintcp_queue_out(struct sock *sk, struct sk_buff *skb)
126 {
127         struct espintcp_ctx *ctx = espintcp_getctx(sk);
128
129         if (skb_queue_len(&ctx->out_queue) >= netdev_max_backlog)
130                 return -ENOBUFS;
131
132         __skb_queue_tail(&ctx->out_queue, skb);
133
134         return 0;
135 }
136 EXPORT_SYMBOL_GPL(espintcp_queue_out);
137
138 /* espintcp length field is 2B and length includes the length field's size */
139 #define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2)
140
141 static int espintcp_sendskb_locked(struct sock *sk, struct espintcp_msg *emsg,
142                                    int flags)
143 {
144         do {
145                 int ret;
146
147                 ret = skb_send_sock_locked(sk, emsg->skb,
148                                            emsg->offset, emsg->len);
149                 if (ret < 0)
150                         return ret;
151
152                 emsg->len -= ret;
153                 emsg->offset += ret;
154         } while (emsg->len > 0);
155
156         kfree_skb(emsg->skb);
157         memset(emsg, 0, sizeof(*emsg));
158
159         return 0;
160 }
161
162 static int espintcp_sendskmsg_locked(struct sock *sk,
163                                      struct espintcp_msg *emsg, int flags)
164 {
165         struct sk_msg *skmsg = &emsg->skmsg;
166         struct scatterlist *sg;
167         int done = 0;
168         int ret;
169
170         flags |= MSG_SENDPAGE_NOTLAST;
171         sg = &skmsg->sg.data[skmsg->sg.start];
172         do {
173                 size_t size = sg->length - emsg->offset;
174                 int offset = sg->offset + emsg->offset;
175                 struct page *p;
176
177                 emsg->offset = 0;
178
179                 if (sg_is_last(sg))
180                         flags &= ~MSG_SENDPAGE_NOTLAST;
181
182                 p = sg_page(sg);
183 retry:
184                 ret = do_tcp_sendpages(sk, p, offset, size, flags);
185                 if (ret < 0) {
186                         emsg->offset = offset - sg->offset;
187                         skmsg->sg.start += done;
188                         return ret;
189                 }
190
191                 if (ret != size) {
192                         offset += ret;
193                         size -= ret;
194                         goto retry;
195                 }
196
197                 done++;
198                 put_page(p);
199                 sk_mem_uncharge(sk, sg->length);
200                 sg = sg_next(sg);
201         } while (sg);
202
203         memset(emsg, 0, sizeof(*emsg));
204
205         return 0;
206 }
207
208 static int espintcp_push_msgs(struct sock *sk)
209 {
210         struct espintcp_ctx *ctx = espintcp_getctx(sk);
211         struct espintcp_msg *emsg = &ctx->partial;
212         int err;
213
214         if (!emsg->len)
215                 return 0;
216
217         if (ctx->tx_running)
218                 return -EAGAIN;
219         ctx->tx_running = 1;
220
221         if (emsg->skb)
222                 err = espintcp_sendskb_locked(sk, emsg, 0);
223         else
224                 err = espintcp_sendskmsg_locked(sk, emsg, 0);
225         if (err == -EAGAIN) {
226                 ctx->tx_running = 0;
227                 return 0;
228         }
229         if (!err)
230                 memset(emsg, 0, sizeof(*emsg));
231
232         ctx->tx_running = 0;
233
234         return err;
235 }
236
237 int espintcp_push_skb(struct sock *sk, struct sk_buff *skb)
238 {
239         struct espintcp_ctx *ctx = espintcp_getctx(sk);
240         struct espintcp_msg *emsg = &ctx->partial;
241         unsigned int len;
242         int offset;
243
244         if (sk->sk_state != TCP_ESTABLISHED) {
245                 kfree_skb(skb);
246                 return -ECONNRESET;
247         }
248
249         offset = skb_transport_offset(skb);
250         len = skb->len - offset;
251
252         espintcp_push_msgs(sk);
253
254         if (emsg->len) {
255                 kfree_skb(skb);
256                 return -ENOBUFS;
257         }
258
259         skb_set_owner_w(skb, sk);
260
261         emsg->offset = offset;
262         emsg->len = len;
263         emsg->skb = skb;
264
265         espintcp_push_msgs(sk);
266
267         return 0;
268 }
269 EXPORT_SYMBOL_GPL(espintcp_push_skb);
270
271 static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
272 {
273         long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
274         struct espintcp_ctx *ctx = espintcp_getctx(sk);
275         struct espintcp_msg *emsg = &ctx->partial;
276         struct iov_iter pfx_iter;
277         struct kvec pfx_iov = {};
278         size_t msglen = size + 2;
279         char buf[2] = {0};
280         int err, end;
281
282         if (msg->msg_flags)
283                 return -EOPNOTSUPP;
284
285         if (size > MAX_ESPINTCP_MSG)
286                 return -EMSGSIZE;
287
288         if (msg->msg_controllen)
289                 return -EOPNOTSUPP;
290
291         lock_sock(sk);
292
293         err = espintcp_push_msgs(sk);
294         if (err < 0) {
295                 err = -ENOBUFS;
296                 goto unlock;
297         }
298
299         sk_msg_init(&emsg->skmsg);
300         while (1) {
301                 /* only -ENOMEM is possible since we don't coalesce */
302                 err = sk_msg_alloc(sk, &emsg->skmsg, msglen, 0);
303                 if (!err)
304                         break;
305
306                 err = sk_stream_wait_memory(sk, &timeo);
307                 if (err)
308                         goto fail;
309         }
310
311         *((__be16 *)buf) = cpu_to_be16(msglen);
312         pfx_iov.iov_base = buf;
313         pfx_iov.iov_len = sizeof(buf);
314         iov_iter_kvec(&pfx_iter, WRITE, &pfx_iov, 1, pfx_iov.iov_len);
315
316         err = sk_msg_memcopy_from_iter(sk, &pfx_iter, &emsg->skmsg,
317                                        pfx_iov.iov_len);
318         if (err < 0)
319                 goto fail;
320
321         err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, &emsg->skmsg, size);
322         if (err < 0)
323                 goto fail;
324
325         end = emsg->skmsg.sg.end;
326         emsg->len = size;
327         sk_msg_iter_var_prev(end);
328         sg_mark_end(sk_msg_elem(&emsg->skmsg, end));
329
330         tcp_rate_check_app_limited(sk);
331
332         err = espintcp_push_msgs(sk);
333         /* this message could be partially sent, keep it */
334         if (err < 0)
335                 goto unlock;
336         release_sock(sk);
337
338         return size;
339
340 fail:
341         sk_msg_free(sk, &emsg->skmsg);
342         memset(emsg, 0, sizeof(*emsg));
343 unlock:
344         release_sock(sk);
345         return err;
346 }
347
348 static struct proto espintcp_prot __ro_after_init;
349 static struct proto_ops espintcp_ops __ro_after_init;
350
351 static void espintcp_data_ready(struct sock *sk)
352 {
353         struct espintcp_ctx *ctx = espintcp_getctx(sk);
354
355         strp_data_ready(&ctx->strp);
356 }
357
358 static void espintcp_tx_work(struct work_struct *work)
359 {
360         struct espintcp_ctx *ctx = container_of(work,
361                                                 struct espintcp_ctx, work);
362         struct sock *sk = ctx->strp.sk;
363
364         lock_sock(sk);
365         if (!ctx->tx_running)
366                 espintcp_push_msgs(sk);
367         release_sock(sk);
368 }
369
370 static void espintcp_write_space(struct sock *sk)
371 {
372         struct espintcp_ctx *ctx = espintcp_getctx(sk);
373
374         schedule_work(&ctx->work);
375         ctx->saved_write_space(sk);
376 }
377
378 static void espintcp_destruct(struct sock *sk)
379 {
380         struct espintcp_ctx *ctx = espintcp_getctx(sk);
381
382         kfree(ctx);
383 }
384
385 bool tcp_is_ulp_esp(struct sock *sk)
386 {
387         return sk->sk_prot == &espintcp_prot;
388 }
389 EXPORT_SYMBOL_GPL(tcp_is_ulp_esp);
390
391 static int espintcp_init_sk(struct sock *sk)
392 {
393         struct inet_connection_sock *icsk = inet_csk(sk);
394         struct strp_callbacks cb = {
395                 .rcv_msg = espintcp_rcv,
396                 .parse_msg = espintcp_parse,
397         };
398         struct espintcp_ctx *ctx;
399         int err;
400
401         /* sockmap is not compatible with espintcp */
402         if (sk->sk_user_data)
403                 return -EBUSY;
404
405         ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
406         if (!ctx)
407                 return -ENOMEM;
408
409         err = strp_init(&ctx->strp, sk, &cb);
410         if (err)
411                 goto free;
412
413         __sk_dst_reset(sk);
414
415         strp_check_rcv(&ctx->strp);
416         skb_queue_head_init(&ctx->ike_queue);
417         skb_queue_head_init(&ctx->out_queue);
418         sk->sk_prot = &espintcp_prot;
419         sk->sk_socket->ops = &espintcp_ops;
420         ctx->saved_data_ready = sk->sk_data_ready;
421         ctx->saved_write_space = sk->sk_write_space;
422         sk->sk_data_ready = espintcp_data_ready;
423         sk->sk_write_space = espintcp_write_space;
424         sk->sk_destruct = espintcp_destruct;
425         rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
426         INIT_WORK(&ctx->work, espintcp_tx_work);
427
428         /* avoid using task_frag */
429         sk->sk_allocation = GFP_ATOMIC;
430
431         return 0;
432
433 free:
434         kfree(ctx);
435         return err;
436 }
437
438 static void espintcp_release(struct sock *sk)
439 {
440         struct espintcp_ctx *ctx = espintcp_getctx(sk);
441         struct sk_buff_head queue;
442         struct sk_buff *skb;
443
444         __skb_queue_head_init(&queue);
445         skb_queue_splice_init(&ctx->out_queue, &queue);
446
447         while ((skb = __skb_dequeue(&queue)))
448                 espintcp_push_skb(sk, skb);
449
450         tcp_release_cb(sk);
451 }
452
453 static void espintcp_close(struct sock *sk, long timeout)
454 {
455         struct espintcp_ctx *ctx = espintcp_getctx(sk);
456         struct espintcp_msg *emsg = &ctx->partial;
457
458         strp_stop(&ctx->strp);
459
460         sk->sk_prot = &tcp_prot;
461         barrier();
462
463         cancel_work_sync(&ctx->work);
464         strp_done(&ctx->strp);
465
466         skb_queue_purge(&ctx->out_queue);
467         skb_queue_purge(&ctx->ike_queue);
468
469         if (emsg->len) {
470                 if (emsg->skb)
471                         kfree_skb(emsg->skb);
472                 else
473                         sk_msg_free(sk, &emsg->skmsg);
474         }
475
476         tcp_close(sk, timeout);
477 }
478
479 static __poll_t espintcp_poll(struct file *file, struct socket *sock,
480                               poll_table *wait)
481 {
482         __poll_t mask = datagram_poll(file, sock, wait);
483         struct sock *sk = sock->sk;
484         struct espintcp_ctx *ctx = espintcp_getctx(sk);
485
486         if (!skb_queue_empty(&ctx->ike_queue))
487                 mask |= EPOLLIN | EPOLLRDNORM;
488
489         return mask;
490 }
491
492 static struct tcp_ulp_ops espintcp_ulp __read_mostly = {
493         .name = "espintcp",
494         .owner = THIS_MODULE,
495         .init = espintcp_init_sk,
496 };
497
498 void __init espintcp_init(void)
499 {
500         memcpy(&espintcp_prot, &tcp_prot, sizeof(tcp_prot));
501         memcpy(&espintcp_ops, &inet_stream_ops, sizeof(inet_stream_ops));
502         espintcp_prot.sendmsg = espintcp_sendmsg;
503         espintcp_prot.recvmsg = espintcp_recvmsg;
504         espintcp_prot.close = espintcp_close;
505         espintcp_prot.release_cb = espintcp_release;
506         espintcp_ops.poll = espintcp_poll;
507
508         tcp_register_ulp(&espintcp_ulp);
509 }