inet: implement lockless IP_MTU_DISCOVER
authorEric Dumazet <edumazet@google.com>
Fri, 22 Sep 2023 03:42:15 +0000 (03:42 +0000)
committerDavid S. Miller <davem@davemloft.net>
Sun, 1 Oct 2023 18:39:18 +0000 (19:39 +0100)
inet->pmtudisc can be read locklessly.

Implement proper lockless reads and writes to inet->pmtudisc

ip_sock_set_mtu_discover() can now be called from arbitrary
contexts.

Signed-off-by: Eric Dumazet <edumazet@google.com>
Reviewed-by: David Ahern <dsahern@kernel.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/net/ip.h
net/ipv4/ip_output.c
net/ipv4/ip_sockglue.c
net/ipv4/ping.c
net/ipv4/raw.c
net/ipv4/udp.c
net/netfilter/ipvs/ip_vs_sync.c

index 3489a1c..46933a0 100644 (file)
@@ -434,19 +434,22 @@ int ip_dont_fragment(const struct sock *sk, const struct dst_entry *dst)
 
 static inline bool ip_sk_accept_pmtu(const struct sock *sk)
 {
-       return inet_sk(sk)->pmtudisc != IP_PMTUDISC_INTERFACE &&
-              inet_sk(sk)->pmtudisc != IP_PMTUDISC_OMIT;
+       u8 pmtudisc = READ_ONCE(inet_sk(sk)->pmtudisc);
+
+       return pmtudisc != IP_PMTUDISC_INTERFACE &&
+              pmtudisc != IP_PMTUDISC_OMIT;
 }
 
 static inline bool ip_sk_use_pmtu(const struct sock *sk)
 {
-       return inet_sk(sk)->pmtudisc < IP_PMTUDISC_PROBE;
+       return READ_ONCE(inet_sk(sk)->pmtudisc) < IP_PMTUDISC_PROBE;
 }
 
 static inline bool ip_sk_ignore_df(const struct sock *sk)
 {
-       return inet_sk(sk)->pmtudisc < IP_PMTUDISC_DO ||
-              inet_sk(sk)->pmtudisc == IP_PMTUDISC_OMIT;
+       u8 pmtudisc = READ_ONCE(inet_sk(sk)->pmtudisc);
+
+       return pmtudisc < IP_PMTUDISC_DO || pmtudisc == IP_PMTUDISC_OMIT;
 }
 
 static inline unsigned int ip_dst_mtu_maybe_forward(const struct dst_entry *dst,
index f07ce05..9fc7be2 100644 (file)
@@ -1387,8 +1387,8 @@ struct sk_buff *__ip_make_skb(struct sock *sk,
        struct ip_options *opt = NULL;
        struct rtable *rt = (struct rtable *)cork->dst;
        struct iphdr *iph;
+       u8 pmtudisc, ttl;
        __be16 df = 0;
-       __u8 ttl;
 
        skb = __skb_dequeue(queue);
        if (!skb)
@@ -1418,8 +1418,9 @@ struct sk_buff *__ip_make_skb(struct sock *sk,
        /* DF bit is set when we want to see DF on outgoing frames.
         * If ignore_df is set too, we still allow to fragment this frame
         * locally. */
-       if (inet->pmtudisc == IP_PMTUDISC_DO ||
-           inet->pmtudisc == IP_PMTUDISC_PROBE ||
+       pmtudisc = READ_ONCE(inet->pmtudisc);
+       if (pmtudisc == IP_PMTUDISC_DO ||
+           pmtudisc == IP_PMTUDISC_PROBE ||
            (skb->len <= dst_mtu(&rt->dst) &&
             ip_dont_fragment(sk, &rt->dst)))
                df = htons(IP_DF);
index 4ad3003..6d874cc 100644 (file)
@@ -622,9 +622,7 @@ int ip_sock_set_mtu_discover(struct sock *sk, int val)
 {
        if (val < IP_PMTUDISC_DONT || val > IP_PMTUDISC_OMIT)
                return -EINVAL;
-       lock_sock(sk);
-       inet_sk(sk)->pmtudisc = val;
-       release_sock(sk);
+       WRITE_ONCE(inet_sk(sk)->pmtudisc, val);
        return 0;
 }
 EXPORT_SYMBOL(ip_sock_set_mtu_discover);
@@ -1050,6 +1048,8 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname,
                        return -EINVAL;
                WRITE_ONCE(inet->mc_ttl, val);
                return 0;
+       case IP_MTU_DISCOVER:
+               return ip_sock_set_mtu_discover(sk, val);
        }
 
        err = 0;
@@ -1107,11 +1107,6 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname,
        case IP_TOS:    /* This sets both TOS and Precedence */
                __ip_sock_set_tos(sk, val);
                break;
-       case IP_MTU_DISCOVER:
-               if (val < IP_PMTUDISC_DONT || val > IP_PMTUDISC_OMIT)
-                       goto e_inval;
-               inet->pmtudisc = val;
-               break;
        case IP_UNICAST_IF:
        {
                struct net_device *dev = NULL;
@@ -1595,6 +1590,9 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
        case IP_MULTICAST_TTL:
                val = READ_ONCE(inet->mc_ttl);
                goto copyval;
+       case IP_MTU_DISCOVER:
+               val = READ_ONCE(inet->pmtudisc);
+               goto copyval;
        }
 
        if (needs_rtnl)
@@ -1634,9 +1632,6 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
        case IP_TOS:
                val = inet->tos;
                break;
-       case IP_MTU_DISCOVER:
-               val = inet->pmtudisc;
-               break;
        case IP_MTU:
        {
                struct dst_entry *dst;
index 4dd809b..50d12b0 100644 (file)
@@ -551,7 +551,7 @@ void ping_err(struct sk_buff *skb, int offset, u32 info)
                case ICMP_DEST_UNREACH:
                        if (code == ICMP_FRAG_NEEDED) { /* Path MTU discovery */
                                ipv4_sk_update_pmtu(skb, sk, info);
-                               if (inet_sock->pmtudisc != IP_PMTUDISC_DONT) {
+                               if (READ_ONCE(inet_sock->pmtudisc) != IP_PMTUDISC_DONT) {
                                        err = EMSGSIZE;
                                        harderr = 1;
                                        break;
index 4b5db5d..ade1aec 100644 (file)
@@ -239,7 +239,7 @@ static void raw_err(struct sock *sk, struct sk_buff *skb, u32 info)
                if (code > NR_ICMP_UNREACH)
                        break;
                if (code == ICMP_FRAG_NEEDED) {
-                       harderr = inet->pmtudisc != IP_PMTUDISC_DONT;
+                       harderr = READ_ONCE(inet->pmtudisc) != IP_PMTUDISC_DONT;
                        err = EMSGSIZE;
                } else {
                        err = icmp_err_convert[code].errno;
index c3ff984..731a723 100644 (file)
@@ -750,7 +750,7 @@ int __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable)
        case ICMP_DEST_UNREACH:
                if (code == ICMP_FRAG_NEEDED) { /* Path MTU discovery */
                        ipv4_sk_update_pmtu(skb, sk, info);
-                       if (inet->pmtudisc != IP_PMTUDISC_DONT) {
+                       if (READ_ONCE(inet->pmtudisc) != IP_PMTUDISC_DONT) {
                                err = EMSGSIZE;
                                harderr = 1;
                                break;
index 3eed167..4f6c795 100644 (file)
@@ -1335,7 +1335,7 @@ static void set_mcast_pmtudisc(struct sock *sk, int val)
 
        /* setsockopt(sock, SOL_IP, IP_MTU_DISCOVER, &val, sizeof(val)); */
        lock_sock(sk);
-       inet->pmtudisc = val;
+       WRITE_ONCE(inet->pmtudisc, val);
 #ifdef CONFIG_IP_VS_IPV6
        if (sk->sk_family == AF_INET6) {
                struct ipv6_pinfo *np = inet6_sk(sk);