Merge tag 'edac_urgent_for_v5.9_rc3' of git://git.kernel.org/pub/scm/linux/kernel...
[linux-2.6-microblaze.git] / net / xfrm / espintcp.c
index 5a0ff66..827ccdf 100644 (file)
@@ -6,12 +6,16 @@
 #include <net/espintcp.h>
 #include <linux/skmsg.h>
 #include <net/inet_common.h>
+#if IS_ENABLED(CONFIG_IPV6)
+#include <net/ipv6_stubs.h>
+#endif
 
 static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb,
                          struct sock *sk)
 {
        if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf ||
            !sk_rmem_schedule(sk, skb, skb->truesize)) {
+               XFRM_INC_STATS(sock_net(sk), LINUX_MIB_XFRMINERROR);
                kfree_skb(skb);
                return;
        }
@@ -31,7 +35,12 @@ static void handle_esp(struct sk_buff *skb, struct sock *sk)
        rcu_read_lock();
        skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif);
        local_bh_disable();
-       xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
+#if IS_ENABLED(CONFIG_IPV6)
+       if (sk->sk_family == AF_INET6)
+               ipv6_stub->xfrm6_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
+       else
+#endif
+               xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
        local_bh_enable();
        rcu_read_unlock();
 }
@@ -41,23 +50,51 @@ static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb)
        struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx,
                                                strp);
        struct strp_msg *rxm = strp_msg(skb);
+       int len = rxm->full_len - 2;
        u32 nonesp_marker;
        int err;
 
+       /* keepalive packet? */
+       if (unlikely(len == 1)) {
+               u8 data;
+
+               err = skb_copy_bits(skb, rxm->offset + 2, &data, 1);
+               if (err < 0) {
+                       XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
+                       kfree_skb(skb);
+                       return;
+               }
+
+               if (data == 0xff) {
+                       kfree_skb(skb);
+                       return;
+               }
+       }
+
+       /* drop other short messages */
+       if (unlikely(len <= sizeof(nonesp_marker))) {
+               XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
+               kfree_skb(skb);
+               return;
+       }
+
        err = skb_copy_bits(skb, rxm->offset + 2, &nonesp_marker,
                            sizeof(nonesp_marker));
        if (err < 0) {
+               XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
                kfree_skb(skb);
                return;
        }
 
        /* remove header, leave non-ESP marker/SPI */
        if (!__pskb_pull(skb, rxm->offset + 2)) {
+               XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR);
                kfree_skb(skb);
                return;
        }
 
        if (pskb_trim(skb, rxm->full_len - 2) != 0) {
+               XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR);
                kfree_skb(skb);
                return;
        }
@@ -83,7 +120,7 @@ static int espintcp_parse(struct strparser *strp, struct sk_buff *skb)
                return err;
 
        len = be16_to_cpu(blen);
-       if (len < 6)
+       if (len < 2)
                return -EINVAL;
 
        return len;
@@ -101,8 +138,11 @@ static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
        flags |= nonblock ? MSG_DONTWAIT : 0;
 
        skb = __skb_recv_datagram(sk, &ctx->ike_queue, flags, &off, &err);
-       if (!skb)
+       if (!skb) {
+               if (err == -EAGAIN && sk->sk_shutdown & RCV_SHUTDOWN)
+                       return 0;
                return err;
+       }
 
        copied = len;
        if (copied > skb->len)
@@ -205,7 +245,7 @@ retry:
        return 0;
 }
 
-static int espintcp_push_msgs(struct sock *sk)
+static int espintcp_push_msgs(struct sock *sk, int flags)
 {
        struct espintcp_ctx *ctx = espintcp_getctx(sk);
        struct espintcp_msg *emsg = &ctx->partial;
@@ -219,12 +259,12 @@ static int espintcp_push_msgs(struct sock *sk)
        ctx->tx_running = 1;
 
        if (emsg->skb)
-               err = espintcp_sendskb_locked(sk, emsg, 0);
+               err = espintcp_sendskb_locked(sk, emsg, flags);
        else
-               err = espintcp_sendskmsg_locked(sk, emsg, 0);
+               err = espintcp_sendskmsg_locked(sk, emsg, flags);
        if (err == -EAGAIN) {
                ctx->tx_running = 0;
-               return 0;
+               return flags & MSG_DONTWAIT ? -EAGAIN : 0;
        }
        if (!err)
                memset(emsg, 0, sizeof(*emsg));
@@ -249,7 +289,7 @@ int espintcp_push_skb(struct sock *sk, struct sk_buff *skb)
        offset = skb_transport_offset(skb);
        len = skb->len - offset;
 
-       espintcp_push_msgs(sk);
+       espintcp_push_msgs(sk, 0);
 
        if (emsg->len) {
                kfree_skb(skb);
@@ -262,7 +302,7 @@ int espintcp_push_skb(struct sock *sk, struct sk_buff *skb)
        emsg->len = len;
        emsg->skb = skb;
 
-       espintcp_push_msgs(sk);
+       espintcp_push_msgs(sk, 0);
 
        return 0;
 }
@@ -279,7 +319,7 @@ static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
        char buf[2] = {0};
        int err, end;
 
-       if (msg->msg_flags)
+       if (msg->msg_flags & ~MSG_DONTWAIT)
                return -EOPNOTSUPP;
 
        if (size > MAX_ESPINTCP_MSG)
@@ -290,9 +330,10 @@ static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 
        lock_sock(sk);
 
-       err = espintcp_push_msgs(sk);
+       err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT);
        if (err < 0) {
-               err = -ENOBUFS;
+               if (err != -EAGAIN || !(msg->msg_flags & MSG_DONTWAIT))
+                       err = -ENOBUFS;
                goto unlock;
        }
 
