Merge tag 'powerpc-5.9-3' of git://git.kernel.org/pub/scm/linux/kernel/git/powerpc...
[linux-2.6-microblaze.git] / net / mptcp / protocol.c
index c0abe73..1aad411 100644 (file)
@@ -16,6 +16,7 @@
 #include <net/inet_hashtables.h>
 #include <net/protocol.h>
 #include <net/tcp.h>
+#include <net/tcp_states.h>
 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
 #include <net/transp_v6.h>
 #endif
@@ -52,18 +53,10 @@ static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk)
        return msk->subflow;
 }
 
-static bool __mptcp_needs_tcp_fallback(const struct mptcp_sock *msk)
-{
-       return msk->first && !sk_is_mptcp(msk->first);
-}
-
-static struct socket *mptcp_is_tcpsk(struct sock *sk)
+static bool mptcp_is_tcpsk(struct sock *sk)
 {
        struct socket *sock = sk->sk_socket;
 
-       if (sock->sk != sk)
-               return NULL;
-
        if (unlikely(sk->sk_prot == &tcp_prot)) {
                /* we are being invoked after mptcp_accept() has
                 * accepted a non-mp-capable flow: sk is a tcp_sk,
@@ -73,59 +66,37 @@ static struct socket *mptcp_is_tcpsk(struct sock *sk)
                 * bypass mptcp.
                 */
                sock->ops = &inet_stream_ops;
-               return sock;
+               return true;
 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
        } else if (unlikely(sk->sk_prot == &tcpv6_prot)) {
                sock->ops = &inet6_stream_ops;
-               return sock;
+               return true;
 #endif
        }
 
-       return NULL;
+       return false;
 }
 
-static struct socket *__mptcp_tcp_fallback(struct mptcp_sock *msk)
+static struct sock *__mptcp_tcp_fallback(struct mptcp_sock *msk)
 {
-       struct socket *sock;
-
        sock_owned_by_me((const struct sock *)msk);
 
-       sock = mptcp_is_tcpsk((struct sock *)msk);
-       if (unlikely(sock))
-               return sock;
-
-       if (likely(!__mptcp_needs_tcp_fallback(msk)))
+       if (likely(!__mptcp_check_fallback(msk)))
                return NULL;
 
-       return msk->subflow;
-}
-
-static bool __mptcp_can_create_subflow(const struct mptcp_sock *msk)
-{
-       return !msk->first;
+       return msk->first;
 }
 
-static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state)
+static int __mptcp_socket_create(struct mptcp_sock *msk)
 {
        struct mptcp_subflow_context *subflow;
        struct sock *sk = (struct sock *)msk;
        struct socket *ssock;
        int err;
 
-       ssock = __mptcp_tcp_fallback(msk);
-       if (unlikely(ssock))
-               return ssock;
-
-       ssock = __mptcp_nmpc_socket(msk);
-       if (ssock)
-               goto set_state;
-
-       if (!__mptcp_can_create_subflow(msk))
-               return ERR_PTR(-EINVAL);
-
        err = mptcp_subflow_create_socket(sk, &ssock);
        if (err)
-               return ERR_PTR(err);
+               return err;
 
        msk->first = ssock->sk;
        msk->subflow = ssock;
@@ -133,10 +104,12 @@ static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state)
        list_add(&subflow->node, &msk->conn_list);
        subflow->request_mptcp = 1;
 
-set_state:
-       if (state != MPTCP_SAME_STATE)
-               inet_sk_state_store(sk, state);
-       return ssock;
+       /* accept() will wait on first subflow sk_wq, and we always wakes up
+        * via msk->sk_socket
+        */
+       RCU_INIT_POINTER(msk->first->sk_wq, &sk->sk_socket->wq);
+
+       return 0;
 }
 
 static void __mptcp_move_skb(struct mptcp_sock *msk, struct sock *ssk,
@@ -170,6 +143,14 @@ static void __mptcp_move_skb(struct mptcp_sock *msk, struct sock *ssk,
        MPTCP_SKB_CB(skb)->offset = offset;
 }
 
+static void mptcp_stop_timer(struct sock *sk)
+{
+       struct inet_connection_sock *icsk = inet_csk(sk);
+
+       sk_stop_timer(sk, &icsk->icsk_retransmit_timer);
+       mptcp_sk(sk)->timer_ival = 0;
+}
+
 /* both sockets must be locked */
 static bool mptcp_subflow_dsn_valid(const struct mptcp_sock *msk,
                                    struct sock *ssk)
@@ -191,6 +172,139 @@ static bool mptcp_subflow_dsn_valid(const struct mptcp_sock *msk,
        return mptcp_subflow_data_available(ssk);
 }
 
+static void mptcp_check_data_fin_ack(struct sock *sk)
+{
+       struct mptcp_sock *msk = mptcp_sk(sk);
+
+       if (__mptcp_check_fallback(msk))
+               return;
+
+       /* Look for an acknowledged DATA_FIN */
+       if (((1 << sk->sk_state) &
+            (TCPF_FIN_WAIT1 | TCPF_CLOSING | TCPF_LAST_ACK)) &&
+           msk->write_seq == atomic64_read(&msk->snd_una)) {
+               mptcp_stop_timer(sk);
+
+               WRITE_ONCE(msk->snd_data_fin_enable, 0);
+
+               switch (sk->sk_state) {
+               case TCP_FIN_WAIT1:
+                       inet_sk_state_store(sk, TCP_FIN_WAIT2);
+                       sk->sk_state_change(sk);
+                       break;
+               case TCP_CLOSING:
+                       fallthrough;
+               case TCP_LAST_ACK:
+                       inet_sk_state_store(sk, TCP_CLOSE);
+                       sk->sk_state_change(sk);
+                       break;
+               }
+
+               if (sk->sk_shutdown == SHUTDOWN_MASK ||
+                   sk->sk_state == TCP_CLOSE)
+                       sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_HUP);
+               else
+                       sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
+       }
+}
+
+static bool mptcp_pending_data_fin(struct sock *sk, u64 *seq)
+{
+       struct mptcp_sock *msk = mptcp_sk(sk);
+
+       if (READ_ONCE(msk->rcv_data_fin) &&
+           ((1 << sk->sk_state) &
+            (TCPF_ESTABLISHED | TCPF_FIN_WAIT1 | TCPF_FIN_WAIT2))) {
+               u64 rcv_data_fin_seq = READ_ONCE(msk->rcv_data_fin_seq);
+
+               if (msk->ack_seq == rcv_data_fin_seq) {
+                       if (seq)
+                               *seq = rcv_data_fin_seq;
+
+                       return true;
+               }
+       }
+
+       return false;
+}
+
+static void mptcp_set_timeout(const struct sock *sk, const struct sock *ssk)
+{
+       long tout = ssk && inet_csk(ssk)->icsk_pending ?
+                                     inet_csk(ssk)->icsk_timeout - jiffies : 0;
+
+       if (tout <= 0)
+               tout = mptcp_sk(sk)->timer_ival;
+       mptcp_sk(sk)->timer_ival = tout > 0 ? tout : TCP_RTO_MIN;
+}
+
+static void mptcp_check_data_fin(struct sock *sk)
+{
+       struct mptcp_sock *msk = mptcp_sk(sk);
+       u64 rcv_data_fin_seq;
+
+       if (__mptcp_check_fallback(msk) || !msk->first)
+               return;
+
+       /* Need to ack a DATA_FIN received from a peer while this side
+        * of the connection is in ESTABLISHED, FIN_WAIT1, or FIN_WAIT2.
+        * msk->rcv_data_fin was set when parsing the incoming options
+        * at the subflow level and the msk lock was not held, so this
+        * is the first opportunity to act on the DATA_FIN and change
+        * the msk state.
+        *
+        * If we are caught up to the sequence number of the incoming
+        * DATA_FIN, send the DATA_ACK now and do state transition.  If
+        * not caught up, do nothing and let the recv code send DATA_ACK
+        * when catching up.
+        */
+
+       if (mptcp_pending_data_fin(sk, &rcv_data_fin_seq)) {
+               struct mptcp_subflow_context *subflow;
+
+               msk->ack_seq++;
+               WRITE_ONCE(msk->rcv_data_fin, 0);
+
+               sk->sk_shutdown |= RCV_SHUTDOWN;
+               smp_mb__before_atomic(); /* SHUTDOWN must be visible first */
+               set_bit(MPTCP_DATA_READY, &msk->flags);
+
+               switch (sk->sk_state) {
+               case TCP_ESTABLISHED:
+                       inet_sk_state_store(sk, TCP_CLOSE_WAIT);
+                       break;
+               case TCP_FIN_WAIT1:
+                       inet_sk_state_store(sk, TCP_CLOSING);
+                       break;
+               case TCP_FIN_WAIT2:
+                       inet_sk_state_store(sk, TCP_CLOSE);
+                       // @@ Close subflows now?
+                       break;
+               default:
+                       /* Other states not expected */
+                       WARN_ON_ONCE(1);
+                       break;
+               }
+
+               mptcp_set_timeout(sk, NULL);
+               mptcp_for_each_subflow(msk, subflow) {
+                       struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
+
+                       lock_sock(ssk);
+                       tcp_send_ack(ssk);
+                       release_sock(ssk);
+               }
+
+               sk->sk_state_change(sk);
+
+               if (sk->sk_shutdown == SHUTDOWN_MASK ||
+                   sk->sk_state == TCP_CLOSE)
+                       sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_HUP);
+               else
+                       sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
+       }
+}
+
 static bool __mptcp_move_skbs_from_subflow(struct mptcp_sock *msk,
                                           struct sock *ssk,
                                           unsigned int *bytes)
