ipmr: add rcu protection over (struct vif_device)->dev
authorEric Dumazet <edumazet@google.com>
Thu, 23 Jun 2022 04:34:32 +0000 (04:34 +0000)
committerDavid S. Miller <davem@davemloft.net>
Fri, 24 Jun 2022 10:34:37 +0000 (11:34 +0100)
We will soon use RCU instead of rwlock in ipmr & ip6mr

This preliminary patch adds proper rcu verbs to read/write
(struct vif_device)->dev

Signed-off-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/linux/mroute_base.h
net/ipv4/ipmr.c
net/ipv4/ipmr_base.c
net/ipv6/ip6mr.c

index e05ee9f..10d1e4f 100644 (file)
@@ -26,7 +26,7 @@
  * @remote: Remote address for tunnels
  */
 struct vif_device {
-       struct net_device *dev;
+       struct net_device __rcu *dev;
        netdevice_tracker dev_tracker;
        unsigned long bytes_in, bytes_out;
        unsigned long pkt_in, pkt_out;
@@ -52,6 +52,7 @@ static inline int mr_call_vif_notifier(struct notifier_block *nb,
                                       unsigned short family,
                                       enum fib_event_type event_type,
                                       struct vif_device *vif,
+                                      struct net_device *vif_dev,
                                       unsigned short vif_index, u32 tb_id,
                                       struct netlink_ext_ack *extack)
 {
@@ -60,7 +61,7 @@ static inline int mr_call_vif_notifier(struct notifier_block *nb,
                        .family = family,
                        .extack = extack,
                },
-               .dev = vif->dev,
+               .dev = vif_dev,
                .vif_index = vif_index,
                .vif_flags = vif->flags,
                .tb_id = tb_id,
@@ -73,6 +74,7 @@ static inline int mr_call_vif_notifiers(struct net *net,
                                        unsigned short family,
                                        enum fib_event_type event_type,
                                        struct vif_device *vif,
+                                       struct net_device *vif_dev,
                                        unsigned short vif_index, u32 tb_id,
                                        unsigned int *ipmr_seq)
 {
@@ -80,7 +82,7 @@ static inline int mr_call_vif_notifiers(struct net *net,
                .info = {
                        .family = family,
                },
-               .dev = vif->dev,
+               .dev = vif_dev,
                .vif_index = vif_index,
                .vif_flags = vif->flags,
                .tb_id = tb_id,
@@ -98,7 +100,8 @@ static inline int mr_call_vif_notifiers(struct net *net,
 #define MAXVIFS        32
 #endif
 
-#define VIF_EXISTS(_mrt, _idx) (!!((_mrt)->vif_table[_idx].dev))
+/* Note: This helper is deprecated. */
+#define VIF_EXISTS(_mrt, _idx) (!!rcu_access_pointer((_mrt)->vif_table[_idx].dev))
 
 /* mfc_flags:
  * MFC_STATIC - the entry was added statically (not by a routing daemon)
index 8324e54..10371a9 100644 (file)
@@ -79,6 +79,12 @@ struct ipmr_result {
 
 static DEFINE_RWLOCK(mrt_lock);
 
+static struct net_device *vif_dev_read(const struct vif_device *vif)
+{
+       return rcu_dereference_check(vif->dev,
+                                    lockdep_is_held(&mrt_lock));
+}
+
 /* Multicast router control variables */
 
 /* Special spinlock for queue of unresolved entries */
@@ -586,7 +592,7 @@ static int __pim_rcv(struct mr_table *mrt, struct sk_buff *skb,
 
        read_lock(&mrt_lock);
        if (mrt->mroute_reg_vif_num >= 0)
-               reg_dev = mrt->vif_table[mrt->mroute_reg_vif_num].dev;
+               reg_dev = vif_dev_read(&mrt->vif_table[mrt->mroute_reg_vif_num]);
        read_unlock(&mrt_lock);
 
        if (!reg_dev)
@@ -614,10 +620,11 @@ static struct net_device *ipmr_reg_vif(struct net *net, struct mr_table *mrt)
 static int call_ipmr_vif_entry_notifiers(struct net *net,
                                         enum fib_event_type event_type,
                                         struct vif_device *vif,
+                                        struct net_device *vif_dev,
                                         vifi_t vif_index, u32 tb_id)
 {
        return mr_call_vif_notifiers(net, RTNL_FAMILY_IPMR, event_type,
-                                    vif, vif_index, tb_id,
+                                    vif, vif_dev, vif_index, tb_id,
                                     &net->ipv4.ipmr_seq);
 }
 
@@ -649,18 +656,14 @@ static int vif_delete(struct mr_table *mrt, int vifi, int notify,
 
        v = &mrt->vif_table[vifi];
 
-       if (VIF_EXISTS(mrt, vifi))
-               call_ipmr_vif_entry_notifiers(net, FIB_EVENT_VIF_DEL, v, vifi,
-                                             mrt->id);
+       dev = rtnl_dereference(v->dev);
+       if (!dev)
+               return -EADDRNOTAVAIL;
 
        write_lock_bh(&mrt_lock);
-       dev = v->dev;
-       v->dev = NULL;
-
-       if (!dev) {
-               write_unlock_bh(&mrt_lock);
-               return -EADDRNOTAVAIL;
-       }
+       call_ipmr_vif_entry_notifiers(net, FIB_EVENT_VIF_DEL, v, dev,
+                                     vifi, mrt->id);
+       RCU_INIT_POINTER(v->dev, NULL);
 
        if (vifi == mrt->mroute_reg_vif_num)
                mrt->mroute_reg_vif_num = -1;
@@ -890,14 +893,15 @@ static int vif_add(struct net *net, struct mr_table *mrt,
 
        /* And finish update writing critical data */
        write_lock_bh(&mrt_lock);
-       v->dev = dev;
+       rcu_assign_pointer(v->dev, dev);
        netdev_tracker_alloc(dev, &v->dev_tracker, GFP_ATOMIC);
        if (v->flags & VIFF_REGISTER)
                mrt->mroute_reg_vif_num = vifi;
        if (vifi+1 > mrt->maxvif)
                mrt->maxvif = vifi+1;
        write_unlock_bh(&mrt_lock);
-       call_ipmr_vif_entry_notifiers(net, FIB_EVENT_VIF_ADD, v, vifi, mrt->id);
+       call_ipmr_vif_entry_notifiers(net, FIB_EVENT_VIF_ADD, v, dev,
+                                     vifi, mrt->id);
        return 0;
 }
 
@@ -1726,7 +1730,7 @@ static int ipmr_device_event(struct notifier_block *this, unsigned long event, v
        ipmr_for_each_table(mrt, net) {
                v = &mrt->vif_table[0];
                for (ct = 0; ct < mrt->maxvif; ct++, v++) {
-                       if (v->dev == dev)
+                       if (rcu_access_pointer(v->dev) == dev)
                                vif_delete(mrt, ct, 1, NULL);
                }
        }
@@ -1811,19 +1815,21 @@ static void ipmr_queue_xmit(struct net *net, struct mr_table *mrt,
 {
        const struct iphdr *iph = ip_hdr(skb);
        struct vif_device *vif = &mrt->vif_table[vifi];
+       struct net_device *vif_dev;
        struct net_device *dev;
        struct rtable *rt;
        struct flowi4 fl4;
        int    encap = 0;
 
-       if (!vif->dev)
+       vif_dev = vif_dev_read(vif);
+       if (!vif_dev)
                goto out_free;
 
        if (vif->flags & VIFF_REGISTER) {
                vif->pkt_out++;
                vif->bytes_out += skb->len;
-               vif->dev->stats.tx_bytes += skb->len;
-               vif->dev->stats.tx_packets++;
+               vif_dev->stats.tx_bytes += skb->len;
+               vif_dev->stats.tx_packets++;
                ipmr_cache_report(mrt, skb, vifi, IGMPMSG_WHOLEPKT);
                goto out_free;
        }
@@ -1881,8 +1887,8 @@ static void ipmr_queue_xmit(struct net *net, struct mr_table *mrt,
        if (vif->flags & VIFF_TUNNEL) {
                ip_encap(net, skb, vif->local, vif->remote);
                /* FIXME: extra output firewall step used to be here. --RR */
-               vif->dev->stats.tx_packets++;
-               vif->dev->stats.tx_bytes += skb->len;
+               vif_dev->stats.tx_packets++;
+               vif_dev->stats.tx_bytes += skb->len;
        }
 
        IPCB(skb)->flags |= IPSKB_FORWARDED;
@@ -1911,7 +1917,7 @@ static int ipmr_find_vif(struct mr_table *mrt, struct net_device *dev)
        int ct;
 
        for (ct = mrt->maxvif-1; ct >= 0; ct--) {
-               if (mrt->vif_table[ct].dev == dev)
+               if (rcu_access_pointer(mrt->vif_table[ct].dev) == dev)
                        break;
        }
        return ct;
@@ -1944,7 +1950,7 @@ static void ip_mr_forward(struct net *net, struct mr_table *mrt,
        }
 
        /* Wrong interface: drop packet and (maybe) send PIM assert. */
-       if (mrt->vif_table[vif].dev != dev) {
+       if (rcu_access_pointer(mrt->vif_table[vif].dev) != dev) {
                if (rt_is_output_route(skb_rtable(skb))) {
                        /* It is our own packet, looped back.
                         * Very complicated situation...
@@ -2744,18 +2750,21 @@ static bool ipmr_fill_table(struct mr_table *mrt, struct sk_buff *skb)
 
 static bool ipmr_fill_vif(struct mr_table *mrt, u32 vifid, struct sk_buff *skb)
 {
+       struct net_device *vif_dev;
        struct nlattr *vif_nest;
        struct vif_device *vif;
 
+       vif = &mrt->vif_table[vifid];
+       vif_dev = vif_dev_read(vif);
        /* if the VIF doesn't exist just continue */
-       if (!VIF_EXISTS(mrt, vifid))
+       if (!vif_dev)
                return true;
 
-       vif = &mrt->vif_table[vifid];
        vif_nest = nla_nest_start_noflag(skb, IPMRA_VIF);
        if (!vif_nest)
                return false;
-       if (nla_put_u32(skb, IPMRA_VIFA_IFINDEX, vif->dev->ifindex) ||
+
+       if (nla_put_u32(skb, IPMRA_VIFA_IFINDEX, vif_dev->ifindex) ||
            nla_put_u32(skb, IPMRA_VIFA_VIF_ID, vifid) ||
            nla_put_u16(skb, IPMRA_VIFA_FLAGS, vif->flags) ||
            nla_put_u64_64bit(skb, IPMRA_VIFA_BYTES_IN, vif->bytes_in,
@@ -2919,9 +2928,11 @@ static int ipmr_vif_seq_show(struct seq_file *seq, void *v)
                         "Interface      BytesIn  PktsIn  BytesOut PktsOut Flags Local    Remote\n");
        } else {
                const struct vif_device *vif = v;
-               const char *name =  vif->dev ?
-                                   vif->dev->name : "none";
+               const struct net_device *vif_dev;
+               const char *name;
 
+               vif_dev = vif_dev_read(vif);
+               name = vif_dev ? vif_dev->name : "none";
                seq_printf(seq,
                           "%2td %-10s %8ld %7ld  %8ld %7ld %05X %08X %08X\n",
                           vif - mrt->vif_table,
index aa8738a..59f62b9 100644 (file)
@@ -13,7 +13,7 @@ void vif_device_init(struct vif_device *v,
                     unsigned short flags,
                     unsigned short get_iflink_mask)
 {
-       v->dev = NULL;
+       RCU_INIT_POINTER(v->dev, NULL);
        v->bytes_in = 0;
        v->bytes_out = 0;
        v->pkt_in = 0;
@@ -208,6 +208,7 @@ EXPORT_SYMBOL(mr_mfc_seq_next);
 int mr_fill_mroute(struct mr_table *mrt, struct sk_buff *skb,
                   struct mr_mfc *c, struct rtmsg *rtm)
 {
+       struct net_device *vif_dev;
        struct rta_mfc_stats mfcs;
        struct nlattr *mp_attr;
        struct rtnexthop *nhp;
@@ -220,10 +221,13 @@ int mr_fill_mroute(struct mr_table *mrt, struct sk_buff *skb,
                return -ENOENT;
        }
 
-       if (VIF_EXISTS(mrt, c->mfc_parent) &&
-           nla_put_u32(skb, RTA_IIF,
-                       mrt->vif_table[c->mfc_parent].dev->ifindex) < 0)
+       rcu_read_lock();
+       vif_dev = rcu_dereference(mrt->vif_table[c->mfc_parent].dev);
+       if (vif_dev && nla_put_u32(skb, RTA_IIF, vif_dev->ifindex) < 0) {
+               rcu_read_unlock();
                return -EMSGSIZE;
+       }
+       rcu_read_unlock();
 
        if (c->mfc_flags & MFC_OFFLOAD)
                rtm->rtm_flags |= RTNH_F_OFFLOAD;
@@ -232,23 +236,27 @@ int mr_fill_mroute(struct mr_table *mrt, struct sk_buff *skb,
        if (!mp_attr)
                return -EMSGSIZE;
 
+       rcu_read_lock();
        for (ct = c->mfc_un.res.minvif; ct < c->mfc_un.res.maxvif; ct++) {
-               if (VIF_EXISTS(mrt, ct) && c->mfc_un.res.ttls[ct] < 255) {
-                       struct vif_device *vif;
+               struct vif_device *vif = &mrt->vif_table[ct];
+
+               vif_dev = rcu_dereference(vif->dev);
+               if (vif_dev && c->mfc_un.res.ttls[ct] < 255) {
 
                        nhp = nla_reserve_nohdr(skb, sizeof(*nhp));
                        if (!nhp) {
+                               rcu_read_unlock();
                                nla_nest_cancel(skb, mp_attr);
                                return -EMSGSIZE;
                        }
 
                        nhp->rtnh_flags = 0;
                        nhp->rtnh_hops = c->mfc_un.res.ttls[ct];
-                       vif = &mrt->vif_table[ct];
-                       nhp->rtnh_ifindex = vif->dev->ifindex;
+                       nhp->rtnh_ifindex = vif_dev->ifindex;
                        nhp->rtnh_len = sizeof(*nhp);
                }
        }
+       rcu_read_unlock();
 
        nla_nest_end(skb, mp_attr);
 
@@ -275,13 +283,14 @@ static bool mr_mfc_uses_dev(const struct mr_table *mrt,
        int ct;
 
        for (ct = c->mfc_un.res.minvif; ct < c->mfc_un.res.maxvif; ct++) {
-               if (VIF_EXISTS(mrt, ct) && c->mfc_un.res.ttls[ct] < 255) {
-                       const struct vif_device *vif;
-
-                       vif = &mrt->vif_table[ct];
-                       if (vif->dev == dev)
-                               return true;
-               }
+               const struct net_device *vif_dev;
+               const struct vif_device *vif;
+
+               vif = &mrt->vif_table[ct];
+               vif_dev = rcu_access_pointer(vif->dev);
+               if (vif_dev && c->mfc_un.res.ttls[ct] < 255 &&
+                   vif_dev == dev)
+                       return true;
        }
        return false;
 }
@@ -402,18 +411,22 @@ int mr_dump(struct net *net, struct notifier_block *nb, unsigned short family,
 
        for (mrt = mr_iter(net, NULL); mrt; mrt = mr_iter(net, mrt)) {
                struct vif_device *v = &mrt->vif_table[0];
+               struct net_device *vif_dev;
                struct mr_mfc *mfc;
                int vifi;
 
                /* Notifiy on table VIF entries */
                read_lock(mrt_lock);
                for (vifi = 0; vifi < mrt->maxvif; vifi++, v++) {
-                       if (!v->dev)
+                       vif_dev = rcu_dereference_check(v->dev,
+                                                       lockdep_is_held(mrt_lock));
+                       if (!vif_dev)
                                continue;
 
                        err = mr_call_vif_notifier(nb, family,
-                                                  FIB_EVENT_VIF_ADD,
-                                                  v, vifi, mrt->id, extack);
+                                                  FIB_EVENT_VIF_ADD, v,
+                                                  vif_dev, vifi,
+                                                  mrt->id, extack);
                        if (err)
                                break;
                }
index aa66c03..44cb3d8 100644 (file)
@@ -64,6 +64,12 @@ struct ip6mr_result {
 
 static DEFINE_RWLOCK(mrt_lock);
 
+static struct net_device *vif_dev_read(const struct vif_device *vif)
+{
+       return rcu_dereference_check(vif->dev,
+                                    lockdep_is_held(&mrt_lock));
+}
+
 /* Multicast router control variables */
 
 /* Special spinlock for queue of unresolved entries */
@@ -430,7 +436,11 @@ static int ip6mr_vif_seq_show(struct seq_file *seq, void *v)
                         "Interface      BytesIn  PktsIn  BytesOut PktsOut Flags\n");
        } else {
                const struct vif_device *vif = v;
-               const char *name = vif->dev ? vif->dev->name : "none";
+               const struct net_device *vif_dev;
+               const char *name;
+
+               vif_dev = vif_dev_read(vif);
+               name = vif_dev ? vif_dev->name : "none";
 
                seq_printf(seq,
                           "%2td %-10s %8ld %7ld  %8ld %7ld %05X\n",
@@ -553,7 +563,7 @@ static int pim6_rcv(struct sk_buff *skb)
 
        read_lock(&mrt_lock);
        if (reg_vif_num >= 0)
-               reg_dev = mrt->vif_table[reg_vif_num].dev;
+               reg_dev = vif_dev_read(&mrt->vif_table[reg_vif_num]);
        read_unlock(&mrt_lock);
 
        if (!reg_dev)
@@ -668,10 +678,11 @@ failure:
 static int call_ip6mr_vif_entry_notifiers(struct net *net,
                                          enum fib_event_type event_type,
                                          struct vif_device *vif,
+                                         struct net_device *vif_dev,
                                          mifi_t vif_index, u32 tb_id)
 {
        return mr_call_vif_notifiers(net, RTNL_FAMILY_IP6MR, event_type,
-                                    vif, vif_index, tb_id,
+                                    vif, vif_dev, vif_index, tb_id,
                                     &net->ipv6.ipmr_seq);
 }
 
@@ -696,19 +707,15 @@ static int mif6_delete(struct mr_table *mrt, int vifi, int notify,
 
        v = &mrt->vif_table[vifi];
 
-       if (VIF_EXISTS(mrt, vifi))
-               call_ip6mr_vif_entry_notifiers(read_pnet(&mrt->net),
-                                              FIB_EVENT_VIF_DEL, v, vifi,
-                                              mrt->id);
+       dev = rtnl_dereference(v->dev);
+       if (!dev)
+               return -EADDRNOTAVAIL;
 
+       call_ip6mr_vif_entry_notifiers(read_pnet(&mrt->net),
+                                      FIB_EVENT_VIF_DEL, v, dev,
+                                      vifi, mrt->id);
        write_lock_bh(&mrt_lock);
-       dev = v->dev;
-       v->dev = NULL;
-
-       if (!dev) {
-               write_unlock_bh(&mrt_lock);
-               return -EADDRNOTAVAIL;
-       }
+       RCU_INIT_POINTER(v->dev, NULL);
 
 #ifdef CONFIG_IPV6_PIMSM_V2
        if (vifi == mrt->mroute_reg_vif_num)
@@ -911,7 +918,7 @@ static int mif6_add(struct net *net, struct mr_table *mrt,
 
        /* And finish update writing critical data */
        write_lock_bh(&mrt_lock);
-       v->dev = dev;
+       rcu_assign_pointer(v->dev, dev);
        netdev_tracker_alloc(dev, &v->dev_tracker, GFP_ATOMIC);
 #ifdef CONFIG_IPV6_PIMSM_V2
        if (v->flags & MIFF_REGISTER)
@@ -921,7 +928,7 @@ static int mif6_add(struct net *net, struct mr_table *mrt,
                mrt->maxvif = vifi + 1;
        write_unlock_bh(&mrt_lock);
        call_ip6mr_vif_entry_notifiers(net, FIB_EVENT_VIF_ADD,
-                                      v, vifi, mrt->id);
+                                      v, dev, vifi, mrt->id);
        return 0;
 }
 
@@ -1241,7 +1248,7 @@ static int ip6mr_device_event(struct notifier_block *this,
        ip6mr_for_each_table(mrt, net) {
                v = &mrt->vif_table[0];
                for (ct = 0; ct < mrt->maxvif; ct++, v++) {
-                       if (v->dev == dev)
+                       if (rcu_access_pointer(v->dev) == dev)
                                mif6_delete(mrt, ct, 1, NULL);
                }
        }
@@ -2019,21 +2026,22 @@ static inline int ip6mr_forward2_finish(struct net *net, struct sock *sk, struct
 static int ip6mr_forward2(struct net *net, struct mr_table *mrt,
                          struct sk_buff *skb, int vifi)
 {
-       struct ipv6hdr *ipv6h;
        struct vif_device *vif = &mrt->vif_table[vifi];
-       struct net_device *dev;
+       struct net_device *vif_dev;
+       struct ipv6hdr *ipv6h;
        struct dst_entry *dst;
        struct flowi6 fl6;
 
-       if (!vif->dev)
+       vif_dev = vif_dev_read(vif);
+       if (!vif_dev)
                goto out_free;
 
 #ifdef CONFIG_IPV6_PIMSM_V2
        if (vif->flags & MIFF_REGISTER) {
                vif->pkt_out++;
                vif->bytes_out += skb->len;
-               vif->dev->stats.tx_bytes += skb->len;
-               vif->dev->stats.tx_packets++;
+               vif_dev->stats.tx_bytes += skb->len;
+               vif_dev->stats.tx_packets++;
                ip6mr_cache_report(mrt, skb, vifi, MRT6MSG_WHOLEPKT);
                goto out_free;
        }
@@ -2066,14 +2074,13 @@ static int ip6mr_forward2(struct net *net, struct mr_table *mrt,
         * not mrouter) cannot join to more than one interface - it will
         * result in receiving multiple packets.
         */
-       dev = vif->dev;
-       skb->dev = dev;
+       skb->dev = vif_dev;
        vif->pkt_out++;
        vif->bytes_out += skb->len;
 
        /* We are about to write */
        /* XXX: extension headers? */
-       if (skb_cow(skb, sizeof(*ipv6h) + LL_RESERVED_SPACE(dev)))
+       if (skb_cow(skb, sizeof(*ipv6h) + LL_RESERVED_SPACE(vif_dev)))
                goto out_free;
 
        ipv6h = ipv6_hdr(skb);
@@ -2082,7 +2089,7 @@ static int ip6mr_forward2(struct net *net, struct mr_table *mrt,
        IP6CB(skb)->flags |= IP6SKB_FORWARDED;
 
        return NF_HOOK(NFPROTO_IPV6, NF_INET_FORWARD,
-                      net, NULL, skb, skb->dev, dev,
+                      net, NULL, skb, skb->dev, vif_dev,
                       ip6mr_forward2_finish);
 
 out_free:
@@ -2095,7 +2102,7 @@ static int ip6mr_find_vif(struct mr_table *mrt, struct net_device *dev)
        int ct;
 
        for (ct = mrt->maxvif - 1; ct >= 0; ct--) {
-               if (mrt->vif_table[ct].dev == dev)
+               if (rcu_access_pointer(mrt->vif_table[ct].dev) == dev)
                        break;
        }
        return ct;
@@ -2133,7 +2140,7 @@ static void ip6_mr_forward(struct net *net, struct mr_table *mrt,
        /*
         * Wrong interface: drop packet and (maybe) send PIM assert.
         */
-       if (mrt->vif_table[vif].dev != dev) {
+       if (rcu_access_pointer(mrt->vif_table[vif].dev) != dev) {
                c->_c.mfc_un.res.wrong_if++;
 
                if (true_vifi >= 0 && mrt->mroute_do_assert &&