Merge https://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf-next
[linux-2.6-microblaze.git] / net / core / filter.c
index 4fae919..4ef77ec 100644 (file)
@@ -6559,10 +6559,21 @@ __bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
                                           ifindex, proto, netns_id, flags);
 
        if (sk) {
-               sk = sk_to_full_sk(sk);
-               if (!sk_fullsock(sk)) {
+               struct sock *sk2 = sk_to_full_sk(sk);
+
+               /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
+                * sock refcnt is decremented to prevent a request_sock leak.
+                */
+               if (!sk_fullsock(sk2))
+                       sk2 = NULL;
+               if (sk2 != sk) {
                        sock_gen_put(sk);
-                       return NULL;
+                       /* Ensure there is no need to bump sk2 refcnt */
+                       if (unlikely(sk2 && !sock_flag(sk2, SOCK_RCU_FREE))) {
+                               WARN_ONCE(1, "Found non-RCU, unreferenced socket!");
+                               return NULL;
+                       }
+                       sk = sk2;
                }
        }
 
@@ -6596,10 +6607,21 @@ bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
                                         flags);
 
        if (sk) {
-               sk = sk_to_full_sk(sk);
-               if (!sk_fullsock(sk)) {
+               struct sock *sk2 = sk_to_full_sk(sk);
+
+               /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
+                * sock refcnt is decremented to prevent a request_sock leak.
+                */
+               if (!sk_fullsock(sk2))
+                       sk2 = NULL;
+               if (sk2 != sk) {
                        sock_gen_put(sk);
-                       return NULL;
+                       /* Ensure there is no need to bump sk2 refcnt */
+                       if (unlikely(sk2 && !sock_flag(sk2, SOCK_RCU_FREE))) {
+                               WARN_ONCE(1, "Found non-RCU, unreferenced socket!");
+                               return NULL;
+                       }
+                       sk = sk2;
                }
        }