Merge tag 'for-5.13/parisc' of git://git.kernel.org/pub/scm/linux/kernel/git/deller...
[linux-2.6-microblaze.git] / net / mptcp / pm_netlink.c
index 5857b82..6ba0408 100644 (file)
@@ -25,6 +25,8 @@ static int pm_nl_pernet_id;
 struct mptcp_pm_addr_entry {
        struct list_head        list;
        struct mptcp_addr_info  addr;
+       u8                      flags;
+       int                     ifindex;
        struct rcu_head         rcu;
        struct socket           *lsk;
 };
@@ -56,8 +58,6 @@ struct pm_nl_pernet {
 #define MPTCP_PM_ADDR_MAX      8
 #define ADD_ADDR_RETRANS_MAX   3
 
-static void mptcp_pm_nl_add_addr_send_ack(struct mptcp_sock *msk);
-
 static bool addresses_equal(const struct mptcp_addr_info *a,
                            struct mptcp_addr_info *b, bool use_port)
 {
@@ -140,6 +140,24 @@ static bool lookup_subflow_by_saddr(const struct list_head *list,
        return false;
 }
 
+static bool lookup_subflow_by_daddr(const struct list_head *list,
+                                   struct mptcp_addr_info *daddr)
+{
+       struct mptcp_subflow_context *subflow;
+       struct mptcp_addr_info cur;
+       struct sock_common *skc;
+
+       list_for_each_entry(subflow, list, node) {
+               skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
+
+               remote_address(skc, &cur);
+               if (addresses_equal(&cur, daddr, daddr->port))
+                       return true;
+       }
+
+       return false;
+}
+
 static struct mptcp_pm_addr_entry *
 select_local_address(const struct pm_nl_pernet *pernet,
                     struct mptcp_sock *msk)
@@ -152,7 +170,7 @@ select_local_address(const struct pm_nl_pernet *pernet,
        rcu_read_lock();
        __mptcp_flush_join_list(msk);
        list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
-               if (!(entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW))
+               if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW))
                        continue;
 
                if (entry->addr.family != sk->sk_family) {
@@ -190,7 +208,7 @@ select_signal_address(struct pm_nl_pernet *pernet, unsigned int pos)
         * can lead to additional addresses not being announced.
         */
        list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
-               if (!(entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL))
+               if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL))
                        continue;
                if (i++ == pos) {
                        ret = entry;
@@ -245,9 +263,9 @@ static void check_work_pending(struct mptcp_sock *msk)
                WRITE_ONCE(msk->pm.work_pending, false);
 }
 
-static struct mptcp_pm_add_entry *
-lookup_anno_list_by_saddr(struct mptcp_sock *msk,
-                         struct mptcp_addr_info *addr)
+struct mptcp_pm_add_entry *
+mptcp_lookup_anno_list_by_saddr(struct mptcp_sock *msk,
+                               struct mptcp_addr_info *addr)
 {
        struct mptcp_pm_add_entry *entry;
 
@@ -308,7 +326,7 @@ static void mptcp_pm_add_timer(struct timer_list *timer)
 
        if (!mptcp_pm_should_add_signal(msk)) {
                pr_debug("retransmit ADD_ADDR id=%d", entry->addr.id);
-               mptcp_pm_announce_addr(msk, &entry->addr, false, entry->addr.port);
+               mptcp_pm_announce_addr(msk, &entry->addr, false);
                mptcp_pm_add_addr_send_ack(msk);
                entry->retrans_times++;
        }
@@ -319,6 +337,9 @@ static void mptcp_pm_add_timer(struct timer_list *timer)
 
        spin_unlock_bh(&msk->pm.lock);
 
+       if (entry->retrans_times == ADD_ADDR_RETRANS_MAX)
+               mptcp_pm_subflow_established(msk);
+
 out:
        __sock_put(sk);
 }
