Merge tag 'devicetree-for-6.0' of git://git.kernel.org/pub/scm/linux/kernel/git/robh...
[linux-2.6-microblaze.git] / net / mptcp / pm_netlink.c
index 7c7395b..291b5da 100644 (file)
@@ -413,7 +413,7 @@ static bool lookup_address_in_vec(const struct mptcp_addr_info *addrs, unsigned
        int i;
 
        for (i = 0; i < nr; i++) {
-               if (mptcp_addresses_equal(&addrs[i], addr, addr->port))
+               if (addrs[i].id == addr->id)
                        return true;
        }
 
@@ -449,7 +449,8 @@ static unsigned int fill_remote_addresses_vec(struct mptcp_sock *msk, bool fullm
                mptcp_for_each_subflow(msk, subflow) {
                        ssk = mptcp_subflow_tcp_sock(subflow);
                        remote_address((struct sock_common *)ssk, &addrs[i]);
-                       if (deny_id0 && mptcp_addresses_equal(&addrs[i], &remote, false))
+                       addrs[i].id = subflow->remote_id;
+                       if (deny_id0 && !addrs[i].id)
                                continue;
 
                        if (!lookup_address_in_vec(addrs, i, &addrs[i]) &&
@@ -463,6 +464,37 @@ static unsigned int fill_remote_addresses_vec(struct mptcp_sock *msk, bool fullm
        return i;
 }
 
+static void __mptcp_pm_send_ack(struct mptcp_sock *msk, struct mptcp_subflow_context *subflow,
+                               bool prio, bool backup)
+{
+       struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
+       bool slow;
+
+       pr_debug("send ack for %s",
+                prio ? "mp_prio" : (mptcp_pm_should_add_signal(msk) ? "add_addr" : "rm_addr"));
+
+       slow = lock_sock_fast(ssk);
+       if (prio) {
+               if (subflow->backup != backup)
+                       msk->last_snd = NULL;
+
+               subflow->send_mp_prio = 1;
+               subflow->backup = backup;
+               subflow->request_bkup = backup;
+       }
+
+       __mptcp_subflow_send_ack(ssk);
+       unlock_sock_fast(ssk, slow);
+}
+
+static void mptcp_pm_send_ack(struct mptcp_sock *msk, struct mptcp_subflow_context *subflow,
+                             bool prio, bool backup)
+{
+       spin_unlock_bh(&msk->pm.lock);
+       __mptcp_pm_send_ack(msk, subflow, prio, backup);
+       spin_lock_bh(&msk->pm.lock);
+}
+
 static struct mptcp_pm_addr_entry *
 __lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id)
 {
@@ -482,30 +514,14 @@ __lookup_addr(struct pm_nl_pernet *pernet, const struct mptcp_addr_info *info,
        struct mptcp_pm_addr_entry *entry;
 
        list_for_each_entry(entry, &pernet->local_addr_list, list) {
-               if ((!lookup_by_id && mptcp_addresses_equal(&entry->addr, info, true)) ||
+               if ((!lookup_by_id &&
+                    mptcp_addresses_equal(&entry->addr, info, entry->addr.port)) ||
                    (lookup_by_id && entry->addr.id == info->id))
                        return entry;
        }
        return NULL;
 }
 
-static int
-lookup_id_by_addr(const struct pm_nl_pernet *pernet, const struct mptcp_addr_info *addr)
-{
-       const struct mptcp_pm_addr_entry *entry;
-       int ret = -1;
-
-       rcu_read_lock();
-       list_for_each_entry(entry, &pernet->local_addr_list, list) {
-               if (mptcp_addresses_equal(&entry->addr, addr, entry->addr.port)) {
-                       ret = entry->addr.id;
-                       break;
-               }
-       }
-       rcu_read_unlock();
-       return ret;
-}
-
 static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
 {
        struct sock *sk = (struct sock *)msk;
@@ -523,13 +539,23 @@ static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
 
        /* do lazy endpoint usage accounting for the MPC subflows */
        if (unlikely(!(msk->pm.status & BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED))) && msk->first) {
+               struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(msk->first);
+               struct mptcp_pm_addr_entry *entry;
                struct mptcp_addr_info mpc_addr;
-               int mpc_id;
+               bool backup = false;
 
                local_address((struct sock_common *)msk->first, &mpc_addr);
-               mpc_id = lookup_id_by_addr(pernet, &mpc_addr);
-               if (mpc_id >= 0)
-                       __clear_bit(mpc_id, msk->pm.id_avail_bitmap);
+               rcu_read_lock();
+               entry = __lookup_addr(pernet, &mpc_addr, false);
+               if (entry) {
+                       __clear_bit(entry->addr.id, msk->pm.id_avail_bitmap);
+                       msk->mpc_endpoint_id = entry->addr.id;
+                       backup = !!(entry->flags & MPTCP_PM_ADDR_FLAG_BACKUP);
+               }
+               rcu_read_unlock();
+
+               if (backup)
+                       mptcp_pm_send_ack(msk, subflow, true, backup);
 
                msk->pm.status |= BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED);
        }
@@ -705,16 +731,8 @@ void mptcp_pm_nl_addr_send_ack(struct mptcp_sock *msk)
                return;
 
        subflow = list_first_entry_or_null(&msk->conn_list, typeof(*subflow), node);
-       if (subflow) {
-               struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
-
-               spin_unlock_bh(&msk->pm.lock);
-               pr_debug("send ack for %s",
-                        mptcp_pm_should_add_signal(msk) ? "add_addr" : "rm_addr");
-
-               mptcp_subflow_send_ack(ssk);
-               spin_lock_bh(&msk->pm.lock);
-       }
+       if (subflow)
+               mptcp_pm_send_ack(msk, subflow, false, false);
 }
 
 int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk,
@@ -729,7 +747,6 @@ int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk,
        mptcp_for_each_subflow(msk, subflow) {
                struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
                struct mptcp_addr_info local, remote;
-               bool slow;
 
                local_address((struct sock_common *)ssk, &local);
                if (!mptcp_addresses_equal(&local, addr, addr->port))
@@ -741,23 +758,18 @@ int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk,
                                continue;
                }
 
-               slow = lock_sock_fast(ssk);
-               if (subflow->backup != bkup)
-                       msk->last_snd = NULL;
-               subflow->backup = bkup;
-               subflow->send_mp_prio = 1;
-               subflow->request_bkup = bkup;
-
-               pr_debug("send ack for mp_prio");
-               __mptcp_subflow_send_ack(ssk);
-               unlock_sock_fast(ssk, slow);
-
+               __mptcp_pm_send_ack(msk, subflow, true, bkup);
                return 0;
        }
 
        return -EINVAL;
 }
 
+static bool mptcp_local_id_match(const struct mptcp_sock *msk, u8 local_id, u8 id)
+{
+       return local_id == id || (!local_id && msk->mpc_endpoint_id == id);
+}
+
 static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk,
                                           const struct mptcp_rm_list *rm_list,
                                           enum linux_mptcp_mib_field rm_type)
