set_mcast_msfilter(): take the guts of setsockopt(MCAST_MSFILTER) into a helper
authorAl Viro <viro@zeniv.linux.org.uk>
Mon, 30 Mar 2020 02:37:56 +0000 (22:37 -0400)
committerAl Viro <viro@zeniv.linux.org.uk>
Thu, 21 May 2020 00:31:28 +0000 (20:31 -0400)
Signed-off-by: Al Viro <viro@zeniv.linux.org.uk>
net/ipv4/ip_sockglue.c

index 65a30e7..cc04411 100644 (file)
@@ -587,6 +587,43 @@ static bool setsockopt_needs_rtnl(int optname)
        return false;
 }
 
+static int set_mcast_msfilter(struct sock *sk, int ifindex,
+                             int numsrc, int fmode,
+                             struct sockaddr_storage *group,
+                             struct sockaddr_storage *list)
+{
+       int msize = IP_MSFILTER_SIZE(numsrc);
+       struct ip_msfilter *msf;
+       struct sockaddr_in *psin;
+       int err, i;
+
+       msf = kmalloc(msize, GFP_KERNEL);
+       if (!msf)
+               return -ENOBUFS;
+
+       psin = (struct sockaddr_in *)group;
+       if (psin->sin_family != AF_INET)
+               goto Eaddrnotavail;
+       msf->imsf_multiaddr = psin->sin_addr.s_addr;
+       msf->imsf_interface = 0;
+       msf->imsf_fmode = fmode;
+       msf->imsf_numsrc = numsrc;
+       for (i = 0; i < numsrc; ++i) {
+               psin = (struct sockaddr_in *)&list[i];
+
+               if (psin->sin_family != AF_INET)
+                       goto Eaddrnotavail;
+               msf->imsf_slist[i] = psin->sin_addr.s_addr;
+       }
+       err = ip_mc_msfilter(sk, msf, ifindex);
+       kfree(msf);
+       return err;
+
+Eaddrnotavail:
+       kfree(msf);
+       return -EADDRNOTAVAIL;
+}
+
 static int do_ip_setsockopt(struct sock *sk, int level,
                            int optname, char __user *optval, unsigned int optlen)
 {
@@ -1079,10 +1116,7 @@ static int do_ip_setsockopt(struct sock *sk, int level,
        }
        case MCAST_MSFILTER:
        {
-               struct sockaddr_in *psin;
-               struct ip_msfilter *msf = NULL;
                struct group_filter *gsf = NULL;
-               int msize, i, ifindex;
 
                if (optlen < GROUP_FILTER_SIZE(0))
                        goto e_inval;
@@ -1095,7 +1129,6 @@ static int do_ip_setsockopt(struct sock *sk, int level,
                        err = PTR_ERR(gsf);
                        break;
                }
-
                /* numsrc >= (4G-140)/128 overflow in 32 bits */
                if (gsf->gf_numsrc >= 0x1ffffff ||
                    gsf->gf_numsrc > net->ipv4.sysctl_igmp_max_msf) {
@@ -1106,36 +1139,10 @@ static int do_ip_setsockopt(struct sock *sk, int level,
                        err = -EINVAL;
                        goto mc_msf_out;
                }
-               msize = IP_MSFILTER_SIZE(gsf->gf_numsrc);
-               msf = kmalloc(msize, GFP_KERNEL);
-               if (!msf) {
-                       err = -ENOBUFS;
-                       goto mc_msf_out;
-               }
-               ifindex = gsf->gf_interface;
-               psin = (struct sockaddr_in *)&gsf->gf_group;
-               if (psin->sin_family != AF_INET) {
-                       err = -EADDRNOTAVAIL;
-                       goto mc_msf_out;
-               }
-               msf->imsf_multiaddr = psin->sin_addr.s_addr;
-               msf->imsf_interface = 0;
-               msf->imsf_fmode = gsf->gf_fmode;
-               msf->imsf_numsrc = gsf->gf_numsrc;
-               err = -EADDRNOTAVAIL;
-               for (i = 0; i < gsf->gf_numsrc; ++i) {
-                       psin = (struct sockaddr_in *)&gsf->gf_slist[i];
-
-                       if (psin->sin_family != AF_INET)
-                               goto mc_msf_out;
-                       msf->imsf_slist[i] = psin->sin_addr.s_addr;
-               }
-               kfree(gsf);
-               gsf = NULL;
-
-               err = ip_mc_msfilter(sk, msf, ifindex);
+               err = set_mcast_msfilter(sk, gsf->gf_interface,
+                                        gsf->gf_numsrc, gsf->gf_fmode,
+                                        &gsf->gf_group, gsf->gf_slist);
 mc_msf_out:
-               kfree(msf);
                kfree(gsf);
                break;
        }