@@ -331,7 +352,7 @@ mptcp_pm_del_add_timer(struct mptcp_sock *msk,
        struct sock *sk = (struct sock *)msk;
 
        spin_lock_bh(&msk->pm.lock);
-       entry = lookup_anno_list_by_saddr(msk, addr);
+       entry = mptcp_lookup_anno_list_by_saddr(msk, addr);
        if (entry)
                entry->retrans_times = ADD_ADDR_RETRANS_MAX;
        spin_unlock_bh(&msk->pm.lock);
@@ -351,7 +372,7 @@ static bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk,
 
        lockdep_assert_held(&msk->pm.lock);
 
-       if (lookup_anno_list_by_saddr(msk, &entry->addr))
+       if (mptcp_lookup_anno_list_by_saddr(msk, &entry->addr))
                return false;
 
        add_entry = kmalloc(sizeof(*add_entry), GFP_ATOMIC);
@@ -417,8 +438,8 @@ static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
                if (local) {
                        if (mptcp_pm_alloc_anno_list(msk, local)) {
                                msk->pm.add_addr_signaled++;
-                               mptcp_pm_announce_addr(msk, &local->addr, false, local->addr.port);
-                               mptcp_pm_nl_add_addr_send_ack(msk);
+                               mptcp_pm_announce_addr(msk, &local->addr, false);
+                               mptcp_pm_nl_addr_send_ack(msk);
                        }
                } else {
                        /* pick failed, avoid fourther attempts later */
@@ -440,7 +461,8 @@ static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
                        check_work_pending(msk);
                        remote_address((struct sock_common *)sk, &remote);
                        spin_unlock_bh(&msk->pm.lock);
-                       __mptcp_subflow_connect(sk, &local->addr, &remote);
+                       __mptcp_subflow_connect(sk, &local->addr, &remote,
+                                               local->flags, local->ifindex);
                        spin_lock_bh(&msk->pm.lock);
                        return;
                }
@@ -468,7 +490,6 @@ static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk)
        struct mptcp_addr_info remote;
        struct mptcp_addr_info local;
        unsigned int subflows_max;
-       bool use_port = false;
 
        add_addr_accept_max = mptcp_pm_get_add_addr_accept_max(msk);
        subflows_max = mptcp_pm_get_subflows_max(msk);
@@ -476,6 +497,10 @@ static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk)
        pr_debug("accepted %d:%d remote family %d",
                 msk->pm.add_addr_accepted, add_addr_accept_max,
                 msk->pm.remote.family);
+
+       if (lookup_subflow_by_daddr(&msk->conn_list, &msk->pm.remote))
+               goto add_addr_echo;
+
        msk->pm.add_addr_accepted++;
        msk->pm.subflows++;
        if (msk->pm.add_addr_accepted >= add_addr_accept_max ||
@@ -488,37 +513,37 @@ static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk)
        remote = msk->pm.remote;
        if (!remote.port)
                remote.port = sk->sk_dport;
-       else
-               use_port = true;
        memset(&local, 0, sizeof(local));
        local.family = remote.family;
 
        spin_unlock_bh(&msk->pm.lock);
-       __mptcp_subflow_connect(sk, &local, &remote);
+       __mptcp_subflow_connect(sk, &local, &remote, 0, 0);
        spin_lock_bh(&msk->pm.lock);
 
-       mptcp_pm_announce_addr(msk, &remote, true, use_port);
-       mptcp_pm_nl_add_addr_send_ack(msk);
+add_addr_echo:
+       mptcp_pm_announce_addr(msk, &msk->pm.remote, true);
+       mptcp_pm_nl_addr_send_ack(msk);
 }
 
-static void mptcp_pm_nl_add_addr_send_ack(struct mptcp_sock *msk)
+void mptcp_pm_nl_addr_send_ack(struct mptcp_sock *msk)
 {
        struct mptcp_subflow_context *subflow;
 
        msk_owned_by_me(msk);
        lockdep_assert_held(&msk->pm.lock);
 
-       if (!mptcp_pm_should_add_signal(msk))
+       if (!mptcp_pm_should_add_signal(msk) &&
+           !mptcp_pm_should_rm_signal(msk))
                return;
 
        __mptcp_flush_join_list(msk);
        subflow = list_first_entry_or_null(&msk->conn_list, typeof(*subflow), node);
        if (subflow) {
                struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
-               u8 add_addr;
 
                spin_unlock_bh(&msk->pm.lock);
-               pr_debug("send ack for add_addr%s%s",
+               pr_debug("send ack for %s%s%s",
+                        mptcp_pm_should_add_signal(msk) ? "add_addr" : "rm_addr",
                         mptcp_pm_should_add_signal_ipv6(msk) ? " [ipv6]" : "",
                         mptcp_pm_should_add_signal_port(msk) ? " [port]" : "");
 
@@ -526,13 +551,6 @@ static void mptcp_pm_nl_add_addr_send_ack(struct mptcp_sock *msk)
                tcp_send_ack(ssk);
                release_sock(ssk);
                spin_lock_bh(&msk->pm.lock);
-
-               add_addr = READ_ONCE(msk->pm.addr_signal);
-               if (mptcp_pm_should_add_signal_ipv6(msk))
-                       add_addr &= ~BIT(MPTCP_ADD_ADDR_IPV6);
-               if (mptcp_pm_should_add_signal_port(msk))
-                       add_addr &= ~BIT(MPTCP_ADD_ADDR_PORT);
-               WRITE_ONCE(msk->pm.addr_signal, add_addr);
        }
 }
 
@@ -571,47 +589,68 @@ int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk,
        return -EINVAL;
 }
 
