Merge tag 'nds32-for-linus-4.18' of git://git.kernel.org/pub/scm/linux/kernel/git...
[linux-2.6-microblaze.git] / kernel / bpf / sockmap.c
index cf7b6a6..98fb793 100644 (file)
@@ -312,10 +312,12 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
        struct smap_psock *psock;
        struct sock *osk;
 
+       lock_sock(sk);
        rcu_read_lock();
        psock = smap_psock_sk(sk);
        if (unlikely(!psock)) {
                rcu_read_unlock();
+               release_sock(sk);
                return sk->sk_prot->close(sk, timeout);
        }
 
@@ -371,6 +373,7 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
                e = psock_map_pop(sk, psock);
        }
        rcu_read_unlock();
+       release_sock(sk);
        close_fun(sk, timeout);
 }
 
@@ -568,7 +571,8 @@ static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)
        while (sg[i].length) {
                free += sg[i].length;
                sk_mem_uncharge(sk, sg[i].length);
-               put_page(sg_page(&sg[i]));
+               if (!md->skb)
+                       put_page(sg_page(&sg[i]));
                sg[i].length = 0;
                sg[i].page_link = 0;
                sg[i].offset = 0;
@@ -577,6 +581,8 @@ static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)
                if (i == MAX_SKB_FRAGS)
                        i = 0;
        }
+       if (md->skb)
+               consume_skb(md->skb);
 
        return free;
 }
@@ -1230,7 +1236,7 @@ static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb)
         */
        TCP_SKB_CB(skb)->bpf.sk_redir = NULL;
        skb->sk = psock->sock;
-       bpf_compute_data_pointers(skb);
+       bpf_compute_data_end_sk_skb(skb);
        preempt_disable();
        rc = (*prog->bpf_func)(skb, prog->insnsi);
        preempt_enable();
@@ -1485,7 +1491,7 @@ static int smap_parse_func_strparser(struct strparser *strp,
         * any socket yet.
         */
        skb->sk = psock->sock;
-       bpf_compute_data_pointers(skb);
+       bpf_compute_data_end_sk_skb(skb);
        rc = (*prog->bpf_func)(skb, prog->insnsi);
        skb->sk = NULL;
        rcu_read_unlock();
@@ -1896,7 +1902,7 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map,
                e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
                if (!e) {
                        err = -ENOMEM;
-                       goto out_progs;
+                       goto out_free;
                }
        }
 
@@ -2069,7 +2075,13 @@ static int sock_map_update_elem(struct bpf_map *map,
                return -EOPNOTSUPP;
        }
 
+       lock_sock(skops.sk);
+       preempt_disable();
+       rcu_read_lock();
        err = sock_map_ctx_update_elem(&skops, map, key, flags);
+       rcu_read_unlock();
+       preempt_enable();
+       release_sock(skops.sk);
        fput(socket->file);
        return err;
 }
@@ -2342,7 +2354,10 @@ static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops,
        if (err)
                goto err;
 
-       /* bpf_map_update_elem() can be called in_irq() */
+       /* psock is valid here because otherwise above *ctx_update_elem would
+        * have thrown an error. It is safe to skip error check.
+        */
+       psock = smap_psock_sk(sock);
        raw_spin_lock_bh(&b->lock);
        l_old = lookup_elem_raw(head, hash, key, key_size);
        if (l_old && map_flags == BPF_NOEXIST) {
@@ -2360,12 +2375,6 @@ static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops,
                goto bucket_err;
        }
 
-       psock = smap_psock_sk(sock);
-       if (unlikely(!psock)) {
-               err = -EINVAL;
-               goto bucket_err;
-       }
-
        rcu_assign_pointer(e->hash_link, l_new);
        rcu_assign_pointer(e->htab,
                           container_of(map, struct bpf_htab, map));
@@ -2388,12 +2397,10 @@ static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops,
        raw_spin_unlock_bh(&b->lock);
        return 0;
 bucket_err:
+       smap_release_sock(psock, sock);
        raw_spin_unlock_bh(&b->lock);
 err:
        kfree(e);
-       psock = smap_psock_sk(sock);
-       if (psock)
-               smap_release_sock(psock, sock);
        return err;
 }
 
@@ -2415,7 +2422,13 @@ static int sock_hash_update_elem(struct bpf_map *map,
                return -EINVAL;
        }
 
+       lock_sock(skops.sk);
+       preempt_disable();
+       rcu_read_lock();
        err = sock_hash_ctx_update_elem(&skops, map, key, flags);
+       rcu_read_unlock();
+       preempt_enable();
+       release_sock(skops.sk);
        fput(socket->file);
        return err;
 }
@@ -2472,10 +2485,8 @@ struct sock  *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
        b = __select_bucket(htab, hash);
        head = &b->head;
 
-       raw_spin_lock_bh(&b->lock);
        l = lookup_elem_raw(head, hash, key, key_size);
        sk = l ? l->sk : NULL;
-       raw_spin_unlock_bh(&b->lock);
        return sk;
 }