skmsg: Introduce a spinlock to protect ingress_msg
authorCong Wang <cong.wang@bytedance.com>
Wed, 31 Mar 2021 02:32:23 +0000 (19:32 -0700)
committerAlexei Starovoitov <ast@kernel.org>
Thu, 1 Apr 2021 17:56:13 +0000 (10:56 -0700)
Currently we rely on lock_sock to protect ingress_msg,
it is too big for this, we can actually just use a spinlock
to protect this list like protecting other skb queues.

__tcp_bpf_recvmsg() is still special because of peeking,
it still has to use lock_sock.

Signed-off-by: Cong Wang <cong.wang@bytedance.com>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Acked-by: Jakub Sitnicki <jakub@cloudflare.com>
Acked-by: John Fastabend <john.fastabend@gmail.com>
Link: https://lore.kernel.org/bpf/20210331023237.41094-3-xiyou.wangcong@gmail.com
include/linux/skmsg.h
net/core/skmsg.c
net/ipv4/tcp_bpf.c

index 6c09d94..f2d45a7 100644 (file)
@@ -89,6 +89,7 @@ struct sk_psock {
 #endif
        struct sk_buff_head             ingress_skb;
        struct list_head                ingress_msg;
+       spinlock_t                      ingress_lock;
        unsigned long                   state;
        struct list_head                link;
        spinlock_t                      link_lock;
@@ -284,7 +285,45 @@ static inline struct sk_psock *sk_psock(const struct sock *sk)
 static inline void sk_psock_queue_msg(struct sk_psock *psock,
                                      struct sk_msg *msg)
 {
+       spin_lock_bh(&psock->ingress_lock);
        list_add_tail(&msg->list, &psock->ingress_msg);
+       spin_unlock_bh(&psock->ingress_lock);
+}
+
+static inline struct sk_msg *sk_psock_dequeue_msg(struct sk_psock *psock)
+{
+       struct sk_msg *msg;
+
+       spin_lock_bh(&psock->ingress_lock);
+       msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
+       if (msg)
+               list_del(&msg->list);
+       spin_unlock_bh(&psock->ingress_lock);
+       return msg;
+}
+
+static inline struct sk_msg *sk_psock_peek_msg(struct sk_psock *psock)
+{
+       struct sk_msg *msg;
+
+       spin_lock_bh(&psock->ingress_lock);
+       msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
+       spin_unlock_bh(&psock->ingress_lock);
+       return msg;
+}
+
+static inline struct sk_msg *sk_psock_next_msg(struct sk_psock *psock,
+                                              struct sk_msg *msg)
+{
+       struct sk_msg *ret;
+
+       spin_lock_bh(&psock->ingress_lock);
+       if (list_is_last(&msg->list, &psock->ingress_msg))
+               ret = NULL;
+       else
+               ret = list_next_entry(msg, list);
+       spin_unlock_bh(&psock->ingress_lock);
+       return ret;
 }
 
 static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
@@ -292,6 +331,13 @@ static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
        return psock ? list_empty(&psock->ingress_msg) : true;
 }
 
+static inline void kfree_sk_msg(struct sk_msg *msg)
+{
+       if (msg->skb)
+               consume_skb(msg->skb);
+       kfree(msg);
+}
+
 static inline void sk_psock_report_error(struct sk_psock *psock, int err)
 {
        struct sock *sk = psock->sk;
index bebf84e..305dddc 100644 (file)
@@ -592,6 +592,7 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node)
 
        INIT_WORK(&psock->work, sk_psock_backlog);
        INIT_LIST_HEAD(&psock->ingress_msg);
+       spin_lock_init(&psock->ingress_lock);
        skb_queue_head_init(&psock->ingress_skb);
 
        sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED);
@@ -638,7 +639,9 @@ static void sk_psock_zap_ingress(struct sk_psock *psock)
                skb_bpf_redirect_clear(skb);
                kfree_skb(skb);
        }
+       spin_lock_bh(&psock->ingress_lock);
        __sk_psock_purge_ingress_msg(psock);
+       spin_unlock_bh(&psock->ingress_lock);
 }
 
 static void sk_psock_link_destroy(struct sk_psock *psock)
index 17c322b..ae98071 100644 (file)
@@ -18,9 +18,7 @@ int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
        struct sk_msg *msg_rx;
        int i, copied = 0;
 
-       msg_rx = list_first_entry_or_null(&psock->ingress_msg,
-                                         struct sk_msg, list);
-
+       msg_rx = sk_psock_peek_msg(psock);
        while (copied != len) {
                struct scatterlist *sge;
 
@@ -68,22 +66,18 @@ int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
                } while (i != msg_rx->sg.end);
 
                if (unlikely(peek)) {
-                       if (msg_rx == list_last_entry(&psock->ingress_msg,
-                                                     struct sk_msg, list))
+                       msg_rx = sk_psock_next_msg(psock, msg_rx);
+                       if (!msg_rx)
                                break;
-                       msg_rx = list_next_entry(msg_rx, list);
                        continue;
                }
 
                msg_rx->sg.start = i;
                if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
-                       list_del(&msg_rx->list);
-                       if (msg_rx->skb)
-                               consume_skb(msg_rx->skb);
-                       kfree(msg_rx);
+                       msg_rx = sk_psock_dequeue_msg(psock);
+                       kfree_sk_msg(msg_rx);
                }
-               msg_rx = list_first_entry_or_null(&psock->ingress_msg,
-                                                 struct sk_msg, list);
+               msg_rx = sk_psock_peek_msg(psock);
        }
 
        return copied;