@@ -207,13 +321,6 @@ static bool __mptcp_move_skbs_from_subflow(struct mptcp_sock *msk,
                return false;
        }
 
-       if (!(sk->sk_userlocks & SOCK_RCVBUF_LOCK)) {
-               int rcvbuf = max(ssk->sk_rcvbuf, sk->sk_rcvbuf);
-
-               if (rcvbuf > sk->sk_rcvbuf)
-                       sk->sk_rcvbuf = rcvbuf;
-       }
-
        tp = tcp_sk(ssk);
        do {
                u32 map_remaining, offset;
@@ -229,6 +336,15 @@ static bool __mptcp_move_skbs_from_subflow(struct mptcp_sock *msk,
                if (!skb)
                        break;
 
+               if (__mptcp_check_fallback(msk)) {
+                       /* if we are running under the workqueue, TCP could have
+                        * collapsed skbs between dummy map creation and now
+                        * be sure to adjust the size
+                        */
+                       map_remaining = skb->len;
+                       subflow->map_data_len = skb->len;
+               }
+
                offset = seq - TCP_SKB_CB(skb)->seq;
                fin = TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN;
                if (fin) {
@@ -265,6 +381,15 @@ static bool __mptcp_move_skbs_from_subflow(struct mptcp_sock *msk,
 
        *bytes = moved;
 
+       /* If the moves have caught up with the DATA_FIN sequence number
+        * it's time to ack the DATA_FIN and change socket state, but
+        * this is not a good place to change state. Let the workqueue
+        * do it.
+        */
+       if (mptcp_pending_data_fin(sk, NULL) &&
+           schedule_work(&msk->work))
+               sock_hold(sk);
+
        return done;
 }
 
@@ -329,16 +454,6 @@ static void __mptcp_flush_join_list(struct mptcp_sock *msk)
        spin_unlock_bh(&msk->join_list_lock);
 }
 
-static void mptcp_set_timeout(const struct sock *sk, const struct sock *ssk)
-{
-       long tout = ssk && inet_csk(ssk)->icsk_pending ?
-                                     inet_csk(ssk)->icsk_timeout - jiffies : 0;
-
-       if (tout <= 0)
-               tout = mptcp_sk(sk)->timer_ival;
-       mptcp_sk(sk)->timer_ival = tout > 0 ? tout : TCP_RTO_MIN;
-}
-
 static bool mptcp_timer_pending(struct sock *sk)
 {
        return timer_pending(&inet_csk(sk)->icsk_retransmit_timer);
@@ -360,7 +475,8 @@ void mptcp_data_acked(struct sock *sk)
 {
        mptcp_reset_timer(sk);
 
-       if (!sk_stream_is_writeable(sk) &&
+       if ((!sk_stream_is_writeable(sk) ||
+            (inet_sk_state_load(sk) != TCP_ESTABLISHED)) &&
            schedule_work(&mptcp_sk(sk)->work))
                sock_hold(sk);
 }
@@ -395,14 +511,6 @@ static void mptcp_check_for_eof(struct mptcp_sock *msk)
        }
 }
 
-static void mptcp_stop_timer(struct sock *sk)
-{
-       struct inet_connection_sock *icsk = inet_csk(sk);
-
-       sk_stop_timer(sk, &icsk->icsk_retransmit_timer);
-       mptcp_sk(sk)->timer_ival = 0;
-}
-
 static bool mptcp_ext_cache_refill(struct mptcp_sock *msk)
 {
        const struct sock *sk = (const struct sock *)msk;
@@ -466,8 +574,15 @@ static void mptcp_clean_una(struct sock *sk)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
        struct mptcp_data_frag *dtmp, *dfrag;
-       u64 snd_una = atomic64_read(&msk->snd_una);
        bool cleaned = false;
+       u64 snd_una;
+
+       /* on fallback we just need to ignore snd_una, as this is really
+        * plain TCP
+        */
+       if (__mptcp_check_fallback(msk))
+               atomic64_set(&msk->snd_una, msk->write_seq);
+       snd_una = atomic64_read(&msk->snd_una);
 
        list_for_each_entry_safe(dfrag, dtmp, &msk->rtx_queue, list) {
                if (after64(dfrag->data_seq + dfrag->data_len, snd_una))
@@ -479,15 +594,20 @@ static void mptcp_clean_una(struct sock *sk)
 
        dfrag = mptcp_rtx_head(sk);
        if (dfrag && after64(snd_una, dfrag->data_seq)) {
-               u64 delta = dfrag->data_seq + dfrag->data_len - snd_una;
+               u64 delta = snd_una - dfrag->data_seq;
+
+               if (WARN_ON_ONCE(delta > dfrag->data_len))
+                       goto out;
 
                dfrag->data_seq += delta;
+               dfrag->offset += delta;
                dfrag->data_len -= delta;
 
                dfrag_uncharge(sk, delta);
                cleaned = true;
        }
 
+out:
        if (cleaned) {
                sk_mem_reclaim_partial(sk);
 
@@ -605,8 +725,10 @@ static int mptcp_sendmsg_frag(struct sock *sk, struct sock *ssk,
                if (!psize)
                        return -EINVAL;
 
-               if (!sk_wmem_schedule(sk, psize + dfrag->overhead))
+               if (!sk_wmem_schedule(sk, psize + dfrag->overhead)) {
+                       iov_iter_revert(&msg->msg_iter, psize);
                        return -ENOMEM;
+               }
        } else {
                offset = dfrag->offset;
                psize = min_t(size_t, dfrag->data_len, avail_size);
@@ -617,8 +739,11 @@ static int mptcp_sendmsg_frag(struct sock *sk, struct sock *ssk,
         */
        ret = do_tcp_sendpages(ssk, page, offset, psize,
                               msg->msg_flags | MSG_SENDPAGE_NOTLAST | MSG_DONTWAIT);
-       if (ret <= 0)
+       if (ret <= 0) {
+               if (!retransmission)
+                       iov_iter_revert(&msg->msg_iter, psize);
                return ret;
+       }
 
        frag_truesize += ret;
        if (!retransmission) {
@@ -673,7 +798,7 @@ static int mptcp_sendmsg_frag(struct sock *sk, struct sock *ssk,
 out:
        if (!retransmission)
                pfrag->offset += frag_truesize;
-       *write_seq += ret;
+       WRITE_ONCE(*write_seq, *write_seq + ret);
        mptcp_subflow_ctx(ssk)->rel_write_seq += ret;
 
        return ret;
@@ -740,7 +865,6 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
        int mss_now = 0, size_goal = 0, ret = 0;
        struct mptcp_sock *msk = mptcp_sk(sk);
        struct page_frag *pfrag;
-       struct socket *ssock;
        size_t copied = 0;
        struct sock *ssk;
        bool tx_ok;
@@ -759,19 +883,15 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
                        goto out;
        }
 
-fallback:
-       ssock = __mptcp_tcp_fallback(msk);
-       if (unlikely(ssock)) {
-               release_sock(sk);
-               pr_debug("fallback passthrough");
-               ret = sock_sendmsg(ssock, msg);
-               return ret >= 0 ? ret + copied : (copied ? copied : ret);
-       }
-
        pfrag = sk_page_frag(sk);
 restart:
        mptcp_clean_una(sk);
 
+       if (sk->sk_err || (sk->sk_shutdown & SEND_SHUTDOWN)) {
+               ret = -EPIPE;
+               goto out;
+       }
+
 wait_for_sndbuf:
        __mptcp_flush_join_list(msk);
        ssk = mptcp_subflow_get_send(msk);
@@ -819,17 +939,6 @@ wait_for_sndbuf:
                        }
                        break;
                }
-               if (ret == 0 && unlikely(__mptcp_needs_tcp_fallback(msk))) {
-                       /* Can happen for passive sockets:
-                        * 3WHS negotiated MPTCP, but first packet after is
-                        * plain TCP (e.g. due to middlebox filtering unknown
-                        * options).
-                        *
-                        * Fall back to TCP.
-                        */
-                       release_sock(ssk);
-                       goto fallback;
-               }
 
                copied += ret;
 
@@ -880,7 +989,6 @@ wait_for_sndbuf:
 
        mptcp_set_timeout(sk, ssk);
        if (copied) {
-               ret = copied;
                tcp_push(ssk, msg->msg_flags, mss_now, tcp_sk(ssk)->nonagle,
                         size_goal);
 
@@ -893,7 +1001,7 @@ wait_for_sndbuf:
        release_sock(ssk);
 out:
        release_sock(sk);
-       return ret;
+       return copied ? : ret;
 }
 
 static void mptcp_wait_data(struct sock *sk, long *timeo)
@@ -949,6 +1057,100 @@ static int __mptcp_recvmsg_mskq(struct mptcp_sock *msk,
        return copied;
 }
 
+/* receive buffer autotuning.  See tcp_rcv_space_adjust for more information.
+ *
+ * Only difference: Use highest rtt estimate of the subflows in use.
+ */
+static void mptcp_rcv_space_adjust(struct mptcp_sock *msk, int copied)
+{
+       struct mptcp_subflow_context *subflow;
+       struct sock *sk = (struct sock *)msk;
+       u32 time, advmss = 1;
+       u64 rtt_us, mstamp;
+
+       sock_owned_by_me(sk);
+
+       if (copied <= 0)
+               return;
+
+       msk->rcvq_space.copied += copied;
+
+       mstamp = div_u64(tcp_clock_ns(), NSEC_PER_USEC);
+       time = tcp_stamp_us_delta(mstamp, msk->rcvq_space.time);
+
+       rtt_us = msk->rcvq_space.rtt_us;
+       if (rtt_us && time < (rtt_us >> 3))
+               return;
+
+       rtt_us = 0;
+       mptcp_for_each_subflow(msk, subflow) {
+               const struct tcp_sock *tp;
+               u64 sf_rtt_us;
+               u32 sf_advmss;
+
+               tp = tcp_sk(mptcp_subflow_tcp_sock(subflow));
+
+               sf_rtt_us = READ_ONCE(tp->rcv_rtt_est.rtt_us);
+               sf_advmss = READ_ONCE(tp->advmss);
+
+               rtt_us = max(sf_rtt_us, rtt_us);
+               advmss = max(sf_advmss, advmss);
+       }
+
+       msk->rcvq_space.rtt_us = rtt_us;
+       if (time < (rtt_us >> 3) || rtt_us == 0)
+               return;
+
+       if (msk->rcvq_space.copied <= msk->rcvq_space.space)
+               goto new_measure;
+
+       if (sock_net(sk)->ipv4.sysctl_tcp_moderate_rcvbuf &&
+           !(sk->sk_userlocks & SOCK_RCVBUF_LOCK)) {
+               int rcvmem, rcvbuf;
+               u64 rcvwin, grow;
+
+               rcvwin = ((u64)msk->rcvq_space.copied << 1) + 16 * advmss;
+
+               grow = rcvwin * (msk->rcvq_space.copied - msk->rcvq_space.space);
+
+               do_div(grow, msk->rcvq_space.space);
+               rcvwin += (grow << 1);
+
+               rcvmem = SKB_TRUESIZE(advmss + MAX_TCP_HEADER);
+               while (tcp_win_from_space(sk, rcvmem) < advmss)
+                       rcvmem += 128;
+
+               do_div(rcvwin, advmss);
+               rcvbuf = min_t(u64, rcvwin * rcvmem,
+                              sock_net(sk)->ipv4.sysctl_tcp_rmem[2]);
+
+               if (rcvbuf > sk->sk_rcvbuf) {
+                       u32 window_clamp;
+
+                       window_clamp = tcp_win_from_space(sk, rcvbuf);
+                       WRITE_ONCE(sk->sk_rcvbuf, rcvbuf);
+
+                       /* Make subflows follow along.  If we do not do this, we
+                        * get drops at subflow level if skbs can't be moved to
+                        * the mptcp rx queue fast enough (announced rcv_win can
+                        * exceed ssk->sk_rcvbuf).
+                        */
+                       mptcp_for_each_subflow(msk, subflow) {
+                               struct sock *ssk;
+
+                               ssk = mptcp_subflow_tcp_sock(subflow);
+                               WRITE_ONCE(ssk->sk_rcvbuf, rcvbuf);
+                               tcp_sk(ssk)->window_clamp = window_clamp;
+                       }
+               }
+       }
+
+       msk->rcvq_space.space = msk->rcvq_space.copied;
+new_measure:
+       msk->rcvq_space.copied = 0;
+       msk->rcvq_space.time = mstamp;
+}
+
 static bool __mptcp_move_skbs(struct mptcp_sock *msk)
 {
        unsigned int moved = 0;
@@ -972,7 +1174,6 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
                         int nonblock, int flags, int *addr_len)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
-       struct socket *ssock;
        int copied = 0;
        int target;
        long timeo;
@@ -981,16 +1182,6 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
                return -EOPNOTSUPP;
 
        lock_sock(sk);
-       ssock = __mptcp_tcp_fallback(msk);
-       if (unlikely(ssock)) {
-fallback:
-               release_sock(sk);
-               pr_debug("fallback-read subflow=%p",
-                        mptcp_subflow_ctx(ssock->sk));
-               copied = sock_recvmsg(ssock, msg, flags);
-               return copied;
-       }
-
        timeo = sock_rcvtimeo(sk, nonblock);
 
        len = min_t(size_t, len, INT_MAX);
@@ -1056,9 +1247,6 @@ fallback:
 
                pr_debug("block timeout %ld", timeo);
                mptcp_wait_data(sk, &timeo);
-               ssock = __mptcp_tcp_fallback(msk);
-               if (unlikely(ssock))
-                       goto fallback;
        }
 
        if (skb_queue_empty(&sk->sk_receive_queue)) {
@@ -1075,6 +1263,8 @@ fallback:
                set_bit(MPTCP_DATA_READY, &msk->flags);
        }
 out_err:
+       mptcp_rcv_space_adjust(msk, copied);
+
        release_sock(sk);
        return copied;
 }
@@ -1083,7 +1273,7 @@ static void mptcp_retransmit_handler(struct sock *sk)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
 
-       if (atomic64_read(&msk->snd_una) == msk->write_seq) {
+       if (atomic64_read(&msk->snd_una) == READ_ONCE(msk->write_seq)) {
                mptcp_stop_timer(sk);
        } else {
                set_bit(MPTCP_WORK_RTX, &msk->flags);
@@ -1172,6 +1362,29 @@ static unsigned int mptcp_sync_mss(struct sock *sk, u32 pmtu)
        return 0;
 }
 
+static void pm_work(struct mptcp_sock *msk)
+{
+       struct mptcp_pm_data *pm = &msk->pm;
+
+       spin_lock_bh(&msk->pm.lock);
+
+       pr_debug("msk=%p status=%x", msk, pm->status);
+       if (pm->status & BIT(MPTCP_PM_ADD_ADDR_RECEIVED)) {
+               pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_RECEIVED);
+               mptcp_pm_nl_add_addr_received(msk);
+       }
+       if (pm->status & BIT(MPTCP_PM_ESTABLISHED)) {
+               pm->status &= ~BIT(MPTCP_PM_ESTABLISHED);
+               mptcp_pm_nl_fully_established(msk);
+       }
+       if (pm->status & BIT(MPTCP_PM_SUBFLOW_ESTABLISHED)) {
+               pm->status &= ~BIT(MPTCP_PM_SUBFLOW_ESTABLISHED);
+               mptcp_pm_nl_subflow_established(msk);
+       }
+
+       spin_unlock_bh(&msk->pm.lock);
+}
+
 static void mptcp_worker(struct work_struct *work)
 {
        struct mptcp_sock *msk = container_of(work, struct mptcp_sock, work);
@@ -1180,17 +1393,25 @@ static void mptcp_worker(struct work_struct *work)
        struct mptcp_data_frag *dfrag;
        u64 orig_write_seq;
        size_t copied = 0;
-       struct msghdr msg;
+       struct msghdr msg = {
+               .msg_flags = MSG_DONTWAIT,
+       };
        long timeo = 0;
 
        lock_sock(sk);
        mptcp_clean_una(sk);
+       mptcp_check_data_fin_ack(sk);
        __mptcp_flush_join_list(msk);
        __mptcp_move_skbs(msk);
 
+       if (msk->pm.status)
+               pm_work(msk);
+
        if (test_and_clear_bit(MPTCP_WORK_EOF, &msk->flags))
                mptcp_check_for_eof(msk);
 
+       mptcp_check_data_fin(sk);
+
        if (!test_and_clear_bit(MPTCP_WORK_RTX, &msk->flags))
                goto unlock;
 
@@ -1207,7 +1428,6 @@ static void mptcp_worker(struct work_struct *work)
 
        lock_sock(ssk);
 
-       msg.msg_flags = MSG_DONTWAIT;
        orig_len = dfrag->data_len;
        orig_offset = dfrag->offset;
        orig_write_seq = dfrag->data_seq;
@@ -1283,7 +1503,12 @@ static int mptcp_init_sock(struct sock *sk)
        if (ret)
                return ret;
 
+       ret = __mptcp_socket_create(mptcp_sk(sk));
+       if (ret)
+               return ret;
+
        sk_sockets_allocated_inc(sk);
+       sk->sk_rcvbuf = sock_net(sk)->ipv4.sysctl_tcp_rmem[1];
        sk->sk_sndbuf = sock_net(sk)->ipv4.sysctl_tcp_wmem[2];
 
        return 0;
@@ -1308,8 +1533,7 @@ static void mptcp_cancel_work(struct sock *sk)
                sock_put(sk);
 }
 
-static void mptcp_subflow_shutdown(struct sock *ssk, int how,
-                                  bool data_fin_tx_enable, u64 data_fin_tx_seq)
+static void mptcp_subflow_shutdown(struct sock *sk, struct sock *ssk, int how)
 {
        lock_sock(ssk);
 
@@ -1322,36 +1546,84 @@ static void mptcp_subflow_shutdown(struct sock *ssk, int how,
                tcp_disconnect(ssk, O_NONBLOCK);
                break;
        default:
-               if (data_fin_tx_enable) {
-                       struct mptcp_subflow_context *subflow;
-
-                       subflow = mptcp_subflow_ctx(ssk);
-                       subflow->data_fin_tx_seq = data_fin_tx_seq;
-                       subflow->data_fin_tx_enable = 1;
+               if (__mptcp_check_fallback(mptcp_sk(sk))) {
+                       pr_debug("Fallback");
+                       ssk->sk_shutdown |= how;
+                       tcp_shutdown(ssk, how);
+               } else {
+                       pr_debug("Sending DATA_FIN on subflow %p", ssk);
+                       mptcp_set_timeout(sk, ssk);
+                       tcp_send_ack(ssk);
                }
-
-               ssk->sk_shutdown |= how;
-               tcp_shutdown(ssk, how);
                break;
        }
 
-       /* Wake up anyone sleeping in poll. */
-       ssk->sk_state_change(ssk);
        release_sock(ssk);
 }
 
-/* Called with msk lock held, releases such lock before returning */
+static const unsigned char new_state[16] = {
+       /* current state:     new state:      action:   */
+       [0 /* (Invalid) */] = TCP_CLOSE,
+       [TCP_ESTABLISHED]   = TCP_FIN_WAIT1 | TCP_ACTION_FIN,
+       [TCP_SYN_SENT]      = TCP_CLOSE,
+       [TCP_SYN_RECV]      = TCP_FIN_WAIT1 | TCP_ACTION_FIN,
+       [TCP_FIN_WAIT1]     = TCP_FIN_WAIT1,
+       [TCP_FIN_WAIT2]     = TCP_FIN_WAIT2,
+       [TCP_TIME_WAIT]     = TCP_CLOSE,        /* should not happen ! */
+       [TCP_CLOSE]         = TCP_CLOSE,
+       [TCP_CLOSE_WAIT]    = TCP_LAST_ACK  | TCP_ACTION_FIN,
+       [TCP_LAST_ACK]      = TCP_LAST_ACK,
+       [TCP_LISTEN]        = TCP_CLOSE,
+       [TCP_CLOSING]       = TCP_CLOSING,
+       [TCP_NEW_SYN_RECV]  = TCP_CLOSE,        /* should not happen ! */
+};
+
+static int mptcp_close_state(struct sock *sk)
+{
+       int next = (int)new_state[sk->sk_state];
+       int ns = next & TCP_STATE_MASK;
+
+       inet_sk_state_store(sk, ns);
+
+       return next & TCP_ACTION_FIN;
+}
+
 static void mptcp_close(struct sock *sk, long timeout)
 {
        struct mptcp_subflow_context *subflow, *tmp;
        struct mptcp_sock *msk = mptcp_sk(sk);
        LIST_HEAD(conn_list);
-       u64 data_fin_tx_seq;
 
        lock_sock(sk);
+       sk->sk_shutdown = SHUTDOWN_MASK;
+
+       if (sk->sk_state == TCP_LISTEN) {
+               inet_sk_state_store(sk, TCP_CLOSE);
+               goto cleanup;
+       } else if (sk->sk_state == TCP_CLOSE) {
+               goto cleanup;
+       }
+
+       if (__mptcp_check_fallback(msk)) {
+               goto update_state;
+       } else if (mptcp_close_state(sk)) {
+               pr_debug("Sending DATA_FIN sk=%p", sk);
+               WRITE_ONCE(msk->write_seq, msk->write_seq + 1);
+               WRITE_ONCE(msk->snd_data_fin_enable, 1);
+
+               mptcp_for_each_subflow(msk, subflow) {
+                       struct sock *tcp_sk = mptcp_subflow_tcp_sock(subflow);
+
+                       mptcp_subflow_shutdown(sk, tcp_sk, SHUTDOWN_MASK);
+               }
+       }
 
+       sk_stream_wait_close(sk, timeout);
+
+update_state:
        inet_sk_state_store(sk, TCP_CLOSE);
 
+cleanup:
        /* be sure to always acquire the join list lock, to sync vs
         * mptcp_finish_join().
         */
@@ -1360,22 +1632,16 @@ static void mptcp_close(struct sock *sk, long timeout)
        spin_unlock_bh(&msk->join_list_lock);
        list_splice_init(&msk->conn_list, &conn_list);
 
-       data_fin_tx_seq = msk->write_seq;
-
        __mptcp_clear_xmit(sk);
 
        release_sock(sk);
 
        list_for_each_entry_safe(subflow, tmp, &conn_list, node) {
                struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
-
-               subflow->data_fin_tx_seq = data_fin_tx_seq;
-               subflow->data_fin_tx_enable = 1;
                __mptcp_close_ssk(sk, ssk, subflow, timeout);
        }
 
        mptcp_cancel_work(sk);
-       mptcp_pm_close(msk);
 
        __skb_queue_purge(&sk->sk_receive_queue);
 
@@ -1447,20 +1713,7 @@ struct sock *mptcp_sk_clone(const struct sock *sk,
        msk->local_key = subflow_req->local_key;
        msk->token = subflow_req->token;
        msk->subflow = NULL;
-
-       if (unlikely(mptcp_token_new_accept(subflow_req->token, nsk))) {
-               nsk->sk_state = TCP_CLOSE;
-               bh_unlock_sock(nsk);
-
-               /* we can't call into mptcp_close() here - possible BH context
-                * free the sock directly.
-                * sk_clone_lock() sets nsk refcnt to two, hence call sk_free()
-                * too.
-                */
-               sk_common_release(nsk);
-               sk_free(nsk);
-               return NULL;
-       }
+       WRITE_ONCE(msk->fully_established, false);
 
        msk->write_seq = subflow_req->idsn + 1;
        atomic64_set(&msk->snd_una, msk->write_seq);
@@ -1482,6 +1735,22 @@ struct sock *mptcp_sk_clone(const struct sock *sk,
        return nsk;
 }
 
+void mptcp_rcv_space_init(struct mptcp_sock *msk, const struct sock *ssk)
+{
+       const struct tcp_sock *tp = tcp_sk(ssk);
+
+       msk->rcvq_space.copied = 0;
+       msk->rcvq_space.rtt_us = 0;
+
+       msk->rcvq_space.time = tp->tcp_mstamp;
+
+       /* initial rcv_space offering made to peer */
+       msk->rcvq_space.space = min_t(u32, tp->rcv_wnd,
+                                     TCP_INIT_CWND * tp->advmss);
+       if (msk->rcvq_space.space == 0)
+               msk->rcvq_space.space = TCP_INIT_CWND * TCP_MSS_DEFAULT;
+}
+
 static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
                                 bool kern)
 {
@@ -1501,7 +1770,6 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
                return NULL;
 
        pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk));
-
        if (sk_is_mptcp(newsk)) {
                struct mptcp_subflow_context *subflow;
                struct sock *new_mptcp_sock;
@@ -1529,8 +1797,8 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
                newsk = new_mptcp_sock;
                mptcp_copy_inaddrs(newsk, ssk);
                list_add(&subflow->node, &msk->conn_list);
-               inet_sk_state_store(newsk, TCP_ESTABLISHED);
 
+               mptcp_rcv_space_init(msk, ssk);
                bh_unlock_sock(new_mptcp_sock);
 
                __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPCAPABLEPASSIVEACK);
@@ -1547,21 +1815,82 @@ static void mptcp_destroy(struct sock *sk)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
 
-       mptcp_token_destroy(msk->token);
+       mptcp_token_destroy(msk);
        if (msk->cached_ext)
                __skb_ext_put(msk->cached_ext);
 
        sk_sockets_allocated_dec(sk);
 }
 
+static int mptcp_setsockopt_sol_socket(struct mptcp_sock *msk, int optname,
+                                      sockptr_t optval, unsigned int optlen)
+{
+       struct sock *sk = (struct sock *)msk;
+       struct socket *ssock;
+       int ret;
+
+       switch (optname) {
+       case SO_REUSEPORT:
+       case SO_REUSEADDR:
+               lock_sock(sk);
+               ssock = __mptcp_nmpc_socket(msk);
+               if (!ssock) {
+                       release_sock(sk);
+                       return -EINVAL;
+               }
+
+               ret = sock_setsockopt(ssock, SOL_SOCKET, optname, optval, optlen);
+               if (ret == 0) {
+                       if (optname == SO_REUSEPORT)
+                               sk->sk_reuseport = ssock->sk->sk_reuseport;
+                       else if (optname == SO_REUSEADDR)
+                               sk->sk_reuse = ssock->sk->sk_reuse;
+               }
+               release_sock(sk);
+               return ret;
+       }
+
+       return sock_setsockopt(sk->sk_socket, SOL_SOCKET, optname, optval, optlen);
+}
+
+static int mptcp_setsockopt_v6(struct mptcp_sock *msk, int optname,
+                              sockptr_t optval, unsigned int optlen)
+{
+       struct sock *sk = (struct sock *)msk;
+       int ret = -EOPNOTSUPP;
+       struct socket *ssock;
+
+       switch (optname) {
+       case IPV6_V6ONLY:
+               lock_sock(sk);
+               ssock = __mptcp_nmpc_socket(msk);
+               if (!ssock) {
+                       release_sock(sk);
+                       return -EINVAL;
+               }
+
+               ret = tcp_setsockopt(ssock->sk, SOL_IPV6, optname, optval, optlen);
+               if (ret == 0)
+                       sk->sk_ipv6only = ssock->sk->sk_ipv6only;
+
+               release_sock(sk);
+               break;
+       }
+
+       return ret;
+}
+
 static int mptcp_setsockopt(struct sock *sk, int level, int optname,
-                           char __user *optval, unsigned int optlen)
+                           sockptr_t optval, unsigned int optlen)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
-       struct socket *ssock;
+       struct sock *ssk;
 
        pr_debug("msk=%p", msk);
 
+       if (level == SOL_SOCKET)
+               return mptcp_setsockopt_sol_socket(msk, optname, optval, optlen);
+
        /* @@ the meaning of setsockopt() when the socket is connected and
         * there are multiple subflows is not yet defined. It is up to the
         * MPTCP-level socket to configure the subflows until the subflow
@@ -1569,11 +1898,13 @@ static int mptcp_setsockopt(struct sock *sk, int level, int optname,
         * to the one remaining subflow.
         */
        lock_sock(sk);
-       ssock = __mptcp_tcp_fallback(msk);
+       ssk = __mptcp_tcp_fallback(msk);
        release_sock(sk);
-       if (ssock)
-               return tcp_setsockopt(ssock->sk, level, optname, optval,
-                                     optlen);
+       if (ssk)
+               return tcp_setsockopt(ssk, level, optname, optval, optlen);
+
+       if (level == SOL_IPV6)
+               return mptcp_setsockopt_v6(msk, optname, optval, optlen);
 
        return -EOPNOTSUPP;
 }
@@ -1582,7 +1913,7 @@ static int mptcp_getsockopt(struct sock *sk, int level, int optname,
                            char __user *optval, int __user *option)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
-       struct socket *ssock;
+       struct sock *ssk;
 
        pr_debug("msk=%p", msk);
 
@@ -1593,11 +1924,10 @@ static int mptcp_getsockopt(struct sock *sk, int level, int optname,
         * to the one remaining subflow.
         */
        lock_sock(sk);
-       ssock = __mptcp_tcp_fallback(msk);
+       ssk = __mptcp_tcp_fallback(msk);
        release_sock(sk);
-       if (ssock)
-               return tcp_getsockopt(ssock->sk, level, optname, optval,
-                                     option);
+       if (ssk)
+               return tcp_getsockopt(ssk, level, optname, optval, option);
 
        return -EOPNOTSUPP;
 }
@@ -1636,6 +1966,20 @@ static void mptcp_release_cb(struct sock *sk)
        }
 }
 