@@ -329,10 +370,9 @@ static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 
        tcp_rate_check_app_limited(sk);
 
-       err = espintcp_push_msgs(sk);
+       err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT);
        /* this message could be partially sent, keep it */
-       if (err < 0)
-               goto unlock;
+
        release_sock(sk);
 
        return size;
@@ -347,6 +387,9 @@ unlock:
 
 static struct proto espintcp_prot __ro_after_init;
 static struct proto_ops espintcp_ops __ro_after_init;
+static struct proto espintcp6_prot;
+static struct proto_ops espintcp6_ops;
+static DEFINE_MUTEX(tcpv6_prot_mutex);
 
 static void espintcp_data_ready(struct sock *sk)
 {
@@ -363,7 +406,7 @@ static void espintcp_tx_work(struct work_struct *work)
 
        lock_sock(sk);
        if (!ctx->tx_running)
-               espintcp_push_msgs(sk);
+               espintcp_push_msgs(sk, 0);
        release_sock(sk);
 }
 
@@ -385,10 +428,14 @@ static void espintcp_destruct(struct sock *sk)
 
 bool tcp_is_ulp_esp(struct sock *sk)
 {
-       return sk->sk_prot == &espintcp_prot;
+       return sk->sk_prot == &espintcp_prot || sk->sk_prot == &espintcp6_prot;
 }
 EXPORT_SYMBOL_GPL(tcp_is_ulp_esp);
 
+static void build_protos(struct proto *espintcp_prot,
+                        struct proto_ops *espintcp_ops,
+                        const struct proto *orig_prot,
+                        const struct proto_ops *orig_ops);
 static int espintcp_init_sk(struct sock *sk)
 {
        struct inet_connection_sock *icsk = inet_csk(sk);
@@ -416,8 +463,19 @@ static int espintcp_init_sk(struct sock *sk)
        strp_check_rcv(&ctx->strp);
        skb_queue_head_init(&ctx->ike_queue);
        skb_queue_head_init(&ctx->out_queue);
-       sk->sk_prot = &espintcp_prot;
-       sk->sk_socket->ops = &espintcp_ops;
+
+       if (sk->sk_family == AF_INET) {
+               sk->sk_prot = &espintcp_prot;
+               sk->sk_socket->ops = &espintcp_ops;
+       } else {
+               mutex_lock(&tcpv6_prot_mutex);
+               if (!espintcp6_prot.recvmsg)
+                       build_protos(&espintcp6_prot, &espintcp6_ops, sk->sk_prot, sk->sk_socket->ops);
+               mutex_unlock(&tcpv6_prot_mutex);
+
+               sk->sk_prot = &espintcp6_prot;
+               sk->sk_socket->ops = &espintcp6_ops;
+       }
        ctx->saved_data_ready = sk->sk_data_ready;
        ctx->saved_write_space = sk->sk_write_space;
        ctx->saved_destruct = sk->sk_destruct;
@@ -491,6 +549,20 @@ static __poll_t espintcp_poll(struct file *file, struct socket *sock,
        return mask;
 }
 
+static void build_protos(struct proto *espintcp_prot,
+                        struct proto_ops *espintcp_ops,
+                        const struct proto *orig_prot,
+                        const struct proto_ops *orig_ops)
+{
+       memcpy(espintcp_prot, orig_prot, sizeof(struct proto));
+       memcpy(espintcp_ops, orig_ops, sizeof(struct proto_ops));
+       espintcp_prot->sendmsg = espintcp_sendmsg;
+       espintcp_prot->recvmsg = espintcp_recvmsg;
+       espintcp_prot->close = espintcp_close;
+       espintcp_prot->release_cb = espintcp_release;
+       espintcp_ops->poll = espintcp_poll;
+}
+
 static struct tcp_ulp_ops espintcp_ulp __read_mostly = {
        .name = "espintcp",
        .owner = THIS_MODULE,
@@ -499,13 +571,7 @@ static struct tcp_ulp_ops espintcp_ulp __read_mostly = {
 
 void __init espintcp_init(void)
 {
-       memcpy(&espintcp_prot, &tcp_prot, sizeof(tcp_prot));
-       memcpy(&espintcp_ops, &inet_stream_ops, sizeof(inet_stream_ops));
-       espintcp_prot.sendmsg = espintcp_sendmsg;
-       espintcp_prot.recvmsg = espintcp_recvmsg;
-       espintcp_prot.close = espintcp_close;
-       espintcp_prot.release_cb = espintcp_release;
-       espintcp_ops.poll = espintcp_poll;
+       build_protos(&espintcp_prot, &espintcp_ops, &tcp_prot, &inet_stream_ops);
 
        tcp_register_ulp(&espintcp_ulp);
 }