Merge tag '9p-for-5.13-rc1' of git://github.com/martinetd/linux
[linux-2.6-microblaze.git] / net / core / sock_map.c
index 42d7972..6f1b82b 100644 (file)
@@ -156,6 +156,8 @@ static void sock_map_del_link(struct sock *sk,
                                strp_stop = true;
                        if (psock->saved_data_ready && stab->progs.stream_verdict)
                                verdict_stop = true;
+                       if (psock->saved_data_ready && stab->progs.skb_verdict)
+                               verdict_stop = true;
                        list_del(&link->list);
                        sk_psock_free_link(link);
                }
@@ -183,26 +185,10 @@ static void sock_map_unref(struct sock *sk, void *link_raw)
 
 static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
 {
-       struct proto *prot;
-
-       switch (sk->sk_type) {
-       case SOCK_STREAM:
-               prot = tcp_bpf_get_proto(sk, psock);
-               break;
-
-       case SOCK_DGRAM:
-               prot = udp_bpf_get_proto(sk, psock);
-               break;
-
-       default:
+       if (!sk->sk_prot->psock_update_sk_prot)
                return -EINVAL;
-       }
-
-       if (IS_ERR(prot))
-               return PTR_ERR(prot);
-
-       sk_psock_update_proto(sk, psock, prot);
-       return 0;
+       psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot;
+       return sk->sk_prot->psock_update_sk_prot(sk, psock, false);
 }
 
 static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
@@ -232,6 +218,7 @@ static int sock_map_link(struct bpf_map *map, struct sock *sk)
        struct sk_psock_progs *progs = sock_map_progs(map);
        struct bpf_prog *stream_verdict = NULL;
        struct bpf_prog *stream_parser = NULL;
+       struct bpf_prog *skb_verdict = NULL;
        struct bpf_prog *msg_parser = NULL;
        struct sk_psock *psock;
        int ret;
@@ -268,6 +255,15 @@ static int sock_map_link(struct bpf_map *map, struct sock *sk)
                }
        }
 
+       skb_verdict = READ_ONCE(progs->skb_verdict);
+       if (skb_verdict) {
+               skb_verdict = bpf_prog_inc_not_zero(skb_verdict);
+               if (IS_ERR(skb_verdict)) {
+                       ret = PTR_ERR(skb_verdict);
+                       goto out_put_msg_parser;
+               }
+       }
+
 no_progs:
        psock = sock_map_psock_get_checked(sk);
        if (IS_ERR(psock)) {
@@ -278,6 +274,9 @@ no_progs:
        if (psock) {
                if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
                    (stream_parser  && READ_ONCE(psock->progs.stream_parser)) ||
+                   (skb_verdict && READ_ONCE(psock->progs.skb_verdict)) ||
+                   (skb_verdict && READ_ONCE(psock->progs.stream_verdict)) ||
+                   (stream_verdict && READ_ONCE(psock->progs.skb_verdict)) ||
                    (stream_verdict && READ_ONCE(psock->progs.stream_verdict))) {
                        sk_psock_put(sk, psock);
                        ret = -EBUSY;
@@ -309,6 +308,9 @@ no_progs:
        } else if (!stream_parser && stream_verdict && !psock->saved_data_ready) {
                psock_set_prog(&psock->progs.stream_verdict, stream_verdict);
                sk_psock_start_verdict(sk,psock);
+       } else if (!stream_verdict && skb_verdict && !psock->saved_data_ready) {
+               psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
+               sk_psock_start_verdict(sk, psock);
        }
        write_unlock_bh(&sk->sk_callback_lock);
        return 0;
@@ -317,6 +319,9 @@ out_unlock_drop:
 out_drop:
        sk_psock_put(sk, psock);
 out_progs:
+       if (skb_verdict)
+               bpf_prog_put(skb_verdict);
+out_put_msg_parser:
        if (msg_parser)
                bpf_prog_put(msg_parser);
 out_put_stream_parser:
@@ -530,12 +535,15 @@ static bool sk_is_udp(const struct sock *sk)
 
 static bool sock_map_redirect_allowed(const struct sock *sk)
 {
-       return sk_is_tcp(sk) && sk->sk_state != TCP_LISTEN;
+       if (sk_is_tcp(sk))
+               return sk->sk_state != TCP_LISTEN;
+       else
+               return sk->sk_state == TCP_ESTABLISHED;
 }
 
 static bool sock_map_sk_is_suitable(const struct sock *sk)
 {
-       return sk_is_tcp(sk) || sk_is_udp(sk);
+       return !!sk->sk_prot->psock_update_sk_prot;
 }
 
 static bool sock_map_sk_state_allowed(const struct sock *sk)
@@ -1442,8 +1450,15 @@ static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
                break;
 #endif
        case BPF_SK_SKB_STREAM_VERDICT:
+               if (progs->skb_verdict)
+                       return -EBUSY;
                pprog = &progs->stream_verdict;
                break;
+       case BPF_SK_SKB_VERDICT:
+               if (progs->stream_verdict)
+                       return -EBUSY;
+               pprog = &progs->skb_verdict;
+               break;
        default:
                return -EOPNOTSUPP;
        }
@@ -1506,7 +1521,7 @@ void sock_map_close(struct sock *sk, long timeout)
 
        lock_sock(sk);
        rcu_read_lock();
-       psock = sk_psock(sk);
+       psock = sk_psock_get(sk);
        if (unlikely(!psock)) {
                rcu_read_unlock();
                release_sock(sk);
@@ -1517,6 +1532,7 @@ void sock_map_close(struct sock *sk, long timeout)
        sock_map_remove_links(sk, psock);
        rcu_read_unlock();
        sk_psock_stop(psock, true);
+       sk_psock_put(sk, psock);
        release_sock(sk);
        saved_close(sk, timeout);
 }