+static int mptcp_hash(struct sock *sk)
+{
+       /* should never be called,
+        * we hash the TCP subflows not the master socket
+        */
+       WARN_ON_ONCE(1);
+       return 0;
+}
+
+static void mptcp_unhash(struct sock *sk)
+{
+       /* called from sk_common_release(), but nothing to do here */
+}
+
 static int mptcp_get_port(struct sock *sk, unsigned short snum)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
@@ -1660,32 +2004,26 @@ void mptcp_finish_connect(struct sock *ssk)
        sk = subflow->conn;
        msk = mptcp_sk(sk);
 
-       if (!subflow->mp_capable) {
-               MPTCP_INC_STATS(sock_net(sk),
-                               MPTCP_MIB_MPCAPABLEACTIVEFALLBACK);
-               return;
-       }
-
        pr_debug("msk=%p, token=%u", sk, subflow->token);
 
        mptcp_crypto_key_sha(subflow->remote_key, NULL, &ack_seq);
        ack_seq++;
        subflow->map_seq = ack_seq;
        subflow->map_subflow_seq = 1;
-       subflow->rel_write_seq = 1;
 
        /* the socket is not connected yet, no msk/subflow ops can access/race
         * accessing the field below
         */
        WRITE_ONCE(msk->remote_key, subflow->remote_key);
        WRITE_ONCE(msk->local_key, subflow->local_key);
