mptcp: check for plain TCP sock at accept time
[linux-2.6-microblaze.git] / net / mptcp / protocol.c
index 84ae96b..dbeb6fe 100644 (file)
@@ -52,13 +52,10 @@ static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk)
        return msk->subflow;
 }
 
-static struct socket *mptcp_is_tcpsk(struct sock *sk)
+static bool mptcp_is_tcpsk(struct sock *sk)
 {
        struct socket *sock = sk->sk_socket;
 
-       if (sock->sk != sk)
-               return NULL;
-
        if (unlikely(sk->sk_prot == &tcp_prot)) {
                /* we are being invoked after mptcp_accept() has
                 * accepted a non-mp-capable flow: sk is a tcp_sk,
@@ -68,27 +65,21 @@ static struct socket *mptcp_is_tcpsk(struct sock *sk)
                 * bypass mptcp.
                 */
                sock->ops = &inet_stream_ops;
-               return sock;
+               return true;
 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
        } else if (unlikely(sk->sk_prot == &tcpv6_prot)) {
                sock->ops = &inet6_stream_ops;
-               return sock;
+               return true;
 #endif
        }
 
-       return NULL;
+       return false;
 }
 
 static struct socket *__mptcp_tcp_fallback(struct mptcp_sock *msk)
 {
-       struct socket *sock;
-
        sock_owned_by_me((const struct sock *)msk);
 
-       sock = mptcp_is_tcpsk((struct sock *)msk);
-       if (unlikely(sock))
-               return sock;
-
        if (likely(!__mptcp_check_fallback(msk)))
                return NULL;
 
@@ -1466,7 +1457,6 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
                return NULL;
 
        pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk));
-
        if (sk_is_mptcp(newsk)) {
                struct mptcp_subflow_context *subflow;
                struct sock *new_mptcp_sock;
@@ -1821,42 +1811,6 @@ unlock:
        return err;
 }
 
-static int mptcp_v4_getname(struct socket *sock, struct sockaddr *uaddr,
-                           int peer)
-{
-       if (sock->sk->sk_prot == &tcp_prot) {
-               /* we are being invoked from __sys_accept4, after
-                * mptcp_accept() has just accepted a non-mp-capable
-                * flow: sk is a tcp_sk, not an mptcp one.
-                *
-                * Hand the socket over to tcp so all further socket ops
-                * bypass mptcp.
-                */
-               sock->ops = &inet_stream_ops;
-       }
-
-       return inet_getname(sock, uaddr, peer);
-}
-
-#if IS_ENABLED(CONFIG_MPTCP_IPV6)
-static int mptcp_v6_getname(struct socket *sock, struct sockaddr *uaddr,
-                           int peer)
-{
-       if (sock->sk->sk_prot == &tcpv6_prot) {
-               /* we are being invoked from __sys_accept4 after
-                * mptcp_accept() has accepted a non-mp-capable
-                * subflow: sk is a tcp_sk, not mptcp.
-                *
-                * Hand the socket over to tcp so all further
-                * socket ops bypass mptcp.
-                */
-               sock->ops = &inet6_stream_ops;
-       }
-
-       return inet6_getname(sock, uaddr, peer);
-}
-#endif
-
 static int mptcp_listen(struct socket *sock, int backlog)
 {
        struct mptcp_sock *msk = mptcp_sk(sock->sk);
@@ -1885,15 +1839,6 @@ unlock:
        return err;
 }
 
-static bool is_tcp_proto(const struct proto *p)
-{
-#if IS_ENABLED(CONFIG_MPTCP_IPV6)
-       return p == &tcp_prot || p == &tcpv6_prot;
-#else
-       return p == &tcp_prot;
-#endif
-}
-
 static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
                               int flags, bool kern)
 {
@@ -1915,7 +1860,7 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
        release_sock(sock->sk);
 
        err = ssock->ops->accept(sock, newsock, flags, kern);
-       if (err == 0 && !is_tcp_proto(newsock->sk->sk_prot)) {
+       if (err == 0 && !mptcp_is_tcpsk(newsock->sk)) {
                struct mptcp_sock *msk = mptcp_sk(newsock->sk);
                struct mptcp_subflow_context *subflow;
 
@@ -2011,7 +1956,7 @@ static const struct proto_ops mptcp_stream_ops = {
        .connect           = mptcp_stream_connect,
        .socketpair        = sock_no_socketpair,
        .accept            = mptcp_stream_accept,
-       .getname           = mptcp_v4_getname,
+       .getname           = inet_getname,
        .poll              = mptcp_poll,
        .ioctl             = inet_ioctl,
        .gettstamp         = sock_gettstamp,
@@ -2065,7 +2010,7 @@ static const struct proto_ops mptcp_v6_stream_ops = {
        .connect           = mptcp_stream_connect,
        .socketpair        = sock_no_socketpair,
        .accept            = mptcp_stream_accept,
-       .getname           = mptcp_v6_getname,
+       .getname           = inet6_getname,
        .poll              = mptcp_poll,
        .ioctl             = inet6_ioctl,
        .gettstamp         = sock_gettstamp,