ip6mr: ip6mr_cache_report() changes
authorEric Dumazet <edumazet@google.com>
Thu, 23 Jun 2022 04:34:40 +0000 (04:34 +0000)
committerDavid S. Miller <davem@davemloft.net>
Fri, 24 Jun 2022 10:34:37 +0000 (11:34 +0100)
ip6mr_cache_report() first argument can be marked const, and we change
the caller convention about which lock needs to be held.

Instead of read_lock(&mrt_lock), we can use rcu_read_lock().

Signed-off-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
net/ipv6/ip6mr.c

index 44cb3d8..a6d9795 100644 (file)
@@ -91,11 +91,11 @@ static void ip6mr_free_table(struct mr_table *mrt);
 static void ip6_mr_forward(struct net *net, struct mr_table *mrt,
                           struct net_device *dev, struct sk_buff *skb,
                           struct mfc6_cache *cache);
-static int ip6mr_cache_report(struct mr_table *mrt, struct sk_buff *pkt,
+static int ip6mr_cache_report(const struct mr_table *mrt, struct sk_buff *pkt,
                              mifi_t mifi, int assert);
 static void mr6_netlink_event(struct mr_table *mrt, struct mfc6_cache *mfc,
                              int cmd);
-static void mrt6msg_netlink_event(struct mr_table *mrt, struct sk_buff *pkt);
+static void mrt6msg_netlink_event(const struct mr_table *mrt, struct sk_buff *pkt);
 static int ip6mr_rtm_dumproute(struct sk_buff *skb,
                               struct netlink_callback *cb);
 static void mroute_clean_tables(struct mr_table *mrt, int flags);
@@ -608,11 +608,12 @@ static netdev_tx_t reg_vif_xmit(struct sk_buff *skb,
        if (ip6mr_fib_lookup(net, &fl6, &mrt) < 0)
                goto tx_err;
 
-       read_lock(&mrt_lock);
        dev->stats.tx_bytes += skb->len;
        dev->stats.tx_packets++;
-       ip6mr_cache_report(mrt, skb, mrt->mroute_reg_vif_num, MRT6MSG_WHOLEPKT);
-       read_unlock(&mrt_lock);
+       rcu_read_lock();
+       ip6mr_cache_report(mrt, skb, READ_ONCE(mrt->mroute_reg_vif_num),
+                          MRT6MSG_WHOLEPKT);
+       rcu_read_unlock();
        kfree_skb(skb);
        return NETDEV_TX_OK;
 
@@ -718,8 +719,10 @@ static int mif6_delete(struct mr_table *mrt, int vifi, int notify,
        RCU_INIT_POINTER(v->dev, NULL);
 
 #ifdef CONFIG_IPV6_PIMSM_V2
-       if (vifi == mrt->mroute_reg_vif_num)
-               mrt->mroute_reg_vif_num = -1;
+       if (vifi == mrt->mroute_reg_vif_num) {
+               /* Pairs with READ_ONCE() in ip6mr_cache_report() and reg_vif_xmit() */
+               WRITE_ONCE(mrt->mroute_reg_vif_num, -1);
+       }
 #endif
 
        if (vifi + 1 == mrt->maxvif) {
@@ -922,7 +925,7 @@ static int mif6_add(struct net *net, struct mr_table *mrt,
        netdev_tracker_alloc(dev, &v->dev_tracker, GFP_ATOMIC);
 #ifdef CONFIG_IPV6_PIMSM_V2
        if (v->flags & MIFF_REGISTER)
-               mrt->mroute_reg_vif_num = vifi;
+               WRITE_ONCE(mrt->mroute_reg_vif_num, vifi);
 #endif
        if (vifi + 1 > mrt->maxvif)
                mrt->maxvif = vifi + 1;
@@ -1033,10 +1036,10 @@ static void ip6mr_cache_resolve(struct net *net, struct mr_table *mrt,
 /*
  *     Bounce a cache query up to pim6sd and netlink.
  *
- *     Called under mrt_lock.
+ *     Called under rcu_read_lock()
  */
 
-static int ip6mr_cache_report(struct mr_table *mrt, struct sk_buff *pkt,
+static int ip6mr_cache_report(const struct mr_table *mrt, struct sk_buff *pkt,
                              mifi_t mifi, int assert)
 {
        struct sock *mroute6_sk;
@@ -1077,7 +1080,7 @@ static int ip6mr_cache_report(struct mr_table *mrt, struct sk_buff *pkt,
                if (assert == MRT6MSG_WRMIFWHOLE)
                        msg->im6_mif = mifi;
                else
-                       msg->im6_mif = mrt->mroute_reg_vif_num;
+                       msg->im6_mif = READ_ONCE(mrt->mroute_reg_vif_num);
                msg->im6_pad = 0;
                msg->im6_src = ipv6_hdr(pkt)->saddr;
                msg->im6_dst = ipv6_hdr(pkt)->daddr;
@@ -1112,10 +1115,8 @@ static int ip6mr_cache_report(struct mr_table *mrt, struct sk_buff *pkt,
        skb->ip_summed = CHECKSUM_UNNECESSARY;
        }
 
-       rcu_read_lock();
        mroute6_sk = rcu_dereference(mrt->mroute_sk);
        if (!mroute6_sk) {
-               rcu_read_unlock();
                kfree_skb(skb);
                return -EINVAL;
        }
@@ -1124,7 +1125,7 @@ static int ip6mr_cache_report(struct mr_table *mrt, struct sk_buff *pkt,
 
        /* Deliver to user space multicast routing algorithms */
        ret = sock_queue_rcv_skb(mroute6_sk, skb);
-       rcu_read_unlock();
+
        if (ret < 0) {
                net_warn_ratelimited("mroute6: pending queue full, dropping entries\n");
                kfree_skb(skb);
@@ -2042,7 +2043,9 @@ static int ip6mr_forward2(struct net *net, struct mr_table *mrt,
                vif->bytes_out += skb->len;
                vif_dev->stats.tx_bytes += skb->len;
                vif_dev->stats.tx_packets++;
+               rcu_read_lock();
                ip6mr_cache_report(mrt, skb, vifi, MRT6MSG_WHOLEPKT);
+               rcu_read_unlock();
                goto out_free;
        }
 #endif
@@ -2155,10 +2158,12 @@ static void ip6_mr_forward(struct net *net, struct mr_table *mrt,
                               c->_c.mfc_un.res.last_assert +
                               MFC_ASSERT_THRESH)) {
                        c->_c.mfc_un.res.last_assert = jiffies;
+                       rcu_read_lock();
                        ip6mr_cache_report(mrt, skb, true_vifi, MRT6MSG_WRONGMIF);
                        if (mrt->mroute_do_wrvifwhole)
                                ip6mr_cache_report(mrt, skb, true_vifi,
                                                   MRT6MSG_WRMIFWHOLE);
+                       rcu_read_unlock();
                }
                goto dont_forward;
        }
@@ -2465,7 +2470,7 @@ static size_t mrt6msg_netlink_msgsize(size_t payloadlen)
        return len;
 }
 
-static void mrt6msg_netlink_event(struct mr_table *mrt, struct sk_buff *pkt)
+static void mrt6msg_netlink_event(const struct mr_table *mrt, struct sk_buff *pkt)
 {
        struct net *net = read_pnet(&mrt->net);
        struct nlmsghdr *nlh;