raw: use more conventional iterators
authorEric Dumazet <edumazet@google.com>
Sat, 18 Jun 2022 03:47:04 +0000 (20:47 -0700)
committerDavid S. Miller <davem@davemloft.net>
Sun, 19 Jun 2022 09:00:02 +0000 (10:00 +0100)
In order to prepare the following patch,
I change raw v4 & v6 code to use more conventional
iterators.

Signed-off-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/net/raw.h
include/net/rawv6.h
net/ipv4/raw.c
net/ipv4/raw_diag.c
net/ipv6/raw.c

index 8ad8df5..719d355 100644 (file)
@@ -20,9 +20,8 @@
 extern struct proto raw_prot;
 
 extern struct raw_hashinfo raw_v4_hashinfo;
-struct sock *__raw_v4_lookup(struct net *net, struct sock *sk,
-                            unsigned short num, __be32 raddr,
-                            __be32 laddr, int dif, int sdif);
+bool raw_v4_match(struct net *net, struct sock *sk, unsigned short num,
+                 __be32 raddr, __be32 laddr, int dif, int sdif);
 
 int raw_abort(struct sock *sk, int err);
 void raw_icmp_error(struct sk_buff *, int, u32);
index 53d86b6..c48c129 100644 (file)
@@ -5,9 +5,9 @@
 #include <net/protocol.h>
 
 extern struct raw_hashinfo raw_v6_hashinfo;
-struct sock *__raw_v6_lookup(struct net *net, struct sock *sk,
-                            unsigned short num, const struct in6_addr *loc_addr,
-                            const struct in6_addr *rmt_addr, int dif, int sdif);
+bool raw_v6_match(struct net *net, struct sock *sk, unsigned short num,
+                 const struct in6_addr *loc_addr,
+                 const struct in6_addr *rmt_addr, int dif, int sdif);
 
 int raw_abort(struct sock *sk, int err);
 
index bbd7178..05e0de4 100644 (file)
@@ -117,24 +117,19 @@ void raw_unhash_sk(struct sock *sk)
 }
 EXPORT_SYMBOL_GPL(raw_unhash_sk);
 
-struct sock *__raw_v4_lookup(struct net *net, struct sock *sk,
-                            unsigned short num, __be32 raddr, __be32 laddr,
-                            int dif, int sdif)
+bool raw_v4_match(struct net *net, struct sock *sk, unsigned short num,
+                 __be32 raddr, __be32 laddr, int dif, int sdif)
 {
-       sk_for_each_from(sk) {
-               struct inet_sock *inet = inet_sk(sk);
-
-               if (net_eq(sock_net(sk), net) && inet->inet_num == num  &&
-                   !(inet->inet_daddr && inet->inet_daddr != raddr)    &&
-                   !(inet->inet_rcv_saddr && inet->inet_rcv_saddr != laddr) &&
-                   raw_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif))
-                       goto found; /* gotcha */
-       }
-       sk = NULL;
-found:
-       return sk;
+       struct inet_sock *inet = inet_sk(sk);
+
+       if (net_eq(sock_net(sk), net) && inet->inet_num == num  &&
+           !(inet->inet_daddr && inet->inet_daddr != raddr)    &&
+           !(inet->inet_rcv_saddr && inet->inet_rcv_saddr != laddr) &&
+           raw_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif))
+               return true;
+       return false;
 }
