Merge git://git.kernel.org/pub/scm/linux/kernel/git/davem/net
[linux-2.6-microblaze.git] / net / mpls / af_mpls.c
index 8fbe6cd..f7c5445 100644 (file)
@@ -1209,21 +1209,57 @@ static const struct nla_policy devconf_mpls_policy[NETCONFA_MAX + 1] = {
        [NETCONFA_IFINDEX]      = { .len = sizeof(int) },
 };
 
+static int mpls_netconf_valid_get_req(struct sk_buff *skb,
+                                     const struct nlmsghdr *nlh,
+                                     struct nlattr **tb,
+                                     struct netlink_ext_ack *extack)
+{
+       int i, err;
+
+       if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(struct netconfmsg))) {
+               NL_SET_ERR_MSG_MOD(extack,
+                                  "Invalid header for netconf get request");
+               return -EINVAL;
+       }
+
+       if (!netlink_strict_get_check(skb))
+               return nlmsg_parse(nlh, sizeof(struct netconfmsg), tb,
+                                  NETCONFA_MAX, devconf_mpls_policy, extack);
+
+       err = nlmsg_parse_strict(nlh, sizeof(struct netconfmsg), tb,
+                                NETCONFA_MAX, devconf_mpls_policy, extack);
+       if (err)
+               return err;
+
+       for (i = 0; i <= NETCONFA_MAX; i++) {
+               if (!tb[i])
+                       continue;
+
+               switch (i) {
+               case NETCONFA_IFINDEX:
+                       break;
+               default:
+                       NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in netconf get request");
+                       return -EINVAL;
+               }
+       }
+
+       return 0;
+}
+
 static int mpls_netconf_get_devconf(struct sk_buff *in_skb,
                                    struct nlmsghdr *nlh,
                                    struct netlink_ext_ack *extack)
 {
        struct net *net = sock_net(in_skb->sk);
        struct nlattr *tb[NETCONFA_MAX + 1];
-       struct netconfmsg *ncm;
        struct net_device *dev;
        struct mpls_dev *mdev;
        struct sk_buff *skb;
        int ifindex;
        int err;
 
-       err = nlmsg_parse(nlh, sizeof(*ncm), tb, NETCONFA_MAX,
-                         devconf_mpls_policy, NULL);
+       err = mpls_netconf_valid_get_req(in_skb, nlh, tb, extack);
        if (err < 0)
                goto errout;
 
@@ -1263,6 +1299,7 @@ errout:
 static int mpls_netconf_dump_devconf(struct sk_buff *skb,
                                     struct netlink_callback *cb)
 {
+       const struct nlmsghdr *nlh = cb->nlh;
        struct net *net = sock_net(skb->sk);
        struct hlist_head *head;
        struct net_device *dev;
@@ -1270,6 +1307,21 @@ static int mpls_netconf_dump_devconf(struct sk_buff *skb,
        int idx, s_idx;
        int h, s_h;
 
+       if (cb->strict_check) {
+               struct netlink_ext_ack *extack = cb->extack;
+               struct netconfmsg *ncm;
+
+               if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ncm))) {
+                       NL_SET_ERR_MSG_MOD(extack, "Invalid header for netconf dump request");
+                       return -EINVAL;
+               }
+
+               if (nlmsg_attrlen(nlh, sizeof(*ncm))) {
+                       NL_SET_ERR_MSG_MOD(extack, "Invalid data after header in netconf dump request");
+                       return -EINVAL;
+               }
+       }
+
        s_h = cb->args[0];
        s_idx = idx = cb->args[1];
 
