Merge tag 'for-5.12/io_uring-2021-02-17' of git://git.kernel.dk/linux-block
[linux-2.6-microblaze.git] / net / ipv4 / tcp.c
index 32545ec..a3422e4 100644 (file)
 #include <asm/ioctls.h>
 #include <net/busy_poll.h>
 
+/* Track pending CMSGs. */
+enum {
+       TCP_CMSG_INQ = 1,
+       TCP_CMSG_TS = 2
+};
+
 struct percpu_counter tcp_orphan_count;
 EXPORT_SYMBOL_GPL(tcp_orphan_count);
 
@@ -475,19 +481,11 @@ static void tcp_tx_timestamp(struct sock *sk, u16 tsflags)
        }
 }
 
-static inline bool tcp_stream_is_readable(const struct tcp_sock *tp,
-                                         int target, struct sock *sk)
+static bool tcp_stream_is_readable(struct sock *sk, int target)
 {
-       int avail = READ_ONCE(tp->rcv_nxt) - READ_ONCE(tp->copied_seq);
-
-       if (avail > 0) {
-               if (avail >= target)
-                       return true;
-               if (tcp_rmem_pressure(sk))
-                       return true;
-               if (tcp_receive_window(tp) <= inet_csk(sk)->icsk_ack.rcv_mss)
-                       return true;
-       }
+       if (tcp_epollin_ready(sk, target))
+               return true;
+
        if (sk->sk_prot->stream_memory_read)
                return sk->sk_prot->stream_memory_read(sk);
        return false;
@@ -562,7 +560,7 @@ __poll_t tcp_poll(struct file *file, struct socket *sock, poll_table *wait)
                    tp->urg_data)
                        target++;
 
-               if (tcp_stream_is_readable(tp, target, sk))
+               if (tcp_stream_is_readable(sk, target))
                        mask |= EPOLLIN | EPOLLRDNORM;
 
                if (!(sk->sk_shutdown & SEND_SHUTDOWN)) {
@@ -1010,7 +1008,7 @@ new_segment:
        }
 
        if (!(flags & MSG_NO_SHARED_FRAGS))
-               skb_shinfo(skb)->tx_flags |= SKBTX_SHARED_FRAG;
+               skb_shinfo(skb)->flags |= SKBFL_SHARED_FRAG;
 
        skb->len += copy;
        skb->data_len += copy;