-       WRITE_ONCE(msk->token, subflow->token);
        WRITE_ONCE(msk->write_seq, subflow->idsn + 1);
        WRITE_ONCE(msk->ack_seq, ack_seq);
        WRITE_ONCE(msk->can_ack, 1);
        atomic64_set(&msk->snd_una, msk->write_seq);
 
        mptcp_pm_new_connection(msk, 0);
+
+       mptcp_rcv_space_init(msk, ssk);
 }
 
 static void mptcp_sock_graft(struct sock *sk, struct socket *parent)
@@ -1708,7 +2046,7 @@ bool mptcp_finish_join(struct sock *sk)
        pr_debug("msk=%p, subflow=%p", msk, subflow);
 
        /* mptcp socket already closing? */
-       if (inet_sk_state_load(parent) != TCP_ESTABLISHED)
+       if (!mptcp_is_fully_established(parent))
                return false;
 
        if (!msk->pm.server_side)
@@ -1761,8 +2099,8 @@ static struct proto mptcp_prot = {
        .sendmsg        = mptcp_sendmsg,
        .recvmsg        = mptcp_recvmsg,
        .release_cb     = mptcp_release_cb,
-       .hash           = inet_hash,
-       .unhash         = inet_unhash,
+       .hash           = mptcp_hash,
+       .unhash         = mptcp_unhash,
        .get_port       = mptcp_get_port,
        .sockets_allocated      = &mptcp_sockets_allocated,
        .memory_allocated       = &tcp_memory_allocated,
@@ -1771,6 +2109,7 @@ static struct proto mptcp_prot = {
        .sysctl_wmem_offset     = offsetof(struct net, ipv4.sysctl_tcp_wmem),
        .sysctl_mem     = sysctl_tcp_mem,
        .obj_size       = sizeof(struct mptcp_sock),
+       .slab_flags     = SLAB_TYPESAFE_BY_RCU,
        .no_autobind    = true,
 };
 
@@ -1781,9 +2120,9 @@ static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
        int err;
 
        lock_sock(sock->sk);
-       ssock = __mptcp_socket_create(msk, MPTCP_SAME_STATE);
-       if (IS_ERR(ssock)) {
-               err = PTR_ERR(ssock);
+       ssock = __mptcp_nmpc_socket(msk);
+       if (!ssock) {
+               err = -EINVAL;
                goto unlock;
        }
 
@@ -1796,10 +2135,18 @@ unlock:
        return err;
 }
 
+static void mptcp_subflow_early_fallback(struct mptcp_sock *msk,
+                                        struct mptcp_subflow_context *subflow)
+{
+       subflow->request_mptcp = 0;
+       __mptcp_do_fallback(msk);
+}
+
 static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
                                int addr_len, int flags)
 {
        struct mptcp_sock *msk = mptcp_sk(sock->sk);
+       struct mptcp_subflow_context *subflow;
        struct socket *ssock;
        int err;
 
@@ -1812,19 +2159,24 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
                goto do_connect;
        }
 
