Merge branch 'inet-exceptions-less-predictable'
[linux-2.6-microblaze.git] / net / ipv4 / tcp_ipv4.c
index e66ad6b..2e62e0d 100644 (file)
@@ -342,7 +342,7 @@ void tcp_v4_mtu_reduced(struct sock *sk)
 
        if ((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE))
                return;
-       mtu = tcp_sk(sk)->mtu_info;
+       mtu = READ_ONCE(tcp_sk(sk)->mtu_info);
        dst = inet_csk_update_pmtu(sk, mtu);
        if (!dst)
                return;
@@ -546,7 +546,7 @@ int tcp_v4_err(struct sk_buff *skb, u32 info)
                        if (sk->sk_state == TCP_LISTEN)
                                goto out;
 
-                       tp->mtu_info = info;
+                       WRITE_ONCE(tp->mtu_info, info);
                        if (!sock_owned_by_user(sk)) {
                                tcp_v4_mtu_reduced(sk);
                        } else {
@@ -2277,51 +2277,72 @@ EXPORT_SYMBOL(tcp_v4_destroy_sock);
 #ifdef CONFIG_PROC_FS
 /* Proc filesystem TCP sock list dumping. */
 
-/*
- * Get next listener socket follow cur.  If cur is NULL, get first socket
- * starting from bucket given in st->bucket; when st->bucket is zero the
- * very first socket in the hash table is returned.
+static unsigned short seq_file_family(const struct seq_file *seq);
+
+static bool seq_sk_match(struct seq_file *seq, const struct sock *sk)
+{
+       unsigned short family = seq_file_family(seq);
+
+       /* AF_UNSPEC is used as a match all */
+       return ((family == AF_UNSPEC || family == sk->sk_family) &&
+               net_eq(sock_net(sk), seq_file_net(seq)));
+}
+
+/* Find a non empty bucket (starting from st->bucket)
+ * and return the first sk from it.
  */
