1 // SPDX-License-Identifier: GPL-2.0
3 #include <net/strparser.h>
6 #include <net/espintcp.h>
7 #include <linux/skmsg.h>
8 #include <net/inet_common.h>
9 #if IS_ENABLED(CONFIG_IPV6)
10 #include <net/ipv6_stubs.h>
13 static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb,
16 if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf ||
17 !sk_rmem_schedule(sk, skb, skb->truesize)) {
18 XFRM_INC_STATS(sock_net(sk), LINUX_MIB_XFRMINERROR);
23 skb_set_owner_r(skb, sk);
25 memset(skb->cb, 0, sizeof(skb->cb));
26 skb_queue_tail(&ctx->ike_queue, skb);
27 ctx->saved_data_ready(sk);
30 static void handle_esp(struct sk_buff *skb, struct sock *sk)
32 skb_reset_transport_header(skb);
33 memset(skb->cb, 0, sizeof(skb->cb));
36 skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif);
38 #if IS_ENABLED(CONFIG_IPV6)
39 if (sk->sk_family == AF_INET6)
40 ipv6_stub->xfrm6_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
43 xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
48 static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb)
50 struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx,
52 struct strp_msg *rxm = strp_msg(skb);
53 int len = rxm->full_len - 2;
57 /* keepalive packet? */
58 if (unlikely(len == 1)) {
61 err = skb_copy_bits(skb, rxm->offset + 2, &data, 1);
63 XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
74 /* drop other short messages */
75 if (unlikely(len <= sizeof(nonesp_marker))) {
76 XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
81 err = skb_copy_bits(skb, rxm->offset + 2, &nonesp_marker,
82 sizeof(nonesp_marker));
84 XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
89 /* remove header, leave non-ESP marker/SPI */
90 if (!__pskb_pull(skb, rxm->offset + 2)) {
91 XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR);
96 if (pskb_trim(skb, rxm->full_len - 2) != 0) {
97 XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR);
102 if (nonesp_marker == 0)
103 handle_nonesp(ctx, skb, strp->sk);
105 handle_esp(skb, strp->sk);
108 static int espintcp_parse(struct strparser *strp, struct sk_buff *skb)
110 struct strp_msg *rxm = strp_msg(skb);
115 if (skb->len < rxm->offset + 2)
118 err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen));
122 len = be16_to_cpu(blen);
129 static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
130 int nonblock, int flags, int *addr_len)
132 struct espintcp_ctx *ctx = espintcp_getctx(sk);
138 flags |= nonblock ? MSG_DONTWAIT : 0;
140 skb = __skb_recv_datagram(sk, &ctx->ike_queue, flags, &off, &err);
142 if (err == -EAGAIN && sk->sk_shutdown & RCV_SHUTDOWN)
148 if (copied > skb->len)
150 else if (copied < skb->len)
151 msg->msg_flags |= MSG_TRUNC;
153 err = skb_copy_datagram_msg(skb, 0, msg, copied);
159 if (flags & MSG_TRUNC)
165 int espintcp_queue_out(struct sock *sk, struct sk_buff *skb)
167 struct espintcp_ctx *ctx = espintcp_getctx(sk);
169 if (skb_queue_len(&ctx->out_queue) >= netdev_max_backlog)
172 __skb_queue_tail(&ctx->out_queue, skb);
176 EXPORT_SYMBOL_GPL(espintcp_queue_out);
178 /* espintcp length field is 2B and length includes the length field's size */
179 #define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2)
181 static int espintcp_sendskb_locked(struct sock *sk, struct espintcp_msg *emsg,
187 ret = skb_send_sock_locked(sk, emsg->skb,
188 emsg->offset, emsg->len);
194 } while (emsg->len > 0);
196 kfree_skb(emsg->skb);
197 memset(emsg, 0, sizeof(*emsg));
202 static int espintcp_sendskmsg_locked(struct sock *sk,
203 struct espintcp_msg *emsg, int flags)
205 struct sk_msg *skmsg = &emsg->skmsg;
206 struct scatterlist *sg;
210 flags |= MSG_SENDPAGE_NOTLAST;
211 sg = &skmsg->sg.data[skmsg->sg.start];
213 size_t size = sg->length - emsg->offset;
214 int offset = sg->offset + emsg->offset;
220 flags &= ~MSG_SENDPAGE_NOTLAST;
224 ret = do_tcp_sendpages(sk, p, offset, size, flags);
226 emsg->offset = offset - sg->offset;
227 skmsg->sg.start += done;
239 sk_mem_uncharge(sk, sg->length);
243 memset(emsg, 0, sizeof(*emsg));
248 static int espintcp_push_msgs(struct sock *sk, int flags)
250 struct espintcp_ctx *ctx = espintcp_getctx(sk);
251 struct espintcp_msg *emsg = &ctx->partial;
262 err = espintcp_sendskb_locked(sk, emsg, flags);
264 err = espintcp_sendskmsg_locked(sk, emsg, flags);
265 if (err == -EAGAIN) {
267 return flags & MSG_DONTWAIT ? -EAGAIN : 0;
270 memset(emsg, 0, sizeof(*emsg));
277 int espintcp_push_skb(struct sock *sk, struct sk_buff *skb)
279 struct espintcp_ctx *ctx = espintcp_getctx(sk);
280 struct espintcp_msg *emsg = &ctx->partial;
284 if (sk->sk_state != TCP_ESTABLISHED) {
289 offset = skb_transport_offset(skb);
290 len = skb->len - offset;
292 espintcp_push_msgs(sk, 0);
299 skb_set_owner_w(skb, sk);
301 emsg->offset = offset;
305 espintcp_push_msgs(sk, 0);
309 EXPORT_SYMBOL_GPL(espintcp_push_skb);
311 static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
313 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
314 struct espintcp_ctx *ctx = espintcp_getctx(sk);
315 struct espintcp_msg *emsg = &ctx->partial;
316 struct iov_iter pfx_iter;
317 struct kvec pfx_iov = {};
318 size_t msglen = size + 2;
322 if (msg->msg_flags & ~MSG_DONTWAIT)
325 if (size > MAX_ESPINTCP_MSG)
328 if (msg->msg_controllen)
333 err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT);
335 if (err != -EAGAIN || !(msg->msg_flags & MSG_DONTWAIT))
340 sk_msg_init(&emsg->skmsg);
342 /* only -ENOMEM is possible since we don't coalesce */
343 err = sk_msg_alloc(sk, &emsg->skmsg, msglen, 0);
347 err = sk_stream_wait_memory(sk, &timeo);
352 *((__be16 *)buf) = cpu_to_be16(msglen);
353 pfx_iov.iov_base = buf;
354 pfx_iov.iov_len = sizeof(buf);
355 iov_iter_kvec(&pfx_iter, WRITE, &pfx_iov, 1, pfx_iov.iov_len);
357 err = sk_msg_memcopy_from_iter(sk, &pfx_iter, &emsg->skmsg,
362 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, &emsg->skmsg, size);
366 end = emsg->skmsg.sg.end;
368 sk_msg_iter_var_prev(end);
369 sg_mark_end(sk_msg_elem(&emsg->skmsg, end));
371 tcp_rate_check_app_limited(sk);
373 err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT);
374 /* this message could be partially sent, keep it */
381 sk_msg_free(sk, &emsg->skmsg);
382 memset(emsg, 0, sizeof(*emsg));
388 static struct proto espintcp_prot __ro_after_init;
389 static struct proto_ops espintcp_ops __ro_after_init;
390 static struct proto espintcp6_prot;
391 static struct proto_ops espintcp6_ops;
392 static DEFINE_MUTEX(tcpv6_prot_mutex);
394 static void espintcp_data_ready(struct sock *sk)
396 struct espintcp_ctx *ctx = espintcp_getctx(sk);
398 strp_data_ready(&ctx->strp);
401 static void espintcp_tx_work(struct work_struct *work)
403 struct espintcp_ctx *ctx = container_of(work,
404 struct espintcp_ctx, work);
405 struct sock *sk = ctx->strp.sk;
408 if (!ctx->tx_running)
409 espintcp_push_msgs(sk, 0);
413 static void espintcp_write_space(struct sock *sk)
415 struct espintcp_ctx *ctx = espintcp_getctx(sk);
417 schedule_work(&ctx->work);
418 ctx->saved_write_space(sk);
421 static void espintcp_destruct(struct sock *sk)
423 struct espintcp_ctx *ctx = espintcp_getctx(sk);
425 ctx->saved_destruct(sk);
429 bool tcp_is_ulp_esp(struct sock *sk)
431 return sk->sk_prot == &espintcp_prot || sk->sk_prot == &espintcp6_prot;
433 EXPORT_SYMBOL_GPL(tcp_is_ulp_esp);
435 static void build_protos(struct proto *espintcp_prot,
436 struct proto_ops *espintcp_ops,
437 const struct proto *orig_prot,
438 const struct proto_ops *orig_ops);
439 static int espintcp_init_sk(struct sock *sk)
441 struct inet_connection_sock *icsk = inet_csk(sk);
442 struct strp_callbacks cb = {
443 .rcv_msg = espintcp_rcv,
444 .parse_msg = espintcp_parse,
446 struct espintcp_ctx *ctx;
449 /* sockmap is not compatible with espintcp */
450 if (sk->sk_user_data)
453 ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
457 err = strp_init(&ctx->strp, sk, &cb);
463 strp_check_rcv(&ctx->strp);
464 skb_queue_head_init(&ctx->ike_queue);
465 skb_queue_head_init(&ctx->out_queue);
467 if (sk->sk_family == AF_INET) {
468 sk->sk_prot = &espintcp_prot;
469 sk->sk_socket->ops = &espintcp_ops;
471 mutex_lock(&tcpv6_prot_mutex);
472 if (!espintcp6_prot.recvmsg)
473 build_protos(&espintcp6_prot, &espintcp6_ops, sk->sk_prot, sk->sk_socket->ops);
474 mutex_unlock(&tcpv6_prot_mutex);
476 sk->sk_prot = &espintcp6_prot;
477 sk->sk_socket->ops = &espintcp6_ops;
479 ctx->saved_data_ready = sk->sk_data_ready;
480 ctx->saved_write_space = sk->sk_write_space;
481 ctx->saved_destruct = sk->sk_destruct;
482 sk->sk_data_ready = espintcp_data_ready;
483 sk->sk_write_space = espintcp_write_space;
484 sk->sk_destruct = espintcp_destruct;
485 rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
486 INIT_WORK(&ctx->work, espintcp_tx_work);
488 /* avoid using task_frag */
489 sk->sk_allocation = GFP_ATOMIC;
498 static void espintcp_release(struct sock *sk)
500 struct espintcp_ctx *ctx = espintcp_getctx(sk);
501 struct sk_buff_head queue;
504 __skb_queue_head_init(&queue);
505 skb_queue_splice_init(&ctx->out_queue, &queue);
507 while ((skb = __skb_dequeue(&queue)))
508 espintcp_push_skb(sk, skb);
513 static void espintcp_close(struct sock *sk, long timeout)
515 struct espintcp_ctx *ctx = espintcp_getctx(sk);
516 struct espintcp_msg *emsg = &ctx->partial;
518 strp_stop(&ctx->strp);
520 sk->sk_prot = &tcp_prot;
523 cancel_work_sync(&ctx->work);
524 strp_done(&ctx->strp);
526 skb_queue_purge(&ctx->out_queue);
527 skb_queue_purge(&ctx->ike_queue);
531 kfree_skb(emsg->skb);
533 sk_msg_free(sk, &emsg->skmsg);
536 tcp_close(sk, timeout);
539 static __poll_t espintcp_poll(struct file *file, struct socket *sock,
542 __poll_t mask = datagram_poll(file, sock, wait);
543 struct sock *sk = sock->sk;
544 struct espintcp_ctx *ctx = espintcp_getctx(sk);
546 if (!skb_queue_empty(&ctx->ike_queue))
547 mask |= EPOLLIN | EPOLLRDNORM;
552 static void build_protos(struct proto *espintcp_prot,
553 struct proto_ops *espintcp_ops,
554 const struct proto *orig_prot,
555 const struct proto_ops *orig_ops)
557 memcpy(espintcp_prot, orig_prot, sizeof(struct proto));
558 memcpy(espintcp_ops, orig_ops, sizeof(struct proto_ops));
559 espintcp_prot->sendmsg = espintcp_sendmsg;
560 espintcp_prot->recvmsg = espintcp_recvmsg;
561 espintcp_prot->close = espintcp_close;
562 espintcp_prot->release_cb = espintcp_release;
563 espintcp_ops->poll = espintcp_poll;
566 static struct tcp_ulp_ops espintcp_ulp __read_mostly = {
568 .owner = THIS_MODULE,
569 .init = espintcp_init_sk,
572 void __init espintcp_init(void)
574 build_protos(&espintcp_prot, &espintcp_ops, &tcp_prot, &inet_stream_ops);
576 tcp_register_ulp(&espintcp_ulp);