Merge branch 'dmi-for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/jdelvar...
[linux-2.6-microblaze.git] / lib / nlattr.c
index cace9b3..bc5b5cf 100644 (file)
@@ -44,8 +44,22 @@ static const u8 nla_attr_minlen[NLA_TYPE_MAX+1] = {
        [NLA_S64]       = sizeof(s64),
 };
 
+/*
+ * Nested policies might refer back to the original
+ * policy in some cases, and userspace could try to
+ * abuse that and recurse by nesting in the right
+ * ways. Limit recursion to avoid this problem.
+ */
+#define MAX_POLICY_RECURSION_DEPTH     10
+
+static int __nla_validate_parse(const struct nlattr *head, int len, int maxtype,
+                               const struct nla_policy *policy,
+                               unsigned int validate,
+                               struct netlink_ext_ack *extack,
+                               struct nlattr **tb, unsigned int depth);
+
 static int validate_nla_bitfield32(const struct nlattr *nla,
-                                  const u32 *valid_flags_mask)
+                                  const u32 valid_flags_mask)
 {
        const struct nla_bitfield32 *bf = nla_data(nla);
 
@@ -53,11 +67,11 @@ static int validate_nla_bitfield32(const struct nlattr *nla,
                return -EINVAL;
 
        /*disallow invalid bit selector */
-       if (bf->selector & ~*valid_flags_mask)
+       if (bf->selector & ~valid_flags_mask)
                return -EINVAL;
 
        /*disallow invalid bit values */
-       if (bf->value & ~*valid_flags_mask)
+       if (bf->value & ~valid_flags_mask)
                return -EINVAL;
 
        /*disallow valid bit values that are not selected*/
@@ -70,7 +84,7 @@ static int validate_nla_bitfield32(const struct nlattr *nla,
 static int nla_validate_array(const struct nlattr *head, int len, int maxtype,
                              const struct nla_policy *policy,
                              struct netlink_ext_ack *extack,
-                             unsigned int validate)
+                             unsigned int validate, unsigned int depth)
 {
        const struct nlattr *entry;
        int rem;
@@ -87,8 +101,9 @@ static int nla_validate_array(const struct nlattr *head, int len, int maxtype,
                        return -ERANGE;
                }
 
-               ret = __nla_validate(nla_data(entry), nla_len(entry),
-                                    maxtype, policy, validate, extack);
+               ret = __nla_validate_parse(nla_data(entry), nla_len(entry),
+                                          maxtype, policy, validate, extack,
+                                          NULL, depth + 1);
                if (ret < 0)
                        return ret;
        }
@@ -96,17 +111,58 @@ static int nla_validate_array(const struct nlattr *head, int len, int maxtype,
        return 0;
 }
 
-static int nla_validate_int_range(const struct nla_policy *pt,
-                                 const struct nlattr *nla,
-                                 struct netlink_ext_ack *extack)
+void nla_get_range_unsigned(const struct nla_policy *pt,
+                           struct netlink_range_validation *range)
 {
-       bool validate_min, validate_max;
-       s64 value;
+       WARN_ON_ONCE(pt->validation_type != NLA_VALIDATE_RANGE_PTR &&
+                    (pt->min < 0 || pt->max < 0));
 
-       validate_min = pt->validation_type == NLA_VALIDATE_RANGE ||
-                      pt->validation_type == NLA_VALIDATE_MIN;
-       validate_max = pt->validation_type == NLA_VALIDATE_RANGE ||
-                      pt->validation_type == NLA_VALIDATE_MAX;
+       range->min = 0;
+
+       switch (pt->type) {
+       case NLA_U8:
+               range->max = U8_MAX;
+               break;
+       case NLA_U16:
+               range->max = U16_MAX;
+               break;
+       case NLA_U32:
+               range->max = U32_MAX;
+               break;
+       case NLA_U64:
+       case NLA_MSECS:
+               range->max = U64_MAX;
+               break;
+       default:
+               WARN_ON_ONCE(1);
+               return;
+       }
+
+       switch (pt->validation_type) {
+       case NLA_VALIDATE_RANGE:
+               range->min = pt->min;
+               range->max = pt->max;
+               break;
+       case NLA_VALIDATE_RANGE_PTR:
+               *range = *pt->range;
+               break;
+       case NLA_VALIDATE_MIN:
+               range->min = pt->min;
+               break;
+       case NLA_VALIDATE_MAX:
+               range->max = pt->max;
+               break;
+       default:
+               break;
+       }
+}
+
+static int nla_validate_int_range_unsigned(const struct nla_policy *pt,
+                                          const struct nlattr *nla,
+                                          struct netlink_ext_ack *extack)
+{
+       struct netlink_range_validation range;
+       u64 value;
 
        switch (pt->type) {
        case NLA_U8:
@@ -118,6 +174,77 @@ static int nla_validate_int_range(const struct nla_policy *pt,
        case NLA_U32:
                value = nla_get_u32(nla);
                break;
+       case NLA_U64:
+       case NLA_MSECS:
+               value = nla_get_u64(nla);
+               break;
+       default:
+               return -EINVAL;
+       }
+
+       nla_get_range_unsigned(pt, &range);
+
+       if (value < range.min || value > range.max) {
+               NL_SET_ERR_MSG_ATTR(extack, nla,
+                                   "integer out of range");
+               return -ERANGE;
+       }
+
+       return 0;
+}
+
+void nla_get_range_signed(const struct nla_policy *pt,
+                         struct netlink_range_validation_signed *range)
+{
+       switch (pt->type) {
+       case NLA_S8:
+               range->min = S8_MIN;
+               range->max = S8_MAX;
+               break;
+       case NLA_S16:
+               range->min = S16_MIN;
+               range->max = S16_MAX;
+               break;
+       case NLA_S32:
+               range->min = S32_MIN;
+               range->max = S32_MAX;
+               break;
+       case NLA_S64:
+               range->min = S64_MIN;
+               range->max = S64_MAX;
+               break;
+       default:
+               WARN_ON_ONCE(1);
+               return;
+       }
+
+       switch (pt->validation_type) {
+       case NLA_VALIDATE_RANGE:
+               range->min = pt->min;
+               range->max = pt->max;
+               break;
+       case NLA_VALIDATE_RANGE_PTR:
+               *range = *pt->range_signed;
+               break;
+       case NLA_VALIDATE_MIN:
+               range->min = pt->min;
+               break;
+       case NLA_VALIDATE_MAX:
+               range->max = pt->max;
+               break;
+       default:
+               break;
+       }
+}
+
+static int nla_validate_int_range_signed(const struct nla_policy *pt,
+                                        const struct nlattr *nla,
+                                        struct netlink_ext_ack *extack)
+{
+       struct netlink_range_validation_signed range;
+       s64 value;
+
+       switch (pt->type) {
        case NLA_S8:
                value = nla_get_s8(nla);
                break;
@@ -130,22 +257,13 @@ static int nla_validate_int_range(const struct nla_policy *pt,
        case NLA_S64:
                value = nla_get_s64(nla);
                break;
-       case NLA_U64:
-               /* treat this one specially, since it may not fit into s64 */
-               if ((validate_min && nla_get_u64(nla) < pt->min) ||
-                   (validate_max && nla_get_u64(nla) > pt->max)) {
-                       NL_SET_ERR_MSG_ATTR(extack, nla,
-                                           "integer out of range");
-                       return -ERANGE;
-               }
-               return 0;
        default:
-               WARN_ON(1);
                return -EINVAL;
        }
 
-       if ((validate_min && value < pt->min) ||
-           (validate_max && value > pt->max)) {
+       nla_get_range_signed(pt, &range);
+
+       if (value < range.min || value > range.max) {
                NL_SET_ERR_MSG_ATTR(extack, nla,
                                    "integer out of range");
                return -ERANGE;
@@ -154,9 +272,31 @@ static int nla_validate_int_range(const struct nla_policy *pt,
        return 0;
 }
 
+static int nla_validate_int_range(const struct nla_policy *pt,
+                                 const struct nlattr *nla,
+                                 struct netlink_ext_ack *extack)
+{
+       switch (pt->type) {
+       case NLA_U8:
+       case NLA_U16:
+       case NLA_U32:
+       case NLA_U64:
+       case NLA_MSECS:
+               return nla_validate_int_range_unsigned(pt, nla, extack);
+       case NLA_S8:
+       case NLA_S16:
+       case NLA_S32:
+       case NLA_S64:
+               return nla_validate_int_range_signed(pt, nla, extack);
+       default:
+               WARN_ON(1);
+               return -EINVAL;
+       }
+}
+
 static int validate_nla(const struct nlattr *nla, int maxtype,
                        const struct nla_policy *policy, unsigned int validate,
-                       struct netlink_ext_ack *extack)
+                       struct netlink_ext_ack *extack, unsigned int depth)
 {
        u16 strict_start_type = policy[0].strict_start_type;
        const struct nla_policy *pt;
@@ -174,7 +314,9 @@ static int validate_nla(const struct nlattr *nla, int maxtype,
        BUG_ON(pt->type > NLA_TYPE_MAX);
 
        if ((nla_attr_len[pt->type] && attrlen != nla_attr_len[pt->type]) ||
-           (pt->type == NLA_EXACT_LEN_WARN && attrlen != pt->len)) {
+           (pt->type == NLA_EXACT_LEN &&
+            pt->validation_type == NLA_VALIDATE_WARN_TOO_LONG &&
+            attrlen != pt->len)) {
                pr_warn_ratelimited("netlink: '%s': attribute type %d has an invalid length.\n",
                                    current->comm, type);
                if (validate & NL_VALIDATE_STRICT_ATTRS) {
@@ -200,15 +342,10 @@ static int validate_nla(const struct nlattr *nla, int maxtype,
        }
 
        switch (pt->type) {
-       case NLA_EXACT_LEN:
-               if (attrlen != pt->len)
-                       goto out_err;
-               break;
-
        case NLA_REJECT:
-               if (extack && pt->validation_data) {
+               if (extack && pt->reject_message) {
                        NL_SET_BAD_ATTR(extack, nla);
-                       extack->_msg = pt->validation_data;
+                       extack->_msg = pt->reject_message;
                        return -EINVAL;
                }
                err = -EINVAL;
@@ -223,7 +360,7 @@ static int validate_nla(const struct nlattr *nla, int maxtype,
                if (attrlen != sizeof(struct nla_bitfield32))
                        goto out_err;
 
-               err = validate_nla_bitfield32(nla, pt->validation_data);
+               err = validate_nla_bitfield32(nla, pt->bitfield32_valid);
                if (err)
                        goto out_err;
                break;
@@ -268,10 +405,11 @@ static int validate_nla(const struct nlattr *nla, int maxtype,
                        break;
                if (attrlen < NLA_HDRLEN)
                        goto out_err;
-               if (pt->validation_data) {
-                       err = __nla_validate(nla_data(nla), nla_len(nla), pt->len,
-                                            pt->validation_data, validate,
-                                            extack);
+               if (pt->nested_policy) {
+                       err = __nla_validate_parse(nla_data(nla), nla_len(nla),
+                                                  pt->len, pt->nested_policy,
+                                                  validate, extack, NULL,
+                                                  depth + 1);
                        if (err < 0) {
                                /*
                                 * return directly to preserve the inner
@@ -289,12 +427,12 @@ static int validate_nla(const struct nlattr *nla, int maxtype,
                        break;
                if (attrlen < NLA_HDRLEN)
                        goto out_err;
-               if (pt->validation_data) {
+               if (pt->nested_policy) {
                        int err;
 
                        err = nla_validate_array(nla_data(nla), nla_len(nla),
-                                                pt->len, pt->validation_data,
-                                                extack, validate);
+                                                pt->len, pt->nested_policy,
+                                                extack, validate, depth);
                        if (err < 0) {
                                /*
                                 * return directly to preserve the inner
@@ -317,6 +455,13 @@ static int validate_nla(const struct nlattr *nla, int maxtype,
                        goto out_err;
                break;
 
+       case NLA_EXACT_LEN:
+               if (pt->validation_type != NLA_VALIDATE_WARN_TOO_LONG) {
+                       if (attrlen != pt->len)
+                               goto out_err;
+                       break;
+               }
+               /* fall through */
        default:
                if (pt->len)
                        minlen = pt->len;
@@ -332,6 +477,7 @@ static int validate_nla(const struct nlattr *nla, int maxtype,
        case NLA_VALIDATE_NONE:
                /* nothing to do */
                break;
+       case NLA_VALIDATE_RANGE_PTR:
        case NLA_VALIDATE_RANGE:
        case NLA_VALIDATE_MIN:
        case NLA_VALIDATE_MAX:
@@ -358,11 +504,17 @@ static int __nla_validate_parse(const struct nlattr *head, int len, int maxtype,
                                const struct nla_policy *policy,
                                unsigned int validate,
                                struct netlink_ext_ack *extack,
-                               struct nlattr **tb)
+                               struct nlattr **tb, unsigned int depth)
 {
        const struct nlattr *nla;
        int rem;
 
+       if (depth >= MAX_POLICY_RECURSION_DEPTH) {
+               NL_SET_ERR_MSG(extack,
+                              "allowed policy recursion depth exceeded");
+               return -EINVAL;
+       }
+
        if (tb)
                memset(tb, 0, sizeof(struct nlattr *) * (maxtype + 1));
 
@@ -379,7 +531,7 @@ static int __nla_validate_parse(const struct nlattr *head, int len, int maxtype,
                }
                if (policy) {
                        int err = validate_nla(nla, maxtype, policy,
-                                              validate, extack);
+                                              validate, extack, depth);
 
                        if (err < 0)
                                return err;
@@ -421,7 +573,7 @@ int __nla_validate(const struct nlattr *head, int len, int maxtype,
                   struct netlink_ext_ack *extack)
 {
        return __nla_validate_parse(head, len, maxtype, policy, validate,
-                                   extack, NULL);
+                                   extack, NULL, 0);
 }
 EXPORT_SYMBOL(__nla_validate);
 
@@ -476,7 +628,7 @@ int __nla_parse(struct nlattr **tb, int maxtype,
                struct netlink_ext_ack *extack)
 {
        return __nla_validate_parse(head, len, maxtype, policy, validate,
-                                   extack, tb);
+                                   extack, tb, 0);
 }
 EXPORT_SYMBOL(__nla_parse);