mptcp: better msk-level shutdown.
authorPaolo Abeni <pabeni@redhat.com>
Tue, 12 Jan 2021 17:25:24 +0000 (18:25 +0100)
committerJakub Kicinski <kuba@kernel.org>
Wed, 13 Jan 2021 04:09:19 +0000 (20:09 -0800)
Instead of re-implementing most of inet_shutdown, re-use
such helper, and implement the MPTCP-specific bits at the
'proto' level.

The msk-level disconnect() can now be invoked, lets provide a
suitable implementation.

As a side effect, this fixes bad state management for listener
sockets. The latter could lead to division by 0 oops since
commit ea4ca586b16f ("mptcp: refine MPTCP-level ack scheduling").

Fixes: 43b54c6ee382 ("mptcp: Use full MPTCP-level disconnect state machine")
Fixes: ea4ca586b16f ("mptcp: refine MPTCP-level ack scheduling")
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
Reviewed-by: Mat Martineau <mathew.j.martineau@linux.intel.com>
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
net/mptcp/protocol.c

index 2ff8c7c..81faeff 100644 (file)
@@ -2642,11 +2642,12 @@ static void mptcp_copy_inaddrs(struct sock *msk, const struct sock *ssk)
 
 static int mptcp_disconnect(struct sock *sk, int flags)
 {
-       /* Should never be called.
-        * inet_stream_connect() calls ->disconnect, but that
-        * refers to the subflow socket, not the mptcp one.
-        */
-       WARN_ON_ONCE(1);
+       struct mptcp_subflow_context *subflow;
+       struct mptcp_sock *msk = mptcp_sk(sk);
+
+       __mptcp_flush_join_list(msk);
+       mptcp_for_each_subflow(msk, subflow)
+               tcp_disconnect(mptcp_subflow_tcp_sock(subflow), flags);
        return 0;
 }
 
@@ -3089,6 +3090,14 @@ bool mptcp_finish_join(struct sock *ssk)
        return true;
 }
 
+static void mptcp_shutdown(struct sock *sk, int how)
+{
+       pr_debug("sk=%p, how=%d", sk, how);
+
+       if ((how & SEND_SHUTDOWN) && mptcp_close_state(sk))
+               __mptcp_wr_shutdown(sk);
+}
+
 static struct proto mptcp_prot = {
        .name           = "MPTCP",
        .owner          = THIS_MODULE,
@@ -3098,7 +3107,7 @@ static struct proto mptcp_prot = {
        .accept         = mptcp_accept,
        .setsockopt     = mptcp_setsockopt,
        .getsockopt     = mptcp_getsockopt,
-       .shutdown       = tcp_shutdown,
+       .shutdown       = mptcp_shutdown,
        .destroy        = mptcp_destroy,
        .sendmsg        = mptcp_sendmsg,
        .recvmsg        = mptcp_recvmsg,
@@ -3344,43 +3353,6 @@ static __poll_t mptcp_poll(struct file *file, struct socket *sock,
        return mask;
 }
 
-static int mptcp_shutdown(struct socket *sock, int how)
-{
-       struct mptcp_sock *msk = mptcp_sk(sock->sk);
-       struct sock *sk = sock->sk;
-       int ret = 0;
-
-       pr_debug("sk=%p, how=%d", msk, how);
-
-       lock_sock(sk);
-
-       how++;
-       if ((how & ~SHUTDOWN_MASK) || !how) {
-               ret = -EINVAL;
-               goto out_unlock;
-       }
-
-       if (sock->state == SS_CONNECTING) {
-               if ((1 << sk->sk_state) &
-                   (TCPF_SYN_SENT | TCPF_SYN_RECV | TCPF_CLOSE))
-                       sock->state = SS_DISCONNECTING;
-               else
-                       sock->state = SS_CONNECTED;
-       }
-
-       sk->sk_shutdown |= how;
-       if ((how & SEND_SHUTDOWN) && mptcp_close_state(sk))
-               __mptcp_wr_shutdown(sk);
-
-       /* Wake up anyone sleeping in poll. */
-       sk->sk_state_change(sk);
-
-out_unlock:
-       release_sock(sk);
-
-       return ret;
-}
-
 static const struct proto_ops mptcp_stream_ops = {
        .family            = PF_INET,
        .owner             = THIS_MODULE,
@@ -3394,7 +3366,7 @@ static const struct proto_ops mptcp_stream_ops = {
        .ioctl             = inet_ioctl,
        .gettstamp         = sock_gettstamp,
        .listen            = mptcp_listen,
-       .shutdown          = mptcp_shutdown,
+       .shutdown          = inet_shutdown,
        .setsockopt        = sock_common_setsockopt,
        .getsockopt        = sock_common_getsockopt,
        .sendmsg           = inet_sendmsg,
@@ -3444,7 +3416,7 @@ static const struct proto_ops mptcp_v6_stream_ops = {
        .ioctl             = inet6_ioctl,
        .gettstamp         = sock_gettstamp,
        .listen            = mptcp_listen,
-       .shutdown          = mptcp_shutdown,
+       .shutdown          = inet_shutdown,
        .setsockopt        = sock_common_setsockopt,
        .getsockopt        = sock_common_getsockopt,
        .sendmsg           = inet6_sendmsg,