sock_map: Introduce BPF_SK_SKB_VERDICT
[linux-2.6-microblaze.git] / net / core / sock_map.c
index d758fb8..c2a0411 100644 (file)
@@ -24,6 +24,10 @@ struct bpf_stab {
 #define SOCK_CREATE_FLAG_MASK                          \
        (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
 
+static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
+                               struct bpf_prog *old, u32 which);
+static struct sk_psock_progs *sock_map_progs(struct bpf_map *map);
+
 static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
 {
        struct bpf_stab *stab;
@@ -148,9 +152,11 @@ static void sock_map_del_link(struct sock *sk,
                        struct bpf_map *map = link->map;
                        struct bpf_stab *stab = container_of(map, struct bpf_stab,
                                                             map);
-                       if (psock->parser.enabled && stab->progs.skb_parser)
+                       if (psock->saved_data_ready && stab->progs.stream_parser)
                                strp_stop = true;
-                       if (psock->parser.enabled && stab->progs.skb_verdict)
+                       if (psock->saved_data_ready && stab->progs.stream_verdict)
+                               verdict_stop = true;
+                       if (psock->saved_data_ready && stab->progs.skb_verdict)
                                verdict_stop = true;
                        list_del(&link->list);
                        sk_psock_free_link(link);
@@ -221,26 +227,38 @@ out:
        return psock;
 }
 
-static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
-                        struct sock *sk)
+static bool sock_map_redirect_allowed(const struct sock *sk);
+
+static int sock_map_link(struct bpf_map *map, struct sock *sk)
 {
-       struct bpf_prog *msg_parser, *skb_parser, *skb_verdict;
+       struct sk_psock_progs *progs = sock_map_progs(map);
+       struct bpf_prog *stream_verdict = NULL;
+       struct bpf_prog *stream_parser = NULL;
+       struct bpf_prog *skb_verdict = NULL;
+       struct bpf_prog *msg_parser = NULL;
        struct sk_psock *psock;
        int ret;
 
-       skb_verdict = READ_ONCE(progs->skb_verdict);
-       if (skb_verdict) {
-               skb_verdict = bpf_prog_inc_not_zero(skb_verdict);
-               if (IS_ERR(skb_verdict))
-                       return PTR_ERR(skb_verdict);
+       /* Only sockets we can redirect into/from in BPF need to hold
+        * refs to parser/verdict progs and have their sk_data_ready
+        * and sk_write_space callbacks overridden.
+        */
+       if (!sock_map_redirect_allowed(sk))
+               goto no_progs;
+
+       stream_verdict = READ_ONCE(progs->stream_verdict);
+       if (stream_verdict) {
+               stream_verdict = bpf_prog_inc_not_zero(stream_verdict);
+               if (IS_ERR(stream_verdict))
+                       return PTR_ERR(stream_verdict);
        }
 
-       skb_parser = READ_ONCE(progs->skb_parser);
-       if (skb_parser) {
-               skb_parser = bpf_prog_inc_not_zero(skb_parser);
-               if (IS_ERR(skb_parser)) {
-                       ret = PTR_ERR(skb_parser);
-                       goto out_put_skb_verdict;
+       stream_parser = READ_ONCE(progs->stream_parser);
+       if (stream_parser) {
+               stream_parser = bpf_prog_inc_not_zero(stream_parser);
+               if (IS_ERR(stream_parser)) {
+                       ret = PTR_ERR(stream_parser);
+                       goto out_put_stream_verdict;
                }
        }
 
@@ -249,10 +267,20 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
                msg_parser = bpf_prog_inc_not_zero(msg_parser);
                if (IS_ERR(msg_parser)) {
                        ret = PTR_ERR(msg_parser);
-                       goto out_put_skb_parser;
+                       goto out_put_stream_parser;
                }
        }
 
+       skb_verdict = READ_ONCE(progs->skb_verdict);
+       if (skb_verdict) {
+               skb_verdict = bpf_prog_inc_not_zero(skb_verdict);
+               if (IS_ERR(skb_verdict)) {
+                       ret = PTR_ERR(skb_verdict);
+                       goto out_put_msg_parser;
+               }
+       }
+
+no_progs:
        psock = sock_map_psock_get_checked(sk);
        if (IS_ERR(psock)) {
                ret = PTR_ERR(psock);
@@ -261,8 +289,11 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
 
        if (psock) {
                if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
-                   (skb_parser  && READ_ONCE(psock->progs.skb_parser)) ||
-                   (skb_verdict && READ_ONCE(psock->progs.skb_verdict))) {
+                   (stream_parser  && READ_ONCE(psock->progs.stream_parser)) ||
+                   (skb_verdict && READ_ONCE(psock->progs.skb_verdict)) ||
+                   (skb_verdict && READ_ONCE(psock->progs.stream_verdict)) ||
+                   (stream_verdict && READ_ONCE(psock->progs.skb_verdict)) ||
+                   (stream_verdict && READ_ONCE(psock->progs.stream_verdict))) {
                        sk_psock_put(sk, psock);
                        ret = -EBUSY;
                        goto out_progs;
@@ -283,16 +314,19 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
                goto out_drop;
 
        write_lock_bh(&sk->sk_callback_lock);
-       if (skb_parser && skb_verdict && !psock->parser.enabled) {
+       if (stream_parser && stream_verdict && !psock->saved_data_ready) {
                ret = sk_psock_init_strp(sk, psock);
                if (ret)
                        goto out_unlock_drop;
-               psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
-               psock_set_prog(&psock->progs.skb_parser, skb_parser);
+               psock_set_prog(&psock->progs.stream_verdict, stream_verdict);
+               psock_set_prog(&psock->progs.stream_parser, stream_parser);
                sk_psock_start_strp(sk, psock);
-       } else if (!skb_parser && skb_verdict && !psock->parser.enabled) {
-               psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
+       } else if (!stream_parser && stream_verdict && !psock->saved_data_ready) {
+               psock_set_prog(&psock->progs.stream_verdict, stream_verdict);
                sk_psock_start_verdict(sk,psock);
+       } else if (!stream_verdict && skb_verdict && !psock->saved_data_ready) {
+               psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
+               sk_psock_start_verdict(sk, psock);
        }
        write_unlock_bh(&sk->sk_callback_lock);
        return 0;
@@ -301,35 +335,17 @@ out_unlock_drop:
 out_drop:
        sk_psock_put(sk, psock);
 out_progs:
-       if (msg_parser)
-               bpf_prog_put(msg_parser);
-out_put_skb_parser:
-       if (skb_parser)
-               bpf_prog_put(skb_parser);
-out_put_skb_verdict:
        if (skb_verdict)
                bpf_prog_put(skb_verdict);
-       return ret;
-}
-
-static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
-{
-       struct sk_psock *psock;
-       int ret;
-
-       psock = sock_map_psock_get_checked(sk);
-       if (IS_ERR(psock))
-               return PTR_ERR(psock);
-
-       if (!psock) {
-               psock = sk_psock_init(sk, map->numa_node);
-               if (IS_ERR(psock))
-                       return PTR_ERR(psock);
-       }
-
-       ret = sock_map_init_proto(sk, psock);
-       if (ret < 0)
-               sk_psock_put(sk, psock);
+out_put_msg_parser:
+       if (msg_parser)
+               bpf_prog_put(msg_parser);
+out_put_stream_parser:
+       if (stream_parser)
+               bpf_prog_put(stream_parser);
+out_put_stream_verdict:
+       if (stream_verdict)
+               bpf_prog_put(stream_verdict);
        return ret;
 }
 
@@ -463,8 +479,6 @@ static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
        return 0;
 }
 
-static bool sock_map_redirect_allowed(const struct sock *sk);
-
 static int sock_map_update_common(struct bpf_map *map, u32 idx,
                                  struct sock *sk, u64 flags)
 {
@@ -484,14 +498,7 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
        if (!link)
                return -ENOMEM;
 
-       /* Only sockets we can redirect into/from in BPF need to hold
-        * refs to parser/verdict progs and have their sk_data_ready
-        * and sk_write_space callbacks overridden.
-        */
-       if (sock_map_redirect_allowed(sk))
-               ret = sock_map_link(map, &stab->progs, sk);
-       else
-               ret = sock_map_link_no_progs(map, sk);
+       ret = sock_map_link(map, sk);
        if (ret < 0)
                goto out_free;
 
@@ -657,7 +664,6 @@ const struct bpf_func_proto bpf_sock_map_update_proto = {
 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
           struct bpf_map *, map, u32, key, u64, flags)
 {
-       struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
        struct sock *sk;
 
        if (unlikely(flags & ~(BPF_F_INGRESS)))
@@ -667,8 +673,7 @@ BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
        if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
                return SK_DROP;
 
-       tcb->bpf.flags = flags;
-       tcb->bpf.sk_redir = sk;
+       skb_bpf_set_redir(skb, sk, flags & BPF_F_INGRESS);
        return SK_PASS;
 }
 
@@ -998,14 +1003,7 @@ static int sock_hash_update_common(struct bpf_map *map, void *key,
        if (!link)
                return -ENOMEM;
 
-       /* Only sockets we can redirect into/from in BPF need to hold
-        * refs to parser/verdict progs and have their sk_data_ready
-        * and sk_write_space callbacks overridden.
-        */
-       if (sock_map_redirect_allowed(sk))
-               ret = sock_map_link(map, &htab->progs, sk);
-       else
-               ret = sock_map_link_no_progs(map, sk);
+       ret = sock_map_link(map, sk);
        if (ret < 0)
                goto out_free;
 
@@ -1250,7 +1248,6 @@ const struct bpf_func_proto bpf_sock_hash_update_proto = {
 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
           struct bpf_map *, map, void *, key, u64, flags)
 {
-       struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
        struct sock *sk;
 
        if (unlikely(flags & ~(BPF_F_INGRESS)))
@@ -1260,8 +1257,7 @@ BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
        if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
                return SK_DROP;
 
-       tcb->bpf.flags = flags;
-       tcb->bpf.sk_redir = sk;
+       skb_bpf_set_redir(skb, sk, flags & BPF_F_INGRESS);
        return SK_PASS;
 }
 
@@ -1448,8 +1444,8 @@ static struct sk_psock_progs *sock_map_progs(struct bpf_map *map)
        return NULL;
 }
 
-int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
-                        struct bpf_prog *old, u32 which)
+static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
+                               struct bpf_prog *old, u32 which)
 {
        struct sk_psock_progs *progs = sock_map_progs(map);
        struct bpf_prog **pprog;
@@ -1461,10 +1457,19 @@ int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
        case BPF_SK_MSG_VERDICT:
                pprog = &progs->msg_parser;
                break;
+#if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
        case BPF_SK_SKB_STREAM_PARSER:
-               pprog = &progs->skb_parser;
+               pprog = &progs->stream_parser;
                break;
+#endif
        case BPF_SK_SKB_STREAM_VERDICT:
+               if (progs->skb_verdict)
+                       return -EBUSY;
+               pprog = &progs->stream_verdict;
+               break;
+       case BPF_SK_SKB_VERDICT:
+               if (progs->stream_verdict)
+                       return -EBUSY;
                pprog = &progs->skb_verdict;
                break;
        default:
@@ -1539,6 +1544,7 @@ void sock_map_close(struct sock *sk, long timeout)
        saved_close = psock->saved_close;
        sock_map_remove_links(sk, psock);
        rcu_read_unlock();
+       sk_psock_stop(psock, true);
        release_sock(sk);
        saved_close(sk, timeout);
 }