1 // SPDX-License-Identifier: GPL-2.0
3 #include <net/strparser.h>
6 #include <net/espintcp.h>
7 #include <linux/skmsg.h>
8 #include <net/inet_common.h>
10 static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb,
13 if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf ||
14 !sk_rmem_schedule(sk, skb, skb->truesize)) {
19 skb_set_owner_r(skb, sk);
21 memset(skb->cb, 0, sizeof(skb->cb));
22 skb_queue_tail(&ctx->ike_queue, skb);
23 ctx->saved_data_ready(sk);
26 static void handle_esp(struct sk_buff *skb, struct sock *sk)
28 skb_reset_transport_header(skb);
29 memset(skb->cb, 0, sizeof(skb->cb));
32 skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif);
34 xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
39 static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb)
41 struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx,
43 struct strp_msg *rxm = strp_msg(skb);
47 err = skb_copy_bits(skb, rxm->offset + 2, &nonesp_marker,
48 sizeof(nonesp_marker));
54 /* remove header, leave non-ESP marker/SPI */
55 if (!__pskb_pull(skb, rxm->offset + 2)) {
60 if (pskb_trim(skb, rxm->full_len - 2) != 0) {
65 if (nonesp_marker == 0)
66 handle_nonesp(ctx, skb, strp->sk);
68 handle_esp(skb, strp->sk);
71 static int espintcp_parse(struct strparser *strp, struct sk_buff *skb)
73 struct strp_msg *rxm = strp_msg(skb);
78 if (skb->len < rxm->offset + 2)
81 err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen));
85 len = be16_to_cpu(blen);
92 static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
93 int nonblock, int flags, int *addr_len)
95 struct espintcp_ctx *ctx = espintcp_getctx(sk);
101 flags |= nonblock ? MSG_DONTWAIT : 0;
103 skb = __skb_recv_datagram(sk, &ctx->ike_queue, flags, &off, &err);
108 if (copied > skb->len)
110 else if (copied < skb->len)
111 msg->msg_flags |= MSG_TRUNC;
113 err = skb_copy_datagram_msg(skb, 0, msg, copied);
119 if (flags & MSG_TRUNC)
125 int espintcp_queue_out(struct sock *sk, struct sk_buff *skb)
127 struct espintcp_ctx *ctx = espintcp_getctx(sk);
129 if (skb_queue_len(&ctx->out_queue) >= netdev_max_backlog)
132 __skb_queue_tail(&ctx->out_queue, skb);
136 EXPORT_SYMBOL_GPL(espintcp_queue_out);
138 /* espintcp length field is 2B and length includes the length field's size */
139 #define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2)
141 static int espintcp_sendskb_locked(struct sock *sk, struct espintcp_msg *emsg,
147 ret = skb_send_sock_locked(sk, emsg->skb,
148 emsg->offset, emsg->len);
154 } while (emsg->len > 0);
156 kfree_skb(emsg->skb);
157 memset(emsg, 0, sizeof(*emsg));
162 static int espintcp_sendskmsg_locked(struct sock *sk,
163 struct espintcp_msg *emsg, int flags)
165 struct sk_msg *skmsg = &emsg->skmsg;
166 struct scatterlist *sg;
170 flags |= MSG_SENDPAGE_NOTLAST;
171 sg = &skmsg->sg.data[skmsg->sg.start];
173 size_t size = sg->length - emsg->offset;
174 int offset = sg->offset + emsg->offset;
180 flags &= ~MSG_SENDPAGE_NOTLAST;
184 ret = do_tcp_sendpages(sk, p, offset, size, flags);
186 emsg->offset = offset - sg->offset;
187 skmsg->sg.start += done;
199 sk_mem_uncharge(sk, sg->length);
203 memset(emsg, 0, sizeof(*emsg));
208 static int espintcp_push_msgs(struct sock *sk)
210 struct espintcp_ctx *ctx = espintcp_getctx(sk);
211 struct espintcp_msg *emsg = &ctx->partial;
222 err = espintcp_sendskb_locked(sk, emsg, 0);
224 err = espintcp_sendskmsg_locked(sk, emsg, 0);
225 if (err == -EAGAIN) {
230 memset(emsg, 0, sizeof(*emsg));
237 int espintcp_push_skb(struct sock *sk, struct sk_buff *skb)
239 struct espintcp_ctx *ctx = espintcp_getctx(sk);
240 struct espintcp_msg *emsg = &ctx->partial;
244 if (sk->sk_state != TCP_ESTABLISHED) {
249 offset = skb_transport_offset(skb);
250 len = skb->len - offset;
252 espintcp_push_msgs(sk);
259 skb_set_owner_w(skb, sk);
261 emsg->offset = offset;
265 espintcp_push_msgs(sk);
269 EXPORT_SYMBOL_GPL(espintcp_push_skb);
271 static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
273 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
274 struct espintcp_ctx *ctx = espintcp_getctx(sk);
275 struct espintcp_msg *emsg = &ctx->partial;
276 struct iov_iter pfx_iter;
277 struct kvec pfx_iov = {};
278 size_t msglen = size + 2;
285 if (size > MAX_ESPINTCP_MSG)
288 if (msg->msg_controllen)
293 err = espintcp_push_msgs(sk);
299 sk_msg_init(&emsg->skmsg);
301 /* only -ENOMEM is possible since we don't coalesce */
302 err = sk_msg_alloc(sk, &emsg->skmsg, msglen, 0);
306 err = sk_stream_wait_memory(sk, &timeo);
311 *((__be16 *)buf) = cpu_to_be16(msglen);
312 pfx_iov.iov_base = buf;
313 pfx_iov.iov_len = sizeof(buf);
314 iov_iter_kvec(&pfx_iter, WRITE, &pfx_iov, 1, pfx_iov.iov_len);
316 err = sk_msg_memcopy_from_iter(sk, &pfx_iter, &emsg->skmsg,
321 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, &emsg->skmsg, size);
325 end = emsg->skmsg.sg.end;
327 sk_msg_iter_var_prev(end);
328 sg_mark_end(sk_msg_elem(&emsg->skmsg, end));
330 tcp_rate_check_app_limited(sk);
332 err = espintcp_push_msgs(sk);
333 /* this message could be partially sent, keep it */
341 sk_msg_free(sk, &emsg->skmsg);
342 memset(emsg, 0, sizeof(*emsg));
348 static struct proto espintcp_prot __ro_after_init;
349 static struct proto_ops espintcp_ops __ro_after_init;
351 static void espintcp_data_ready(struct sock *sk)
353 struct espintcp_ctx *ctx = espintcp_getctx(sk);
355 strp_data_ready(&ctx->strp);
358 static void espintcp_tx_work(struct work_struct *work)
360 struct espintcp_ctx *ctx = container_of(work,
361 struct espintcp_ctx, work);
362 struct sock *sk = ctx->strp.sk;
365 if (!ctx->tx_running)
366 espintcp_push_msgs(sk);
370 static void espintcp_write_space(struct sock *sk)
372 struct espintcp_ctx *ctx = espintcp_getctx(sk);
374 schedule_work(&ctx->work);
375 ctx->saved_write_space(sk);
378 static void espintcp_destruct(struct sock *sk)
380 struct espintcp_ctx *ctx = espintcp_getctx(sk);
385 bool tcp_is_ulp_esp(struct sock *sk)
387 return sk->sk_prot == &espintcp_prot;
389 EXPORT_SYMBOL_GPL(tcp_is_ulp_esp);
391 static int espintcp_init_sk(struct sock *sk)
393 struct inet_connection_sock *icsk = inet_csk(sk);
394 struct strp_callbacks cb = {
395 .rcv_msg = espintcp_rcv,
396 .parse_msg = espintcp_parse,
398 struct espintcp_ctx *ctx;
401 /* sockmap is not compatible with espintcp */
402 if (sk->sk_user_data)
405 ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
409 err = strp_init(&ctx->strp, sk, &cb);
415 strp_check_rcv(&ctx->strp);
416 skb_queue_head_init(&ctx->ike_queue);
417 skb_queue_head_init(&ctx->out_queue);
418 sk->sk_prot = &espintcp_prot;
419 sk->sk_socket->ops = &espintcp_ops;
420 ctx->saved_data_ready = sk->sk_data_ready;
421 ctx->saved_write_space = sk->sk_write_space;
422 sk->sk_data_ready = espintcp_data_ready;
423 sk->sk_write_space = espintcp_write_space;
424 sk->sk_destruct = espintcp_destruct;
425 rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
426 INIT_WORK(&ctx->work, espintcp_tx_work);
428 /* avoid using task_frag */
429 sk->sk_allocation = GFP_ATOMIC;
438 static void espintcp_release(struct sock *sk)
440 struct espintcp_ctx *ctx = espintcp_getctx(sk);
441 struct sk_buff_head queue;
444 __skb_queue_head_init(&queue);
445 skb_queue_splice_init(&ctx->out_queue, &queue);
447 while ((skb = __skb_dequeue(&queue)))
448 espintcp_push_skb(sk, skb);
453 static void espintcp_close(struct sock *sk, long timeout)
455 struct espintcp_ctx *ctx = espintcp_getctx(sk);
456 struct espintcp_msg *emsg = &ctx->partial;
458 strp_stop(&ctx->strp);
460 sk->sk_prot = &tcp_prot;
463 cancel_work_sync(&ctx->work);
464 strp_done(&ctx->strp);
466 skb_queue_purge(&ctx->out_queue);
467 skb_queue_purge(&ctx->ike_queue);
471 kfree_skb(emsg->skb);
473 sk_msg_free(sk, &emsg->skmsg);
476 tcp_close(sk, timeout);
479 static __poll_t espintcp_poll(struct file *file, struct socket *sock,
482 __poll_t mask = datagram_poll(file, sock, wait);
483 struct sock *sk = sock->sk;
484 struct espintcp_ctx *ctx = espintcp_getctx(sk);
486 if (!skb_queue_empty(&ctx->ike_queue))
487 mask |= EPOLLIN | EPOLLRDNORM;
492 static struct tcp_ulp_ops espintcp_ulp __read_mostly = {
494 .owner = THIS_MODULE,
495 .init = espintcp_init_sk,
498 void __init espintcp_init(void)
500 memcpy(&espintcp_prot, &tcp_prot, sizeof(tcp_prot));
501 memcpy(&espintcp_ops, &inet_stream_ops, sizeof(inet_stream_ops));
502 espintcp_prot.sendmsg = espintcp_sendmsg;
503 espintcp_prot.recvmsg = espintcp_recvmsg;
504 espintcp_prot.close = espintcp_close;
505 espintcp_prot.release_cb = espintcp_release;
506 espintcp_ops.poll = espintcp_poll;
508 tcp_register_ulp(&espintcp_ulp);