netfilter: nft_socket: add support for cgroupsv2
[linux-2.6-microblaze.git] / net / netfilter / nft_socket.c
index c9b8a2b..9c169d1 100644 (file)
@@ -9,6 +9,7 @@
 
 struct nft_socket {
        enum nft_socket_keys            key:8;
+       u8                              level;
        union {
                u8                      dreg;
        };
@@ -33,6 +34,26 @@ static void nft_socket_wildcard(const struct nft_pktinfo *pkt,
        }
 }
 
+#ifdef CONFIG_CGROUPS
+static noinline bool
+nft_sock_get_eval_cgroupv2(u32 *dest, const struct nft_pktinfo *pkt, u32 level)
+{
+       struct sock *sk = skb_to_full_sk(pkt->skb);
+       struct cgroup *cgrp;
+
+       if (!sk || !sk_fullsock(sk) || !net_eq(nft_net(pkt), sock_net(sk)))
+               return false;
+
+       cgrp = sock_cgroup_ptr(&sk->sk_cgrp_data);
+       if (level > cgrp->level)
+               return false;
+
+       memcpy(dest, &cgrp->ancestor_ids[level], sizeof(u64));
+
+       return true;
+}
+#endif
+
 static void nft_socket_eval(const struct nft_expr *expr,
                            struct nft_regs *regs,
                            const struct nft_pktinfo *pkt)
@@ -85,6 +106,14 @@ static void nft_socket_eval(const struct nft_expr *expr,
                }
                nft_socket_wildcard(pkt, regs, sk, dest);
                break;
+#ifdef CONFIG_CGROUPS
+       case NFT_SOCKET_CGROUPV2:
+               if (!nft_sock_get_eval_cgroupv2(dest, pkt, priv->level)) {
+                       regs->verdict.code = NFT_BREAK;
+                       return;
+               }
+               break;
+#endif
        default:
                WARN_ON(1);
                regs->verdict.code = NFT_BREAK;
@@ -97,6 +126,7 @@ static void nft_socket_eval(const struct nft_expr *expr,
 static const struct nla_policy nft_socket_policy[NFTA_SOCKET_MAX + 1] = {
        [NFTA_SOCKET_KEY]               = { .type = NLA_U32 },
        [NFTA_SOCKET_DREG]              = { .type = NLA_U32 },
+       [NFTA_SOCKET_LEVEL]             = { .type = NLA_U32 },
 };
 
 static int nft_socket_init(const struct nft_ctx *ctx,
@@ -104,7 +134,7 @@ static int nft_socket_init(const struct nft_ctx *ctx,
                           const struct nlattr * const tb[])
 {
        struct nft_socket *priv = nft_expr_priv(expr);
-       unsigned int len;
+       unsigned int len, level;
 
        if (!tb[NFTA_SOCKET_DREG] || !tb[NFTA_SOCKET_KEY])
                return -EINVAL;
@@ -129,6 +159,19 @@ static int nft_socket_init(const struct nft_ctx *ctx,
        case NFT_SOCKET_MARK:
                len = sizeof(u32);
                break;
+#ifdef CONFIG_CGROUPS
+       case NFT_SOCKET_CGROUPV2:
+               if (!tb[NFTA_SOCKET_LEVEL])
+                       return -EINVAL;
+
+               level = ntohl(nla_get_u32(tb[NFTA_SOCKET_LEVEL]));
+               if (level > 255)
+                       return -EOPNOTSUPP;
+
+               priv->level = level;
+               len = sizeof(u64);
+               break;
+#endif
        default:
                return -EOPNOTSUPP;
        }
@@ -146,6 +189,9 @@ static int nft_socket_dump(struct sk_buff *skb,
                return -1;
        if (nft_dump_register(skb, NFTA_SOCKET_DREG, priv->dreg))
                return -1;
+       if (priv->key == NFT_SOCKET_CGROUPV2 &&
+           nla_put_u32(skb, NFTA_SOCKET_LEVEL, htonl(priv->level)))
+               return -1;
        return 0;
 }