mptcp: cache msk on MP_JOIN init_req
[linux-2.6-microblaze.git] / net / mptcp / subflow.c
index bbdb74b..4068bdb 100644 (file)
@@ -69,6 +69,9 @@ static void subflow_req_destructor(struct request_sock *req)
 
        pr_debug("subflow_req=%p", subflow_req);
 
+       if (subflow_req->msk)
+               sock_put((struct sock *)subflow_req->msk);
+
        if (subflow_req->mp_capable)
                mptcp_token_destroy_request(subflow_req->token);
        tcp_request_sock_ops.destructor(req);
@@ -86,8 +89,8 @@ static void subflow_generate_hmac(u64 key1, u64 key2, u32 nonce1, u32 nonce2,
 }
 
 /* validate received token and create truncated hmac and nonce for SYN-ACK */
-static bool subflow_token_join_request(struct request_sock *req,
-                                      const struct sk_buff *skb)
+static struct mptcp_sock *subflow_token_join_request(struct request_sock *req,
+                                                    const struct sk_buff *skb)
 {
        struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req);
        u8 hmac[SHA256_DIGEST_SIZE];
@@ -97,13 +100,13 @@ static bool subflow_token_join_request(struct request_sock *req,
        msk = mptcp_token_get_sock(subflow_req->token);
        if (!msk) {
                SUBFLOW_REQ_INC_STATS(req, MPTCP_MIB_JOINNOTOKEN);
-               return false;
+               return NULL;
        }
 
        local_id = mptcp_pm_get_local_id(msk, (struct sock_common *)req);
        if (local_id < 0) {
                sock_put((struct sock *)msk);
-               return false;
+               return NULL;
        }
        subflow_req->local_id = local_id;
 
@@ -114,9 +117,7 @@ static bool subflow_token_join_request(struct request_sock *req,
                              subflow_req->remote_nonce, hmac);
 
        subflow_req->thmac = get_unaligned_be64(hmac);
-
-       sock_put((struct sock *)msk);
-       return true;
+       return msk;
 }
 
 static void subflow_init_req(struct request_sock *req,
@@ -133,6 +134,7 @@ static void subflow_init_req(struct request_sock *req,
 
        subflow_req->mp_capable = 0;
        subflow_req->mp_join = 0;
+       subflow_req->msk = NULL;
 
 #ifdef CONFIG_TCP_MD5SIG
        /* no MPTCP if MD5SIG is enabled on this socket or we may run out of
@@ -166,12 +168,9 @@ static void subflow_init_req(struct request_sock *req,
                subflow_req->remote_id = mp_opt.join_id;
                subflow_req->token = mp_opt.token;
                subflow_req->remote_nonce = mp_opt.nonce;
-               pr_debug("token=%u, remote_nonce=%u", subflow_req->token,
-                        subflow_req->remote_nonce);
-               if (!subflow_token_join_request(req, skb)) {
-                       subflow_req->mp_join = 0;
-                       // @@ need to trigger RST
-               }
+               subflow_req->msk = subflow_token_join_request(req, skb);
+               pr_debug("token=%u, remote_nonce=%u msk=%p", subflow_req->token,
+                        subflow_req->remote_nonce, subflow_req->msk);
        }
 }
 
@@ -354,10 +353,9 @@ static bool subflow_hmac_valid(const struct request_sock *req,
        const struct mptcp_subflow_request_sock *subflow_req;
        u8 hmac[SHA256_DIGEST_SIZE];
        struct mptcp_sock *msk;
-       bool ret;
 
        subflow_req = mptcp_subflow_rsk(req);
-       msk = mptcp_token_get_sock(subflow_req->token);
+       msk = subflow_req->msk;
        if (!msk)
                return false;
 
@@ -365,12 +363,7 @@ static bool subflow_hmac_valid(const struct request_sock *req,
                              subflow_req->remote_nonce,
                              subflow_req->local_nonce, hmac);
 
-       ret = true;
-       if (crypto_memneq(hmac, mp_opt->hmac, MPTCPOPT_HMAC_LEN))
-               ret = false;
-
-       sock_put((struct sock *)msk);
-       return ret;
+       return !crypto_memneq(hmac, mp_opt->hmac, MPTCPOPT_HMAC_LEN);
 }
 
 static void mptcp_sock_destruct(struct sock *sk)
@@ -522,10 +515,12 @@ create_child:
                } else if (ctx->mp_join) {
                        struct mptcp_sock *owner;
 
-                       owner = mptcp_token_get_sock(ctx->token);
+                       owner = subflow_req->msk;
                        if (!owner)
                                goto dispose_child;
 
+                       /* move the msk reference ownership to the subflow */
+                       subflow_req->msk = NULL;
                        ctx->conn = (struct sock *)owner;
                        if (!mptcp_finish_join(child))
                                goto dispose_child;