Merge v5.14-rc3 into usb-next
[linux-2.6-microblaze.git] / net / ipv4 / udp_bpf.c
index 954c459..9f5a5cd 100644 (file)
@@ -21,6 +21,45 @@ static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
        return udp_prot.recvmsg(sk, msg, len, noblock, flags, addr_len);
 }
 
+static bool udp_sk_has_data(struct sock *sk)
+{
+       return !skb_queue_empty(&udp_sk(sk)->reader_queue) ||
+              !skb_queue_empty(&sk->sk_receive_queue);
+}
+
+static bool psock_has_data(struct sk_psock *psock)
+{
+       return !skb_queue_empty(&psock->ingress_skb) ||
+              !sk_psock_queue_empty(psock);
+}
+
+#define udp_msg_has_data(__sk, __psock)        \
+               ({ udp_sk_has_data(__sk) || psock_has_data(__psock); })
+
+static int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock,
+                            long timeo)
+{
+       DEFINE_WAIT_FUNC(wait, woken_wake_function);
+       int ret = 0;
+
+       if (sk->sk_shutdown & RCV_SHUTDOWN)
+               return 1;
+
+       if (!timeo)
+               return ret;
+
+       add_wait_queue(sk_sleep(sk), &wait);
+       sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
+       ret = udp_msg_has_data(sk, psock);
+       if (!ret) {
+               wait_woken(&wait, TASK_INTERRUPTIBLE, timeo);
+               ret = udp_msg_has_data(sk, psock);
+       }
+       sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
+       remove_wait_queue(sk_sleep(sk), &wait);
+       return ret;
+}
+
 static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
                           int nonblock, int flags, int *addr_len)
 {
@@ -34,8 +73,7 @@ static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
        if (unlikely(!psock))
                return sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
 
-       lock_sock(sk);
-       if (sk_psock_queue_empty(psock)) {
+       if (!psock_has_data(psock)) {
                ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
                goto out;
        }
@@ -43,26 +81,21 @@ static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
 msg_bytes_ready:
        copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
        if (!copied) {
-               int data, err = 0;
                long timeo;
+               int data;
 
                timeo = sock_rcvtimeo(sk, nonblock);
-               data = sk_msg_wait_data(sk, psock, flags, timeo, &err);
+               data = udp_msg_wait_data(sk, psock, timeo);
                if (data) {
-                       if (!sk_psock_queue_empty(psock))
+                       if (psock_has_data(psock))
                                goto msg_bytes_ready;
                        ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
                        goto out;
                }
-               if (err) {
-                       ret = err;
-                       goto out;
-               }
                copied = -EAGAIN;
        }
        ret = copied;
 out:
-       release_sock(sk);
        sk_psock_put(sk, psock);
        return ret;
 }
@@ -101,7 +134,7 @@ static int __init udp_bpf_v4_build_proto(void)
        udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV4], &udp_prot);
        return 0;
 }
-core_initcall(udp_bpf_v4_build_proto);
+late_initcall(udp_bpf_v4_build_proto);
 
 int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
 {