net: fib_rules: Add port mask support
authorIdo Schimmel <idosch@nvidia.com>
Mon, 17 Feb 2025 13:41:03 +0000 (15:41 +0200)
committerJakub Kicinski <kuba@kernel.org>
Thu, 20 Feb 2025 02:43:38 +0000 (18:43 -0800)
Add support for configuring and deleting rules that match on source and
destination ports using a mask as well as support for dumping such rules
to user space.

Reviewed-by: Petr Machata <petrm@nvidia.com>
Signed-off-by: Ido Schimmel <idosch@nvidia.com>
Reviewed-by: Guillaume Nault <gnault@redhat.com>
Reviewed-by: David Ahern <dsahern@kernel.org>
Link: https://patch.msgid.link/20250217134109.311176-3-idosch@nvidia.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
include/net/fib_rules.h
net/core/fib_rules.c

index 710caac..cfeb2fd 100644 (file)
@@ -43,6 +43,8 @@ struct fib_rule {
        struct fib_kuid_range   uid_range;
        struct fib_rule_port_range      sport_range;
        struct fib_rule_port_range      dport_range;
+       u16                     sport_mask;
+       u16                     dport_mask;
        struct rcu_head         rcu;
 };
 
@@ -159,6 +161,12 @@ static inline bool fib_rule_port_range_compare(struct fib_rule_port_range *a,
                a->end == b->end;
 }
 
+static inline bool
+fib_rule_port_is_range(const struct fib_rule_port_range *range)
+{
+       return range->start != range->end;
+}
+
 static inline bool fib_rule_requires_fldissect(struct fib_rule *rule)
 {
        return rule->iifindex != LOOPBACK_IFINDEX && (rule->ip_proto ||
index f5b1900..ba6beaa 100644 (file)
@@ -481,11 +481,17 @@ static struct fib_rule *rule_find(struct fib_rules_ops *ops,
                                                 &rule->sport_range))
                        continue;
 
+               if (rule->sport_mask && r->sport_mask != rule->sport_mask)
+                       continue;
+
                if (fib_rule_port_range_set(&rule->dport_range) &&
                    !fib_rule_port_range_compare(&r->dport_range,
                                                 &rule->dport_range))
                        continue;
 
+               if (rule->dport_mask && r->dport_mask != rule->dport_mask)
+                       continue;
+
                if (!ops->compare(r, frh, tb))
                        continue;
                return r;
@@ -515,6 +521,33 @@ static int fib_nl2rule_l3mdev(struct nlattr *nla, struct fib_rule *nlrule,
 }
 #endif
 
+static int fib_nl2rule_port_mask(const struct nlattr *mask_attr,
+                                const struct fib_rule_port_range *range,
+                                u16 *port_mask,
+                                struct netlink_ext_ack *extack)
+{
+       if (!fib_rule_port_range_valid(range)) {
+               NL_SET_ERR_MSG_ATTR(extack, mask_attr,
+                                   "Cannot specify port mask without port value");
+               return -EINVAL;
+       }
+
+       if (fib_rule_port_is_range(range)) {
+               NL_SET_ERR_MSG_ATTR(extack, mask_attr,
+                                   "Cannot specify port mask for port range");
+               return -EINVAL;
+       }
+
+       if (range->start & ~nla_get_u16(mask_attr)) {
+               NL_SET_ERR_MSG_ATTR(extack, mask_attr, "Invalid port mask");
+               return -EINVAL;
+       }
+
+       *port_mask = nla_get_u16(mask_attr);
+
+       return 0;
+}
+
 static int fib_nl2rule(struct net *net, struct nlmsghdr *nlh,
                       struct netlink_ext_ack *extack,
                       struct fib_rules_ops *ops,
@@ -644,6 +677,16 @@ static int fib_nl2rule(struct net *net, struct nlmsghdr *nlh,
                        NL_SET_ERR_MSG(extack, "Invalid sport range");
                        goto errout_free;
                }
+               if (!fib_rule_port_is_range(&nlrule->sport_range))
+                       nlrule->sport_mask = U16_MAX;
+       }
+
+       if (tb[FRA_SPORT_MASK]) {
+               err = fib_nl2rule_port_mask(tb[FRA_SPORT_MASK],
+                                           &nlrule->sport_range,
+                                           &nlrule->sport_mask, extack);
+               if (err)
+                       goto errout_free;
        }
 
        if (tb[FRA_DPORT_RANGE]) {
@@ -653,6 +696,16 @@ static int fib_nl2rule(struct net *net, struct nlmsghdr *nlh,
                        NL_SET_ERR_MSG(extack, "Invalid dport range");
                        goto errout_free;
                }
+               if (!fib_rule_port_is_range(&nlrule->dport_range))
+                       nlrule->dport_mask = U16_MAX;
+       }
+
+       if (tb[FRA_DPORT_MASK]) {
+               err = fib_nl2rule_port_mask(tb[FRA_DPORT_MASK],
+                                           &nlrule->dport_range,
+                                           &nlrule->dport_mask, extack);
+               if (err)
+                       goto errout_free;
        }
 
        *rule = nlrule;
@@ -751,10 +804,16 @@ static int rule_exists(struct fib_rules_ops *ops, struct fib_rule_hdr *frh,
                                                 &rule->sport_range))
                        continue;
 
+               if (r->sport_mask != rule->sport_mask)
+                       continue;
+
                if (!fib_rule_port_range_compare(&r->dport_range,
                                                 &rule->dport_range))
                        continue;
 
+               if (r->dport_mask != rule->dport_mask)
+                       continue;
+
                if (!ops->compare(r, frh, tb))
                        continue;
                return 1;
@@ -1051,7 +1110,9 @@ static inline size_t fib_rule_nlmsg_size(struct fib_rules_ops *ops,
                         + nla_total_size(1) /* FRA_PROTOCOL */
                         + nla_total_size(1) /* FRA_IP_PROTO */
                         + nla_total_size(sizeof(struct fib_rule_port_range)) /* FRA_SPORT_RANGE */
-                        + nla_total_size(sizeof(struct fib_rule_port_range)); /* FRA_DPORT_RANGE */
+                        + nla_total_size(sizeof(struct fib_rule_port_range)) /* FRA_DPORT_RANGE */
+                        + nla_total_size(2) /* FRA_SPORT_MASK */
+                        + nla_total_size(2); /* FRA_DPORT_MASK */
 
        if (ops->nlmsg_payload)
                payload += ops->nlmsg_payload(rule);
@@ -1119,8 +1180,12 @@ static int fib_nl_fill_rule(struct sk_buff *skb, struct fib_rule *rule,
             nla_put_uid_range(skb, &rule->uid_range)) ||
            (fib_rule_port_range_set(&rule->sport_range) &&
             nla_put_port_range(skb, FRA_SPORT_RANGE, &rule->sport_range)) ||
+           (rule->sport_mask && nla_put_u16(skb, FRA_SPORT_MASK,
+                                            rule->sport_mask)) ||
            (fib_rule_port_range_set(&rule->dport_range) &&
             nla_put_port_range(skb, FRA_DPORT_RANGE, &rule->dport_range)) ||
+           (rule->dport_mask && nla_put_u16(skb, FRA_DPORT_MASK,
+                                            rule->dport_mask)) ||
            (rule->ip_proto && nla_put_u8(skb, FRA_IP_PROTO, rule->ip_proto)))
                goto nla_put_failure;