-static void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk)
+static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk,
+                                          const struct mptcp_rm_list *rm_list,
+                                          enum linux_mptcp_mib_field rm_type)
 {
        struct mptcp_subflow_context *subflow, *tmp;
        struct sock *sk = (struct sock *)msk;
        u8 i;
 
-       pr_debug("address rm_list_nr %d", msk->pm.rm_list_rx.nr);
+       pr_debug("%s rm_list_nr %d",
+                rm_type == MPTCP_MIB_RMADDR ? "address" : "subflow", rm_list->nr);
 
        msk_owned_by_me(msk);
 
-       if (!msk->pm.rm_list_rx.nr)
+       if (!rm_list->nr)
                return;
 
        if (list_empty(&msk->conn_list))
                return;
 
-       for (i = 0; i < msk->pm.rm_list_rx.nr; i++) {
+       for (i = 0; i < rm_list->nr; i++) {
                list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) {
                        struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
                        int how = RCV_SHUTDOWN | SEND_SHUTDOWN;
+                       u8 id = subflow->local_id;
+
+                       if (rm_type == MPTCP_MIB_RMADDR)
+                               id = subflow->remote_id;
 
-                       if (msk->pm.rm_list_rx.ids[i] != subflow->remote_id)
+                       if (rm_list->ids[i] != id)
                                continue;
 
-                       pr_debug(" -> address rm_list_ids[%d]=%u", i, msk->pm.rm_list_rx.ids[i]);
+                       pr_debug(" -> %s rm_list_ids[%d]=%u local_id=%u remote_id=%u",
+                                rm_type == MPTCP_MIB_RMADDR ? "address" : "subflow",
+                                i, rm_list->ids[i], subflow->local_id, subflow->remote_id);
                        spin_unlock_bh(&msk->pm.lock);
                        mptcp_subflow_shutdown(sk, ssk, how);
                        mptcp_close_ssk(sk, ssk, subflow);
                        spin_lock_bh(&msk->pm.lock);
 
-                       msk->pm.add_addr_accepted--;
+                       if (rm_type == MPTCP_MIB_RMADDR) {
+                               msk->pm.add_addr_accepted--;
+                               WRITE_ONCE(msk->pm.accept_addr, true);
+                       } else if (rm_type == MPTCP_MIB_RMSUBFLOW) {
+                               msk->pm.local_addr_used--;
+                       }
                        msk->pm.subflows--;
-                       WRITE_ONCE(msk->pm.accept_addr, true);
-
-                       __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMADDR);
-
-                       break;
+                       __MPTCP_INC_STATS(sock_net(sk), rm_type);
                }
        }
 }
 
+static void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk)
+{
+       mptcp_pm_nl_rm_addr_or_subflow(msk, &msk->pm.rm_list_rx, MPTCP_MIB_RMADDR);
+}
+
+void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk,
+                                    const struct mptcp_rm_list *rm_list)
+{
+       mptcp_pm_nl_rm_addr_or_subflow(msk, rm_list, MPTCP_MIB_RMSUBFLOW);
+}
+
 void mptcp_pm_nl_work(struct mptcp_sock *msk)
 {
        struct mptcp_pm_data *pm = &msk->pm;
@@ -627,7 +666,7 @@ void mptcp_pm_nl_work(struct mptcp_sock *msk)
        }
        if (pm->status & BIT(MPTCP_PM_ADD_ADDR_SEND_ACK)) {
                pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_SEND_ACK);
