bpf: Migrate cgroup_bpf to internal cgroup_bpf_attach_type enum
[linux-2.6-microblaze.git] / net / ipv4 / udp.c
index 6268280..8851c94 100644 (file)
@@ -645,10 +645,12 @@ static struct sock *__udp4_lib_err_encap(struct net *net,
                                         const struct iphdr *iph,
                                         struct udphdr *uh,
                                         struct udp_table *udptable,
+                                        struct sock *sk,
                                         struct sk_buff *skb, u32 info)
 {
+       int (*lookup)(struct sock *sk, struct sk_buff *skb);
        int network_offset, transport_offset;
-       struct sock *sk;
+       struct udp_sock *up;
 
        network_offset = skb_network_offset(skb);
        transport_offset = skb_transport_offset(skb);
@@ -659,18 +661,28 @@ static struct sock *__udp4_lib_err_encap(struct net *net,
        /* Transport header needs to point to the UDP header */
        skb_set_transport_header(skb, iph->ihl << 2);
 
+       if (sk) {
+               up = udp_sk(sk);
+
+               lookup = READ_ONCE(up->encap_err_lookup);
+               if (lookup && lookup(sk, skb))
+                       sk = NULL;
+
+               goto out;
+       }
+
        sk = __udp4_lib_lookup(net, iph->daddr, uh->source,
                               iph->saddr, uh->dest, skb->dev->ifindex, 0,
                               udptable, NULL);
        if (sk) {
-               int (*lookup)(struct sock *sk, struct sk_buff *skb);
-               struct udp_sock *up = udp_sk(sk);
+               up = udp_sk(sk);
 
                lookup = READ_ONCE(up->encap_err_lookup);
                if (!lookup || lookup(sk, skb))
                        sk = NULL;
        }
 
+out:
        if (!sk)
                sk = ERR_PTR(__udp4_lib_err_encap_no_sk(skb, info));
 
@@ -707,15 +719,16 @@ int __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable)
        sk = __udp4_lib_lookup(net, iph->daddr, uh->dest,
                               iph->saddr, uh->source, skb->dev->ifindex,
                               inet_sdif(skb), udptable, NULL);
+
        if (!sk || udp_sk(sk)->encap_type) {
                /* No socket for error: try tunnels before discarding */
-               sk = ERR_PTR(-ENOENT);
                if (static_branch_unlikely(&udp_encap_needed_key)) {
-                       sk = __udp4_lib_err_encap(net, iph, uh, udptable, skb,
+                       sk = __udp4_lib_err_encap(net, iph, uh, udptable, sk, skb,
                                                  info);
                        if (!sk)
                                return 0;
-               }
+               } else
+                       sk = ERR_PTR(-ENOENT);
 
                if (IS_ERR(sk)) {
                        __ICMP_INC_STATS(net, ICMP_MIB_INERRORS);
@@ -1102,7 +1115,7 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
        }
 
        ipcm_init_sk(&ipc, inet);
-       ipc.gso_size = up->gso_size;
+       ipc.gso_size = READ_ONCE(up->gso_size);
 
        if (msg->msg_controllen) {
                err = udp_cmsg_send(sk, msg, &ipc.gso_size);
@@ -1130,7 +1143,7 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
                rcu_read_unlock();
        }
 
-       if (cgroup_bpf_enabled(BPF_CGROUP_UDP4_SENDMSG) && !connected) {
+       if (cgroup_bpf_enabled(CGROUP_UDP4_SENDMSG) && !connected) {
                err = BPF_CGROUP_RUN_PROG_UDP4_SENDMSG_LOCK(sk,
                                            (struct sockaddr *)usin, &ipc.addr);
                if (err)
@@ -2695,7 +2708,7 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname,
        case UDP_SEGMENT:
                if (val < 0 || val > USHRT_MAX)
                        return -EINVAL;
-               up->gso_size = val;
+               WRITE_ONCE(up->gso_size, val);
                break;
 
        case UDP_GRO:
@@ -2790,7 +2803,7 @@ int udp_lib_getsockopt(struct sock *sk, int level, int optname,
                break;
 
        case UDP_SEGMENT:
-               val = up->gso_size;
+               val = READ_ONCE(up->gso_size);
                break;
 
        case UDP_GRO: