netlink: access nlk groups safely in netlink bind and getname
[linux-2.6-microblaze.git] / net / netlink / af_netlink.c
index 94a61e6..3278077 100644 (file)
@@ -955,7 +955,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
        struct net *net = sock_net(sk);
        struct netlink_sock *nlk = nlk_sk(sk);
        struct sockaddr_nl *nladdr = (struct sockaddr_nl *)addr;
-       int err;
+       int err = 0;
        long unsigned int groups = nladdr->nl_groups;
        bool bound;
 
@@ -983,6 +983,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
                        return -EINVAL;
        }
 
+       netlink_lock_table();
        if (nlk->netlink_bind && groups) {
                int group;
 
@@ -993,7 +994,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
                        if (!err)
                                continue;
                        netlink_undo_bind(group, groups, sk);
-                       return err;
+                       goto unlock;
                }
        }
 
@@ -1006,12 +1007,13 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
                        netlink_autobind(sock);
                if (err) {
                        netlink_undo_bind(nlk->ngroups, groups, sk);
-                       return err;
+                       goto unlock;
                }
        }
 
        if (!groups && (nlk->groups == NULL || !(u32)nlk->groups[0]))
-               return 0;
+               goto unlock;
+       netlink_unlock_table();
 
        netlink_table_grab();
        netlink_update_subscriptions(sk, nlk->subscriptions +
@@ -1022,6 +1024,10 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
        netlink_table_ungrab();
 
        return 0;
+
+unlock:
+       netlink_unlock_table();
+       return err;
 }
 
 static int netlink_connect(struct socket *sock, struct sockaddr *addr,
@@ -1079,7 +1085,9 @@ static int netlink_getname(struct socket *sock, struct sockaddr *addr,
                nladdr->nl_groups = netlink_group_mask(nlk->dst_group);
        } else {
                nladdr->nl_pid = nlk->portid;
+               netlink_lock_table();
                nladdr->nl_groups = nlk->groups ? nlk->groups[0] : 0;
+               netlink_unlock_table();
        }
        return 0;
 }