fou: Support binding FoU socket
[linux-2.6-microblaze.git] / net / ipv4 / fou.c
index 79e98e2..100e63f 100644 (file)
@@ -499,15 +499,45 @@ out_unlock:
        return err;
 }
 
-static int fou_add_to_port_list(struct net *net, struct fou *fou)
+static bool fou_cfg_cmp(struct fou *fou, struct fou_cfg *cfg)
+{
+       struct sock *sk = fou->sock->sk;
+       struct udp_port_cfg *udp_cfg = &cfg->udp_config;
+
+       if (fou->family != udp_cfg->family ||
+           fou->port != udp_cfg->local_udp_port ||
+           sk->sk_dport != udp_cfg->peer_udp_port ||
+           sk->sk_bound_dev_if != udp_cfg->bind_ifindex)
+               return false;
+
+       if (fou->family == AF_INET) {
+               if (sk->sk_rcv_saddr != udp_cfg->local_ip.s_addr ||
+                   sk->sk_daddr != udp_cfg->peer_ip.s_addr)
+                       return false;
+               else
+                       return true;
+#if IS_ENABLED(CONFIG_IPV6)
+       } else {
+               if (ipv6_addr_cmp(&sk->sk_v6_rcv_saddr, &udp_cfg->local_ip6) ||
+                   ipv6_addr_cmp(&sk->sk_v6_daddr, &udp_cfg->peer_ip6))
+                       return false;
+               else
+                       return true;
+#endif
+       }
+
+       return false;
+}
+
+static int fou_add_to_port_list(struct net *net, struct fou *fou,
+                               struct fou_cfg *cfg)
 {
        struct fou_net *fn = net_generic(net, fou_net_id);
        struct fou *fout;
 
        mutex_lock(&fn->fou_lock);
        list_for_each_entry(fout, &fn->fou_list, list) {
-               if (fou->port == fout->port &&
-                   fou->family == fout->family) {
+               if (fou_cfg_cmp(fout, cfg)) {
                        mutex_unlock(&fn->fou_lock);
                        return -EALREADY;
                }
@@ -585,7 +615,7 @@ static int fou_create(struct net *net, struct fou_cfg *cfg,
 
        sk->sk_allocation = GFP_ATOMIC;
 
-       err = fou_add_to_port_list(net, fou);
+       err = fou_add_to_port_list(net, fou, cfg);
        if (err)
                goto error;
 
@@ -605,14 +635,12 @@ error:
 static int fou_destroy(struct net *net, struct fou_cfg *cfg)
 {
        struct fou_net *fn = net_generic(net, fou_net_id);
-       __be16 port = cfg->udp_config.local_udp_port;
-       u8 family = cfg->udp_config.family;
        int err = -EINVAL;
        struct fou *fou;
 
        mutex_lock(&fn->fou_lock);
        list_for_each_entry(fou, &fn->fou_list, list) {
-               if (fou->port == port && fou->family == family) {
+               if (fou_cfg_cmp(fou, cfg)) {
                        fou_release(fou);
                        err = 0;
                        break;
@@ -626,16 +654,27 @@ static int fou_destroy(struct net *net, struct fou_cfg *cfg)
 static struct genl_family fou_nl_family;
 
 static const struct nla_policy fou_nl_policy[FOU_ATTR_MAX + 1] = {
-       [FOU_ATTR_PORT] = { .type = NLA_U16, },
-       [FOU_ATTR_AF] = { .type = NLA_U8, },
-       [FOU_ATTR_IPPROTO] = { .type = NLA_U8, },
-       [FOU_ATTR_TYPE] = { .type = NLA_U8, },
-       [FOU_ATTR_REMCSUM_NOPARTIAL] = { .type = NLA_FLAG, },
+       [FOU_ATTR_PORT]                 = { .type = NLA_U16, },
+       [FOU_ATTR_AF]                   = { .type = NLA_U8, },
+       [FOU_ATTR_IPPROTO]              = { .type = NLA_U8, },
+       [FOU_ATTR_TYPE]                 = { .type = NLA_U8, },
+       [FOU_ATTR_REMCSUM_NOPARTIAL]    = { .type = NLA_FLAG, },
+       [FOU_ATTR_LOCAL_V4]             = { .type = NLA_U32, },
+       [FOU_ATTR_PEER_V4]              = { .type = NLA_U32, },
+       [FOU_ATTR_LOCAL_V6]             = { .type = sizeof(struct in6_addr), },
+       [FOU_ATTR_PEER_V6]              = { .type = sizeof(struct in6_addr), },
+       [FOU_ATTR_PEER_PORT]            = { .type = NLA_U16, },
+       [FOU_ATTR_IFINDEX]              = { .type = NLA_S32, },
 };
 
 static int parse_nl_config(struct genl_info *info,
                           struct fou_cfg *cfg)
 {
+       bool has_local = false, has_peer = false;
+       struct nlattr *attr;
+       int ifindex;
+       __be16 port;
+
        memset(cfg, 0, sizeof(*cfg));
 
        cfg->udp_config.family = AF_INET;
@@ -657,8 +696,7 @@ static int parse_nl_config(struct genl_info *info,
        }
 
        if (info->attrs[FOU_ATTR_PORT]) {
-               __be16 port = nla_get_be16(info->attrs[FOU_ATTR_PORT]);
-
+               port = nla_get_be16(info->attrs[FOU_ATTR_PORT]);
                cfg->udp_config.local_udp_port = port;
        }
 
@@ -671,6 +709,52 @@ static int parse_nl_config(struct genl_info *info,
        if (info->attrs[FOU_ATTR_REMCSUM_NOPARTIAL])
                cfg->flags |= FOU_F_REMCSUM_NOPARTIAL;
 
+       if (cfg->udp_config.family == AF_INET) {
+               if (info->attrs[FOU_ATTR_LOCAL_V4]) {
+                       attr = info->attrs[FOU_ATTR_LOCAL_V4];
+                       cfg->udp_config.local_ip.s_addr = nla_get_in_addr(attr);
+                       has_local = true;
+               }
+
+               if (info->attrs[FOU_ATTR_PEER_V4]) {
+                       attr = info->attrs[FOU_ATTR_PEER_V4];
+                       cfg->udp_config.peer_ip.s_addr = nla_get_in_addr(attr);
+                       has_peer = true;
+               }
+#if IS_ENABLED(CONFIG_IPV6)
+       } else {
+               if (info->attrs[FOU_ATTR_LOCAL_V6]) {
+                       attr = info->attrs[FOU_ATTR_LOCAL_V6];
+                       cfg->udp_config.local_ip6 = nla_get_in6_addr(attr);
+                       has_local = true;
+               }
+
+               if (info->attrs[FOU_ATTR_PEER_V6]) {
+                       attr = info->attrs[FOU_ATTR_PEER_V6];
+                       cfg->udp_config.peer_ip6 = nla_get_in6_addr(attr);
+                       has_peer = true;
+               }
+#endif
+       }
+
+       if (has_peer) {
+               if (info->attrs[FOU_ATTR_PEER_PORT]) {
+                       port = nla_get_be16(info->attrs[FOU_ATTR_PEER_PORT]);
+                       cfg->udp_config.peer_udp_port = port;
+               } else {
+                       return -EINVAL;
+               }
+       }
+
+       if (info->attrs[FOU_ATTR_IFINDEX]) {
+               if (!has_local)
+                       return -EINVAL;
+
+               ifindex = nla_get_s32(info->attrs[FOU_ATTR_IFINDEX]);
+
+               cfg->udp_config.bind_ifindex = ifindex;
+       }
+
        return 0;
 }
 
@@ -702,15 +786,37 @@ static int fou_nl_cmd_rm_port(struct sk_buff *skb, struct genl_info *info)
 
 static int fou_fill_info(struct fou *fou, struct sk_buff *msg)
 {
+       struct sock *sk = fou->sock->sk;
+
        if (nla_put_u8(msg, FOU_ATTR_AF, fou->sock->sk->sk_family) ||
            nla_put_be16(msg, FOU_ATTR_PORT, fou->port) ||
+           nla_put_be16(msg, FOU_ATTR_PEER_PORT, sk->sk_dport) ||
            nla_put_u8(msg, FOU_ATTR_IPPROTO, fou->protocol) ||
-           nla_put_u8(msg, FOU_ATTR_TYPE, fou->type))
+           nla_put_u8(msg, FOU_ATTR_TYPE, fou->type) ||
+           nla_put_s32(msg, FOU_ATTR_IFINDEX, sk->sk_bound_dev_if))
                return -1;
 
        if (fou->flags & FOU_F_REMCSUM_NOPARTIAL)
                if (nla_put_flag(msg, FOU_ATTR_REMCSUM_NOPARTIAL))
                        return -1;
+
+       if (fou->sock->sk->sk_family == AF_INET) {
+               if (nla_put_in_addr(msg, FOU_ATTR_LOCAL_V4, sk->sk_rcv_saddr))
+                       return -1;
+
+               if (nla_put_in_addr(msg, FOU_ATTR_PEER_V4, sk->sk_daddr))
+                       return -1;
+#if IS_ENABLED(CONFIG_IPV6)
+       } else {
+               if (nla_put_in6_addr(msg, FOU_ATTR_LOCAL_V6,
+                                    &sk->sk_v6_rcv_saddr))
+                       return -1;
+
+               if (nla_put_in6_addr(msg, FOU_ATTR_PEER_V6, &sk->sk_v6_daddr))
+                       return -1;
+#endif
+       }
+
        return 0;
 }
 
@@ -763,7 +869,7 @@ static int fou_nl_cmd_get_port(struct sk_buff *skb, struct genl_info *info)
        ret = -ESRCH;
        mutex_lock(&fn->fou_lock);
        list_for_each_entry(fout, &fn->fou_list, list) {
-               if (port == fout->port && family == fout->family) {
+               if (fou_cfg_cmp(fout, &cfg)) {
                        ret = fou_dump_info(fout, info->snd_portid,
                                            info->snd_seq, 0, msg,
                                            info->genlhdr->cmd);
@@ -808,20 +914,17 @@ static const struct genl_ops fou_nl_ops[] = {
        {
                .cmd = FOU_CMD_ADD,
                .doit = fou_nl_cmd_add_port,
-               .policy = fou_nl_policy,
                .flags = GENL_ADMIN_PERM,
        },
        {
                .cmd = FOU_CMD_DEL,
                .doit = fou_nl_cmd_rm_port,
-               .policy = fou_nl_policy,
                .flags = GENL_ADMIN_PERM,
        },
        {
                .cmd = FOU_CMD_GET,
                .doit = fou_nl_cmd_get_port,
                .dumpit = fou_nl_dump,
-               .policy = fou_nl_policy,
        },
 };
 
@@ -830,6 +933,7 @@ static struct genl_family fou_nl_family __ro_after_init = {
        .name           = FOU_GENL_NAME,
        .version        = FOU_GENL_VERSION,
        .maxattr        = FOU_ATTR_MAX,
+       .policy = fou_nl_policy,
        .netnsok        = true,
        .module         = THIS_MODULE,
        .ops            = fou_nl_ops,