tcp: fix a signed-integer-overflow bug in tcp_add_backlog()
[linux-2.6-microblaze.git] / net / ipv4 / tcp_ipv4.c
index 0c83780..87d440f 100644 (file)
@@ -199,16 +199,18 @@ static int tcp_v4_pre_connect(struct sock *sk, struct sockaddr *uaddr,
 /* This will initiate an outgoing connection. */
 int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 {
+       struct inet_bind_hashbucket *prev_addr_hashbucket = NULL;
        struct sockaddr_in *usin = (struct sockaddr_in *)uaddr;
+       struct inet_timewait_death_row *tcp_death_row;
+       __be32 daddr, nexthop, prev_sk_rcv_saddr;
        struct inet_sock *inet = inet_sk(sk);
        struct tcp_sock *tp = tcp_sk(sk);
+       struct ip_options_rcu *inet_opt;
+       struct net *net = sock_net(sk);
        __be16 orig_sport, orig_dport;
-       __be32 daddr, nexthop;
        struct flowi4 *fl4;
        struct rtable *rt;
        int err;
-       struct ip_options_rcu *inet_opt;
-       struct inet_timewait_death_row *tcp_death_row = sock_net(sk)->ipv4.tcp_death_row;
 
        if (addr_len < sizeof(struct sockaddr_in))
                return -EINVAL;
@@ -234,7 +236,7 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
        if (IS_ERR(rt)) {
                err = PTR_ERR(rt);
                if (err == -ENETUNREACH)
-                       IP_INC_STATS(sock_net(sk), IPSTATS_MIB_OUTNOROUTES);
+                       IP_INC_STATS(net, IPSTATS_MIB_OUTNOROUTES);
                return err;
        }
 
@@ -246,10 +248,29 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
        if (!inet_opt || !inet_opt->opt.srr)
                daddr = fl4->daddr;
 
-       if (!inet->inet_saddr)
+       tcp_death_row = &sock_net(sk)->ipv4.tcp_death_row;
+
+       if (!inet->inet_saddr) {
+               if (inet_csk(sk)->icsk_bind2_hash) {
+                       prev_addr_hashbucket = inet_bhashfn_portaddr(tcp_death_row->hashinfo,
+                                                                    sk, net, inet->inet_num);
+                       prev_sk_rcv_saddr = sk->sk_rcv_saddr;
+               }
                inet->inet_saddr = fl4->saddr;
+       }
+
        sk_rcv_saddr_set(sk, inet->inet_saddr);
 
+       if (prev_addr_hashbucket) {
+               err = inet_bhash2_update_saddr(prev_addr_hashbucket, sk);
+               if (err) {
+                       inet->inet_saddr = 0;
+                       sk_rcv_saddr_set(sk, prev_sk_rcv_saddr);
+                       ip_rt_put(rt);
+                       return err;
+               }
+       }
+
        if (tp->rx_opt.ts_recent_stamp && inet->inet_daddr != daddr) {
                /* Reset inherited state */
                tp->rx_opt.ts_recent       = 0;
@@ -298,12 +319,11 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
                                                  inet->inet_daddr,
                                                  inet->inet_sport,
                                                  usin->sin_port));
-               tp->tsoffset = secure_tcp_ts_off(sock_net(sk),
-                                                inet->inet_saddr,
+               tp->tsoffset = secure_tcp_ts_off(net, inet->inet_saddr,
                                                 inet->inet_daddr);
        }
 
-       inet->inet_id = prandom_u32();
+       inet->inet_id = get_random_u16();
 
        if (tcp_fastopen_defer_connect(sk, &err))
                return err;
@@ -475,9 +495,9 @@ int tcp_v4_err(struct sk_buff *skb, u32 info)
        int err;
        struct net *net = dev_net(skb->dev);
 
-       sk = __inet_lookup_established(net, &tcp_hashinfo, iph->daddr,
-                                      th->dest, iph->saddr, ntohs(th->source),
-                                      inet_iif(skb), 0);
+       sk = __inet_lookup_established(net, net->ipv4.tcp_death_row.hashinfo,
+                                      iph->daddr, th->dest, iph->saddr,
+                                      ntohs(th->source), inet_iif(skb), 0);
        if (!sk) {
                __ICMP_INC_STATS(net, ICMP_MIB_INERRORS);
                return -ENOENT;
@@ -740,8 +760,8 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb)
                 * Incoming packet is checked with md5 hash with finding key,
                 * no RST generated if md5 hash doesn't match.
                 */
-               sk1 = __inet_lookup_listener(net, &tcp_hashinfo, NULL, 0,
-                                            ip_hdr(skb)->saddr,
+               sk1 = __inet_lookup_listener(net, net->ipv4.tcp_death_row.hashinfo,
+                                            NULL, 0, ip_hdr(skb)->saddr,
                                             th->source, ip_hdr(skb)->daddr,
                                             ntohs(th->source), dif, sdif);
                /* don't send rst if it can't find key */
@@ -1523,7 +1543,7 @@ struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb,
        inet_csk(newsk)->icsk_ext_hdr_len = 0;
        if (inet_opt)
                inet_csk(newsk)->icsk_ext_hdr_len = inet_opt->opt.optlen;
-       newinet->inet_id = prandom_u32();
+       newinet->inet_id = get_random_u16();
 
        /* Set ToS of the new socket based upon the value of incoming SYN.
         * ECT bits are set later in tcp_init_transfer().
@@ -1709,6 +1729,7 @@ EXPORT_SYMBOL(tcp_v4_do_rcv);
 
 int tcp_v4_early_demux(struct sk_buff *skb)
 {
+       struct net *net = dev_net(skb->dev);
        const struct iphdr *iph;
        const struct tcphdr *th;
        struct sock *sk;
@@ -1725,7 +1746,7 @@ int tcp_v4_early_demux(struct sk_buff *skb)
        if (th->doff < sizeof(struct tcphdr) / 4)
                return 0;
 
-       sk = __inet_lookup_established(dev_net(skb->dev), &tcp_hashinfo,
+       sk = __inet_lookup_established(net, net->ipv4.tcp_death_row.hashinfo,
                                       iph->saddr, th->source,
                                       iph->daddr, ntohs(th->dest),
                                       skb->skb_iif, inet_sdif(skb));
@@ -1853,11 +1874,13 @@ bool tcp_add_backlog(struct sock *sk, struct sk_buff *skb,
        __skb_push(skb, hdrlen);
 
 no_coalesce:
+       limit = (u32)READ_ONCE(sk->sk_rcvbuf) + (u32)(READ_ONCE(sk->sk_sndbuf) >> 1);
+
        /* Only socket owner can try to collapse/prune rx queues
         * to reduce memory overhead, so add a little headroom here.
         * Few sockets backlog are possibly concurrently non empty.
         */
-       limit = READ_ONCE(sk->sk_rcvbuf) + READ_ONCE(sk->sk_sndbuf) + 64*1024;
+       limit += 64 * 1024;
 
        if (unlikely(sk_add_backlog(sk, skb, limit))) {
                bh_unlock_sock(sk);
@@ -1951,7 +1974,8 @@ int tcp_v4_rcv(struct sk_buff *skb)
        th = (const struct tcphdr *)skb->data;
        iph = ip_hdr(skb);
 lookup:
-       sk = __inet_lookup_skb(&tcp_hashinfo, skb, __tcp_hdrlen(th), th->source,
+       sk = __inet_lookup_skb(net->ipv4.tcp_death_row.hashinfo,
+                              skb, __tcp_hdrlen(th), th->source,
                               th->dest, sdif, &refcounted);
        if (!sk)
                goto no_tcp_socket;
@@ -2133,9 +2157,9 @@ do_time_wait:
        }
        switch (tcp_timewait_state_process(inet_twsk(sk), skb, th)) {
        case TCP_TW_SYN: {
-               struct sock *sk2 = inet_lookup_listener(dev_net(skb->dev),
-                                                       &tcp_hashinfo, skb,
-                                                       __tcp_hdrlen(th),
+               struct sock *sk2 = inet_lookup_listener(net,
+                                                       net->ipv4.tcp_death_row.hashinfo,
+                                                       skb, __tcp_hdrlen(th),
                                                        iph->saddr, th->source,
                                                        iph->daddr, th->dest,
                                                        inet_iif(skb),
@@ -2285,15 +2309,16 @@ static bool seq_sk_match(struct seq_file *seq, const struct sock *sk)
  */
 static void *listening_get_first(struct seq_file *seq)
 {
+       struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo;
        struct tcp_iter_state *st = seq->private;
 
        st->offset = 0;
-       for (; st->bucket <= tcp_hashinfo.lhash2_mask; st->bucket++) {
+       for (; st->bucket <= hinfo->lhash2_mask; st->bucket++) {
                struct inet_listen_hashbucket *ilb2;
                struct hlist_nulls_node *node;
                struct sock *sk;
 
-               ilb2 = &tcp_hashinfo.lhash2[st->bucket];
+               ilb2 = &hinfo->lhash2[st->bucket];
                if (hlist_nulls_empty(&ilb2->nulls_head))
                        continue;
 
@@ -2318,6 +2343,7 @@ static void *listening_get_next(struct seq_file *seq, void *cur)
        struct tcp_iter_state *st = seq->private;
        struct inet_listen_hashbucket *ilb2;
        struct hlist_nulls_node *node;
+       struct inet_hashinfo *hinfo;
        struct sock *sk = cur;
 
        ++st->num;
@@ -2329,7 +2355,8 @@ static void *listening_get_next(struct seq_file *seq, void *cur)
                        return sk;
        }
 
-       ilb2 = &tcp_hashinfo.lhash2[st->bucket];
+       hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo;
+       ilb2 = &hinfo->lhash2[st->bucket];
        spin_unlock(&ilb2->lock);
        ++st->bucket;
        return listening_get_first(seq);
@@ -2351,9 +2378,10 @@ static void *listening_get_idx(struct seq_file *seq, loff_t *pos)
        return rc;
 }
 
-static inline bool empty_bucket(const struct tcp_iter_state *st)
+static inline bool empty_bucket(struct inet_hashinfo *hinfo,
+                               const struct tcp_iter_state *st)
 {
-       return hlist_nulls_empty(&tcp_hashinfo.ehash[st->bucket].chain);
+       return hlist_nulls_empty(&hinfo->ehash[st->bucket].chain);
 }
 
 /*
@@ -2362,20 +2390,21 @@ static inline bool empty_bucket(const struct tcp_iter_state *st)
  */
 static void *established_get_first(struct seq_file *seq)
 {
+       struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo;
        struct tcp_iter_state *st = seq->private;
 
        st->offset = 0;
-       for (; st->bucket <= tcp_hashinfo.ehash_mask; ++st->bucket) {
+       for (; st->bucket <= hinfo->ehash_mask; ++st->bucket) {
                struct sock *sk;
                struct hlist_nulls_node *node;
-               spinlock_t *lock = inet_ehash_lockp(&tcp_hashinfo, st->bucket);
+               spinlock_t *lock = inet_ehash_lockp(hinfo, st->bucket);
 
                /* Lockless fast path for the common case of empty buckets */
-               if (empty_bucket(st))
+               if (empty_bucket(hinfo, st))
                        continue;
 
                spin_lock_bh(lock);
-               sk_nulls_for_each(sk, node, &tcp_hashinfo.ehash[st->bucket].chain) {
+               sk_nulls_for_each(sk, node, &hinfo->ehash[st->bucket].chain) {
                        if (seq_sk_match(seq, sk))
                                return sk;
                }
@@ -2387,9 +2416,10 @@ static void *established_get_first(struct seq_file *seq)
 
 static void *established_get_next(struct seq_file *seq, void *cur)
 {
-       struct sock *sk = cur;
-       struct hlist_nulls_node *node;
+       struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo;
        struct tcp_iter_state *st = seq->private;
+       struct hlist_nulls_node *node;
+       struct sock *sk = cur;
 
        ++st->num;
        ++st->offset;
@@ -2401,7 +2431,7 @@ static void *established_get_next(struct seq_file *seq, void *cur)
                        return sk;
        }
 
-       spin_unlock_bh(inet_ehash_lockp(&tcp_hashinfo, st->bucket));
+       spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket));
        ++st->bucket;
        return established_get_first(seq);
 }
@@ -2439,6 +2469,7 @@ static void *tcp_get_idx(struct seq_file *seq, loff_t pos)
 
 static void *tcp_seek_last_pos(struct seq_file *seq)
 {
+       struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo;
        struct tcp_iter_state *st = seq->private;
        int bucket = st->bucket;
        int offset = st->offset;
@@ -2447,7 +2478,7 @@ static void *tcp_seek_last_pos(struct seq_file *seq)
 
        switch (st->state) {
        case TCP_SEQ_STATE_LISTENING:
-               if (st->bucket > tcp_hashinfo.lhash2_mask)
+               if (st->bucket > hinfo->lhash2_mask)
                        break;
                st->state = TCP_SEQ_STATE_LISTENING;
                rc = listening_get_first(seq);
@@ -2459,7 +2490,7 @@ static void *tcp_seek_last_pos(struct seq_file *seq)
                st->state = TCP_SEQ_STATE_ESTABLISHED;
                fallthrough;
        case TCP_SEQ_STATE_ESTABLISHED:
-               if (st->bucket > tcp_hashinfo.ehash_mask)
+               if (st->bucket > hinfo->ehash_mask)
                        break;
                rc = established_get_first(seq);
                while (offset-- && rc && bucket == st->bucket)
@@ -2527,16 +2558,17 @@ EXPORT_SYMBOL(tcp_seq_next);
 
 void tcp_seq_stop(struct seq_file *seq, void *v)
 {
+       struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo;
        struct tcp_iter_state *st = seq->private;
 
        switch (st->state) {
        case TCP_SEQ_STATE_LISTENING:
                if (v != SEQ_START_TOKEN)
-                       spin_unlock(&tcp_hashinfo.lhash2[st->bucket].lock);
+                       spin_unlock(&hinfo->lhash2[st->bucket].lock);
                break;
        case TCP_SEQ_STATE_ESTABLISHED:
                if (v)
-                       spin_unlock_bh(inet_ehash_lockp(&tcp_hashinfo, st->bucket));
+                       spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket));
                break;
        }
 }
@@ -2731,6 +2763,7 @@ static int bpf_iter_tcp_realloc_batch(struct bpf_tcp_iter_state *iter,
 static unsigned int bpf_iter_tcp_listening_batch(struct seq_file *seq,
                                                 struct sock *start_sk)
 {
+       struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo;
        struct bpf_tcp_iter_state *iter = seq->private;
        struct tcp_iter_state *st = &iter->state;
        struct hlist_nulls_node *node;
@@ -2750,7 +2783,7 @@ static unsigned int bpf_iter_tcp_listening_batch(struct seq_file *seq,
                        expected++;
                }
        }
-       spin_unlock(&tcp_hashinfo.lhash2[st->bucket].lock);
+       spin_unlock(&hinfo->lhash2[st->bucket].lock);
 
        return expected;
 }
@@ -2758,6 +2791,7 @@ static unsigned int bpf_iter_tcp_listening_batch(struct seq_file *seq,
 static unsigned int bpf_iter_tcp_established_batch(struct seq_file *seq,
                                                   struct sock *start_sk)
 {
+       struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo;
        struct bpf_tcp_iter_state *iter = seq->private;
        struct tcp_iter_state *st = &iter->state;
        struct hlist_nulls_node *node;
@@ -2777,13 +2811,14 @@ static unsigned int bpf_iter_tcp_established_batch(struct seq_file *seq,
                        expected++;
                }
        }
-       spin_unlock_bh(inet_ehash_lockp(&tcp_hashinfo, st->bucket));
+       spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket));
 
        return expected;
 }
 
 static struct sock *bpf_iter_tcp_batch(struct seq_file *seq)
 {
+       struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo;
        struct bpf_tcp_iter_state *iter = seq->private;
        struct tcp_iter_state *st = &iter->state;
        unsigned int expected;
@@ -2799,7 +2834,7 @@ static struct sock *bpf_iter_tcp_batch(struct seq_file *seq)
                st->offset = 0;
                st->bucket++;
                if (st->state == TCP_SEQ_STATE_LISTENING &&
-                   st->bucket > tcp_hashinfo.lhash2_mask) {
+                   st->bucket > hinfo->lhash2_mask) {
                        st->state = TCP_SEQ_STATE_ESTABLISHED;
                        st->bucket = 0;
                }
@@ -3064,7 +3099,7 @@ struct proto tcp_prot = {
        .slab_flags             = SLAB_TYPESAFE_BY_RCU,
        .twsk_prot              = &tcp_timewait_sock_ops,
        .rsk_prot               = &tcp_request_sock_ops,
-       .h.hashinfo             = &tcp_hashinfo,
+       .h.hashinfo             = NULL,
        .no_autobind            = true,
        .diag_destroy           = tcp_abort,
 };
@@ -3072,19 +3107,43 @@ EXPORT_SYMBOL(tcp_prot);
 
 static void __net_exit tcp_sk_exit(struct net *net)
 {
-       struct inet_timewait_death_row *tcp_death_row = net->ipv4.tcp_death_row;
-
        if (net->ipv4.tcp_congestion_control)
                bpf_module_put(net->ipv4.tcp_congestion_control,
                               net->ipv4.tcp_congestion_control->owner);
-       if (refcount_dec_and_test(&tcp_death_row->tw_refcount))
-               kfree(tcp_death_row);
 }
 
-static int __net_init tcp_sk_init(struct net *net)
+static void __net_init tcp_set_hashinfo(struct net *net)
 {
-       int cnt;
+       struct inet_hashinfo *hinfo;
+       unsigned int ehash_entries;
+       struct net *old_net;
+
+       if (net_eq(net, &init_net))
+               goto fallback;
 
+       old_net = current->nsproxy->net_ns;
+       ehash_entries = READ_ONCE(old_net->ipv4.sysctl_tcp_child_ehash_entries);
+       if (!ehash_entries)
+               goto fallback;
+
+       ehash_entries = roundup_pow_of_two(ehash_entries);
+       hinfo = inet_pernet_hashinfo_alloc(&tcp_hashinfo, ehash_entries);
+       if (!hinfo) {
+               pr_warn("Failed to allocate TCP ehash (entries: %u) "
+                       "for a netns, fallback to the global one\n",
+                       ehash_entries);
+fallback:
+               hinfo = &tcp_hashinfo;
+               ehash_entries = tcp_hashinfo.ehash_mask + 1;
+       }
+
+       net->ipv4.tcp_death_row.hashinfo = hinfo;
+       net->ipv4.tcp_death_row.sysctl_max_tw_buckets = ehash_entries / 2;
+       net->ipv4.sysctl_max_syn_backlog = max(128U, ehash_entries / 128);
+}
+
+static int __net_init tcp_sk_init(struct net *net)
+{
        net->ipv4.sysctl_tcp_ecn = 2;
        net->ipv4.sysctl_tcp_ecn_fallback = 1;
 
@@ -3110,15 +3169,9 @@ static int __net_init tcp_sk_init(struct net *net)
        net->ipv4.sysctl_tcp_tw_reuse = 2;
        net->ipv4.sysctl_tcp_no_ssthresh_metrics_save = 1;
 
-       net->ipv4.tcp_death_row = kzalloc(sizeof(struct inet_timewait_death_row), GFP_KERNEL);
-       if (!net->ipv4.tcp_death_row)
-               return -ENOMEM;
-       refcount_set(&net->ipv4.tcp_death_row->tw_refcount, 1);
-       cnt = tcp_hashinfo.ehash_mask + 1;
-       net->ipv4.tcp_death_row->sysctl_max_tw_buckets = cnt / 2;
-       net->ipv4.tcp_death_row->hashinfo = &tcp_hashinfo;
+       refcount_set(&net->ipv4.tcp_death_row.tw_refcount, 1);
+       tcp_set_hashinfo(net);
 
-       net->ipv4.sysctl_max_syn_backlog = max(128, cnt / 128);
        net->ipv4.sysctl_tcp_sack = 1;
        net->ipv4.sysctl_tcp_window_scaling = 1;
        net->ipv4.sysctl_tcp_timestamps = 1;
@@ -3139,8 +3192,10 @@ static int __net_init tcp_sk_init(struct net *net)
        net->ipv4.sysctl_tcp_tso_win_divisor = 3;
        /* Default TSQ limit of 16 TSO segments */
        net->ipv4.sysctl_tcp_limit_output_bytes = 16 * 65536;
-       /* rfc5961 challenge ack rate limiting */
-       net->ipv4.sysctl_tcp_challenge_ack_limit = 1000;
+
+       /* rfc5961 challenge ack rate limiting, per net-ns, disabled by default. */
+       net->ipv4.sysctl_tcp_challenge_ack_limit = INT_MAX;
+
        net->ipv4.sysctl_tcp_min_tso_segs = 2;
        net->ipv4.sysctl_tcp_tso_rtt_log = 9;  /* 2^9 = 512 usec */
        net->ipv4.sysctl_tcp_min_rtt_wlen = 300;
@@ -3178,10 +3233,13 @@ static void __net_exit tcp_sk_exit_batch(struct list_head *net_exit_list)
 {
        struct net *net;
 
-       inet_twsk_purge(&tcp_hashinfo, AF_INET);
+       tcp_twsk_purge(net_exit_list, AF_INET);
 
-       list_for_each_entry(net, net_exit_list, exit_list)
+       list_for_each_entry(net, net_exit_list, exit_list) {
+               inet_pernet_hashinfo_free(net->ipv4.tcp_death_row.hashinfo);
+               WARN_ON_ONCE(!refcount_dec_and_test(&net->ipv4.tcp_death_row.tw_refcount));
                tcp_fastopen_ctx_destroy(net);
+       }
 }
 
 static struct pernet_operations __net_initdata tcp_sk_ops = {