-static void *listening_get_next(struct seq_file *seq, void *cur)
+static void *listening_get_first(struct seq_file *seq)
 {
-       struct tcp_seq_afinfo *afinfo;
        struct tcp_iter_state *st = seq->private;
-       struct net *net = seq_file_net(seq);
-       struct inet_listen_hashbucket *ilb;
-       struct hlist_nulls_node *node;
-       struct sock *sk = cur;
 
-       if (st->bpf_seq_afinfo)
-               afinfo = st->bpf_seq_afinfo;
-       else
-               afinfo = PDE_DATA(file_inode(seq->file));
+       st->offset = 0;
+       for (; st->bucket <= tcp_hashinfo.lhash2_mask; st->bucket++) {
+               struct inet_listen_hashbucket *ilb2;
+               struct inet_connection_sock *icsk;
+               struct sock *sk;
 
-       if (!sk) {
-get_head:
-               ilb = &tcp_hashinfo.listening_hash[st->bucket];
-               spin_lock(&ilb->lock);
-               sk = sk_nulls_head(&ilb->nulls_head);
-               st->offset = 0;
-               goto get_sk;
+               ilb2 = &tcp_hashinfo.lhash2[st->bucket];
+               if (hlist_empty(&ilb2->head))
+                       continue;
+
+               spin_lock(&ilb2->lock);
+               inet_lhash2_for_each_icsk(icsk, &ilb2->head) {
+                       sk = (struct sock *)icsk;
+                       if (seq_sk_match(seq, sk))
+                               return sk;
+               }
+               spin_unlock(&ilb2->lock);
        }
-       ilb = &tcp_hashinfo.listening_hash[st->bucket];
+
+       return NULL;
+}
+
+/* Find the next sk of "cur" within the same bucket (i.e. st->bucket).
+ * If "cur" is the last one in the st->bucket,
+ * call listening_get_first() to return the first sk of the next
+ * non empty bucket.
+ */
+static void *listening_get_next(struct seq_file *seq, void *cur)
+{
+       struct tcp_iter_state *st = seq->private;
+       struct inet_listen_hashbucket *ilb2;
+       struct inet_connection_sock *icsk;
+       struct sock *sk = cur;
+
        ++st->num;
        ++st->offset;
 
-       sk = sk_nulls_next(sk);
-get_sk:
-       sk_nulls_for_each_from(sk, node) {
-               if (!net_eq(sock_net(sk), net))
-                       continue;
-               if (afinfo->family == AF_UNSPEC ||
-                   sk->sk_family == afinfo->family)
+       icsk = inet_csk(sk);
+       inet_lhash2_for_each_icsk_continue(icsk) {
+               sk = (struct sock *)icsk;
+               if (seq_sk_match(seq, sk))
                        return sk;
        }
-       spin_unlock(&ilb->lock);
-       st->offset = 0;
-       if (++st->bucket < INET_LHTABLE_SIZE)
-               goto get_head;
-       return NULL;
+
+       ilb2 = &tcp_hashinfo.lhash2[st->bucket];
+       spin_unlock(&ilb2->lock);
+       ++st->bucket;
+       return listening_get_first(seq);
 }
 
 static void *listening_get_idx(struct seq_file *seq, loff_t *pos)
@@ -2331,7 +2352,7 @@ static void *listening_get_idx(struct seq_file *seq, loff_t *pos)
 
        st->bucket = 0;
        st->offset = 0;
-       rc = listening_get_next(seq, NULL);
+       rc = listening_get_first(seq);
 
        while (rc && *pos) {
                rc = listening_get_next(seq, rc);
@@ -2351,15 +2372,7 @@ static inline bool empty_bucket(const struct tcp_iter_state *st)
  */
 static void *established_get_first(struct seq_file *seq)
 {
-       struct tcp_seq_afinfo *afinfo;
        struct tcp_iter_state *st = seq->private;
-       struct net *net = seq_file_net(seq);
-       void *rc = NULL;
-
-       if (st->bpf_seq_afinfo)
-               afinfo = st->bpf_seq_afinfo;
-       else
-               afinfo = PDE_DATA(file_inode(seq->file));
 
        st->offset = 0;
        for (; st->bucket <= tcp_hashinfo.ehash_mask; ++st->bucket) {
@@ -2373,32 +2386,20 @@ static void *established_get_first(struct seq_file *seq)
 
                spin_lock_bh(lock);
                sk_nulls_for_each(sk, node, &tcp_hashinfo.ehash[st->bucket].chain) {
-                       if ((afinfo->family != AF_UNSPEC &&
-                            sk->sk_family != afinfo->family) ||
-                           !net_eq(sock_net(sk), net)) {
-                               continue;
-                       }
-                       rc = sk;
-                       goto out;
+                       if (seq_sk_match(seq, sk))
+                               return sk;
                }
                spin_unlock_bh(lock);
        }
-out:
-       return rc;
+
+       return NULL;
 }
 
 static void *established_get_next(struct seq_file *seq, void *cur)
 {
-       struct tcp_seq_afinfo *afinfo;
        struct sock *sk = cur;
        struct hlist_nulls_node *node;
        struct tcp_iter_state *st = seq->private;
-       struct net *net = seq_file_net(seq);
-
-       if (st->bpf_seq_afinfo)
-               afinfo = st->bpf_seq_afinfo;
-       else
-               afinfo = PDE_DATA(file_inode(seq->file));
 
        ++st->num;
        ++st->offset;
@@ -2406,9 +2407,7 @@ static void *established_get_next(struct seq_file *seq, void *cur)
        sk = sk_nulls_next(sk);
 
        sk_nulls_for_each_from(sk, node) {
-               if ((afinfo->family == AF_UNSPEC ||
-                    sk->sk_family == afinfo->family) &&
-                   net_eq(sock_net(sk), net))
+               if (seq_sk_match(seq, sk))
                        return sk;
        }
 
@@ -2451,17 +2450,18 @@ static void *tcp_get_idx(struct seq_file *seq, loff_t pos)
 static void *tcp_seek_last_pos(struct seq_file *seq)
 {
        struct tcp_iter_state *st = seq->private;
+       int bucket = st->bucket;
        int offset = st->offset;
        int orig_num = st->num;
        void *rc = NULL;
 
        switch (st->state) {
        case TCP_SEQ_STATE_LISTENING:
-               if (st->bucket >= INET_LHTABLE_SIZE)
+               if (st->bucket > tcp_hashinfo.lhash2_mask)
                        break;
                st->state = TCP_SEQ_STATE_LISTENING;
-               rc = listening_get_next(seq, NULL);
-               while (offset-- && rc)
+               rc = listening_get_first(seq);
+               while (offset-- && rc && bucket == st->bucket)
                        rc = listening_get_next(seq, rc);
                if (rc)
                        break;
@@ -2472,7 +2472,7 @@ static void *tcp_seek_last_pos(struct seq_file *seq)
                if (st->bucket > tcp_hashinfo.ehash_mask)
                        break;
                rc = established_get_first(seq);
-               while (offset-- && rc)
+               while (offset-- && rc && bucket == st->bucket)
                        rc = established_get_next(seq, rc);
        }
 
@@ -2542,7 +2542,7 @@ void tcp_seq_stop(struct seq_file *seq, void *v)
        switch (st->state) {
        case TCP_SEQ_STATE_LISTENING:
                if (v != SEQ_START_TOKEN)
-                       spin_unlock(&tcp_hashinfo.listening_hash[st->bucket].lock);
+                       spin_unlock(&tcp_hashinfo.lhash2[st->bucket].lock);
                break;
        case TCP_SEQ_STATE_ESTABLISHED:
                if (v)
@@ -2687,6 +2687,15 @@ out:
 }
 
 #ifdef CONFIG_BPF_SYSCALL
+struct bpf_tcp_iter_state {
+       struct tcp_iter_state state;
+       unsigned int cur_sk;
+       unsigned int end_sk;
+       unsigned int max_sk;
+       struct sock **batch;
+       bool st_bucket_done;
+};
+
 struct bpf_iter__tcp {
        __bpf_md_ptr(struct bpf_iter_meta *, meta);
        __bpf_md_ptr(struct sock_common *, sk_common);
@@ -2705,16 +2714,204 @@ static int tcp_prog_seq_show(struct bpf_prog *prog, struct bpf_iter_meta *meta,
        return bpf_iter_run_prog(prog, &ctx);
 }
 
+static void bpf_iter_tcp_put_batch(struct bpf_tcp_iter_state *iter)
+{
+       while (iter->cur_sk < iter->end_sk)
+               sock_put(iter->batch[iter->cur_sk++]);
+}
+
+static int bpf_iter_tcp_realloc_batch(struct bpf_tcp_iter_state *iter,
+                                     unsigned int new_batch_sz)
+{
+       struct sock **new_batch;
+
+       new_batch = kvmalloc(sizeof(*new_batch) * new_batch_sz,
+                            GFP_USER | __GFP_NOWARN);
+       if (!new_batch)
+               return -ENOMEM;
+
+       bpf_iter_tcp_put_batch(iter);
+       kvfree(iter->batch);
+       iter->batch = new_batch;
+       iter->max_sk = new_batch_sz;
+
+       return 0;
+}
+
+static unsigned int bpf_iter_tcp_listening_batch(struct seq_file *seq,
+                                                struct sock *start_sk)
+{
+       struct bpf_tcp_iter_state *iter = seq->private;
+       struct tcp_iter_state *st = &iter->state;
+       struct inet_connection_sock *icsk;
+       unsigned int expected = 1;
+       struct sock *sk;
+
+       sock_hold(start_sk);
+       iter->batch[iter->end_sk++] = start_sk;
+
+       icsk = inet_csk(start_sk);
+       inet_lhash2_for_each_icsk_continue(icsk) {
+               sk = (struct sock *)icsk;
+               if (seq_sk_match(seq, sk)) {
+                       if (iter->end_sk < iter->max_sk) {
+                               sock_hold(sk);
+                               iter->batch[iter->end_sk++] = sk;
+                       }
+                       expected++;
+               }
+       }
+       spin_unlock(&tcp_hashinfo.lhash2[st->bucket].lock);
+
+       return expected;
+}
+
+static unsigned int bpf_iter_tcp_established_batch(struct seq_file *seq,
+                                                  struct sock *start_sk)
+{
+       struct bpf_tcp_iter_state *iter = seq->private;
+       struct tcp_iter_state *st = &iter->state;
+       struct hlist_nulls_node *node;
+       unsigned int expected = 1;
+       struct sock *sk;
+
+       sock_hold(start_sk);
+       iter->batch[iter->end_sk++] = start_sk;
+
+       sk = sk_nulls_next(start_sk);
+       sk_nulls_for_each_from(sk, node) {
+               if (seq_sk_match(seq, sk)) {
+                       if (iter->end_sk < iter->max_sk) {
+                               sock_hold(sk);
+                               iter->batch[iter->end_sk++] = sk;
+                       }
+                       expected++;
+               }
+       }
+       spin_unlock_bh(inet_ehash_lockp(&tcp_hashinfo, st->bucket));
+
+       return expected;
+}
+
+static struct sock *bpf_iter_tcp_batch(struct seq_file *seq)
+{
+       struct bpf_tcp_iter_state *iter = seq->private;
+       struct tcp_iter_state *st = &iter->state;
+       unsigned int expected;
+       bool resized = false;
+       struct sock *sk;
+
+       /* The st->bucket is done.  Directly advance to the next
+        * bucket instead of having the tcp_seek_last_pos() to skip
+        * one by one in the current bucket and eventually find out
+        * it has to advance to the next bucket.
+        */
+       if (iter->st_bucket_done) {
+               st->offset = 0;
+               st->bucket++;
+               if (st->state == TCP_SEQ_STATE_LISTENING &&
+                   st->bucket > tcp_hashinfo.lhash2_mask) {
+                       st->state = TCP_SEQ_STATE_ESTABLISHED;
+                       st->bucket = 0;
+               }
+       }
+
+again:
+       /* Get a new batch */
+       iter->cur_sk = 0;
+       iter->end_sk = 0;
+       iter->st_bucket_done = false;
+
+       sk = tcp_seek_last_pos(seq);
+       if (!sk)
+               return NULL; /* Done */
+
+       if (st->state == TCP_SEQ_STATE_LISTENING)
+               expected = bpf_iter_tcp_listening_batch(seq, sk);
+       else
+               expected = bpf_iter_tcp_established_batch(seq, sk);
+
+       if (iter->end_sk == expected) {
+               iter->st_bucket_done = true;
+               return sk;
+       }
+
+       if (!resized && !bpf_iter_tcp_realloc_batch(iter, expected * 3 / 2)) {
+               resized = true;
+               goto again;
+       }
+
+       return sk;
+}
+
+static void *bpf_iter_tcp_seq_start(struct seq_file *seq, loff_t *pos)
+{
+       /* bpf iter does not support lseek, so it always
+        * continue from where it was stop()-ped.
+        */
+       if (*pos)
+               return bpf_iter_tcp_batch(seq);
+
+       return SEQ_START_TOKEN;
+}
+
+static void *bpf_iter_tcp_seq_next(struct seq_file *seq, void *v, loff_t *pos)
+{
+       struct bpf_tcp_iter_state *iter = seq->private;
+       struct tcp_iter_state *st = &iter->state;
+       struct sock *sk;
+
+       /* Whenever seq_next() is called, the iter->cur_sk is
+        * done with seq_show(), so advance to the next sk in
+        * the batch.
+        */
+       if (iter->cur_sk < iter->end_sk) {
+               /* Keeping st->num consistent in tcp_iter_state.
+                * bpf_iter_tcp does not use st->num.
+                * meta.seq_num is used instead.
+                */
+               st->num++;
+               /* Move st->offset to the next sk in the bucket such that
+                * the future start() will resume at st->offset in
+                * st->bucket.  See tcp_seek_last_pos().
+                */
+               st->offset++;
+               sock_put(iter->batch[iter->cur_sk++]);
+       }
+
+       if (iter->cur_sk < iter->end_sk)
+               sk = iter->batch[iter->cur_sk];
+       else
+               sk = bpf_iter_tcp_batch(seq);
+
+       ++*pos;
+       /* Keeping st->last_pos consistent in tcp_iter_state.
+        * bpf iter does not do lseek, so st->last_pos always equals to *pos.
+        */
+       st->last_pos = *pos;
+       return sk;
+}
+
 static int bpf_iter_tcp_seq_show(struct seq_file *seq, void *v)
 {
        struct bpf_iter_meta meta;
        struct bpf_prog *prog;
        struct sock *sk = v;
+       bool slow;
        uid_t uid;
+       int ret;
 
        if (v == SEQ_START_TOKEN)
                return 0;
 
+       if (sk_fullsock(sk))
+               slow = lock_sock_fast(sk);
+
+       if (unlikely(sk_unhashed(sk))) {
+               ret = SEQ_SKIP;
+               goto unlock;
+       }
+
        if (sk->sk_state == TCP_TIME_WAIT) {
                uid = 0;
        } else if (sk->sk_state == TCP_NEW_SYN_RECV) {
@@ -2728,11 +2925,18 @@ static int bpf_iter_tcp_seq_show(struct seq_file *seq, void *v)
 
        meta.seq = seq;
        prog = bpf_iter_get_info(&meta, false);
-       return tcp_prog_seq_show(prog, &meta, v, uid);
+       ret = tcp_prog_seq_show(prog, &meta, v, uid);
+
+unlock:
+       if (sk_fullsock(sk))
+               unlock_sock_fast(sk, slow);
+       return ret;
+
 }
 
 static void bpf_iter_tcp_seq_stop(struct seq_file *seq, void *v)
 {
+       struct bpf_tcp_iter_state *iter = seq->private;
        struct bpf_iter_meta meta;
        struct bpf_prog *prog;
 
@@ -2743,17 +2947,34 @@ static void bpf_iter_tcp_seq_stop(struct seq_file *seq, void *v)
                        (void)tcp_prog_seq_show(prog, &meta, v, 0);
        }
 
-       tcp_seq_stop(seq, v);
+       if (iter->cur_sk < iter->end_sk) {
+               bpf_iter_tcp_put_batch(iter);
+               iter->st_bucket_done = false;
+       }
 }
 
 static const struct seq_operations bpf_iter_tcp_seq_ops = {
        .show           = bpf_iter_tcp_seq_show,
-       .start          = tcp_seq_start,
-       .next           = tcp_seq_next,
+       .start          = bpf_iter_tcp_seq_start,
+       .next           = bpf_iter_tcp_seq_next,
        .stop           = bpf_iter_tcp_seq_stop,
 };
+#endif
+static unsigned short seq_file_family(const struct seq_file *seq)
+{
+       const struct tcp_seq_afinfo *afinfo;
+
+#ifdef CONFIG_BPF_SYSCALL
+       /* Iterated from bpf_iter.  Let the bpf prog to filter instead. */
+       if (seq->op == &bpf_iter_tcp_seq_ops)
+               return AF_UNSPEC;
 #endif
 
+       /* Iterated from proc fs */
+       afinfo = PDE_DATA(file_inode(seq->file));
+       return afinfo->family;
+}
+
 static const struct seq_operations tcp4_seq_ops = {
        .show           = tcp4_seq_show,
        .start          = tcp_seq_start,
@@ -2964,8 +3185,7 @@ static int __net_init tcp_sk_init(struct net *net)
        net->ipv4.sysctl_tcp_comp_sack_slack_ns = 100 * NSEC_PER_USEC;
        net->ipv4.sysctl_tcp_comp_sack_nr = 44;
        net->ipv4.sysctl_tcp_fastopen = TFO_CLIENT_ENABLE;
-       spin_lock_init(&net->ipv4.tcp_fastopen_ctx_lock);
-       net->ipv4.sysctl_tcp_fastopen_blackhole_timeout = 60 * 60;
+       net->ipv4.sysctl_tcp_fastopen_blackhole_timeout = 0;
        atomic_set(&net->ipv4.tfo_active_disable_times, 0);
 
        /* Reno is always built in */
@@ -3003,39 +3223,55 @@ static struct pernet_operations __net_initdata tcp_sk_ops = {
 DEFINE_BPF_ITER_FUNC(tcp, struct bpf_iter_meta *meta,
                     struct sock_common *sk_common, uid_t uid)
 
+#define INIT_BATCH_SZ 16
+
 static int bpf_iter_init_tcp(void *priv_data, struct bpf_iter_aux_info *aux)
 {
-       struct tcp_iter_state *st = priv_data;
-       struct tcp_seq_afinfo *afinfo;
-       int ret;
+       struct bpf_tcp_iter_state *iter = priv_data;
+       int err;
 
-       afinfo = kmalloc(sizeof(*afinfo), GFP_USER | __GFP_NOWARN);
-       if (!afinfo)
-               return -ENOMEM;
+       err = bpf_iter_init_seq_net(priv_data, aux);
+       if (err)
+               return err;
 
-       afinfo->family = AF_UNSPEC;
-       st->bpf_seq_afinfo = afinfo;
-       ret = bpf_iter_init_seq_net(priv_data, aux);
-       if (ret)
-               kfree(afinfo);
-       return ret;
+       err = bpf_iter_tcp_realloc_batch(iter, INIT_BATCH_SZ);
+       if (err) {
+               bpf_iter_fini_seq_net(priv_data);
+               return err;
+       }
+
+       return 0;
 }
 
 static void bpf_iter_fini_tcp(void *priv_data)
 {
-       struct tcp_iter_state *st = priv_data;
+       struct bpf_tcp_iter_state *iter = priv_data;
 
-       kfree(st->bpf_seq_afinfo);
        bpf_iter_fini_seq_net(priv_data);
+       kvfree(iter->batch);
 }
 
 static const struct bpf_iter_seq_info tcp_seq_info = {
        .seq_ops                = &bpf_iter_tcp_seq_ops,
        .init_seq_private       = bpf_iter_init_tcp,
        .fini_seq_private       = bpf_iter_fini_tcp,
-       .seq_priv_size          = sizeof(struct tcp_iter_state),
+       .seq_priv_size          = sizeof(struct bpf_tcp_iter_state),
 };
 
+static const struct bpf_func_proto *
+bpf_iter_tcp_get_func_proto(enum bpf_func_id func_id,
+                           const struct bpf_prog *prog)
+{
+       switch (func_id) {
+       case BPF_FUNC_setsockopt:
+               return &bpf_sk_setsockopt_proto;
+       case BPF_FUNC_getsockopt:
+               return &bpf_sk_getsockopt_proto;
+       default:
+               return NULL;
+       }
+}
+
 static struct bpf_iter_reg tcp_reg_info = {
        .target                 = "tcp",
        .ctx_arg_info_size      = 1,
@@ -3043,6 +3279,7 @@ static struct bpf_iter_reg tcp_reg_info = {
                { offsetof(struct bpf_iter__tcp, sk_common),
                  PTR_TO_BTF_ID_OR_NULL },
        },
+       .get_func_proto         = bpf_iter_tcp_get_func_proto,
        .seq_info               = &tcp_seq_info,
 };