-               mptcp_pm_nl_add_addr_send_ack(msk);
+               mptcp_pm_nl_addr_send_ack(msk);
        }
        if (pm->status & BIT(MPTCP_PM_RM_ADDR_RECEIVED)) {
                pm->status &= ~BIT(MPTCP_PM_RM_ADDR_RECEIVED);
@@ -645,50 +684,9 @@ void mptcp_pm_nl_work(struct mptcp_sock *msk)
        spin_unlock_bh(&msk->pm.lock);
 }
 
-void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk,
-                                    const struct mptcp_rm_list *rm_list)
-{
-       struct mptcp_subflow_context *subflow, *tmp;
-       struct sock *sk = (struct sock *)msk;
-       u8 i;
-
-       pr_debug("subflow rm_list_nr %d", rm_list->nr);
-
-       msk_owned_by_me(msk);
-
-       if (!rm_list->nr)
-               return;
-
-       if (list_empty(&msk->conn_list))
-               return;
-
-       for (i = 0; i < rm_list->nr; i++) {
-               list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) {
-                       struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
-                       int how = RCV_SHUTDOWN | SEND_SHUTDOWN;
-
-                       if (rm_list->ids[i] != subflow->local_id)
-                               continue;
-
-                       pr_debug(" -> subflow rm_list_ids[%d]=%u", i, rm_list->ids[i]);
-                       spin_unlock_bh(&msk->pm.lock);
-                       mptcp_subflow_shutdown(sk, ssk, how);
-                       mptcp_close_ssk(sk, ssk, subflow);
-                       spin_lock_bh(&msk->pm.lock);
-
-                       msk->pm.local_addr_used--;
-                       msk->pm.subflows--;
-
-                       __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMSUBFLOW);
-
-                       break;
-               }
-       }
-}
-
 static bool address_use_port(struct mptcp_pm_addr_entry *entry)
 {
-       return (entry->addr.flags &
+       return (entry->flags &
                (MPTCP_PM_ADDR_FLAG_SIGNAL | MPTCP_PM_ADDR_FLAG_SUBFLOW)) ==
                MPTCP_PM_ADDR_FLAG_SIGNAL;
 }
@@ -740,11 +738,11 @@ find_next:
        if (entry->addr.id > pernet->next_id)
                pernet->next_id = entry->addr.id;
 
-       if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
+       if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
                addr_max = pernet->add_addr_signal_max;
                WRITE_ONCE(pernet->add_addr_signal_max, addr_max + 1);
        }
-       if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
+       if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
                addr_max = pernet->local_addr_max;
                WRITE_ONCE(pernet->local_addr_max, addr_max + 1);
        }
@@ -846,10 +844,10 @@ int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
                return -ENOMEM;
 
        entry->addr = skc_local;
-       entry->addr.ifindex = 0;
-       entry->addr.flags = 0;
        entry->addr.id = 0;
        entry->addr.port = 0;
+       entry->ifindex = 0;
+       entry->flags = 0;
        entry->lsk = NULL;
        ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
        if (ret < 0)
@@ -964,14 +962,14 @@ skip_family:
        if (tb[MPTCP_PM_ADDR_ATTR_IF_IDX]) {
                u32 val = nla_get_s32(tb[MPTCP_PM_ADDR_ATTR_IF_IDX]);
 
-               entry->addr.ifindex = val;
+               entry->ifindex = val;
        }
 
        if (tb[MPTCP_PM_ADDR_ATTR_ID])
                entry->addr.id = nla_get_u8(tb[MPTCP_PM_ADDR_ATTR_ID]);
 
        if (tb[MPTCP_PM_ADDR_ATTR_FLAGS])
-               entry->addr.flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]);
+               entry->flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]);
 
        if (tb[MPTCP_PM_ADDR_ATTR_PORT])
                entry->addr.port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT]));
@@ -1161,6 +1159,41 @@ static void mptcp_pm_free_addr_entry(struct mptcp_pm_addr_entry *entry)
        }
 }
 
