ip6mr: Make mroute_sk rcu-based
authorYuval Mintz <yuvalm@mellanox.com>
Wed, 28 Feb 2018 21:29:30 +0000 (23:29 +0200)
committerDavid S. Miller <davem@davemloft.net>
Thu, 1 Mar 2018 18:13:23 +0000 (13:13 -0500)
In ipmr the mr_table socket is handled under RCU. Introduce the same
for ip6mr.

Signed-off-by: Yuval Mintz <yuvalm@mellanox.com>
Acked-by: Nikolay Aleksandrov <nikolay@cumulusnetworks.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/linux/mroute6.h
net/ipv6/ip6_output.c
net/ipv6/ip6mr.c

index e5e5b82..e1b9fb0 100644 (file)
@@ -111,12 +111,12 @@ extern int ip6mr_get_route(struct net *net, struct sk_buff *skb,
                           struct rtmsg *rtm, u32 portid);
 
 #ifdef CONFIG_IPV6_MROUTE
-extern struct sock *mroute6_socket(struct net *net, struct sk_buff *skb);
+bool mroute6_is_socket(struct net *net, struct sk_buff *skb);
 extern int ip6mr_sk_done(struct sock *sk);
 #else
-static inline struct sock *mroute6_socket(struct net *net, struct sk_buff *skb)
+static inline bool mroute6_is_socket(struct net *net, struct sk_buff *skb)
 {
-       return NULL;
+       return false;
 }
 static inline int ip6mr_sk_done(struct sock *sk)
 {
index 997c7f1..a6eb0e6 100644 (file)
@@ -71,7 +71,7 @@ static int ip6_finish_output2(struct net *net, struct sock *sk, struct sk_buff *
                struct inet6_dev *idev = ip6_dst_idev(skb_dst(skb));
 
                if (!(dev->flags & IFF_LOOPBACK) && sk_mc_loop(sk) &&
-                   ((mroute6_socket(net, skb) &&
+                   ((mroute6_is_socket(net, skb) &&
                     !(IP6CB(skb)->flags & IP6SKB_FORWARDED)) ||
                     ipv6_chk_mcast_addr(dev, &ipv6_hdr(skb)->daddr,
                                         &ipv6_hdr(skb)->saddr))) {
index e397990..a0e297d 100644 (file)
@@ -58,7 +58,7 @@ struct mr6_table {
        struct list_head        list;
        possible_net_t          net;
        u32                     id;
-       struct sock             *mroute6_sk;
+       struct sock __rcu       *mroute6_sk;
        struct timer_list       ipmr_expire_timer;
        struct list_head        mfc6_unres_queue;
        struct list_head        mfc6_cache_array[MFC6_LINES];
@@ -1121,6 +1121,7 @@ static void ip6mr_cache_resolve(struct net *net, struct mr6_table *mrt,
 static int ip6mr_cache_report(struct mr6_table *mrt, struct sk_buff *pkt,
                              mifi_t mifi, int assert)
 {
+       struct sock *mroute6_sk;
        struct sk_buff *skb;
        struct mrt6msg *msg;
        int ret;
@@ -1190,17 +1191,19 @@ static int ip6mr_cache_report(struct mr6_table *mrt, struct sk_buff *pkt,
        skb->ip_summed = CHECKSUM_UNNECESSARY;
        }
 
-       if (!mrt->mroute6_sk) {
+       rcu_read_lock();
+       mroute6_sk = rcu_dereference(mrt->mroute6_sk);
+       if (!mroute6_sk) {
+               rcu_read_unlock();
                kfree_skb(skb);
                return -EINVAL;
        }
 
        mrt6msg_netlink_event(mrt, skb);
 
-       /*
-        *      Deliver to user space multicast routing algorithms
-        */
-       ret = sock_queue_rcv_skb(mrt->mroute6_sk, skb);
+       /* Deliver to user space multicast routing algorithms */
+       ret = sock_queue_rcv_skb(mroute6_sk, skb);
+       rcu_read_unlock();
        if (ret < 0) {
                net_warn_ratelimited("mroute6: pending queue full, dropping entries\n");
                kfree_skb(skb);
@@ -1584,11 +1587,11 @@ static int ip6mr_sk_init(struct mr6_table *mrt, struct sock *sk)
 
        rtnl_lock();
        write_lock_bh(&mrt_lock);
-       if (likely(mrt->mroute6_sk == NULL)) {
-               mrt->mroute6_sk = sk;
-               net->ipv6.devconf_all->mc_forwarding++;
-       } else {
+       if (rtnl_dereference(mrt->mroute6_sk)) {
                err = -EADDRINUSE;
+       } else {
+               rcu_assign_pointer(mrt->mroute6_sk, sk);
+               net->ipv6.devconf_all->mc_forwarding++;
        }
        write_unlock_bh(&mrt_lock);
 
@@ -1614,9 +1617,9 @@ int ip6mr_sk_done(struct sock *sk)
 
        rtnl_lock();
        ip6mr_for_each_table(mrt, net) {
-               if (sk == mrt->mroute6_sk) {
+               if (sk == rtnl_dereference(mrt->mroute6_sk)) {
                        write_lock_bh(&mrt_lock);
-                       mrt->mroute6_sk = NULL;
+                       RCU_INIT_POINTER(mrt->mroute6_sk, NULL);
                        net->ipv6.devconf_all->mc_forwarding--;
                        write_unlock_bh(&mrt_lock);
                        inet6_netconf_notify_devconf(net, RTM_NEWNETCONF,
@@ -1630,11 +1633,12 @@ int ip6mr_sk_done(struct sock *sk)
                }
        }
        rtnl_unlock();
+       synchronize_rcu();
 
        return err;
 }
 
-struct sock *mroute6_socket(struct net *net, struct sk_buff *skb)
+bool mroute6_is_socket(struct net *net, struct sk_buff *skb)
 {
        struct mr6_table *mrt;
        struct flowi6 fl6 = {
@@ -1646,8 +1650,9 @@ struct sock *mroute6_socket(struct net *net, struct sk_buff *skb)
        if (ip6mr_fib_lookup(net, &fl6, &mrt) < 0)
                return NULL;
 
-       return mrt->mroute6_sk;
+       return rcu_access_pointer(mrt->mroute6_sk);
 }
+EXPORT_SYMBOL(mroute6_is_socket);
 
 /*
  *     Socket options and virtual interface manipulation. The whole
@@ -1674,7 +1679,8 @@ int ip6_mroute_setsockopt(struct sock *sk, int optname, char __user *optval, uns
                return -ENOENT;
 
        if (optname != MRT6_INIT) {
-               if (sk != mrt->mroute6_sk && !ns_capable(net->user_ns, CAP_NET_ADMIN))
+               if (sk != rcu_access_pointer(mrt->mroute6_sk) &&
+                   !ns_capable(net->user_ns, CAP_NET_ADMIN))
                        return -EACCES;
        }
 
@@ -1696,7 +1702,8 @@ int ip6_mroute_setsockopt(struct sock *sk, int optname, char __user *optval, uns
                if (vif.mif6c_mifi >= MAXMIFS)
                        return -ENFILE;
                rtnl_lock();
-               ret = mif6_add(net, mrt, &vif, sk == mrt->mroute6_sk);
+               ret = mif6_add(net, mrt, &vif,
+                              sk == rtnl_dereference(mrt->mroute6_sk));
                rtnl_unlock();
                return ret;
 
@@ -1731,7 +1738,9 @@ int ip6_mroute_setsockopt(struct sock *sk, int optname, char __user *optval, uns
                        ret = ip6mr_mfc_delete(mrt, &mfc, parent);
                else
                        ret = ip6mr_mfc_add(net, mrt, &mfc,
-                                           sk == mrt->mroute6_sk, parent);
+                                           sk ==
+                                           rtnl_dereference(mrt->mroute6_sk),
+                                           parent);
                rtnl_unlock();
                return ret;
 
@@ -1783,7 +1792,7 @@ int ip6_mroute_setsockopt(struct sock *sk, int optname, char __user *optval, uns
                /* "pim6reg%u" should not exceed 16 bytes (IFNAMSIZ) */
                if (v != RT_TABLE_DEFAULT && v >= 100000000)
                        return -EINVAL;
-               if (sk == mrt->mroute6_sk)
+               if (sk == rcu_access_pointer(mrt->mroute6_sk))
                        return -EBUSY;
 
                rtnl_lock();