-       ssock = __mptcp_socket_create(msk, TCP_SYN_SENT);
-       if (IS_ERR(ssock)) {
-               err = PTR_ERR(ssock);
+       ssock = __mptcp_nmpc_socket(msk);
+       if (!ssock) {
+               err = -EINVAL;
                goto unlock;
        }
 
+       mptcp_token_destroy(msk);
+       inet_sk_state_store(sock->sk, TCP_SYN_SENT);
+       subflow = mptcp_subflow_ctx(ssock->sk);
 #ifdef CONFIG_TCP_MD5SIG
        /* no MPTCP if MD5SIG is enabled on this socket or we may run out of
         * TCP option space.
         */
        if (rcu_access_pointer(tcp_sk(ssock->sk)->md5sig_info))
-               mptcp_subflow_ctx(ssock->sk)->request_mptcp = 0;
+               mptcp_subflow_early_fallback(msk, subflow);
 #endif
+       if (subflow->request_mptcp && mptcp_token_new_connect(ssock->sk))
+               mptcp_subflow_early_fallback(msk, subflow);
 
 do_connect:
        err = ssock->ops->connect(ssock, uaddr, addr_len, flags);
@@ -1843,42 +2195,6 @@ unlock:
        return err;
 }
 
-static int mptcp_v4_getname(struct socket *sock, struct sockaddr *uaddr,
-                           int peer)
-{
-       if (sock->sk->sk_prot == &tcp_prot) {
-               /* we are being invoked from __sys_accept4, after
-                * mptcp_accept() has just accepted a non-mp-capable
-                * flow: sk is a tcp_sk, not an mptcp one.
-                *
-                * Hand the socket over to tcp so all further socket ops
-                * bypass mptcp.
-                */
-               sock->ops = &inet_stream_ops;
-       }
-
-       return inet_getname(sock, uaddr, peer);
-}
-
-#if IS_ENABLED(CONFIG_MPTCP_IPV6)
-static int mptcp_v6_getname(struct socket *sock, struct sockaddr *uaddr,
-                           int peer)
-{
-       if (sock->sk->sk_prot == &tcpv6_prot) {
-               /* we are being invoked from __sys_accept4 after
-                * mptcp_accept() has accepted a non-mp-capable
-                * subflow: sk is a tcp_sk, not mptcp.
-                *
-                * Hand the socket over to tcp so all further
-                * socket ops bypass mptcp.
-                */
-               sock->ops = &inet6_stream_ops;
-       }
-
-       return inet6_getname(sock, uaddr, peer);
-}
-#endif
-
 static int mptcp_listen(struct socket *sock, int backlog)
 {
        struct mptcp_sock *msk = mptcp_sk(sock->sk);
@@ -1888,12 +2204,14 @@ static int mptcp_listen(struct socket *sock, int backlog)
        pr_debug("msk=%p", msk);
 
        lock_sock(sock->sk);
-       ssock = __mptcp_socket_create(msk, TCP_LISTEN);
-       if (IS_ERR(ssock)) {
-               err = PTR_ERR(ssock);
+       ssock = __mptcp_nmpc_socket(msk);
+       if (!ssock) {
+               err = -EINVAL;
                goto unlock;
        }
 
+       mptcp_token_destroy(msk);
+       inet_sk_state_store(sock->sk, TCP_LISTEN);
        sock_set_flag(sock->sk, SOCK_RCU_FREE);
 
        err = ssock->ops->listen(ssock, backlog);
@@ -1906,15 +2224,6 @@ unlock:
        return err;
 }
 
-static bool is_tcp_proto(const struct proto *p)
-{
-#if IS_ENABLED(CONFIG_MPTCP_IPV6)
-       return p == &tcp_prot || p == &tcpv6_prot;
-#else
-       return p == &tcp_prot;
-#endif
-}
-
 static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
                               int flags, bool kern)
 {
@@ -1932,11 +2241,12 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
        if (!ssock)
                goto unlock_fail;
 
+       clear_bit(MPTCP_DATA_READY, &msk->flags);
        sock_hold(ssock->sk);
        release_sock(sock->sk);
 
        err = ssock->ops->accept(sock, newsock, flags, kern);
-       if (err == 0 && !is_tcp_proto(newsock->sk->sk_prot)) {
+       if (err == 0 && !mptcp_is_tcpsk(newsock->sk)) {
                struct mptcp_sock *msk = mptcp_sk(newsock->sk);
                struct mptcp_subflow_context *subflow;
 
@@ -1944,7 +2254,7 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
                 * This is needed so NOSPACE flag can be set from tcp stack.
                 */
                __mptcp_flush_join_list(msk);
-               list_for_each_entry(subflow, &msk->conn_list, node) {
+               mptcp_for_each_subflow(msk, subflow) {
                        struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
 
                        if (!ssk->sk_socket)
@@ -1952,6 +2262,8 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
                }
        }
 
+       if (inet_csk_listen_poll(ssock->sk))
+               set_bit(MPTCP_DATA_READY, &msk->flags);
        sock_put(ssock->sk);
        return err;
 
@@ -1960,39 +2272,36 @@ unlock_fail:
        return -EINVAL;
 }
 
+static __poll_t mptcp_check_readable(struct mptcp_sock *msk)
+{
+       return test_bit(MPTCP_DATA_READY, &msk->flags) ? EPOLLIN | EPOLLRDNORM :
+              0;
+}
+
 static __poll_t mptcp_poll(struct file *file, struct socket *sock,
                           struct poll_table_struct *wait)
 {
        struct sock *sk = sock->sk;
        struct mptcp_sock *msk;
-       struct socket *ssock;
        __poll_t mask = 0;
+       int state;
 
        msk = mptcp_sk(sk);
-       lock_sock(sk);
-       ssock = __mptcp_tcp_fallback(msk);
-       if (!ssock)
-               ssock = __mptcp_nmpc_socket(msk);
-       if (ssock) {
-               mask = ssock->ops->poll(file, ssock, wait);
-               release_sock(sk);
-               return mask;
-       }
-
-       release_sock(sk);
        sock_poll_wait(file, sock, wait);
-       lock_sock(sk);
 
-       if (test_bit(MPTCP_DATA_READY, &msk->flags))
-               mask = EPOLLIN | EPOLLRDNORM;
-       if (sk_stream_is_writeable(sk) &&
-           test_bit(MPTCP_SEND_SPACE, &msk->flags))
-               mask |= EPOLLOUT | EPOLLWRNORM;
+       state = inet_sk_state_load(sk);
+       if (state == TCP_LISTEN)
+               return mptcp_check_readable(msk);
+
+       if (state != TCP_SYN_SENT && state != TCP_SYN_RECV) {
+               mask |= mptcp_check_readable(msk);
+               if (sk_stream_is_writeable(sk) &&
+                   test_bit(MPTCP_SEND_SPACE, &msk->flags))
+                       mask |= EPOLLOUT | EPOLLWRNORM;
+       }
        if (sk->sk_shutdown & RCV_SHUTDOWN)
                mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP;
 
-       release_sock(sk);
-
        return mask;
 }
 
@@ -2000,23 +2309,13 @@ static int mptcp_shutdown(struct socket *sock, int how)
 {
        struct mptcp_sock *msk = mptcp_sk(sock->sk);
        struct mptcp_subflow_context *subflow;
-       struct socket *ssock;
        int ret = 0;
 
        pr_debug("sk=%p, how=%d", msk, how);
 
        lock_sock(sock->sk);
-       ssock = __mptcp_tcp_fallback(msk);
-       if (ssock) {
-               release_sock(sock->sk);
-               return inet_shutdown(ssock, how);
-       }
-
-       if (how == SHUT_WR || how == SHUT_RDWR)
-               inet_sk_state_store(sock->sk, TCP_FIN_WAIT1);
 
        how++;
-
        if ((how & ~SHUTDOWN_MASK) || !how) {
                ret = -EINVAL;
                goto out_unlock;
@@ -2030,13 +2329,36 @@ static int mptcp_shutdown(struct socket *sock, int how)
                        sock->state = SS_CONNECTED;
        }
 
-       __mptcp_flush_join_list(msk);
-       mptcp_for_each_subflow(msk, subflow) {
-               struct sock *tcp_sk = mptcp_subflow_tcp_sock(subflow);
+       /* If we've already sent a FIN, or it's a closed state, skip this. */
+       if (__mptcp_check_fallback(msk)) {
+               if (how == SHUT_WR || how == SHUT_RDWR)
+                       inet_sk_state_store(sock->sk, TCP_FIN_WAIT1);
 
-               mptcp_subflow_shutdown(tcp_sk, how, 1, msk->write_seq);
+               mptcp_for_each_subflow(msk, subflow) {
+                       struct sock *tcp_sk = mptcp_subflow_tcp_sock(subflow);
+
+                       mptcp_subflow_shutdown(sock->sk, tcp_sk, how);
+               }
+       } else if ((how & SEND_SHUTDOWN) &&
+                  ((1 << sock->sk->sk_state) &
+                   (TCPF_ESTABLISHED | TCPF_SYN_SENT |
+                    TCPF_SYN_RECV | TCPF_CLOSE_WAIT)) &&
+                  mptcp_close_state(sock->sk)) {
+               __mptcp_flush_join_list(msk);
+
+               WRITE_ONCE(msk->write_seq, msk->write_seq + 1);
+               WRITE_ONCE(msk->snd_data_fin_enable, 1);
+
+               mptcp_for_each_subflow(msk, subflow) {
+                       struct sock *tcp_sk = mptcp_subflow_tcp_sock(subflow);
+
+                       mptcp_subflow_shutdown(sock->sk, tcp_sk, how);
+               }
        }
 
+       /* Wake up anyone sleeping in poll. */
+       sock->sk->sk_state_change(sock->sk);
+
 out_unlock:
        release_sock(sock->sk);
 
@@ -2051,7 +2373,7 @@ static const struct proto_ops mptcp_stream_ops = {
        .connect           = mptcp_stream_connect,
        .socketpair        = sock_no_socketpair,
        .accept            = mptcp_stream_accept,
-       .getname           = mptcp_v4_getname,
+       .getname           = inet_getname,
        .poll              = mptcp_poll,
        .ioctl             = inet_ioctl,
        .gettstamp         = sock_gettstamp,
@@ -2063,10 +2385,6 @@ static const struct proto_ops mptcp_stream_ops = {
        .recvmsg           = inet_recvmsg,
        .mmap              = sock_no_mmap,
        .sendpage          = inet_sendpage,
-#ifdef CONFIG_COMPAT
-       .compat_setsockopt = compat_sock_common_setsockopt,
-       .compat_getsockopt = compat_sock_common_getsockopt,
-#endif
 };
 
 static struct inet_protosw mptcp_protosw = {
@@ -2077,7 +2395,7 @@ static struct inet_protosw mptcp_protosw = {
        .flags          = INET_PROTOSW_ICSK,
 };
 
-void mptcp_proto_init(void)
+void __init mptcp_proto_init(void)
 {
        mptcp_prot.h.hashinfo = tcp_prot.h.hashinfo;
 
@@ -2086,6 +2404,7 @@ void mptcp_proto_init(void)
 
        mptcp_subflow_init();
        mptcp_pm_init();
+       mptcp_token_init();
 
        if (proto_register(&mptcp_prot, 1) != 0)
                panic("Failed to register MPTCP proto.\n");
@@ -2104,7 +2423,7 @@ static const struct proto_ops mptcp_v6_stream_ops = {
        .connect           = mptcp_stream_connect,
        .socketpair        = sock_no_socketpair,
        .accept            = mptcp_stream_accept,
-       .getname           = mptcp_v6_getname,
+       .getname           = inet6_getname,
        .poll              = mptcp_poll,
        .ioctl             = inet6_ioctl,
        .gettstamp         = sock_gettstamp,
@@ -2118,8 +2437,6 @@ static const struct proto_ops mptcp_v6_stream_ops = {
        .sendpage          = inet_sendpage,
 #ifdef CONFIG_COMPAT
        .compat_ioctl      = inet6_compat_ioctl,
-       .compat_setsockopt = compat_sock_common_setsockopt,
-       .compat_getsockopt = compat_sock_common_getsockopt,
 #endif
 };
 
@@ -2139,7 +2456,7 @@ static struct inet_protosw mptcp_v6_protosw = {
        .flags          = INET_PROTOSW_ICSK,
 };
 
-int mptcp_proto_v6_init(void)
+int __init mptcp_proto_v6_init(void)
 {
        int err;