+static int mptcp_nl_remove_id_zero_address(struct net *net,
+                                          struct mptcp_addr_info *addr)
+{
+       struct mptcp_rm_list list = { .nr = 0 };
+       long s_slot = 0, s_num = 0;
+       struct mptcp_sock *msk;
+
+       list.ids[list.nr++] = 0;
+
+       while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
+               struct sock *sk = (struct sock *)msk;
+               struct mptcp_addr_info msk_local;
+
+               if (list_empty(&msk->conn_list))
+                       goto next;
+
+               local_address((struct sock_common *)msk, &msk_local);
+               if (!addresses_equal(&msk_local, addr, addr->port))
+                       goto next;
+
+               lock_sock(sk);
+               spin_lock_bh(&msk->pm.lock);
+               mptcp_pm_remove_addr(msk, &list);
+               mptcp_pm_nl_rm_subflow_received(msk, &list);
+               spin_unlock_bh(&msk->pm.lock);
+               release_sock(sk);
+
+next:
+               sock_put(sk);
+               cond_resched();
+       }
+
+       return 0;
+}
+
 static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
 {
        struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
@@ -1173,6 +1206,14 @@ static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
        if (ret < 0)
                return ret;
 
+       /* the zero id address is special: the first address used by the msk
+        * always gets such an id, so different subflows can have different zero
+        * id addresses. Additionally zero id is not accounted for in id_bitmap.
+        * Let's use an 'mptcp_rm_list' instead of the common remove code.
+        */
+       if (addr.addr.id == 0)
+               return mptcp_nl_remove_id_zero_address(sock_net(skb->sk), &addr.addr);
+
        spin_lock_bh(&pernet->lock);
        entry = __lookup_addr_by_id(pernet, addr.addr.id);
        if (!entry) {
@@ -1180,11 +1221,11 @@ static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
                spin_unlock_bh(&pernet->lock);
                return -EINVAL;
        }
-       if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
+       if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
                addr_max = pernet->add_addr_signal_max;
                WRITE_ONCE(pernet->add_addr_signal_max, addr_max - 1);
        }
-       if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
+       if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
                addr_max = pernet->local_addr_max;
                WRITE_ONCE(pernet->local_addr_max, addr_max - 1);
        }
@@ -1300,10 +1341,10 @@ static int mptcp_nl_fill_addr(struct sk_buff *skb,
                goto nla_put_failure;
        if (nla_put_u8(skb, MPTCP_PM_ADDR_ATTR_ID, addr->id))
                goto nla_put_failure;
-       if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->addr.flags))
+       if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->flags))
                goto nla_put_failure;
-       if (entry->addr.ifindex &&
-           nla_put_s32(skb, MPTCP_PM_ADDR_ATTR_IF_IDX, entry->addr.ifindex))
+       if (entry->ifindex &&
+           nla_put_s32(skb, MPTCP_PM_ADDR_ATTR_IF_IDX, entry->ifindex))
                goto nla_put_failure;
 
        if (addr->family == AF_INET &&
@@ -1531,7 +1572,7 @@ static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info)
        if (ret < 0)
                return ret;
 
-       if (addr.addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP)
+       if (addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP)
                bkup = 1;
 
        list_for_each_entry(entry, &pernet->local_addr_list, list) {
@@ -1541,9 +1582,9 @@ static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info)
                                return ret;
 
                        if (bkup)
-                               entry->addr.flags |= MPTCP_PM_ADDR_FLAG_BACKUP;
+                               entry->flags |= MPTCP_PM_ADDR_FLAG_BACKUP;
                        else
-                               entry->addr.flags &= ~MPTCP_PM_ADDR_FLAG_BACKUP;
+                               entry->flags &= ~MPTCP_PM_ADDR_FLAG_BACKUP;
                }
        }
 
@@ -1649,9 +1690,21 @@ static int mptcp_event_sub_closed(struct sk_buff *skb,
                                  const struct mptcp_sock *msk,
                                  const struct sock *ssk)
 {
+       const struct mptcp_subflow_context *sf;
+
        if (mptcp_event_put_token_and_ssk(skb, msk, ssk))
                return -EMSGSIZE;
 
+       sf = mptcp_subflow_ctx(ssk);
+       if (!sf->reset_seen)
+               return 0;
+
+       if (nla_put_u32(skb, MPTCP_ATTR_RESET_REASON, sf->reset_reason))
+               return -EMSGSIZE;
+
+       if (nla_put_u32(skb, MPTCP_ATTR_RESET_FLAGS, sf->reset_transient))
+               return -EMSGSIZE;
+
        return 0;
 }