229fd1af2e2972fd7eeff982818a6e927ccfa394
[linux-2.6-microblaze.git] / net / mptcp / pm_netlink.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Multipath TCP
3  *
4  * Copyright (c) 2020, Red Hat, Inc.
5  */
6
7 #define pr_fmt(fmt) "MPTCP: " fmt
8
9 #include <linux/inet.h>
10 #include <linux/kernel.h>
11 #include <net/tcp.h>
12 #include <net/netns/generic.h>
13 #include <net/mptcp.h>
14 #include <net/genetlink.h>
15 #include <uapi/linux/mptcp.h>
16
17 #include "protocol.h"
18 #include "mib.h"
19
20 /* forward declaration */
21 static struct genl_family mptcp_genl_family;
22
23 static int pm_nl_pernet_id;
24
25 struct mptcp_pm_addr_entry {
26         struct list_head        list;
27         struct mptcp_addr_info  addr;
28         struct rcu_head         rcu;
29         struct socket           *lsk;
30 };
31
32 struct mptcp_pm_add_entry {
33         struct list_head        list;
34         struct mptcp_addr_info  addr;
35         struct timer_list       add_timer;
36         struct mptcp_sock       *sock;
37         u8                      retrans_times;
38 };
39
40 #define MAX_ADDR_ID             255
41 #define BITMAP_SZ DIV_ROUND_UP(MAX_ADDR_ID + 1, BITS_PER_LONG)
42
43 struct pm_nl_pernet {
44         /* protects pernet updates */
45         spinlock_t              lock;
46         struct list_head        local_addr_list;
47         unsigned int            addrs;
48         unsigned int            add_addr_signal_max;
49         unsigned int            add_addr_accept_max;
50         unsigned int            local_addr_max;
51         unsigned int            subflows_max;
52         unsigned int            next_id;
53         unsigned long           id_bitmap[BITMAP_SZ];
54 };
55
56 #define MPTCP_PM_ADDR_MAX       8
57 #define ADD_ADDR_RETRANS_MAX    3
58
59 static void mptcp_pm_nl_add_addr_send_ack(struct mptcp_sock *msk);
60
61 static bool addresses_equal(const struct mptcp_addr_info *a,
62                             struct mptcp_addr_info *b, bool use_port)
63 {
64         bool addr_equals = false;
65
66         if (a->family == b->family) {
67                 if (a->family == AF_INET)
68                         addr_equals = a->addr.s_addr == b->addr.s_addr;
69 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
70                 else
71                         addr_equals = !ipv6_addr_cmp(&a->addr6, &b->addr6);
72         } else if (a->family == AF_INET) {
73                 if (ipv6_addr_v4mapped(&b->addr6))
74                         addr_equals = a->addr.s_addr == b->addr6.s6_addr32[3];
75         } else if (b->family == AF_INET) {
76                 if (ipv6_addr_v4mapped(&a->addr6))
77                         addr_equals = a->addr6.s6_addr32[3] == b->addr.s_addr;
78 #endif
79         }
80
81         if (!addr_equals)
82                 return false;
83         if (!use_port)
84                 return true;
85
86         return a->port == b->port;
87 }
88
89 static bool address_zero(const struct mptcp_addr_info *addr)
90 {
91         struct mptcp_addr_info zero;
92
93         memset(&zero, 0, sizeof(zero));
94         zero.family = addr->family;
95
96         return addresses_equal(addr, &zero, true);
97 }
98
99 static void local_address(const struct sock_common *skc,
100                           struct mptcp_addr_info *addr)
101 {
102         addr->family = skc->skc_family;
103         addr->port = htons(skc->skc_num);
104         if (addr->family == AF_INET)
105                 addr->addr.s_addr = skc->skc_rcv_saddr;
106 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
107         else if (addr->family == AF_INET6)
108                 addr->addr6 = skc->skc_v6_rcv_saddr;
109 #endif
110 }
111
112 static void remote_address(const struct sock_common *skc,
113                            struct mptcp_addr_info *addr)
114 {
115         addr->family = skc->skc_family;
116         addr->port = skc->skc_dport;
117         if (addr->family == AF_INET)
118                 addr->addr.s_addr = skc->skc_daddr;
119 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
120         else if (addr->family == AF_INET6)
121                 addr->addr6 = skc->skc_v6_daddr;
122 #endif
123 }
124
125 static bool lookup_subflow_by_saddr(const struct list_head *list,
126                                     struct mptcp_addr_info *saddr)
127 {
128         struct mptcp_subflow_context *subflow;
129         struct mptcp_addr_info cur;
130         struct sock_common *skc;
131
132         list_for_each_entry(subflow, list, node) {
133                 skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
134
135                 local_address(skc, &cur);
136                 if (addresses_equal(&cur, saddr, saddr->port))
137                         return true;
138         }
139
140         return false;
141 }
142
143 static struct mptcp_pm_addr_entry *
144 select_local_address(const struct pm_nl_pernet *pernet,
145                      struct mptcp_sock *msk)
146 {
147         struct mptcp_pm_addr_entry *entry, *ret = NULL;
148         struct sock *sk = (struct sock *)msk;
149
150         msk_owned_by_me(msk);
151
152         rcu_read_lock();
153         __mptcp_flush_join_list(msk);
154         list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
155                 if (!(entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW))
156                         continue;
157
158                 if (entry->addr.family != sk->sk_family) {
159 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
160                         if ((entry->addr.family == AF_INET &&
161                              !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) ||
162                             (sk->sk_family == AF_INET &&
163                              !ipv6_addr_v4mapped(&entry->addr.addr6)))
164 #endif
165                                 continue;
166                 }
167
168                 /* avoid any address already in use by subflows and
169                  * pending join
170                  */
171                 if (!lookup_subflow_by_saddr(&msk->conn_list, &entry->addr)) {
172                         ret = entry;
173                         break;
174                 }
175         }
176         rcu_read_unlock();
177         return ret;
178 }
179
180 static struct mptcp_pm_addr_entry *
181 select_signal_address(struct pm_nl_pernet *pernet, unsigned int pos)
182 {
183         struct mptcp_pm_addr_entry *entry, *ret = NULL;
184         int i = 0;
185
186         rcu_read_lock();
187         /* do not keep any additional per socket state, just signal
188          * the address list in order.
189          * Note: removal from the local address list during the msk life-cycle
190          * can lead to additional addresses not being announced.
191          */
192         list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
193                 if (!(entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL))
194                         continue;
195                 if (i++ == pos) {
196                         ret = entry;
197                         break;
198                 }
199         }
200         rcu_read_unlock();
201         return ret;
202 }
203
204 unsigned int mptcp_pm_get_add_addr_signal_max(struct mptcp_sock *msk)
205 {
206         struct pm_nl_pernet *pernet;
207
208         pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
209         return READ_ONCE(pernet->add_addr_signal_max);
210 }
211 EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_signal_max);
212
213 unsigned int mptcp_pm_get_add_addr_accept_max(struct mptcp_sock *msk)
214 {
215         struct pm_nl_pernet *pernet;
216
217         pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
218         return READ_ONCE(pernet->add_addr_accept_max);
219 }
220 EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_accept_max);
221
222 unsigned int mptcp_pm_get_subflows_max(struct mptcp_sock *msk)
223 {
224         struct pm_nl_pernet *pernet;
225
226         pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
227         return READ_ONCE(pernet->subflows_max);
228 }
229 EXPORT_SYMBOL_GPL(mptcp_pm_get_subflows_max);
230
231 static unsigned int mptcp_pm_get_local_addr_max(struct mptcp_sock *msk)
232 {
233         struct pm_nl_pernet *pernet;
234
235         pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
236         return READ_ONCE(pernet->local_addr_max);
237 }
238
239 static void check_work_pending(struct mptcp_sock *msk)
240 {
241         if (msk->pm.add_addr_signaled == mptcp_pm_get_add_addr_signal_max(msk) &&
242             (msk->pm.local_addr_used == mptcp_pm_get_local_addr_max(msk) ||
243              msk->pm.subflows == mptcp_pm_get_subflows_max(msk)))
244                 WRITE_ONCE(msk->pm.work_pending, false);
245 }
246
247 static struct mptcp_pm_add_entry *
248 lookup_anno_list_by_saddr(struct mptcp_sock *msk,
249                           struct mptcp_addr_info *addr)
250 {
251         struct mptcp_pm_add_entry *entry;
252
253         lockdep_assert_held(&msk->pm.lock);
254
255         list_for_each_entry(entry, &msk->pm.anno_list, list) {
256                 if (addresses_equal(&entry->addr, addr, true))
257                         return entry;
258         }
259
260         return NULL;
261 }
262
263 bool mptcp_pm_sport_in_anno_list(struct mptcp_sock *msk, const struct sock *sk)
264 {
265         struct mptcp_pm_add_entry *entry;
266         struct mptcp_addr_info saddr;
267         bool ret = false;
268
269         local_address((struct sock_common *)sk, &saddr);
270
271         spin_lock_bh(&msk->pm.lock);
272         list_for_each_entry(entry, &msk->pm.anno_list, list) {
273                 if (addresses_equal(&entry->addr, &saddr, true)) {
274                         ret = true;
275                         goto out;
276                 }
277         }
278
279 out:
280         spin_unlock_bh(&msk->pm.lock);
281         return ret;
282 }
283
284 static void mptcp_pm_add_timer(struct timer_list *timer)
285 {
286         struct mptcp_pm_add_entry *entry = from_timer(entry, timer, add_timer);
287         struct mptcp_sock *msk = entry->sock;
288         struct sock *sk = (struct sock *)msk;
289
290         pr_debug("msk=%p", msk);
291
292         if (!msk)
293                 return;
294
295         if (inet_sk_state_load(sk) == TCP_CLOSE)
296                 return;
297
298         if (!entry->addr.id)
299                 return;
300
301         if (mptcp_pm_should_add_signal(msk)) {
302                 sk_reset_timer(sk, timer, jiffies + TCP_RTO_MAX / 8);
303                 goto out;
304         }
305
306         spin_lock_bh(&msk->pm.lock);
307
308         if (!mptcp_pm_should_add_signal(msk)) {
309                 pr_debug("retransmit ADD_ADDR id=%d", entry->addr.id);
310                 mptcp_pm_announce_addr(msk, &entry->addr, false, entry->addr.port);
311                 mptcp_pm_add_addr_send_ack(msk);
312                 entry->retrans_times++;
313         }
314
315         if (entry->retrans_times < ADD_ADDR_RETRANS_MAX)
316                 sk_reset_timer(sk, timer,
317                                jiffies + mptcp_get_add_addr_timeout(sock_net(sk)));
318
319         spin_unlock_bh(&msk->pm.lock);
320
321 out:
322         __sock_put(sk);
323 }
324
325 struct mptcp_pm_add_entry *
326 mptcp_pm_del_add_timer(struct mptcp_sock *msk,
327                        struct mptcp_addr_info *addr)
328 {
329         struct mptcp_pm_add_entry *entry;
330         struct sock *sk = (struct sock *)msk;
331
332         spin_lock_bh(&msk->pm.lock);
333         entry = lookup_anno_list_by_saddr(msk, addr);
334         if (entry)
335                 entry->retrans_times = ADD_ADDR_RETRANS_MAX;
336         spin_unlock_bh(&msk->pm.lock);
337
338         if (entry)
339                 sk_stop_timer_sync(sk, &entry->add_timer);
340
341         return entry;
342 }
343
344 static bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk,
345                                      struct mptcp_pm_addr_entry *entry)
346 {
347         struct mptcp_pm_add_entry *add_entry = NULL;
348         struct sock *sk = (struct sock *)msk;
349         struct net *net = sock_net(sk);
350
351         lockdep_assert_held(&msk->pm.lock);
352
353         if (lookup_anno_list_by_saddr(msk, &entry->addr))
354                 return false;
355
356         add_entry = kmalloc(sizeof(*add_entry), GFP_ATOMIC);
357         if (!add_entry)
358                 return false;
359
360         list_add(&add_entry->list, &msk->pm.anno_list);
361
362         add_entry->addr = entry->addr;
363         add_entry->sock = msk;
364         add_entry->retrans_times = 0;
365
366         timer_setup(&add_entry->add_timer, mptcp_pm_add_timer, 0);
367         sk_reset_timer(sk, &add_entry->add_timer,
368                        jiffies + mptcp_get_add_addr_timeout(net));
369
370         return true;
371 }
372
373 void mptcp_pm_free_anno_list(struct mptcp_sock *msk)
374 {
375         struct mptcp_pm_add_entry *entry, *tmp;
376         struct sock *sk = (struct sock *)msk;
377         LIST_HEAD(free_list);
378
379         pr_debug("msk=%p", msk);
380
381         spin_lock_bh(&msk->pm.lock);
382         list_splice_init(&msk->pm.anno_list, &free_list);
383         spin_unlock_bh(&msk->pm.lock);
384
385         list_for_each_entry_safe(entry, tmp, &free_list, list) {
386                 sk_stop_timer_sync(sk, &entry->add_timer);
387                 kfree(entry);
388         }
389 }
390
391 static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
392 {
393         struct sock *sk = (struct sock *)msk;
394         struct mptcp_pm_addr_entry *local;
395         unsigned int add_addr_signal_max;
396         unsigned int local_addr_max;
397         struct pm_nl_pernet *pernet;
398         unsigned int subflows_max;
399
400         pernet = net_generic(sock_net(sk), pm_nl_pernet_id);
401
402         add_addr_signal_max = mptcp_pm_get_add_addr_signal_max(msk);
403         local_addr_max = mptcp_pm_get_local_addr_max(msk);
404         subflows_max = mptcp_pm_get_subflows_max(msk);
405
406         pr_debug("local %d:%d signal %d:%d subflows %d:%d\n",
407                  msk->pm.local_addr_used, local_addr_max,
408                  msk->pm.add_addr_signaled, add_addr_signal_max,
409                  msk->pm.subflows, subflows_max);
410
411         /* check first for announce */
412         if (msk->pm.add_addr_signaled < add_addr_signal_max) {
413                 local = select_signal_address(pernet,
414                                               msk->pm.add_addr_signaled);
415
416                 if (local) {
417                         if (mptcp_pm_alloc_anno_list(msk, local)) {
418                                 msk->pm.add_addr_signaled++;
419                                 mptcp_pm_announce_addr(msk, &local->addr, false, local->addr.port);
420                                 mptcp_pm_nl_add_addr_send_ack(msk);
421                         }
422                 } else {
423                         /* pick failed, avoid fourther attempts later */
424                         msk->pm.local_addr_used = add_addr_signal_max;
425                 }
426
427                 check_work_pending(msk);
428         }
429
430         /* check if should create a new subflow */
431         if (msk->pm.local_addr_used < local_addr_max &&
432             msk->pm.subflows < subflows_max) {
433                 local = select_local_address(pernet, msk);
434                 if (local) {
435                         struct mptcp_addr_info remote = { 0 };
436
437                         msk->pm.local_addr_used++;
438                         msk->pm.subflows++;
439                         check_work_pending(msk);
440                         remote_address((struct sock_common *)sk, &remote);
441                         spin_unlock_bh(&msk->pm.lock);
442                         __mptcp_subflow_connect(sk, &local->addr, &remote);
443                         spin_lock_bh(&msk->pm.lock);
444                         return;
445                 }
446
447                 /* lookup failed, avoid fourther attempts later */
448                 msk->pm.local_addr_used = local_addr_max;
449                 check_work_pending(msk);
450         }
451 }
452
453 static void mptcp_pm_nl_fully_established(struct mptcp_sock *msk)
454 {
455         mptcp_pm_create_subflow_or_signal_addr(msk);
456 }
457
458 static void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk)
459 {
460         mptcp_pm_create_subflow_or_signal_addr(msk);
461 }
462
463 static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk)
464 {
465         struct sock *sk = (struct sock *)msk;
466         unsigned int add_addr_accept_max;
467         struct mptcp_addr_info remote;
468         struct mptcp_addr_info local;
469         unsigned int subflows_max;
470         bool use_port = false;
471
472         add_addr_accept_max = mptcp_pm_get_add_addr_accept_max(msk);
473         subflows_max = mptcp_pm_get_subflows_max(msk);
474
475         pr_debug("accepted %d:%d remote family %d",
476                  msk->pm.add_addr_accepted, add_addr_accept_max,
477                  msk->pm.remote.family);
478         msk->pm.add_addr_accepted++;
479         msk->pm.subflows++;
480         if (msk->pm.add_addr_accepted >= add_addr_accept_max ||
481             msk->pm.subflows >= subflows_max)
482                 WRITE_ONCE(msk->pm.accept_addr, false);
483
484         /* connect to the specified remote address, using whatever
485          * local address the routing configuration will pick.
486          */
487         remote = msk->pm.remote;
488         if (!remote.port)
489                 remote.port = sk->sk_dport;
490         else
491                 use_port = true;
492         memset(&local, 0, sizeof(local));
493         local.family = remote.family;
494
495         spin_unlock_bh(&msk->pm.lock);
496         __mptcp_subflow_connect(sk, &local, &remote);
497         spin_lock_bh(&msk->pm.lock);
498
499         mptcp_pm_announce_addr(msk, &remote, true, use_port);
500         mptcp_pm_nl_add_addr_send_ack(msk);
501 }
502
503 static void mptcp_pm_nl_add_addr_send_ack(struct mptcp_sock *msk)
504 {
505         struct mptcp_subflow_context *subflow;
506
507         msk_owned_by_me(msk);
508         lockdep_assert_held(&msk->pm.lock);
509
510         if (!mptcp_pm_should_add_signal(msk))
511                 return;
512
513         __mptcp_flush_join_list(msk);
514         subflow = list_first_entry_or_null(&msk->conn_list, typeof(*subflow), node);
515         if (subflow) {
516                 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
517                 u8 add_addr;
518
519                 spin_unlock_bh(&msk->pm.lock);
520                 pr_debug("send ack for add_addr%s%s",
521                          mptcp_pm_should_add_signal_ipv6(msk) ? " [ipv6]" : "",
522                          mptcp_pm_should_add_signal_port(msk) ? " [port]" : "");
523
524                 lock_sock(ssk);
525                 tcp_send_ack(ssk);
526                 release_sock(ssk);
527                 spin_lock_bh(&msk->pm.lock);
528
529                 add_addr = READ_ONCE(msk->pm.addr_signal);
530                 if (mptcp_pm_should_add_signal_ipv6(msk))
531                         add_addr &= ~BIT(MPTCP_ADD_ADDR_IPV6);
532                 if (mptcp_pm_should_add_signal_port(msk))
533                         add_addr &= ~BIT(MPTCP_ADD_ADDR_PORT);
534                 WRITE_ONCE(msk->pm.addr_signal, add_addr);
535         }
536 }
537
538 int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk,
539                                  struct mptcp_addr_info *addr,
540                                  u8 bkup)
541 {
542         struct mptcp_subflow_context *subflow;
543
544         pr_debug("bkup=%d", bkup);
545
546         mptcp_for_each_subflow(msk, subflow) {
547                 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
548                 struct sock *sk = (struct sock *)msk;
549                 struct mptcp_addr_info local;
550
551                 local_address((struct sock_common *)ssk, &local);
552                 if (!addresses_equal(&local, addr, addr->port))
553                         continue;
554
555                 subflow->backup = bkup;
556                 subflow->send_mp_prio = 1;
557                 subflow->request_bkup = bkup;
558                 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPPRIOTX);
559
560                 spin_unlock_bh(&msk->pm.lock);
561                 pr_debug("send ack for mp_prio");
562                 lock_sock(ssk);
563                 tcp_send_ack(ssk);
564                 release_sock(ssk);
565                 spin_lock_bh(&msk->pm.lock);
566
567                 return 0;
568         }
569
570         return -EINVAL;
571 }
572
573 static void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk)
574 {
575         struct mptcp_subflow_context *subflow, *tmp;
576         struct sock *sk = (struct sock *)msk;
577
578         pr_debug("address rm_id %d", msk->pm.rm_id);
579
580         msk_owned_by_me(msk);
581
582         if (!msk->pm.rm_id)
583                 return;
584
585         if (list_empty(&msk->conn_list))
586                 return;
587
588         list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) {
589                 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
590                 int how = RCV_SHUTDOWN | SEND_SHUTDOWN;
591
592                 if (msk->pm.rm_id != subflow->remote_id)
593                         continue;
594
595                 spin_unlock_bh(&msk->pm.lock);
596                 mptcp_subflow_shutdown(sk, ssk, how);
597                 mptcp_close_ssk(sk, ssk, subflow);
598                 spin_lock_bh(&msk->pm.lock);
599
600                 msk->pm.add_addr_accepted--;
601                 msk->pm.subflows--;
602                 WRITE_ONCE(msk->pm.accept_addr, true);
603
604                 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMADDR);
605
606                 break;
607         }
608 }
609
610 void mptcp_pm_nl_work(struct mptcp_sock *msk)
611 {
612         struct mptcp_pm_data *pm = &msk->pm;
613
614         msk_owned_by_me(msk);
615
616         spin_lock_bh(&msk->pm.lock);
617
618         pr_debug("msk=%p status=%x", msk, pm->status);
619         if (pm->status & BIT(MPTCP_PM_ADD_ADDR_RECEIVED)) {
620                 pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_RECEIVED);
621                 mptcp_pm_nl_add_addr_received(msk);
622         }
623         if (pm->status & BIT(MPTCP_PM_ADD_ADDR_SEND_ACK)) {
624                 pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_SEND_ACK);
625                 mptcp_pm_nl_add_addr_send_ack(msk);
626         }
627         if (pm->status & BIT(MPTCP_PM_RM_ADDR_RECEIVED)) {
628                 pm->status &= ~BIT(MPTCP_PM_RM_ADDR_RECEIVED);
629                 mptcp_pm_nl_rm_addr_received(msk);
630         }
631         if (pm->status & BIT(MPTCP_PM_ESTABLISHED)) {
632                 pm->status &= ~BIT(MPTCP_PM_ESTABLISHED);
633                 mptcp_pm_nl_fully_established(msk);
634         }
635         if (pm->status & BIT(MPTCP_PM_SUBFLOW_ESTABLISHED)) {
636                 pm->status &= ~BIT(MPTCP_PM_SUBFLOW_ESTABLISHED);
637                 mptcp_pm_nl_subflow_established(msk);
638         }
639
640         spin_unlock_bh(&msk->pm.lock);
641 }
642
643 void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk, u8 rm_id)
644 {
645         struct mptcp_subflow_context *subflow, *tmp;
646         struct sock *sk = (struct sock *)msk;
647
648         pr_debug("subflow rm_id %d", rm_id);
649
650         msk_owned_by_me(msk);
651
652         if (!rm_id)
653                 return;
654
655         if (list_empty(&msk->conn_list))
656                 return;
657
658         list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) {
659                 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
660                 int how = RCV_SHUTDOWN | SEND_SHUTDOWN;
661
662                 if (rm_id != subflow->local_id)
663                         continue;
664
665                 spin_unlock_bh(&msk->pm.lock);
666                 mptcp_subflow_shutdown(sk, ssk, how);
667                 mptcp_close_ssk(sk, ssk, subflow);
668                 spin_lock_bh(&msk->pm.lock);
669
670                 msk->pm.local_addr_used--;
671                 msk->pm.subflows--;
672
673                 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMSUBFLOW);
674
675                 break;
676         }
677 }
678
679 static bool address_use_port(struct mptcp_pm_addr_entry *entry)
680 {
681         return (entry->addr.flags &
682                 (MPTCP_PM_ADDR_FLAG_SIGNAL | MPTCP_PM_ADDR_FLAG_SUBFLOW)) ==
683                 MPTCP_PM_ADDR_FLAG_SIGNAL;
684 }
685
686 static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet,
687                                              struct mptcp_pm_addr_entry *entry)
688 {
689         struct mptcp_pm_addr_entry *cur;
690         unsigned int addr_max;
691         int ret = -EINVAL;
692
693         spin_lock_bh(&pernet->lock);
694         /* to keep the code simple, don't do IDR-like allocation for address ID,
695          * just bail when we exceed limits
696          */
697         if (pernet->next_id == MAX_ADDR_ID)
698                 pernet->next_id = 1;
699         if (pernet->addrs >= MPTCP_PM_ADDR_MAX)
700                 goto out;
701         if (test_bit(entry->addr.id, pernet->id_bitmap))
702                 goto out;
703
704         /* do not insert duplicate address, differentiate on port only
705          * singled addresses
706          */
707         list_for_each_entry(cur, &pernet->local_addr_list, list) {
708                 if (addresses_equal(&cur->addr, &entry->addr,
709                                     address_use_port(entry) &&
710                                     address_use_port(cur)))
711                         goto out;
712         }
713
714         if (!entry->addr.id) {
715 find_next:
716                 entry->addr.id = find_next_zero_bit(pernet->id_bitmap,
717                                                     MAX_ADDR_ID + 1,
718                                                     pernet->next_id);
719                 if ((!entry->addr.id || entry->addr.id > MAX_ADDR_ID) &&
720                     pernet->next_id != 1) {
721                         pernet->next_id = 1;
722                         goto find_next;
723                 }
724         }
725
726         if (!entry->addr.id || entry->addr.id > MAX_ADDR_ID)
727                 goto out;
728
729         __set_bit(entry->addr.id, pernet->id_bitmap);
730         if (entry->addr.id > pernet->next_id)
731                 pernet->next_id = entry->addr.id;
732
733         if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
734                 addr_max = pernet->add_addr_signal_max;
735                 WRITE_ONCE(pernet->add_addr_signal_max, addr_max + 1);
736         }
737         if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
738                 addr_max = pernet->local_addr_max;
739                 WRITE_ONCE(pernet->local_addr_max, addr_max + 1);
740         }
741
742         pernet->addrs++;
743         list_add_tail_rcu(&entry->list, &pernet->local_addr_list);
744         ret = entry->addr.id;
745
746 out:
747         spin_unlock_bh(&pernet->lock);
748         return ret;
749 }
750
751 static int mptcp_pm_nl_create_listen_socket(struct sock *sk,
752                                             struct mptcp_pm_addr_entry *entry)
753 {
754         struct sockaddr_storage addr;
755         struct mptcp_sock *msk;
756         struct socket *ssock;
757         int backlog = 1024;
758         int err;
759
760         err = sock_create_kern(sock_net(sk), entry->addr.family,
761                                SOCK_STREAM, IPPROTO_MPTCP, &entry->lsk);
762         if (err)
763                 return err;
764
765         msk = mptcp_sk(entry->lsk->sk);
766         if (!msk) {
767                 err = -EINVAL;
768                 goto out;
769         }
770
771         ssock = __mptcp_nmpc_socket(msk);
772         if (!ssock) {
773                 err = -EINVAL;
774                 goto out;
775         }
776
777         mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family);
778         err = kernel_bind(ssock, (struct sockaddr *)&addr,
779                           sizeof(struct sockaddr_in));
780         if (err) {
781                 pr_warn("kernel_bind error, err=%d", err);
782                 goto out;
783         }
784
785         err = kernel_listen(ssock, backlog);
786         if (err) {
787                 pr_warn("kernel_listen error, err=%d", err);
788                 goto out;
789         }
790
791         return 0;
792
793 out:
794         sock_release(entry->lsk);
795         return err;
796 }
797
798 int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
799 {
800         struct mptcp_pm_addr_entry *entry;
801         struct mptcp_addr_info skc_local;
802         struct mptcp_addr_info msk_local;
803         struct pm_nl_pernet *pernet;
804         int ret = -1;
805
806         if (WARN_ON_ONCE(!msk))
807                 return -1;
808
809         /* The 0 ID mapping is defined by the first subflow, copied into the msk
810          * addr
811          */
812         local_address((struct sock_common *)msk, &msk_local);
813         local_address((struct sock_common *)skc, &skc_local);
814         if (addresses_equal(&msk_local, &skc_local, false))
815                 return 0;
816
817         if (address_zero(&skc_local))
818                 return 0;
819
820         pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
821
822         rcu_read_lock();
823         list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
824                 if (addresses_equal(&entry->addr, &skc_local, entry->addr.port)) {
825                         ret = entry->addr.id;
826                         break;
827                 }
828         }
829         rcu_read_unlock();
830         if (ret >= 0)
831                 return ret;
832
833         /* address not found, add to local list */
834         entry = kmalloc(sizeof(*entry), GFP_ATOMIC);
835         if (!entry)
836                 return -ENOMEM;
837
838         entry->addr = skc_local;
839         entry->addr.ifindex = 0;
840         entry->addr.flags = 0;
841         entry->addr.id = 0;
842         entry->addr.port = 0;
843         entry->lsk = NULL;
844         ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
845         if (ret < 0)
846                 kfree(entry);
847
848         return ret;
849 }
850
851 void mptcp_pm_nl_data_init(struct mptcp_sock *msk)
852 {
853         struct mptcp_pm_data *pm = &msk->pm;
854         bool subflows;
855
856         subflows = !!mptcp_pm_get_subflows_max(msk);
857         WRITE_ONCE(pm->work_pending, (!!mptcp_pm_get_local_addr_max(msk) && subflows) ||
858                    !!mptcp_pm_get_add_addr_signal_max(msk));
859         WRITE_ONCE(pm->accept_addr, !!mptcp_pm_get_add_addr_accept_max(msk) && subflows);
860         WRITE_ONCE(pm->accept_subflow, subflows);
861 }
862
863 #define MPTCP_PM_CMD_GRP_OFFSET       0
864 #define MPTCP_PM_EV_GRP_OFFSET        1
865
866 static const struct genl_multicast_group mptcp_pm_mcgrps[] = {
867         [MPTCP_PM_CMD_GRP_OFFSET]       = { .name = MPTCP_PM_CMD_GRP_NAME, },
868         [MPTCP_PM_EV_GRP_OFFSET]        = { .name = MPTCP_PM_EV_GRP_NAME,
869                                             .flags = GENL_UNS_ADMIN_PERM,
870                                           },
871 };
872
873 static const struct nla_policy
874 mptcp_pm_addr_policy[MPTCP_PM_ADDR_ATTR_MAX + 1] = {
875         [MPTCP_PM_ADDR_ATTR_FAMILY]     = { .type       = NLA_U16,      },
876         [MPTCP_PM_ADDR_ATTR_ID]         = { .type       = NLA_U8,       },
877         [MPTCP_PM_ADDR_ATTR_ADDR4]      = { .type       = NLA_U32,      },
878         [MPTCP_PM_ADDR_ATTR_ADDR6]      =
879                 NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)),
880         [MPTCP_PM_ADDR_ATTR_PORT]       = { .type       = NLA_U16       },
881         [MPTCP_PM_ADDR_ATTR_FLAGS]      = { .type       = NLA_U32       },
882         [MPTCP_PM_ADDR_ATTR_IF_IDX]     = { .type       = NLA_S32       },
883 };
884
885 static const struct nla_policy mptcp_pm_policy[MPTCP_PM_ATTR_MAX + 1] = {
886         [MPTCP_PM_ATTR_ADDR]            =
887                                         NLA_POLICY_NESTED(mptcp_pm_addr_policy),
888         [MPTCP_PM_ATTR_RCV_ADD_ADDRS]   = { .type       = NLA_U32,      },
889         [MPTCP_PM_ATTR_SUBFLOWS]        = { .type       = NLA_U32,      },
890 };
891
892 static int mptcp_pm_family_to_addr(int family)
893 {
894 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
895         if (family == AF_INET6)
896                 return MPTCP_PM_ADDR_ATTR_ADDR6;
897 #endif
898         return MPTCP_PM_ADDR_ATTR_ADDR4;
899 }
900
901 static int mptcp_pm_parse_addr(struct nlattr *attr, struct genl_info *info,
902                                bool require_family,
903                                struct mptcp_pm_addr_entry *entry)
904 {
905         struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1];
906         int err, addr_addr;
907
908         if (!attr) {
909                 GENL_SET_ERR_MSG(info, "missing address info");
910                 return -EINVAL;
911         }
912
913         /* no validation needed - was already done via nested policy */
914         err = nla_parse_nested_deprecated(tb, MPTCP_PM_ADDR_ATTR_MAX, attr,
915                                           mptcp_pm_addr_policy, info->extack);
916         if (err)
917                 return err;
918
919         memset(entry, 0, sizeof(*entry));
920         if (!tb[MPTCP_PM_ADDR_ATTR_FAMILY]) {
921                 if (!require_family)
922                         goto skip_family;
923
924                 NL_SET_ERR_MSG_ATTR(info->extack, attr,
925                                     "missing family");
926                 return -EINVAL;
927         }
928
929         entry->addr.family = nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_FAMILY]);
930         if (entry->addr.family != AF_INET
931 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
932             && entry->addr.family != AF_INET6
933 #endif
934             ) {
935                 NL_SET_ERR_MSG_ATTR(info->extack, attr,
936                                     "unknown address family");
937                 return -EINVAL;
938         }
939         addr_addr = mptcp_pm_family_to_addr(entry->addr.family);
940         if (!tb[addr_addr]) {
941                 NL_SET_ERR_MSG_ATTR(info->extack, attr,
942                                     "missing address data");
943                 return -EINVAL;
944         }
945
946 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
947         if (entry->addr.family == AF_INET6)
948                 entry->addr.addr6 = nla_get_in6_addr(tb[addr_addr]);
949         else
950 #endif
951                 entry->addr.addr.s_addr = nla_get_in_addr(tb[addr_addr]);
952
953 skip_family:
954         if (tb[MPTCP_PM_ADDR_ATTR_IF_IDX]) {
955                 u32 val = nla_get_s32(tb[MPTCP_PM_ADDR_ATTR_IF_IDX]);
956
957                 entry->addr.ifindex = val;
958         }
959
960         if (tb[MPTCP_PM_ADDR_ATTR_ID])
961                 entry->addr.id = nla_get_u8(tb[MPTCP_PM_ADDR_ATTR_ID]);
962
963         if (tb[MPTCP_PM_ADDR_ATTR_FLAGS])
964                 entry->addr.flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]);
965
966         if (tb[MPTCP_PM_ADDR_ATTR_PORT])
967                 entry->addr.port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT]));
968
969         return 0;
970 }
971
972 static struct pm_nl_pernet *genl_info_pm_nl(struct genl_info *info)
973 {
974         return net_generic(genl_info_net(info), pm_nl_pernet_id);
975 }
976
977 static int mptcp_nl_add_subflow_or_signal_addr(struct net *net)
978 {
979         struct mptcp_sock *msk;
980         long s_slot = 0, s_num = 0;
981
982         while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
983                 struct sock *sk = (struct sock *)msk;
984
985                 if (!READ_ONCE(msk->fully_established))
986                         goto next;
987
988                 lock_sock(sk);
989                 spin_lock_bh(&msk->pm.lock);
990                 mptcp_pm_create_subflow_or_signal_addr(msk);
991                 spin_unlock_bh(&msk->pm.lock);
992                 release_sock(sk);
993
994 next:
995                 sock_put(sk);
996                 cond_resched();
997         }
998
999         return 0;
1000 }
1001
1002 static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info)
1003 {
1004         struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1005         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1006         struct mptcp_pm_addr_entry addr, *entry;
1007         int ret;
1008
1009         ret = mptcp_pm_parse_addr(attr, info, true, &addr);
1010         if (ret < 0)
1011                 return ret;
1012
1013         entry = kmalloc(sizeof(*entry), GFP_KERNEL);
1014         if (!entry) {
1015                 GENL_SET_ERR_MSG(info, "can't allocate addr");
1016                 return -ENOMEM;
1017         }
1018
1019         *entry = addr;
1020         if (entry->addr.port) {
1021                 ret = mptcp_pm_nl_create_listen_socket(skb->sk, entry);
1022                 if (ret) {
1023                         GENL_SET_ERR_MSG(info, "create listen socket error");
1024                         kfree(entry);
1025                         return ret;
1026                 }
1027         }
1028         ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
1029         if (ret < 0) {
1030                 GENL_SET_ERR_MSG(info, "too many addresses or duplicate one");
1031                 if (entry->lsk)
1032                         sock_release(entry->lsk);
1033                 kfree(entry);
1034                 return ret;
1035         }
1036
1037         mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk));
1038
1039         return 0;
1040 }
1041
1042 static struct mptcp_pm_addr_entry *
1043 __lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id)
1044 {
1045         struct mptcp_pm_addr_entry *entry;
1046
1047         list_for_each_entry(entry, &pernet->local_addr_list, list) {
1048                 if (entry->addr.id == id)
1049                         return entry;
1050         }
1051         return NULL;
1052 }
1053
1054 static bool remove_anno_list_by_saddr(struct mptcp_sock *msk,
1055                                       struct mptcp_addr_info *addr)
1056 {
1057         struct mptcp_pm_add_entry *entry;
1058
1059         entry = mptcp_pm_del_add_timer(msk, addr);
1060         if (entry) {
1061                 list_del(&entry->list);
1062                 kfree(entry);
1063                 return true;
1064         }
1065
1066         return false;
1067 }
1068
1069 static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk,
1070                                       struct mptcp_addr_info *addr,
1071                                       bool force)
1072 {
1073         bool ret;
1074
1075         ret = remove_anno_list_by_saddr(msk, addr);
1076         if (ret || force) {
1077                 spin_lock_bh(&msk->pm.lock);
1078                 mptcp_pm_remove_addr(msk, addr->id);
1079                 spin_unlock_bh(&msk->pm.lock);
1080         }
1081         return ret;
1082 }
1083
1084 static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net,
1085                                                    struct mptcp_addr_info *addr)
1086 {
1087         struct mptcp_sock *msk;
1088         long s_slot = 0, s_num = 0;
1089
1090         pr_debug("remove_id=%d", addr->id);
1091
1092         while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1093                 struct sock *sk = (struct sock *)msk;
1094                 bool remove_subflow;
1095
1096                 if (list_empty(&msk->conn_list)) {
1097                         mptcp_pm_remove_anno_addr(msk, addr, false);
1098                         goto next;
1099                 }
1100
1101                 lock_sock(sk);
1102                 remove_subflow = lookup_subflow_by_saddr(&msk->conn_list, addr);
1103                 mptcp_pm_remove_anno_addr(msk, addr, remove_subflow);
1104                 if (remove_subflow)
1105                         mptcp_pm_remove_subflow(msk, addr->id);
1106                 release_sock(sk);
1107
1108 next:
1109                 sock_put(sk);
1110                 cond_resched();
1111         }
1112
1113         return 0;
1114 }
1115
1116 struct addr_entry_release_work {
1117         struct rcu_work rwork;
1118         struct mptcp_pm_addr_entry *entry;
1119 };
1120
1121 static void mptcp_pm_release_addr_entry(struct work_struct *work)
1122 {
1123         struct addr_entry_release_work *w;
1124         struct mptcp_pm_addr_entry *entry;
1125
1126         w = container_of(to_rcu_work(work), struct addr_entry_release_work, rwork);
1127         entry = w->entry;
1128         if (entry) {
1129                 if (entry->lsk)
1130                         sock_release(entry->lsk);
1131                 kfree(entry);
1132         }
1133         kfree(w);
1134 }
1135
1136 static void mptcp_pm_free_addr_entry(struct mptcp_pm_addr_entry *entry)
1137 {
1138         struct addr_entry_release_work *w;
1139
1140         w = kmalloc(sizeof(*w), GFP_ATOMIC);
1141         if (w) {
1142                 INIT_RCU_WORK(&w->rwork, mptcp_pm_release_addr_entry);
1143                 w->entry = entry;
1144                 queue_rcu_work(system_wq, &w->rwork);
1145         }
1146 }
1147
1148 static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
1149 {
1150         struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1151         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1152         struct mptcp_pm_addr_entry addr, *entry;
1153         unsigned int addr_max;
1154         int ret;
1155
1156         ret = mptcp_pm_parse_addr(attr, info, false, &addr);
1157         if (ret < 0)
1158                 return ret;
1159
1160         spin_lock_bh(&pernet->lock);
1161         entry = __lookup_addr_by_id(pernet, addr.addr.id);
1162         if (!entry) {
1163                 GENL_SET_ERR_MSG(info, "address not found");
1164                 spin_unlock_bh(&pernet->lock);
1165                 return -EINVAL;
1166         }
1167         if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
1168                 addr_max = pernet->add_addr_signal_max;
1169                 WRITE_ONCE(pernet->add_addr_signal_max, addr_max - 1);
1170         }
1171         if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
1172                 addr_max = pernet->local_addr_max;
1173                 WRITE_ONCE(pernet->local_addr_max, addr_max - 1);
1174         }
1175
1176         pernet->addrs--;
1177         list_del_rcu(&entry->list);
1178         __clear_bit(entry->addr.id, pernet->id_bitmap);
1179         spin_unlock_bh(&pernet->lock);
1180
1181         mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), &entry->addr);
1182         mptcp_pm_free_addr_entry(entry);
1183
1184         return ret;
1185 }
1186
1187 static void __flush_addrs(struct net *net, struct list_head *list)
1188 {
1189         while (!list_empty(list)) {
1190                 struct mptcp_pm_addr_entry *cur;
1191
1192                 cur = list_entry(list->next,
1193                                  struct mptcp_pm_addr_entry, list);
1194                 mptcp_nl_remove_subflow_and_signal_addr(net, &cur->addr);
1195                 list_del_rcu(&cur->list);
1196                 mptcp_pm_free_addr_entry(cur);
1197         }
1198 }
1199
1200 static void __reset_counters(struct pm_nl_pernet *pernet)
1201 {
1202         WRITE_ONCE(pernet->add_addr_signal_max, 0);
1203         WRITE_ONCE(pernet->add_addr_accept_max, 0);
1204         WRITE_ONCE(pernet->local_addr_max, 0);
1205         pernet->addrs = 0;
1206 }
1207
1208 static int mptcp_nl_cmd_flush_addrs(struct sk_buff *skb, struct genl_info *info)
1209 {
1210         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1211         LIST_HEAD(free_list);
1212
1213         spin_lock_bh(&pernet->lock);
1214         list_splice_init(&pernet->local_addr_list, &free_list);
1215         __reset_counters(pernet);
1216         pernet->next_id = 1;
1217         bitmap_zero(pernet->id_bitmap, MAX_ADDR_ID + 1);
1218         spin_unlock_bh(&pernet->lock);
1219         __flush_addrs(sock_net(skb->sk), &free_list);
1220         return 0;
1221 }
1222
1223 static int mptcp_nl_fill_addr(struct sk_buff *skb,
1224                               struct mptcp_pm_addr_entry *entry)
1225 {
1226         struct mptcp_addr_info *addr = &entry->addr;
1227         struct nlattr *attr;
1228
1229         attr = nla_nest_start(skb, MPTCP_PM_ATTR_ADDR);
1230         if (!attr)
1231                 return -EMSGSIZE;
1232
1233         if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_FAMILY, addr->family))
1234                 goto nla_put_failure;
1235         if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_PORT, ntohs(addr->port)))
1236                 goto nla_put_failure;
1237         if (nla_put_u8(skb, MPTCP_PM_ADDR_ATTR_ID, addr->id))
1238                 goto nla_put_failure;
1239         if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->addr.flags))
1240                 goto nla_put_failure;
1241         if (entry->addr.ifindex &&
1242             nla_put_s32(skb, MPTCP_PM_ADDR_ATTR_IF_IDX, entry->addr.ifindex))
1243                 goto nla_put_failure;
1244
1245         if (addr->family == AF_INET &&
1246             nla_put_in_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR4,
1247                             addr->addr.s_addr))
1248                 goto nla_put_failure;
1249 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1250         else if (addr->family == AF_INET6 &&
1251                  nla_put_in6_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR6, &addr->addr6))
1252                 goto nla_put_failure;
1253 #endif
1254         nla_nest_end(skb, attr);
1255         return 0;
1256
1257 nla_put_failure:
1258         nla_nest_cancel(skb, attr);
1259         return -EMSGSIZE;
1260 }
1261
1262 static int mptcp_nl_cmd_get_addr(struct sk_buff *skb, struct genl_info *info)
1263 {
1264         struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1265         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1266         struct mptcp_pm_addr_entry addr, *entry;
1267         struct sk_buff *msg;
1268         void *reply;
1269         int ret;
1270
1271         ret = mptcp_pm_parse_addr(attr, info, false, &addr);
1272         if (ret < 0)
1273                 return ret;
1274
1275         msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1276         if (!msg)
1277                 return -ENOMEM;
1278
1279         reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
1280                                   info->genlhdr->cmd);
1281         if (!reply) {
1282                 GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
1283                 ret = -EMSGSIZE;
1284                 goto fail;
1285         }
1286
1287         spin_lock_bh(&pernet->lock);
1288         entry = __lookup_addr_by_id(pernet, addr.addr.id);
1289         if (!entry) {
1290                 GENL_SET_ERR_MSG(info, "address not found");
1291                 ret = -EINVAL;
1292                 goto unlock_fail;
1293         }
1294
1295         ret = mptcp_nl_fill_addr(msg, entry);
1296         if (ret)
1297                 goto unlock_fail;
1298
1299         genlmsg_end(msg, reply);
1300         ret = genlmsg_reply(msg, info);
1301         spin_unlock_bh(&pernet->lock);
1302         return ret;
1303
1304 unlock_fail:
1305         spin_unlock_bh(&pernet->lock);
1306
1307 fail:
1308         nlmsg_free(msg);
1309         return ret;
1310 }
1311
1312 static int mptcp_nl_cmd_dump_addrs(struct sk_buff *msg,
1313                                    struct netlink_callback *cb)
1314 {
1315         struct net *net = sock_net(msg->sk);
1316         struct mptcp_pm_addr_entry *entry;
1317         struct pm_nl_pernet *pernet;
1318         int id = cb->args[0];
1319         void *hdr;
1320         int i;
1321
1322         pernet = net_generic(net, pm_nl_pernet_id);
1323
1324         spin_lock_bh(&pernet->lock);
1325         for (i = id; i < MAX_ADDR_ID + 1; i++) {
1326                 if (test_bit(i, pernet->id_bitmap)) {
1327                         entry = __lookup_addr_by_id(pernet, i);
1328                         if (!entry)
1329                                 break;
1330
1331                         if (entry->addr.id <= id)
1332                                 continue;
1333
1334                         hdr = genlmsg_put(msg, NETLINK_CB(cb->skb).portid,
1335                                           cb->nlh->nlmsg_seq, &mptcp_genl_family,
1336                                           NLM_F_MULTI, MPTCP_PM_CMD_GET_ADDR);
1337                         if (!hdr)
1338                                 break;
1339
1340                         if (mptcp_nl_fill_addr(msg, entry) < 0) {
1341                                 genlmsg_cancel(msg, hdr);
1342                                 break;
1343                         }
1344
1345                         id = entry->addr.id;
1346                         genlmsg_end(msg, hdr);
1347                 }
1348         }
1349         spin_unlock_bh(&pernet->lock);
1350
1351         cb->args[0] = id;
1352         return msg->len;
1353 }
1354
1355 static int parse_limit(struct genl_info *info, int id, unsigned int *limit)
1356 {
1357         struct nlattr *attr = info->attrs[id];
1358
1359         if (!attr)
1360                 return 0;
1361
1362         *limit = nla_get_u32(attr);
1363         if (*limit > MPTCP_PM_ADDR_MAX) {
1364                 GENL_SET_ERR_MSG(info, "limit greater than maximum");
1365                 return -EINVAL;
1366         }
1367         return 0;
1368 }
1369
1370 static int
1371 mptcp_nl_cmd_set_limits(struct sk_buff *skb, struct genl_info *info)
1372 {
1373         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1374         unsigned int rcv_addrs, subflows;
1375         int ret;
1376
1377         spin_lock_bh(&pernet->lock);
1378         rcv_addrs = pernet->add_addr_accept_max;
1379         ret = parse_limit(info, MPTCP_PM_ATTR_RCV_ADD_ADDRS, &rcv_addrs);
1380         if (ret)
1381                 goto unlock;
1382
1383         subflows = pernet->subflows_max;
1384         ret = parse_limit(info, MPTCP_PM_ATTR_SUBFLOWS, &subflows);
1385         if (ret)
1386                 goto unlock;
1387
1388         WRITE_ONCE(pernet->add_addr_accept_max, rcv_addrs);
1389         WRITE_ONCE(pernet->subflows_max, subflows);
1390
1391 unlock:
1392         spin_unlock_bh(&pernet->lock);
1393         return ret;
1394 }
1395
1396 static int
1397 mptcp_nl_cmd_get_limits(struct sk_buff *skb, struct genl_info *info)
1398 {
1399         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1400         struct sk_buff *msg;
1401         void *reply;
1402
1403         msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1404         if (!msg)
1405                 return -ENOMEM;
1406
1407         reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
1408                                   MPTCP_PM_CMD_GET_LIMITS);
1409         if (!reply)
1410                 goto fail;
1411
1412         if (nla_put_u32(msg, MPTCP_PM_ATTR_RCV_ADD_ADDRS,
1413                         READ_ONCE(pernet->add_addr_accept_max)))
1414                 goto fail;
1415
1416         if (nla_put_u32(msg, MPTCP_PM_ATTR_SUBFLOWS,
1417                         READ_ONCE(pernet->subflows_max)))
1418                 goto fail;
1419
1420         genlmsg_end(msg, reply);
1421         return genlmsg_reply(msg, info);
1422
1423 fail:
1424         GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
1425         nlmsg_free(msg);
1426         return -EMSGSIZE;
1427 }
1428
1429 static int mptcp_nl_addr_backup(struct net *net,
1430                                 struct mptcp_addr_info *addr,
1431                                 u8 bkup)
1432 {
1433         long s_slot = 0, s_num = 0;
1434         struct mptcp_sock *msk;
1435         int ret = -EINVAL;
1436
1437         while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1438                 struct sock *sk = (struct sock *)msk;
1439
1440                 if (list_empty(&msk->conn_list))
1441                         goto next;
1442
1443                 lock_sock(sk);
1444                 spin_lock_bh(&msk->pm.lock);
1445                 ret = mptcp_pm_nl_mp_prio_send_ack(msk, addr, bkup);
1446                 spin_unlock_bh(&msk->pm.lock);
1447                 release_sock(sk);
1448
1449 next:
1450                 sock_put(sk);
1451                 cond_resched();
1452         }
1453
1454         return ret;
1455 }
1456
1457 static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info)
1458 {
1459         struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1460         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1461         struct mptcp_pm_addr_entry addr, *entry;
1462         struct net *net = sock_net(skb->sk);
1463         u8 bkup = 0;
1464         int ret;
1465
1466         ret = mptcp_pm_parse_addr(attr, info, true, &addr);
1467         if (ret < 0)
1468                 return ret;
1469
1470         if (addr.addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP)
1471                 bkup = 1;
1472
1473         list_for_each_entry(entry, &pernet->local_addr_list, list) {
1474                 if (addresses_equal(&entry->addr, &addr.addr, true)) {
1475                         ret = mptcp_nl_addr_backup(net, &entry->addr, bkup);
1476                         if (ret)
1477                                 return ret;
1478
1479                         if (bkup)
1480                                 entry->addr.flags |= MPTCP_PM_ADDR_FLAG_BACKUP;
1481                         else
1482                                 entry->addr.flags &= ~MPTCP_PM_ADDR_FLAG_BACKUP;
1483                 }
1484         }
1485
1486         return 0;
1487 }
1488
1489 static void mptcp_nl_mcast_send(struct net *net, struct sk_buff *nlskb, gfp_t gfp)
1490 {
1491         genlmsg_multicast_netns(&mptcp_genl_family, net,
1492                                 nlskb, 0, MPTCP_PM_EV_GRP_OFFSET, gfp);
1493 }
1494
1495 static int mptcp_event_add_subflow(struct sk_buff *skb, const struct sock *ssk)
1496 {
1497         const struct inet_sock *issk = inet_sk(ssk);
1498         const struct mptcp_subflow_context *sf;
1499
1500         if (nla_put_u16(skb, MPTCP_ATTR_FAMILY, ssk->sk_family))
1501                 return -EMSGSIZE;
1502
1503         switch (ssk->sk_family) {
1504         case AF_INET:
1505                 if (nla_put_in_addr(skb, MPTCP_ATTR_SADDR4, issk->inet_saddr))
1506                         return -EMSGSIZE;
1507                 if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, issk->inet_daddr))
1508                         return -EMSGSIZE;
1509                 break;
1510 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1511         case AF_INET6: {
1512                 const struct ipv6_pinfo *np = inet6_sk(ssk);
1513
1514                 if (nla_put_in6_addr(skb, MPTCP_ATTR_SADDR6, &np->saddr))
1515                         return -EMSGSIZE;
1516                 if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &ssk->sk_v6_daddr))
1517                         return -EMSGSIZE;
1518                 break;
1519         }
1520 #endif
1521         default:
1522                 WARN_ON_ONCE(1);
1523                 return -EMSGSIZE;
1524         }
1525
1526         if (nla_put_be16(skb, MPTCP_ATTR_SPORT, issk->inet_sport))
1527                 return -EMSGSIZE;
1528         if (nla_put_be16(skb, MPTCP_ATTR_DPORT, issk->inet_dport))
1529                 return -EMSGSIZE;
1530
1531         sf = mptcp_subflow_ctx(ssk);
1532         if (WARN_ON_ONCE(!sf))
1533                 return -EINVAL;
1534
1535         if (nla_put_u8(skb, MPTCP_ATTR_LOC_ID, sf->local_id))
1536                 return -EMSGSIZE;
1537
1538         if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, sf->remote_id))
1539                 return -EMSGSIZE;
1540
1541         return 0;
1542 }
1543
1544 static int mptcp_event_put_token_and_ssk(struct sk_buff *skb,
1545                                          const struct mptcp_sock *msk,
1546                                          const struct sock *ssk)
1547 {
1548         const struct sock *sk = (const struct sock *)msk;
1549         const struct mptcp_subflow_context *sf;
1550         u8 sk_err;
1551
1552         if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
1553                 return -EMSGSIZE;
1554
1555         if (mptcp_event_add_subflow(skb, ssk))
1556                 return -EMSGSIZE;
1557
1558         sf = mptcp_subflow_ctx(ssk);
1559         if (WARN_ON_ONCE(!sf))
1560                 return -EINVAL;
1561
1562         if (nla_put_u8(skb, MPTCP_ATTR_BACKUP, sf->backup))
1563                 return -EMSGSIZE;
1564
1565         if (ssk->sk_bound_dev_if &&
1566             nla_put_s32(skb, MPTCP_ATTR_IF_IDX, ssk->sk_bound_dev_if))
1567                 return -EMSGSIZE;
1568
1569         sk_err = ssk->sk_err;
1570         if (sk_err && sk->sk_state == TCP_ESTABLISHED &&
1571             nla_put_u8(skb, MPTCP_ATTR_ERROR, sk_err))
1572                 return -EMSGSIZE;
1573
1574         return 0;
1575 }
1576
1577 static int mptcp_event_sub_established(struct sk_buff *skb,
1578                                        const struct mptcp_sock *msk,
1579                                        const struct sock *ssk)
1580 {
1581         return mptcp_event_put_token_and_ssk(skb, msk, ssk);
1582 }
1583
1584 static int mptcp_event_sub_closed(struct sk_buff *skb,
1585                                   const struct mptcp_sock *msk,
1586                                   const struct sock *ssk)
1587 {
1588         if (mptcp_event_put_token_and_ssk(skb, msk, ssk))
1589                 return -EMSGSIZE;
1590
1591         return 0;
1592 }
1593
1594 static int mptcp_event_created(struct sk_buff *skb,
1595                                const struct mptcp_sock *msk,
1596                                const struct sock *ssk)
1597 {
1598         int err = nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token);
1599
1600         if (err)
1601                 return err;
1602
1603         return mptcp_event_add_subflow(skb, ssk);
1604 }
1605
1606 void mptcp_event_addr_removed(const struct mptcp_sock *msk, uint8_t id)
1607 {
1608         struct net *net = sock_net((const struct sock *)msk);
1609         struct nlmsghdr *nlh;
1610         struct sk_buff *skb;
1611
1612         if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
1613                 return;
1614
1615         skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC);
1616         if (!skb)
1617                 return;
1618
1619         nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, MPTCP_EVENT_REMOVED);
1620         if (!nlh)
1621                 goto nla_put_failure;
1622
1623         if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
1624                 goto nla_put_failure;
1625
1626         if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, id))
1627                 goto nla_put_failure;
1628
1629         genlmsg_end(skb, nlh);
1630         mptcp_nl_mcast_send(net, skb, GFP_ATOMIC);
1631         return;
1632
1633 nla_put_failure:
1634         kfree_skb(skb);
1635 }
1636
1637 void mptcp_event_addr_announced(const struct mptcp_sock *msk,
1638                                 const struct mptcp_addr_info *info)
1639 {
1640         struct net *net = sock_net((const struct sock *)msk);
1641         struct nlmsghdr *nlh;
1642         struct sk_buff *skb;
1643
1644         if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
1645                 return;
1646
1647         skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC);
1648         if (!skb)
1649                 return;
1650
1651         nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0,
1652                           MPTCP_EVENT_ANNOUNCED);
1653         if (!nlh)
1654                 goto nla_put_failure;
1655
1656         if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
1657                 goto nla_put_failure;
1658
1659         if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, info->id))
1660                 goto nla_put_failure;
1661
1662         if (nla_put_be16(skb, MPTCP_ATTR_DPORT, info->port))
1663                 goto nla_put_failure;
1664
1665         switch (info->family) {
1666         case AF_INET:
1667                 if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, info->addr.s_addr))
1668                         goto nla_put_failure;
1669                 break;
1670 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1671         case AF_INET6:
1672                 if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &info->addr6))
1673                         goto nla_put_failure;
1674                 break;
1675 #endif
1676         default:
1677                 WARN_ON_ONCE(1);
1678                 goto nla_put_failure;
1679         }
1680
1681         genlmsg_end(skb, nlh);
1682         mptcp_nl_mcast_send(net, skb, GFP_ATOMIC);
1683         return;
1684
1685 nla_put_failure:
1686         kfree_skb(skb);
1687 }
1688
1689 void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk,
1690                  const struct sock *ssk, gfp_t gfp)
1691 {
1692         struct net *net = sock_net((const struct sock *)msk);
1693         struct nlmsghdr *nlh;
1694         struct sk_buff *skb;
1695
1696         if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
1697                 return;
1698
1699         skb = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp);
1700         if (!skb)
1701                 return;
1702
1703         nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, type);
1704         if (!nlh)
1705                 goto nla_put_failure;
1706
1707         switch (type) {
1708         case MPTCP_EVENT_UNSPEC:
1709                 WARN_ON_ONCE(1);
1710                 break;
1711         case MPTCP_EVENT_CREATED:
1712         case MPTCP_EVENT_ESTABLISHED:
1713                 if (mptcp_event_created(skb, msk, ssk) < 0)
1714                         goto nla_put_failure;
1715                 break;
1716         case MPTCP_EVENT_CLOSED:
1717                 if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token) < 0)
1718                         goto nla_put_failure;
1719                 break;
1720         case MPTCP_EVENT_ANNOUNCED:
1721         case MPTCP_EVENT_REMOVED:
1722                 /* call mptcp_event_addr_announced()/removed instead */
1723                 WARN_ON_ONCE(1);
1724                 break;
1725         case MPTCP_EVENT_SUB_ESTABLISHED:
1726         case MPTCP_EVENT_SUB_PRIORITY:
1727                 if (mptcp_event_sub_established(skb, msk, ssk) < 0)
1728                         goto nla_put_failure;
1729                 break;
1730         case MPTCP_EVENT_SUB_CLOSED:
1731                 if (mptcp_event_sub_closed(skb, msk, ssk) < 0)
1732                         goto nla_put_failure;
1733                 break;
1734         }
1735
1736         genlmsg_end(skb, nlh);
1737         mptcp_nl_mcast_send(net, skb, gfp);
1738         return;
1739
1740 nla_put_failure:
1741         kfree_skb(skb);
1742 }
1743
1744 static const struct genl_small_ops mptcp_pm_ops[] = {
1745         {
1746                 .cmd    = MPTCP_PM_CMD_ADD_ADDR,
1747                 .doit   = mptcp_nl_cmd_add_addr,
1748                 .flags  = GENL_ADMIN_PERM,
1749         },
1750         {
1751                 .cmd    = MPTCP_PM_CMD_DEL_ADDR,
1752                 .doit   = mptcp_nl_cmd_del_addr,
1753                 .flags  = GENL_ADMIN_PERM,
1754         },
1755         {
1756                 .cmd    = MPTCP_PM_CMD_FLUSH_ADDRS,
1757                 .doit   = mptcp_nl_cmd_flush_addrs,
1758                 .flags  = GENL_ADMIN_PERM,
1759         },
1760         {
1761                 .cmd    = MPTCP_PM_CMD_GET_ADDR,
1762                 .doit   = mptcp_nl_cmd_get_addr,
1763                 .dumpit   = mptcp_nl_cmd_dump_addrs,
1764         },
1765         {
1766                 .cmd    = MPTCP_PM_CMD_SET_LIMITS,
1767                 .doit   = mptcp_nl_cmd_set_limits,
1768                 .flags  = GENL_ADMIN_PERM,
1769         },
1770         {
1771                 .cmd    = MPTCP_PM_CMD_GET_LIMITS,
1772                 .doit   = mptcp_nl_cmd_get_limits,
1773         },
1774         {
1775                 .cmd    = MPTCP_PM_CMD_SET_FLAGS,
1776                 .doit   = mptcp_nl_cmd_set_flags,
1777                 .flags  = GENL_ADMIN_PERM,
1778         },
1779 };
1780
1781 static struct genl_family mptcp_genl_family __ro_after_init = {
1782         .name           = MPTCP_PM_NAME,
1783         .version        = MPTCP_PM_VER,
1784         .maxattr        = MPTCP_PM_ATTR_MAX,
1785         .policy         = mptcp_pm_policy,
1786         .netnsok        = true,
1787         .module         = THIS_MODULE,
1788         .small_ops      = mptcp_pm_ops,
1789         .n_small_ops    = ARRAY_SIZE(mptcp_pm_ops),
1790         .mcgrps         = mptcp_pm_mcgrps,
1791         .n_mcgrps       = ARRAY_SIZE(mptcp_pm_mcgrps),
1792 };
1793
1794 static int __net_init pm_nl_init_net(struct net *net)
1795 {
1796         struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id);
1797
1798         INIT_LIST_HEAD_RCU(&pernet->local_addr_list);
1799         __reset_counters(pernet);
1800         pernet->next_id = 1;
1801         bitmap_zero(pernet->id_bitmap, MAX_ADDR_ID + 1);
1802         spin_lock_init(&pernet->lock);
1803         return 0;
1804 }
1805
1806 static void __net_exit pm_nl_exit_net(struct list_head *net_list)
1807 {
1808         struct net *net;
1809
1810         list_for_each_entry(net, net_list, exit_list) {
1811                 struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id);
1812
1813                 /* net is removed from namespace list, can't race with
1814                  * other modifiers
1815                  */
1816                 __flush_addrs(net, &pernet->local_addr_list);
1817         }
1818 }
1819
1820 static struct pernet_operations mptcp_pm_pernet_ops = {
1821         .init = pm_nl_init_net,
1822         .exit_batch = pm_nl_exit_net,
1823         .id = &pm_nl_pernet_id,
1824         .size = sizeof(struct pm_nl_pernet),
1825 };
1826
1827 void __init mptcp_pm_nl_init(void)
1828 {
1829         if (register_pernet_subsys(&mptcp_pm_pernet_ops) < 0)
1830                 panic("Failed to register MPTCP PM pernet subsystem.\n");
1831
1832         if (genl_register_family(&mptcp_genl_family))
1833                 panic("Failed to register MPTCP PM netlink family\n");
1834 }