@@ -1217,7 +1215,7 @@ int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size)
 
        if (flags & MSG_ZEROCOPY && size && sock_flag(sk, SOCK_ZEROCOPY)) {
                skb = tcp_write_queue_tail(sk);
-               uarg = sock_zerocopy_realloc(sk, size, skb_zcopy(skb));
+               uarg = msg_zerocopy_realloc(sk, size, skb_zcopy(skb));
                if (!uarg) {
                        err = -ENOBUFS;
                        goto out_err;
@@ -1429,7 +1427,7 @@ out:
                tcp_push(sk, flags, mss_now, tp->nonagle, size_goal);
        }
 out_nopush:
-       sock_zerocopy_put(uarg);
+       net_zcopy_put(uarg);
        return copied + copied_syn;
 
 do_error:
@@ -1440,7 +1438,7 @@ do_fault:
        if (copied + copied_syn)
                goto out;
 out_err:
-       sock_zerocopy_put_abort(uarg, true);
+       net_zcopy_put_abort(uarg, true);
        err = sk_stream_error(sk, flags, err);
        /* make sure we wake any epoll edge trigger waiter */
        if (unlikely(tcp_rtx_and_write_queues_empty(sk) && err == -EAGAIN)) {
@@ -1739,6 +1737,20 @@ int tcp_set_rcvlowat(struct sock *sk, int val)
 }
 EXPORT_SYMBOL(tcp_set_rcvlowat);
 
+static void tcp_update_recv_tstamps(struct sk_buff *skb,
+                                   struct scm_timestamping_internal *tss)
+{
+       if (skb->tstamp)
+               tss->ts[0] = ktime_to_timespec64(skb->tstamp);
+       else
+               tss->ts[0] = (struct timespec64) {0};
+
+       if (skb_hwtstamps(skb)->hwtstamp)
+               tss->ts[2] = ktime_to_timespec64(skb_hwtstamps(skb)->hwtstamp);
+       else
+               tss->ts[2] = (struct timespec64) {0};
+}
+
 #ifdef CONFIG_MMU
 static const struct vm_operations_struct tcp_vm_ops = {
 };
@@ -1842,13 +1854,13 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len,
                              struct scm_timestamping_internal *tss,
                              int *cmsg_flags);
 static int receive_fallback_to_copy(struct sock *sk,
-                                   struct tcp_zerocopy_receive *zc, int inq)
+                                   struct tcp_zerocopy_receive *zc, int inq,
+                                   struct scm_timestamping_internal *tss)
 {
        unsigned long copy_address = (unsigned long)zc->copybuf_address;
-       struct scm_timestamping_internal tss_unused;
-       int err, cmsg_flags_unused;
        struct msghdr msg = {};
        struct iovec iov;
+       int err;
 
        zc->length = 0;
        zc->recv_skip_hint = 0;
@@ -1862,7 +1874,7 @@ static int receive_fallback_to_copy(struct sock *sk,
                return err;
 
        err = tcp_recvmsg_locked(sk, &msg, inq, /*nonblock=*/1, /*flags=*/0,
-                                &tss_unused, &cmsg_flags_unused);
+                                tss, &zc->msg_flags);
        if (err < 0)
                return err;
 
@@ -1903,21 +1915,27 @@ static int tcp_copy_straggler_data(struct tcp_zerocopy_receive *zc,
        return (__s32)copylen;
 }
 
-static int tcp_zerocopy_handle_leftover_data(struct tcp_zerocopy_receive *zc,
-                                            struct sock *sk,
-                                            struct sk_buff *skb,
-                                            u32 *seq,
-                                            s32 copybuf_len)
+static int tcp_zc_handle_leftover(struct tcp_zerocopy_receive *zc,
+                                 struct sock *sk,
+                                 struct sk_buff *skb,
+                                 u32 *seq,
+                                 s32 copybuf_len,
+                                 struct scm_timestamping_internal *tss)
 {
        u32 offset, copylen = min_t(u32, copybuf_len, zc->recv_skip_hint);
 
        if (!copylen)
                return 0;
        /* skb is null if inq < PAGE_SIZE. */
-       if (skb)
+       if (skb) {
                offset = *seq - TCP_SKB_CB(skb)->seq;
-       else
+       } else {
                skb = tcp_recv_skb(sk, *seq, &offset);
+               if (TCP_SKB_CB(skb)->has_rxtstamp) {
+                       tcp_update_recv_tstamps(skb, tss);
+                       zc->msg_flags |= TCP_CMSG_TS;
+               }
+       }
 
        zc->copybuf_len = tcp_copy_straggler_data(zc, skb, copylen, &offset,
                                                  seq);
@@ -2004,9 +2022,38 @@ static int tcp_zerocopy_vm_insert_batch(struct vm_area_struct *vma,
                err);
 }
 
+#define TCP_VALID_ZC_MSG_FLAGS   (TCP_CMSG_TS)
+static void tcp_recv_timestamp(struct msghdr *msg, const struct sock *sk,
+                              struct scm_timestamping_internal *tss);
+static void tcp_zc_finalize_rx_tstamp(struct sock *sk,
+                                     struct tcp_zerocopy_receive *zc,
+                                     struct scm_timestamping_internal *tss)
+{
+       unsigned long msg_control_addr;
+       struct msghdr cmsg_dummy;
+
+       msg_control_addr = (unsigned long)zc->msg_control;
+       cmsg_dummy.msg_control = (void *)msg_control_addr;
+       cmsg_dummy.msg_controllen =
+               (__kernel_size_t)zc->msg_controllen;
+       cmsg_dummy.msg_flags = in_compat_syscall()
+               ? MSG_CMSG_COMPAT : 0;
+       zc->msg_flags = 0;
+       if (zc->msg_control == msg_control_addr &&
+           zc->msg_controllen == cmsg_dummy.msg_controllen) {
+               tcp_recv_timestamp(&cmsg_dummy, sk, tss);
+               zc->msg_control = (__u64)
+                       ((uintptr_t)cmsg_dummy.msg_control);
+               zc->msg_controllen =
+                       (__u64)cmsg_dummy.msg_controllen;
+               zc->msg_flags = (__u32)cmsg_dummy.msg_flags;
+       }
+}
+
 #define TCP_ZEROCOPY_PAGE_BATCH_SIZE 32
 static int tcp_zerocopy_receive(struct sock *sk,
-                               struct tcp_zerocopy_receive *zc)
+                               struct tcp_zerocopy_receive *zc,
+                               struct scm_timestamping_internal *tss)
 {
        u32 length = 0, offset, vma_len, avail_len, copylen = 0;
        unsigned long address = (unsigned long)zc->address;
@@ -2023,6 +2070,7 @@ static int tcp_zerocopy_receive(struct sock *sk,
        int ret;
 
        zc->copybuf_len = 0;
+       zc->msg_flags = 0;
 
        if (address & (PAGE_SIZE - 1) || address != zc->address)
                return -EINVAL;
@@ -2033,7 +2081,7 @@ static int tcp_zerocopy_receive(struct sock *sk,
        sock_rps_record_flow(sk);
 
        if (inq && inq <= copybuf_len)
-               return receive_fallback_to_copy(sk, zc, inq);
+               return receive_fallback_to_copy(sk, zc, inq, tss);
 
        if (inq < PAGE_SIZE) {
                zc->length = 0;
@@ -2078,6 +2126,11 @@ static int tcp_zerocopy_receive(struct sock *sk,
                        } else {
                                skb = tcp_recv_skb(sk, seq, &offset);
                        }
+
+                       if (TCP_SKB_CB(skb)->has_rxtstamp) {
+                               tcp_update_recv_tstamps(skb, tss);
+                               zc->msg_flags |= TCP_CMSG_TS;
+                       }
                        zc->recv_skip_hint = skb->len - offset;
                        frags = skb_advance_to_frag(skb, offset, &offset_frag);
                        if (!frags || offset_frag)
@@ -2120,8 +2173,7 @@ out:
        mmap_read_unlock(current->mm);
        /* Try to copy straggler data. */
        if (!ret)
-               copylen = tcp_zerocopy_handle_leftover_data(zc, sk, skb, &seq,
-                                                           copybuf_len);
+               copylen = tcp_zc_handle_leftover(zc, sk, skb, &seq, copybuf_len, tss);
 
        if (length + copylen) {
                WRITE_ONCE(tp->copied_seq, seq);
@@ -2142,20 +2194,6 @@ out:
 }
 #endif
 
-static void tcp_update_recv_tstamps(struct sk_buff *skb,
-                                   struct scm_timestamping_internal *tss)
-{
-       if (skb->tstamp)
-               tss->ts[0] = ktime_to_timespec64(skb->tstamp);
-       else
-               tss->ts[0] = (struct timespec64) {0};
-
-       if (skb_hwtstamps(skb)->hwtstamp)
-               tss->ts[2] = ktime_to_timespec64(skb_hwtstamps(skb)->hwtstamp);
-       else
-               tss->ts[2] = (struct timespec64) {0};
-}
-
 /* Similar to __sock_recv_timestamp, but does not require an skb */
 static void tcp_recv_timestamp(struct msghdr *msg, const struct sock *sk,
                               struct scm_timestamping_internal *tss)
@@ -2272,7 +2310,7 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len,
                goto out;
 
        if (tp->recvmsg_inq)
-               *cmsg_flags = 1;
+               *cmsg_flags = TCP_CMSG_INQ;
        timeo = sock_rcvtimeo(sk, nonblock);
 
        /* Urgent data needs to be handled specially. */
@@ -2453,7 +2491,7 @@ skip_copy:
 
                if (TCP_SKB_CB(skb)->has_rxtstamp) {
                        tcp_update_recv_tstamps(skb, tss);
-                       *cmsg_flags |= 2;
+                       *cmsg_flags |= TCP_CMSG_TS;
                }
 
                if (used + offset < skb->len)
@@ -2513,9 +2551,9 @@ int tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int nonblock,
        release_sock(sk);
 
        if (cmsg_flags && ret >= 0) {
-               if (cmsg_flags & 2)
+               if (cmsg_flags & TCP_CMSG_TS)
                        tcp_recv_timestamp(msg, sk, &tss);
-               if (cmsg_flags & 1) {
+               if (cmsg_flags & TCP_CMSG_INQ) {
                        inq = tcp_inq_hint(sk);
                        put_cmsg(msg, SOL_TCP, TCP_CM_INQ, sizeof(inq), &inq);
                }
@@ -3767,11 +3805,24 @@ static size_t tcp_opt_stats_get_size(void)
                nla_total_size(sizeof(u16)) + /* TCP_NLA_TIMEOUT_REHASH */
                nla_total_size(sizeof(u32)) + /* TCP_NLA_BYTES_NOTSENT */
                nla_total_size_64bit(sizeof(u64)) + /* TCP_NLA_EDT */
+               nla_total_size(sizeof(u8)) + /* TCP_NLA_TTL */
                0;
 }
 
+/* Returns TTL or hop limit of an incoming packet from skb. */
+static u8 tcp_skb_ttl_or_hop_limit(const struct sk_buff *skb)
+{
+       if (skb->protocol == htons(ETH_P_IP))
+               return ip_hdr(skb)->ttl;
+       else if (skb->protocol == htons(ETH_P_IPV6))
+               return ipv6_hdr(skb)->hop_limit;
+       else
+               return 0;
+}
+
 struct sk_buff *tcp_get_timestamping_opt_stats(const struct sock *sk,
-                                              const struct sk_buff *orig_skb)
+                                              const struct sk_buff *orig_skb,
+                                              const struct sk_buff *ack_skb)
 {
        const struct tcp_sock *tp = tcp_sk(sk);
        struct sk_buff *stats;
@@ -3827,6 +3878,9 @@ struct sk_buff *tcp_get_timestamping_opt_stats(const struct sock *sk,
                    max_t(int, 0, tp->write_seq - tp->snd_nxt));
        nla_put_u64_64bit(stats, TCP_NLA_EDT, orig_skb->skb_mstamp_ns,
                          TCP_NLA_PAD);
+       if (ack_skb)
+               nla_put_u8(stats, TCP_NLA_TTL,
+                          tcp_skb_ttl_or_hop_limit(ack_skb));
 
        return stats;
 }
@@ -4083,6 +4137,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
        }
 #ifdef CONFIG_MMU
        case TCP_ZEROCOPY_RECEIVE: {
+               struct scm_timestamping_internal tss;
                struct tcp_zerocopy_receive zc = {};
                int err;
 
@@ -4090,19 +4145,36 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
                        return -EFAULT;
                if (len < offsetofend(struct tcp_zerocopy_receive, length))
                        return -EINVAL;
-               if (len > sizeof(zc)) {
+               if (unlikely(len > sizeof(zc))) {
+                       err = check_zeroed_user(optval + sizeof(zc),
+                                               len - sizeof(zc));
+                       if (err < 1)
+                               return err == 0 ? -EINVAL : err;
                        len = sizeof(zc);
                        if (put_user(len, optlen))
                                return -EFAULT;
                }
                if (copy_from_user(&zc, optval, len))
                        return -EFAULT;
+               if (zc.reserved)
+                       return -EINVAL;
+               if (zc.msg_flags &  ~(TCP_VALID_ZC_MSG_FLAGS))
+                       return -EINVAL;
                lock_sock(sk);
-               err = tcp_zerocopy_receive(sk, &zc);
+               err = tcp_zerocopy_receive(sk, &zc, &tss);
+               err = BPF_CGROUP_RUN_PROG_GETSOCKOPT_KERN(sk, level, optname,
+                                                         &zc, &len, err);
                release_sock(sk);
-               if (len >= offsetofend(struct tcp_zerocopy_receive, err))
-                       goto zerocopy_rcv_sk_err;
+               if (len >= offsetofend(struct tcp_zerocopy_receive, msg_flags))
+                       goto zerocopy_rcv_cmsg;
                switch (len) {
+               case offsetofend(struct tcp_zerocopy_receive, msg_flags):
+                       goto zerocopy_rcv_cmsg;
+               case offsetofend(struct tcp_zerocopy_receive, msg_controllen):
+               case offsetofend(struct tcp_zerocopy_receive, msg_control):
+               case offsetofend(struct tcp_zerocopy_receive, flags):
+               case offsetofend(struct tcp_zerocopy_receive, copybuf_len):
+               case offsetofend(struct tcp_zerocopy_receive, copybuf_address):
                case offsetofend(struct tcp_zerocopy_receive, err):
                        goto zerocopy_rcv_sk_err;
                case offsetofend(struct tcp_zerocopy_receive, inq):
@@ -4111,6 +4183,11 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
                default:
                        goto zerocopy_rcv_out;
                }
+zerocopy_rcv_cmsg:
+               if (zc.msg_flags & TCP_CMSG_TS)
+                       tcp_zc_finalize_rx_tstamp(sk, &zc, &tss);
+               else
+                       zc.msg_flags = 0;
 zerocopy_rcv_sk_err:
                if (!err)
                        zc.err = sock_error(sk);
@@ -4133,6 +4210,18 @@ zerocopy_rcv_out:
        return 0;
 }
 
+bool tcp_bpf_bypass_getsockopt(int level, int optname)
+{
+       /* TCP do_tcp_getsockopt has optimized getsockopt implementation
+        * to avoid extra socket lock for TCP_ZEROCOPY_RECEIVE.
+        */
+       if (level == SOL_TCP && optname == TCP_ZEROCOPY_RECEIVE)
+               return true;
+
+       return false;
+}
+EXPORT_SYMBOL(tcp_bpf_bypass_getsockopt);
+
 int tcp_getsockopt(struct sock *sk, int level, int optname, char __user *optval,
                   int __user *optlen)
 {