-EXPORT_SYMBOL_GPL(__raw_v4_lookup);
+EXPORT_SYMBOL_GPL(raw_v4_match);
 
 /*
  *     0 - deliver
@@ -168,23 +163,21 @@ static int icmp_filter(const struct sock *sk, const struct sk_buff *skb)
  */
 static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash)
 {
+       struct net *net = dev_net(skb->dev);
        int sdif = inet_sdif(skb);
        int dif = inet_iif(skb);
-       struct sock *sk;
        struct hlist_head *head;
        int delivered = 0;
-       struct net *net;
+       struct sock *sk;
 
-       read_lock(&raw_v4_hashinfo.lock);
        head = &raw_v4_hashinfo.ht[hash];
        if (hlist_empty(head))
-               goto out;
-
-       net = dev_net(skb->dev);
-       sk = __raw_v4_lookup(net, __sk_head(head), iph->protocol,
-                            iph->saddr, iph->daddr, dif, sdif);
-
-       while (sk) {
+               return 0;
+       read_lock(&raw_v4_hashinfo.lock);
+       sk_for_each(sk, head) {
+               if (!raw_v4_match(net, sk, iph->protocol,
+                                 iph->saddr, iph->daddr, dif, sdif))
+                       continue;
                delivered = 1;
                if ((iph->protocol != IPPROTO_ICMP || !icmp_filter(sk, skb)) &&
                    ip_mc_sf_allow(sk, iph->daddr, iph->saddr,
@@ -195,31 +188,16 @@ static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash)
                        if (clone)
                                raw_rcv(sk, clone);
                }
-               sk = __raw_v4_lookup(net, sk_next(sk), iph->protocol,
-                                    iph->saddr, iph->daddr,
-                                    dif, sdif);
        }
-out:
        read_unlock(&raw_v4_hashinfo.lock);
        return delivered;
 }
 
 int raw_local_deliver(struct sk_buff *skb, int protocol)
 {
-       int hash;
-       struct sock *raw_sk;
-
-       hash = protocol & (RAW_HTABLE_SIZE - 1);
-       raw_sk = sk_head(&raw_v4_hashinfo.ht[hash]);
-
-       /* If there maybe a raw socket we must check - if not we
-        * don't care less
-        */
-       if (raw_sk && !raw_v4_input(skb, ip_hdr(skb), hash))
-               raw_sk = NULL;
-
-       return raw_sk != NULL;
+       int hash = protocol & (RAW_HTABLE_SIZE - 1);
 
+       return raw_v4_input(skb, ip_hdr(skb), hash);
 }
 
 static void raw_err(struct sock *sk, struct sk_buff *skb, u32 info)
@@ -286,29 +264,24 @@ static void raw_err(struct sock *sk, struct sk_buff *skb, u32 info)
 
 void raw_icmp_error(struct sk_buff *skb, int protocol, u32 info)
 {
-       int hash;
-       struct sock *raw_sk;
+       struct net *net = dev_net(skb->dev);;
+       int dif = skb->dev->ifindex;
+       int sdif = inet_sdif(skb);
+       struct hlist_head *head;
        const struct iphdr *iph;
-       struct net *net;
+       struct sock *sk;
+       int hash;
 
        hash = protocol & (RAW_HTABLE_SIZE - 1);
+       head = &raw_v4_hashinfo.ht[hash];
 
        read_lock(&raw_v4_hashinfo.lock);
-       raw_sk = sk_head(&raw_v4_hashinfo.ht[hash]);
-       if (raw_sk) {
-               int dif = skb->dev->ifindex;
-               int sdif = inet_sdif(skb);
-
+       sk_for_each(sk, head) {
                iph = (const struct iphdr *)skb->data;
-               net = dev_net(skb->dev);
-
-               while ((raw_sk = __raw_v4_lookup(net, raw_sk, protocol,
-                                               iph->daddr, iph->saddr,
-                                               dif, sdif)) != NULL) {
-                       raw_err(raw_sk, skb, info);
-                       raw_sk = sk_next(raw_sk);
-                       iph = (const struct iphdr *)skb->data;
-               }
+               if (!raw_v4_match(net, sk, iph->protocol,
+                                 iph->saddr, iph->daddr, dif, sdif))
+                       continue;
+               raw_err(sk, skb, info);
        }
        read_unlock(&raw_v4_hashinfo.lock);
 }
index ccacbde..b6d92dc 100644 (file)
@@ -34,31 +34,30 @@ raw_get_hashinfo(const struct inet_diag_req_v2 *r)
  * use helper to figure it out.
  */
 
-static struct sock *raw_lookup(struct net *net, struct sock *from,
-                              const struct inet_diag_req_v2 *req)
+static bool raw_lookup(struct net *net, struct sock *sk,
+                      const struct inet_diag_req_v2 *req)
 {
        struct inet_diag_req_raw *r = (void *)req;
-       struct sock *sk = NULL;
 
        if (r->sdiag_family == AF_INET)
-               sk = __raw_v4_lookup(net, from, r->sdiag_raw_protocol,
-                                    r->id.idiag_dst[0],
-                                    r->id.idiag_src[0],
-                                    r->id.idiag_if, 0);
+               return raw_v4_match(net, sk, r->sdiag_raw_protocol,
+                                   r->id.idiag_dst[0],
+                                   r->id.idiag_src[0],
+                                   r->id.idiag_if, 0);
 #if IS_ENABLED(CONFIG_IPV6)
        else
-               sk = __raw_v6_lookup(net, from, r->sdiag_raw_protocol,
-                                    (const struct in6_addr *)r->id.idiag_src,
-                                    (const struct in6_addr *)r->id.idiag_dst,
-                                    r->id.idiag_if, 0);
+               return raw_v6_match(net, sk, r->sdiag_raw_protocol,
+                                   (const struct in6_addr *)r->id.idiag_src,
+                                   (const struct in6_addr *)r->id.idiag_dst,
+                                   r->id.idiag_if, 0);
 #endif
-       return sk;
+       return false;
 }
 
 static struct sock *raw_sock_get(struct net *net, const struct inet_diag_req_v2 *r)
 {
        struct raw_hashinfo *hashinfo = raw_get_hashinfo(r);
-       struct sock *sk = NULL, *s;
+       struct sock *sk;
        int slot;
 
        if (IS_ERR(hashinfo))
@@ -66,9 +65,8 @@ static struct sock *raw_sock_get(struct net *net, const struct inet_diag_req_v2
 
        read_lock(&hashinfo->lock);
        for (slot = 0; slot < RAW_HTABLE_SIZE; slot++) {
-               sk_for_each(s, &hashinfo->ht[slot]) {
-                       sk = raw_lookup(net, s, r);
-                       if (sk) {
+               sk_for_each(sk, &hashinfo->ht[slot]) {
+                       if (raw_lookup(net, sk, r)) {
                                /*
                                 * Grab it and keep until we fill
                                 * diag meaage to be reported, so
@@ -81,10 +79,11 @@ static struct sock *raw_sock_get(struct net *net, const struct inet_diag_req_v2
                        }
                }
        }
+       sk = ERR_PTR(-ENOENT);
 out_unlock:
        read_unlock(&hashinfo->lock);
 
-       return sk ? sk : ERR_PTR(-ENOENT);
+       return sk;
 }
 
 static int raw_diag_dump_one(struct netlink_callback *cb,
index 3b7cbd5..c0f2e34 100644 (file)
@@ -66,41 +66,27 @@ struct raw_hashinfo raw_v6_hashinfo = {
 };
 EXPORT_SYMBOL_GPL(raw_v6_hashinfo);
 
-struct sock *__raw_v6_lookup(struct net *net, struct sock *sk,
-               unsigned short num, const struct in6_addr *loc_addr,
-               const struct in6_addr *rmt_addr, int dif, int sdif)
+bool raw_v6_match(struct net *net, struct sock *sk, unsigned short num,
+                 const struct in6_addr *loc_addr,
+                 const struct in6_addr *rmt_addr, int dif, int sdif)
 {
-       bool is_multicast = ipv6_addr_is_multicast(loc_addr);
-
-       sk_for_each_from(sk)
-               if (inet_sk(sk)->inet_num == num) {
-
-                       if (!net_eq(sock_net(sk), net))
-                               continue;
-
-                       if (!ipv6_addr_any(&sk->sk_v6_daddr) &&
-                           !ipv6_addr_equal(&sk->sk_v6_daddr, rmt_addr))
-                               continue;
-
-                       if (!raw_sk_bound_dev_eq(net, sk->sk_bound_dev_if,
-                                                dif, sdif))
-                               continue;
-
-                       if (!ipv6_addr_any(&sk->sk_v6_rcv_saddr)) {
-                               if (ipv6_addr_equal(&sk->sk_v6_rcv_saddr, loc_addr))
-                                       goto found;
-                               if (is_multicast &&
-                                   inet6_mc_check(sk, loc_addr, rmt_addr))
-                                       goto found;
-                               continue;
-                       }
-                       goto found;
-               }
-       sk = NULL;
-found:
-       return sk;
+       if (inet_sk(sk)->inet_num != num ||
+           !net_eq(sock_net(sk), net) ||
+           (!ipv6_addr_any(&sk->sk_v6_daddr) &&
+            !ipv6_addr_equal(&sk->sk_v6_daddr, rmt_addr)) ||
+           !raw_sk_bound_dev_eq(net, sk->sk_bound_dev_if,
+                                dif, sdif))
+               return false;
+
+       if (ipv6_addr_any(&sk->sk_v6_rcv_saddr) ||
+           ipv6_addr_equal(&sk->sk_v6_rcv_saddr, loc_addr) ||
+           (ipv6_addr_is_multicast(loc_addr) &&
+            inet6_mc_check(sk, loc_addr, rmt_addr)))
+               return true;
+
+       return false;
 }
-EXPORT_SYMBOL_GPL(__raw_v6_lookup);
+EXPORT_SYMBOL_GPL(raw_v6_match);
 
 /*
  *     0 - deliver
@@ -156,31 +142,28 @@ EXPORT_SYMBOL(rawv6_mh_filter_unregister);
  */
 static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr)
 {
+       struct net *net = dev_net(skb->dev);
        const struct in6_addr *saddr;
        const struct in6_addr *daddr;
+       struct hlist_head *head;
        struct sock *sk;
        bool delivered = false;
        __u8 hash;
-       struct net *net;
 
        saddr = &ipv6_hdr(skb)->saddr;
        daddr = saddr + 1;
 
        hash = nexthdr & (RAW_HTABLE_SIZE - 1);
-
+       head = &raw_v6_hashinfo.ht[hash];
+       if (hlist_empty(head))
+               return false;
        read_lock(&raw_v6_hashinfo.lock);
-       sk = sk_head(&raw_v6_hashinfo.ht[hash]);
-
-       if (!sk)
-               goto out;
-
-       net = dev_net(skb->dev);
-       sk = __raw_v6_lookup(net, sk, nexthdr, daddr, saddr,
-                            inet6_iif(skb), inet6_sdif(skb));
-
-       while (sk) {
+       sk_for_each(sk, head) {
                int filtered;
 
+               if (!raw_v6_match(net, sk, nexthdr, daddr, saddr,
+                                 inet6_iif(skb), inet6_sdif(skb)))
+                       continue;
                delivered = true;
                switch (nexthdr) {
                case IPPROTO_ICMPV6:
@@ -219,23 +202,14 @@ static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr)
                                rawv6_rcv(sk, clone);
                        }
                }
-               sk = __raw_v6_lookup(net, sk_next(sk), nexthdr, daddr, saddr,
-                                    inet6_iif(skb), inet6_sdif(skb));
        }
-out:
        read_unlock(&raw_v6_hashinfo.lock);
        return delivered;
 }
 
 bool raw6_local_deliver(struct sk_buff *skb, int nexthdr)
 {
-       struct sock *raw_sk;
-
-       raw_sk = sk_head(&raw_v6_hashinfo.ht[nexthdr & (RAW_HTABLE_SIZE - 1)]);
-       if (raw_sk && !ipv6_raw_deliver(skb, nexthdr))
-               raw_sk = NULL;
-
-       return raw_sk != NULL;
+       return ipv6_raw_deliver(skb, nexthdr);
 }
 
 /* This cleans up af_inet6 a bit. -DaveM */
@@ -361,28 +335,25 @@ static void rawv6_err(struct sock *sk, struct sk_buff *skb,
 void raw6_icmp_error(struct sk_buff *skb, int nexthdr,
                u8 type, u8 code, int inner_offset, __be32 info)
 {
+       const struct in6_addr *saddr, *daddr;
+       struct net *net = dev_net(skb->dev);
+       struct hlist_head *head;
        struct sock *sk;
        int hash;
-       const struct in6_addr *saddr, *daddr;
-       struct net *net;
 
        hash = nexthdr & (RAW_HTABLE_SIZE - 1);
-
+       head = &raw_v6_hashinfo.ht[hash];
        read_lock(&raw_v6_hashinfo.lock);
-       sk = sk_head(&raw_v6_hashinfo.ht[hash]);
-       if (sk) {
+       sk_for_each(sk, head) {
                /* Note: ipv6_hdr(skb) != skb->data */
                const struct ipv6hdr *ip6h = (const struct ipv6hdr *)skb->data;
                saddr = &ip6h->saddr;
                daddr = &ip6h->daddr;
-               net = dev_net(skb->dev);
 
-               while ((sk = __raw_v6_lookup(net, sk, nexthdr, saddr, daddr,
-                                            inet6_iif(skb), inet6_iif(skb)))) {
-                       rawv6_err(sk, skb, NULL, type, code,
-                                       inner_offset, info);
-                       sk = sk_next(sk);
-               }
+               if (!raw_v6_match(net, sk, nexthdr, &ip6h->saddr, &ip6h->daddr,
+                                 inet6_iif(skb), inet6_iif(skb)))
+                       continue;
+               rawv6_err(sk, skb, NULL, type, code, inner_offset, info);
        }
        read_unlock(&raw_v6_hashinfo.lock);
 }