@@ -781,6 +793,7 @@ static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk,
                return;
 
        for (i = 0; i < rm_list->nr; i++) {
+               u8 rm_id = rm_list->ids[i];
                bool removed = false;
 
                list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) {
@@ -788,15 +801,15 @@ static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk,
                        int how = RCV_SHUTDOWN | SEND_SHUTDOWN;
                        u8 id = subflow->local_id;
 
-                       if (rm_type == MPTCP_MIB_RMADDR)
-                               id = subflow->remote_id;
-
-                       if (rm_list->ids[i] != id)
+                       if (rm_type == MPTCP_MIB_RMADDR && subflow->remote_id != rm_id)
+                               continue;
+                       if (rm_type == MPTCP_MIB_RMSUBFLOW && !mptcp_local_id_match(msk, id, rm_id))
                                continue;
 
-                       pr_debug(" -> %s rm_list_ids[%d]=%u local_id=%u remote_id=%u",
+                       pr_debug(" -> %s rm_list_ids[%d]=%u local_id=%u remote_id=%u mpc_id=%u",
                                 rm_type == MPTCP_MIB_RMADDR ? "address" : "subflow",
-                                i, rm_list->ids[i], subflow->local_id, subflow->remote_id);
+                                i, rm_id, subflow->local_id, subflow->remote_id,
+                                msk->mpc_endpoint_id);
                        spin_unlock_bh(&msk->pm.lock);
                        mptcp_subflow_shutdown(sk, ssk, how);
 
@@ -808,7 +821,7 @@ static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk,
                        __MPTCP_INC_STATS(sock_net(sk), rm_type);
                }
                if (rm_type == MPTCP_MIB_RMSUBFLOW)
-                       __set_bit(rm_list->ids[i], msk->pm.id_avail_bitmap);
+                       __set_bit(rm_id ? rm_id : msk->mpc_endpoint_id, msk->pm.id_avail_bitmap);
                if (!removed)
                        continue;
 
@@ -907,10 +920,11 @@ static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet,
        /* do not insert duplicate address, differentiate on port only
         * singled addresses
         */
+       if (!address_use_port(entry))
+               entry->addr.port = 0;
        list_for_each_entry(cur, &pernet->local_addr_list, list) {
                if (mptcp_addresses_equal(&cur->addr, &entry->addr,
-                                         address_use_port(entry) &&
-                                         address_use_port(cur))) {
+                                         cur->addr.port || entry->addr.port)) {
                        /* allow replacing the exiting endpoint only if such
                         * endpoint is an implicit one and the user-space
                         * did not provide an endpoint id
@@ -956,7 +970,10 @@ find_next:
        }
 
        pernet->addrs++;
-       list_add_tail_rcu(&entry->list, &pernet->local_addr_list);
+       if (!entry->addr.port)
+               list_add_tail_rcu(&entry->list, &pernet->local_addr_list);
+       else
+               list_add_rcu(&entry->list, &pernet->local_addr_list);
        ret = entry->addr.id;
 
 out:
@@ -1134,7 +1151,7 @@ void mptcp_pm_nl_subflow_chk_stale(const struct mptcp_sock *msk, struct sock *ss
                        }
                        unlock_sock_fast(ssk, slow);
 
-                       /* always try to push the pending data regarless of re-injections:
+                       /* always try to push the pending data regardless of re-injections:
                         * we can possibly use backup subflows now, and subflow selection
                         * is cheap under the msk socket lock
                         */