@@ -1286,7 +1338,7 @@ static int mpls_netconf_dump_devconf(struct sk_buff *skb,
                                goto cont;
                        if (mpls_netconf_fill_devconf(skb, mdev,
                                                      NETLINK_CB(cb->skb).portid,
-                                                     cb->nlh->nlmsg_seq,
+                                                     nlh->nlmsg_seq,
                                                      RTM_NEWNETCONF,
                                                      NLM_F_MULTI,
                                                      NETCONFA_ALL) < 0) {
@@ -1822,6 +1874,9 @@ static int rtm_to_route_config(struct sk_buff *skb,
                                goto errout;
                        break;
                }
+               case RTA_GATEWAY:
+                       NL_SET_ERR_MSG(extack, "MPLS does not support RTA_GATEWAY attribute");
+                       goto errout;
                case RTA_VIA:
                {
                        if (nla_get_via(nla, &cfg->rc_via_alen,
@@ -2015,30 +2070,140 @@ nla_put_failure:
        return -EMSGSIZE;
 }
 
+#if IS_ENABLED(CONFIG_INET)
+static int mpls_valid_fib_dump_req(struct net *net, const struct nlmsghdr *nlh,
+                                  struct fib_dump_filter *filter,
+                                  struct netlink_callback *cb)
+{
+       return ip_valid_fib_dump_req(net, nlh, filter, cb);
+}
+#else
+static int mpls_valid_fib_dump_req(struct net *net, const struct nlmsghdr *nlh,
+                                  struct fib_dump_filter *filter,
+                                  struct netlink_callback *cb)
+{
+       struct netlink_ext_ack *extack = cb->extack;
+       struct nlattr *tb[RTA_MAX + 1];
+       struct rtmsg *rtm;
+       int err, i;
+
+       if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) {
+               NL_SET_ERR_MSG_MOD(extack, "Invalid header for FIB dump request");
+               return -EINVAL;
+       }
+
+       rtm = nlmsg_data(nlh);
+       if (rtm->rtm_dst_len || rtm->rtm_src_len  || rtm->rtm_tos   ||
+           rtm->rtm_table   || rtm->rtm_scope    || rtm->rtm_type  ||
+           rtm->rtm_flags) {
+               NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for FIB dump request");
+               return -EINVAL;
+       }
+
+       if (rtm->rtm_protocol) {
+               filter->protocol = rtm->rtm_protocol;
+               filter->filter_set = 1;
+               cb->answer_flags = NLM_F_DUMP_FILTERED;
+       }
+
+       err = nlmsg_parse_strict(nlh, sizeof(*rtm), tb, RTA_MAX,
+                                rtm_mpls_policy, extack);
+       if (err < 0)
+               return err;
+
+       for (i = 0; i <= RTA_MAX; ++i) {
+               int ifindex;
+
+               if (i == RTA_OIF) {
+                       ifindex = nla_get_u32(tb[i]);
+                       filter->dev = __dev_get_by_index(net, ifindex);
+                       if (!filter->dev)
+                               return -ENODEV;
+                       filter->filter_set = 1;
+               } else if (tb[i]) {
+                       NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in dump request");
+                       return -EINVAL;
+               }
+       }
+
+       return 0;
+}
+#endif
+
+static bool mpls_rt_uses_dev(struct mpls_route *rt,
+                            const struct net_device *dev)
+{
+       struct net_device *nh_dev;
+
+       if (rt->rt_nhn == 1) {
+               struct mpls_nh *nh = rt->rt_nh;
+
+               nh_dev = rtnl_dereference(nh->nh_dev);
+               if (dev == nh_dev)
+                       return true;
+       } else {
+               for_nexthops(rt) {
+                       nh_dev = rtnl_dereference(nh->nh_dev);
+                       if (nh_dev == dev)
+                               return true;
+               } endfor_nexthops(rt);
+       }
+
+       return false;
+}
+
 static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
 {
+       const struct nlmsghdr *nlh = cb->nlh;
        struct net *net = sock_net(skb->sk);
        struct mpls_route __rcu **platform_label;
+       struct fib_dump_filter filter = {};
+       unsigned int flags = NLM_F_MULTI;
        size_t platform_labels;
        unsigned int index;
 
        ASSERT_RTNL();
 
+       if (cb->strict_check) {
+               int err;
+
+               err = mpls_valid_fib_dump_req(net, nlh, &filter, cb);
+               if (err < 0)
+                       return err;
+
+               /* for MPLS, there is only 1 table with fixed type and flags.
+                * If either are set in the filter then return nothing.
+                */
+               if ((filter.table_id && filter.table_id != RT_TABLE_MAIN) ||
+                   (filter.rt_type && filter.rt_type != RTN_UNICAST) ||
+                    filter.flags)
+                       return skb->len;
+       }
+
        index = cb->args[0];
        if (index < MPLS_LABEL_FIRST_UNRESERVED)
                index = MPLS_LABEL_FIRST_UNRESERVED;
 
        platform_label = rtnl_dereference(net->mpls.platform_label);
        platform_labels = net->mpls.platform_labels;
+
+       if (filter.filter_set)
+               flags |= NLM_F_DUMP_FILTERED;
+
        for (; index < platform_labels; index++) {
                struct mpls_route *rt;
+
                rt = rtnl_dereference(platform_label[index]);
                if (!rt)
                        continue;
 
+               if ((filter.dev && !mpls_rt_uses_dev(rt, filter.dev)) ||
+                   (filter.protocol && rt->rt_protocol != filter.protocol))
+                       continue;
+
                if (mpls_dump_route(skb, NETLINK_CB(cb->skb).portid,
                                    cb->nlh->nlmsg_seq, RTM_NEWROUTE,
-                                   index, rt, NLM_F_MULTI) < 0)
+                                   index, rt, flags) < 0)
                        break;
        }
        cb->args[0] = index;
@@ -2110,6 +2275,64 @@ errout:
                rtnl_set_sk_err(net, RTNLGRP_MPLS_ROUTE, err);
 }
 
+static int mpls_valid_getroute_req(struct sk_buff *skb,
+                                  const struct nlmsghdr *nlh,
+                                  struct nlattr **tb,
+                                  struct netlink_ext_ack *extack)
+{
+       struct rtmsg *rtm;
+       int i, err;
+
+       if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) {
+               NL_SET_ERR_MSG_MOD(extack,
+                                  "Invalid header for get route request");
+               return -EINVAL;
+       }
+
+       if (!netlink_strict_get_check(skb))
+               return nlmsg_parse(nlh, sizeof(*rtm), tb, RTA_MAX,
+                                  rtm_mpls_policy, extack);
+
+       rtm = nlmsg_data(nlh);
+       if ((rtm->rtm_dst_len && rtm->rtm_dst_len != 20) ||
+           rtm->rtm_src_len || rtm->rtm_tos || rtm->rtm_table ||
+           rtm->rtm_protocol || rtm->rtm_scope || rtm->rtm_type) {
+               NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for get route request");
+               return -EINVAL;
+       }
+       if (rtm->rtm_flags & ~RTM_F_FIB_MATCH) {
+               NL_SET_ERR_MSG_MOD(extack,
+                                  "Invalid flags for get route request");
+               return -EINVAL;
+       }
+
+       err = nlmsg_parse_strict(nlh, sizeof(*rtm), tb, RTA_MAX,
+                                rtm_mpls_policy, extack);
+       if (err)
+               return err;
+
+       if ((tb[RTA_DST] || tb[RTA_NEWDST]) && !rtm->rtm_dst_len) {
+               NL_SET_ERR_MSG_MOD(extack, "rtm_dst_len must be 20 for MPLS");
+               return -EINVAL;
+       }
+
+       for (i = 0; i <= RTA_MAX; i++) {
+               if (!tb[i])
+                       continue;
+
+               switch (i) {
+               case RTA_DST:
+               case RTA_NEWDST:
+                       break;
+               default:
+                       NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in get route request");
+                       return -EINVAL;
+               }
+       }
+
+       return 0;
+}
+
 static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
                         struct netlink_ext_ack *extack)
 {
@@ -2129,8 +2352,7 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
        u8 n_labels;
        int err;
 
-       err = nlmsg_parse(in_nlh, sizeof(*rtm), tb, RTA_MAX,
-                         rtm_mpls_policy, extack);
+       err = mpls_valid_getroute_req(in_skb, in_nlh, tb, extack);
        if (err < 0)
                goto errout;