mptcp: do not block subflows creation on errors
[linux-2.6-microblaze.git] / net / mptcp / pm_netlink.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Multipath TCP
3  *
4  * Copyright (c) 2020, Red Hat, Inc.
5  */
6
7 #define pr_fmt(fmt) "MPTCP: " fmt
8
9 #include <linux/inet.h>
10 #include <linux/kernel.h>
11 #include <net/tcp.h>
12 #include <net/netns/generic.h>
13 #include <net/mptcp.h>
14 #include <net/genetlink.h>
15 #include <uapi/linux/mptcp.h>
16
17 #include "protocol.h"
18 #include "mib.h"
19
20 /* forward declaration */
21 static struct genl_family mptcp_genl_family;
22
23 static int pm_nl_pernet_id;
24
25 struct mptcp_pm_addr_entry {
26         struct list_head        list;
27         struct mptcp_addr_info  addr;
28         u8                      flags;
29         int                     ifindex;
30         struct socket           *lsk;
31 };
32
33 struct mptcp_pm_add_entry {
34         struct list_head        list;
35         struct mptcp_addr_info  addr;
36         struct timer_list       add_timer;
37         struct mptcp_sock       *sock;
38         u8                      retrans_times;
39 };
40
41 struct pm_nl_pernet {
42         /* protects pernet updates */
43         spinlock_t              lock;
44         struct list_head        local_addr_list;
45         unsigned int            addrs;
46         unsigned int            stale_loss_cnt;
47         unsigned int            add_addr_signal_max;
48         unsigned int            add_addr_accept_max;
49         unsigned int            local_addr_max;
50         unsigned int            subflows_max;
51         unsigned int            next_id;
52         DECLARE_BITMAP(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1);
53 };
54
55 #define MPTCP_PM_ADDR_MAX       8
56 #define ADD_ADDR_RETRANS_MAX    3
57
58 static bool addresses_equal(const struct mptcp_addr_info *a,
59                             const struct mptcp_addr_info *b, bool use_port)
60 {
61         bool addr_equals = false;
62
63         if (a->family == b->family) {
64                 if (a->family == AF_INET)
65                         addr_equals = a->addr.s_addr == b->addr.s_addr;
66 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
67                 else
68                         addr_equals = !ipv6_addr_cmp(&a->addr6, &b->addr6);
69         } else if (a->family == AF_INET) {
70                 if (ipv6_addr_v4mapped(&b->addr6))
71                         addr_equals = a->addr.s_addr == b->addr6.s6_addr32[3];
72         } else if (b->family == AF_INET) {
73                 if (ipv6_addr_v4mapped(&a->addr6))
74                         addr_equals = a->addr6.s6_addr32[3] == b->addr.s_addr;
75 #endif
76         }
77
78         if (!addr_equals)
79                 return false;
80         if (!use_port)
81                 return true;
82
83         return a->port == b->port;
84 }
85
86 static bool address_zero(const struct mptcp_addr_info *addr)
87 {
88         struct mptcp_addr_info zero;
89
90         memset(&zero, 0, sizeof(zero));
91         zero.family = addr->family;
92
93         return addresses_equal(addr, &zero, true);
94 }
95
96 static void local_address(const struct sock_common *skc,
97                           struct mptcp_addr_info *addr)
98 {
99         addr->family = skc->skc_family;
100         addr->port = htons(skc->skc_num);
101         if (addr->family == AF_INET)
102                 addr->addr.s_addr = skc->skc_rcv_saddr;
103 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
104         else if (addr->family == AF_INET6)
105                 addr->addr6 = skc->skc_v6_rcv_saddr;
106 #endif
107 }
108
109 static void remote_address(const struct sock_common *skc,
110                            struct mptcp_addr_info *addr)
111 {
112         addr->family = skc->skc_family;
113         addr->port = skc->skc_dport;
114         if (addr->family == AF_INET)
115                 addr->addr.s_addr = skc->skc_daddr;
116 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
117         else if (addr->family == AF_INET6)
118                 addr->addr6 = skc->skc_v6_daddr;
119 #endif
120 }
121
122 static bool lookup_subflow_by_saddr(const struct list_head *list,
123                                     struct mptcp_addr_info *saddr)
124 {
125         struct mptcp_subflow_context *subflow;
126         struct mptcp_addr_info cur;
127         struct sock_common *skc;
128
129         list_for_each_entry(subflow, list, node) {
130                 skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
131
132                 local_address(skc, &cur);
133                 if (addresses_equal(&cur, saddr, saddr->port))
134                         return true;
135         }
136
137         return false;
138 }
139
140 static bool lookup_subflow_by_daddr(const struct list_head *list,
141                                     struct mptcp_addr_info *daddr)
142 {
143         struct mptcp_subflow_context *subflow;
144         struct mptcp_addr_info cur;
145         struct sock_common *skc;
146
147         list_for_each_entry(subflow, list, node) {
148                 skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
149
150                 remote_address(skc, &cur);
151                 if (addresses_equal(&cur, daddr, daddr->port))
152                         return true;
153         }
154
155         return false;
156 }
157
158 static struct mptcp_pm_addr_entry *
159 select_local_address(const struct pm_nl_pernet *pernet,
160                      struct mptcp_sock *msk)
161 {
162         struct mptcp_pm_addr_entry *entry, *ret = NULL;
163         struct sock *sk = (struct sock *)msk;
164
165         msk_owned_by_me(msk);
166
167         rcu_read_lock();
168         __mptcp_flush_join_list(msk);
169         list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
170                 if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW))
171                         continue;
172
173                 if (!test_bit(entry->addr.id, msk->pm.id_avail_bitmap))
174                         continue;
175
176                 if (entry->addr.family != sk->sk_family) {
177 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
178                         if ((entry->addr.family == AF_INET &&
179                              !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) ||
180                             (sk->sk_family == AF_INET &&
181                              !ipv6_addr_v4mapped(&entry->addr.addr6)))
182 #endif
183                                 continue;
184                 }
185
186                 ret = entry;
187                 break;
188         }
189         rcu_read_unlock();
190         return ret;
191 }
192
193 static struct mptcp_pm_addr_entry *
194 select_signal_address(struct pm_nl_pernet *pernet, struct mptcp_sock *msk)
195 {
196         struct mptcp_pm_addr_entry *entry, *ret = NULL;
197
198         rcu_read_lock();
199         /* do not keep any additional per socket state, just signal
200          * the address list in order.
201          * Note: removal from the local address list during the msk life-cycle
202          * can lead to additional addresses not being announced.
203          */
204         list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
205                 if (!test_bit(entry->addr.id, msk->pm.id_avail_bitmap))
206                         continue;
207
208                 if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL))
209                         continue;
210
211                 ret = entry;
212                 break;
213         }
214         rcu_read_unlock();
215         return ret;
216 }
217
218 unsigned int mptcp_pm_get_add_addr_signal_max(struct mptcp_sock *msk)
219 {
220         struct pm_nl_pernet *pernet;
221
222         pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
223         return READ_ONCE(pernet->add_addr_signal_max);
224 }
225 EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_signal_max);
226
227 unsigned int mptcp_pm_get_add_addr_accept_max(struct mptcp_sock *msk)
228 {
229         struct pm_nl_pernet *pernet;
230
231         pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
232         return READ_ONCE(pernet->add_addr_accept_max);
233 }
234 EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_accept_max);
235
236 unsigned int mptcp_pm_get_subflows_max(struct mptcp_sock *msk)
237 {
238         struct pm_nl_pernet *pernet;
239
240         pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
241         return READ_ONCE(pernet->subflows_max);
242 }
243 EXPORT_SYMBOL_GPL(mptcp_pm_get_subflows_max);
244
245 unsigned int mptcp_pm_get_local_addr_max(struct mptcp_sock *msk)
246 {
247         struct pm_nl_pernet *pernet;
248
249         pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
250         return READ_ONCE(pernet->local_addr_max);
251 }
252 EXPORT_SYMBOL_GPL(mptcp_pm_get_local_addr_max);
253
254 bool mptcp_pm_nl_check_work_pending(struct mptcp_sock *msk)
255 {
256         struct pm_nl_pernet *pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
257
258         if (msk->pm.subflows == mptcp_pm_get_subflows_max(msk) ||
259             (find_next_and_bit(pernet->id_bitmap, msk->pm.id_avail_bitmap,
260                                MPTCP_PM_MAX_ADDR_ID + 1, 0) == MPTCP_PM_MAX_ADDR_ID + 1)) {
261                 WRITE_ONCE(msk->pm.work_pending, false);
262                 return false;
263         }
264         return true;
265 }
266
267 struct mptcp_pm_add_entry *
268 mptcp_lookup_anno_list_by_saddr(struct mptcp_sock *msk,
269                                 struct mptcp_addr_info *addr)
270 {
271         struct mptcp_pm_add_entry *entry;
272
273         lockdep_assert_held(&msk->pm.lock);
274
275         list_for_each_entry(entry, &msk->pm.anno_list, list) {
276                 if (addresses_equal(&entry->addr, addr, true))
277                         return entry;
278         }
279
280         return NULL;
281 }
282
283 bool mptcp_pm_sport_in_anno_list(struct mptcp_sock *msk, const struct sock *sk)
284 {
285         struct mptcp_pm_add_entry *entry;
286         struct mptcp_addr_info saddr;
287         bool ret = false;
288
289         local_address((struct sock_common *)sk, &saddr);
290
291         spin_lock_bh(&msk->pm.lock);
292         list_for_each_entry(entry, &msk->pm.anno_list, list) {
293                 if (addresses_equal(&entry->addr, &saddr, true)) {
294                         ret = true;
295                         goto out;
296                 }
297         }
298
299 out:
300         spin_unlock_bh(&msk->pm.lock);
301         return ret;
302 }
303
304 static void mptcp_pm_add_timer(struct timer_list *timer)
305 {
306         struct mptcp_pm_add_entry *entry = from_timer(entry, timer, add_timer);
307         struct mptcp_sock *msk = entry->sock;
308         struct sock *sk = (struct sock *)msk;
309
310         pr_debug("msk=%p", msk);
311
312         if (!msk)
313                 return;
314
315         if (inet_sk_state_load(sk) == TCP_CLOSE)
316                 return;
317
318         if (!entry->addr.id)
319                 return;
320
321         if (mptcp_pm_should_add_signal_addr(msk)) {
322                 sk_reset_timer(sk, timer, jiffies + TCP_RTO_MAX / 8);
323                 goto out;
324         }
325
326         spin_lock_bh(&msk->pm.lock);
327
328         if (!mptcp_pm_should_add_signal_addr(msk)) {
329                 pr_debug("retransmit ADD_ADDR id=%d", entry->addr.id);
330                 mptcp_pm_announce_addr(msk, &entry->addr, false);
331                 mptcp_pm_add_addr_send_ack(msk);
332                 entry->retrans_times++;
333         }
334
335         if (entry->retrans_times < ADD_ADDR_RETRANS_MAX)
336                 sk_reset_timer(sk, timer,
337                                jiffies + mptcp_get_add_addr_timeout(sock_net(sk)));
338
339         spin_unlock_bh(&msk->pm.lock);
340
341         if (entry->retrans_times == ADD_ADDR_RETRANS_MAX)
342                 mptcp_pm_subflow_established(msk);
343
344 out:
345         __sock_put(sk);
346 }
347
348 struct mptcp_pm_add_entry *
349 mptcp_pm_del_add_timer(struct mptcp_sock *msk,
350                        struct mptcp_addr_info *addr, bool check_id)
351 {
352         struct mptcp_pm_add_entry *entry;
353         struct sock *sk = (struct sock *)msk;
354
355         spin_lock_bh(&msk->pm.lock);
356         entry = mptcp_lookup_anno_list_by_saddr(msk, addr);
357         if (entry && (!check_id || entry->addr.id == addr->id))
358                 entry->retrans_times = ADD_ADDR_RETRANS_MAX;
359         spin_unlock_bh(&msk->pm.lock);
360
361         if (entry && (!check_id || entry->addr.id == addr->id))
362                 sk_stop_timer_sync(sk, &entry->add_timer);
363
364         return entry;
365 }
366
367 static bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk,
368                                      struct mptcp_pm_addr_entry *entry)
369 {
370         struct mptcp_pm_add_entry *add_entry = NULL;
371         struct sock *sk = (struct sock *)msk;
372         struct net *net = sock_net(sk);
373
374         lockdep_assert_held(&msk->pm.lock);
375
376         if (mptcp_lookup_anno_list_by_saddr(msk, &entry->addr))
377                 return false;
378
379         add_entry = kmalloc(sizeof(*add_entry), GFP_ATOMIC);
380         if (!add_entry)
381                 return false;
382
383         list_add(&add_entry->list, &msk->pm.anno_list);
384
385         add_entry->addr = entry->addr;
386         add_entry->sock = msk;
387         add_entry->retrans_times = 0;
388
389         timer_setup(&add_entry->add_timer, mptcp_pm_add_timer, 0);
390         sk_reset_timer(sk, &add_entry->add_timer,
391                        jiffies + mptcp_get_add_addr_timeout(net));
392
393         return true;
394 }
395
396 void mptcp_pm_free_anno_list(struct mptcp_sock *msk)
397 {
398         struct mptcp_pm_add_entry *entry, *tmp;
399         struct sock *sk = (struct sock *)msk;
400         LIST_HEAD(free_list);
401
402         pr_debug("msk=%p", msk);
403
404         spin_lock_bh(&msk->pm.lock);
405         list_splice_init(&msk->pm.anno_list, &free_list);
406         spin_unlock_bh(&msk->pm.lock);
407
408         list_for_each_entry_safe(entry, tmp, &free_list, list) {
409                 sk_stop_timer_sync(sk, &entry->add_timer);
410                 kfree(entry);
411         }
412 }
413
414 static bool lookup_address_in_vec(struct mptcp_addr_info *addrs, unsigned int nr,
415                                   struct mptcp_addr_info *addr)
416 {
417         int i;
418
419         for (i = 0; i < nr; i++) {
420                 if (addresses_equal(&addrs[i], addr, addr->port))
421                         return true;
422         }
423
424         return false;
425 }
426
427 /* Fill all the remote addresses into the array addrs[],
428  * and return the array size.
429  */
430 static unsigned int fill_remote_addresses_vec(struct mptcp_sock *msk, bool fullmesh,
431                                               struct mptcp_addr_info *addrs)
432 {
433         bool deny_id0 = READ_ONCE(msk->pm.remote_deny_join_id0);
434         struct sock *sk = (struct sock *)msk, *ssk;
435         struct mptcp_subflow_context *subflow;
436         struct mptcp_addr_info remote = { 0 };
437         unsigned int subflows_max;
438         int i = 0;
439
440         subflows_max = mptcp_pm_get_subflows_max(msk);
441         remote_address((struct sock_common *)sk, &remote);
442
443         /* Non-fullmesh endpoint, fill in the single entry
444          * corresponding to the primary MPC subflow remote address
445          */
446         if (!fullmesh) {
447                 if (deny_id0)
448                         return 0;
449
450                 msk->pm.subflows++;
451                 addrs[i++] = remote;
452         } else {
453                 mptcp_for_each_subflow(msk, subflow) {
454                         ssk = mptcp_subflow_tcp_sock(subflow);
455                         remote_address((struct sock_common *)ssk, &addrs[i]);
456                         if (deny_id0 && addresses_equal(&addrs[i], &remote, false))
457                                 continue;
458
459                         if (!lookup_address_in_vec(addrs, i, &addrs[i]) &&
460                             msk->pm.subflows < subflows_max) {
461                                 msk->pm.subflows++;
462                                 i++;
463                         }
464                 }
465         }
466
467         return i;
468 }
469
470 static struct mptcp_pm_addr_entry *
471 __lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id)
472 {
473         struct mptcp_pm_addr_entry *entry;
474
475         list_for_each_entry(entry, &pernet->local_addr_list, list) {
476                 if (entry->addr.id == id)
477                         return entry;
478         }
479         return NULL;
480 }
481
482 static int
483 lookup_id_by_addr(struct pm_nl_pernet *pernet, const struct mptcp_addr_info *addr)
484 {
485         struct mptcp_pm_addr_entry *entry;
486         int ret = -1;
487
488         rcu_read_lock();
489         list_for_each_entry(entry, &pernet->local_addr_list, list) {
490                 if (addresses_equal(&entry->addr, addr, entry->addr.port)) {
491                         ret = entry->addr.id;
492                         break;
493                 }
494         }
495         rcu_read_unlock();
496         return ret;
497 }
498
499 static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
500 {
501         struct sock *sk = (struct sock *)msk;
502         struct mptcp_pm_addr_entry *local;
503         unsigned int add_addr_signal_max;
504         unsigned int local_addr_max;
505         struct pm_nl_pernet *pernet;
506         unsigned int subflows_max;
507
508         pernet = net_generic(sock_net(sk), pm_nl_pernet_id);
509
510         add_addr_signal_max = mptcp_pm_get_add_addr_signal_max(msk);
511         local_addr_max = mptcp_pm_get_local_addr_max(msk);
512         subflows_max = mptcp_pm_get_subflows_max(msk);
513
514         /* do lazy endpoint usage accounting for the MPC subflows */
515         if (unlikely(!(msk->pm.status & BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED))) && msk->first) {
516                 struct mptcp_addr_info mpc_addr;
517                 int mpc_id;
518
519                 local_address((struct sock_common *)msk->first, &mpc_addr);
520                 mpc_id = lookup_id_by_addr(pernet, &mpc_addr);
521                 if (mpc_id >= 0)
522                         __clear_bit(mpc_id, msk->pm.id_avail_bitmap);
523
524                 msk->pm.status |= BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED);
525         }
526
527         pr_debug("local %d:%d signal %d:%d subflows %d:%d\n",
528                  msk->pm.local_addr_used, local_addr_max,
529                  msk->pm.add_addr_signaled, add_addr_signal_max,
530                  msk->pm.subflows, subflows_max);
531
532         /* check first for announce */
533         if (msk->pm.add_addr_signaled < add_addr_signal_max) {
534                 local = select_signal_address(pernet, msk);
535
536                 if (local) {
537                         if (mptcp_pm_alloc_anno_list(msk, local)) {
538                                 __clear_bit(local->addr.id, msk->pm.id_avail_bitmap);
539                                 msk->pm.add_addr_signaled++;
540                                 mptcp_pm_announce_addr(msk, &local->addr, false);
541                                 mptcp_pm_nl_addr_send_ack(msk);
542                         }
543                 }
544         }
545
546         /* check if should create a new subflow */
547         while (msk->pm.local_addr_used < local_addr_max &&
548                msk->pm.subflows < subflows_max) {
549                 struct mptcp_addr_info addrs[MPTCP_PM_ADDR_MAX];
550                 bool fullmesh;
551                 int i, nr;
552
553                 local = select_local_address(pernet, msk);
554                 if (!local)
555                         break;
556
557                 fullmesh = !!(local->flags & MPTCP_PM_ADDR_FLAG_FULLMESH);
558
559                 msk->pm.local_addr_used++;
560                 nr = fill_remote_addresses_vec(msk, fullmesh, addrs);
561                 if (nr)
562                         __clear_bit(local->addr.id, msk->pm.id_avail_bitmap);
563                 spin_unlock_bh(&msk->pm.lock);
564                 for (i = 0; i < nr; i++)
565                         __mptcp_subflow_connect(sk, &local->addr, &addrs[i]);
566                 spin_lock_bh(&msk->pm.lock);
567         }
568         mptcp_pm_nl_check_work_pending(msk);
569 }
570
571 static void mptcp_pm_nl_fully_established(struct mptcp_sock *msk)
572 {
573         mptcp_pm_create_subflow_or_signal_addr(msk);
574 }
575
576 static void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk)
577 {
578         mptcp_pm_create_subflow_or_signal_addr(msk);
579 }
580
581 /* Fill all the local addresses into the array addrs[],
582  * and return the array size.
583  */
584 static unsigned int fill_local_addresses_vec(struct mptcp_sock *msk,
585                                              struct mptcp_addr_info *addrs)
586 {
587         struct sock *sk = (struct sock *)msk;
588         struct mptcp_pm_addr_entry *entry;
589         struct mptcp_addr_info local;
590         struct pm_nl_pernet *pernet;
591         unsigned int subflows_max;
592         int i = 0;
593
594         pernet = net_generic(sock_net(sk), pm_nl_pernet_id);
595         subflows_max = mptcp_pm_get_subflows_max(msk);
596
597         rcu_read_lock();
598         __mptcp_flush_join_list(msk);
599         list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
600                 if (!(entry->flags & MPTCP_PM_ADDR_FLAG_FULLMESH))
601                         continue;
602
603                 if (entry->addr.family != sk->sk_family) {
604 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
605                         if ((entry->addr.family == AF_INET &&
606                              !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) ||
607                             (sk->sk_family == AF_INET &&
608                              !ipv6_addr_v4mapped(&entry->addr.addr6)))
609 #endif
610                                 continue;
611                 }
612
613                 if (msk->pm.subflows < subflows_max) {
614                         msk->pm.subflows++;
615                         addrs[i++] = entry->addr;
616                 }
617         }
618         rcu_read_unlock();
619
620         /* If the array is empty, fill in the single
621          * 'IPADDRANY' local address
622          */
623         if (!i) {
624                 memset(&local, 0, sizeof(local));
625                 local.family = msk->pm.remote.family;
626
627                 msk->pm.subflows++;
628                 addrs[i++] = local;
629         }
630
631         return i;
632 }
633
634 static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk)
635 {
636         struct mptcp_addr_info addrs[MPTCP_PM_ADDR_MAX];
637         struct sock *sk = (struct sock *)msk;
638         unsigned int add_addr_accept_max;
639         struct mptcp_addr_info remote;
640         unsigned int subflows_max;
641         int i, nr;
642
643         add_addr_accept_max = mptcp_pm_get_add_addr_accept_max(msk);
644         subflows_max = mptcp_pm_get_subflows_max(msk);
645
646         pr_debug("accepted %d:%d remote family %d",
647                  msk->pm.add_addr_accepted, add_addr_accept_max,
648                  msk->pm.remote.family);
649
650         if (lookup_subflow_by_daddr(&msk->conn_list, &msk->pm.remote))
651                 goto add_addr_echo;
652
653         /* connect to the specified remote address, using whatever
654          * local address the routing configuration will pick.
655          */
656         remote = msk->pm.remote;
657         if (!remote.port)
658                 remote.port = sk->sk_dport;
659         nr = fill_local_addresses_vec(msk, addrs);
660
661         msk->pm.add_addr_accepted++;
662         if (msk->pm.add_addr_accepted >= add_addr_accept_max ||
663             msk->pm.subflows >= subflows_max)
664                 WRITE_ONCE(msk->pm.accept_addr, false);
665
666         spin_unlock_bh(&msk->pm.lock);
667         for (i = 0; i < nr; i++)
668                 __mptcp_subflow_connect(sk, &addrs[i], &remote);
669         spin_lock_bh(&msk->pm.lock);
670
671 add_addr_echo:
672         mptcp_pm_announce_addr(msk, &msk->pm.remote, true);
673         mptcp_pm_nl_addr_send_ack(msk);
674 }
675
676 void mptcp_pm_nl_addr_send_ack(struct mptcp_sock *msk)
677 {
678         struct mptcp_subflow_context *subflow;
679
680         msk_owned_by_me(msk);
681         lockdep_assert_held(&msk->pm.lock);
682
683         if (!mptcp_pm_should_add_signal(msk) &&
684             !mptcp_pm_should_rm_signal(msk))
685                 return;
686
687         __mptcp_flush_join_list(msk);
688         subflow = list_first_entry_or_null(&msk->conn_list, typeof(*subflow), node);
689         if (subflow) {
690                 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
691
692                 spin_unlock_bh(&msk->pm.lock);
693                 pr_debug("send ack for %s",
694                          mptcp_pm_should_add_signal(msk) ? "add_addr" : "rm_addr");
695
696                 mptcp_subflow_send_ack(ssk);
697                 spin_lock_bh(&msk->pm.lock);
698         }
699 }
700
701 static int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk,
702                                         struct mptcp_addr_info *addr,
703                                         u8 bkup)
704 {
705         struct mptcp_subflow_context *subflow;
706
707         pr_debug("bkup=%d", bkup);
708
709         mptcp_for_each_subflow(msk, subflow) {
710                 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
711                 struct sock *sk = (struct sock *)msk;
712                 struct mptcp_addr_info local;
713
714                 local_address((struct sock_common *)ssk, &local);
715                 if (!addresses_equal(&local, addr, addr->port))
716                         continue;
717
718                 subflow->backup = bkup;
719                 subflow->send_mp_prio = 1;
720                 subflow->request_bkup = bkup;
721                 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPPRIOTX);
722
723                 spin_unlock_bh(&msk->pm.lock);
724                 pr_debug("send ack for mp_prio");
725                 mptcp_subflow_send_ack(ssk);
726                 spin_lock_bh(&msk->pm.lock);
727
728                 return 0;
729         }
730
731         return -EINVAL;
732 }
733
734 static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk,
735                                            const struct mptcp_rm_list *rm_list,
736                                            enum linux_mptcp_mib_field rm_type)
737 {
738         struct mptcp_subflow_context *subflow, *tmp;
739         struct sock *sk = (struct sock *)msk;
740         u8 i;
741
742         pr_debug("%s rm_list_nr %d",
743                  rm_type == MPTCP_MIB_RMADDR ? "address" : "subflow", rm_list->nr);
744
745         msk_owned_by_me(msk);
746
747         if (sk->sk_state == TCP_LISTEN)
748                 return;
749
750         if (!rm_list->nr)
751                 return;
752
753         if (list_empty(&msk->conn_list))
754                 return;
755
756         for (i = 0; i < rm_list->nr; i++) {
757                 bool removed = false;
758
759                 list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) {
760                         struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
761                         int how = RCV_SHUTDOWN | SEND_SHUTDOWN;
762                         u8 id = subflow->local_id;
763
764                         if (rm_type == MPTCP_MIB_RMADDR)
765                                 id = subflow->remote_id;
766
767                         if (rm_list->ids[i] != id)
768                                 continue;
769
770                         pr_debug(" -> %s rm_list_ids[%d]=%u local_id=%u remote_id=%u",
771                                  rm_type == MPTCP_MIB_RMADDR ? "address" : "subflow",
772                                  i, rm_list->ids[i], subflow->local_id, subflow->remote_id);
773                         spin_unlock_bh(&msk->pm.lock);
774                         mptcp_subflow_shutdown(sk, ssk, how);
775
776                         /* the following takes care of updating the subflows counter */
777                         mptcp_close_ssk(sk, ssk, subflow);
778                         spin_lock_bh(&msk->pm.lock);
779
780                         removed = true;
781                         __MPTCP_INC_STATS(sock_net(sk), rm_type);
782                 }
783                 __set_bit(rm_list->ids[1], msk->pm.id_avail_bitmap);
784                 if (!removed)
785                         continue;
786
787                 if (rm_type == MPTCP_MIB_RMADDR) {
788                         msk->pm.add_addr_accepted--;
789                         WRITE_ONCE(msk->pm.accept_addr, true);
790                 } else if (rm_type == MPTCP_MIB_RMSUBFLOW) {
791                         msk->pm.local_addr_used--;
792                 }
793         }
794 }
795
796 static void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk)
797 {
798         mptcp_pm_nl_rm_addr_or_subflow(msk, &msk->pm.rm_list_rx, MPTCP_MIB_RMADDR);
799 }
800
801 void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk,
802                                      const struct mptcp_rm_list *rm_list)
803 {
804         mptcp_pm_nl_rm_addr_or_subflow(msk, rm_list, MPTCP_MIB_RMSUBFLOW);
805 }
806
807 void mptcp_pm_nl_work(struct mptcp_sock *msk)
808 {
809         struct mptcp_pm_data *pm = &msk->pm;
810
811         msk_owned_by_me(msk);
812
813         if (!(pm->status & MPTCP_PM_WORK_MASK))
814                 return;
815
816         spin_lock_bh(&msk->pm.lock);
817
818         pr_debug("msk=%p status=%x", msk, pm->status);
819         if (pm->status & BIT(MPTCP_PM_ADD_ADDR_RECEIVED)) {
820                 pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_RECEIVED);
821                 mptcp_pm_nl_add_addr_received(msk);
822         }
823         if (pm->status & BIT(MPTCP_PM_ADD_ADDR_SEND_ACK)) {
824                 pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_SEND_ACK);
825                 mptcp_pm_nl_addr_send_ack(msk);
826         }
827         if (pm->status & BIT(MPTCP_PM_RM_ADDR_RECEIVED)) {
828                 pm->status &= ~BIT(MPTCP_PM_RM_ADDR_RECEIVED);
829                 mptcp_pm_nl_rm_addr_received(msk);
830         }
831         if (pm->status & BIT(MPTCP_PM_ESTABLISHED)) {
832                 pm->status &= ~BIT(MPTCP_PM_ESTABLISHED);
833                 mptcp_pm_nl_fully_established(msk);
834         }
835         if (pm->status & BIT(MPTCP_PM_SUBFLOW_ESTABLISHED)) {
836                 pm->status &= ~BIT(MPTCP_PM_SUBFLOW_ESTABLISHED);
837                 mptcp_pm_nl_subflow_established(msk);
838         }
839
840         spin_unlock_bh(&msk->pm.lock);
841 }
842
843 static bool address_use_port(struct mptcp_pm_addr_entry *entry)
844 {
845         return (entry->flags &
846                 (MPTCP_PM_ADDR_FLAG_SIGNAL | MPTCP_PM_ADDR_FLAG_SUBFLOW)) ==
847                 MPTCP_PM_ADDR_FLAG_SIGNAL;
848 }
849
850 static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet,
851                                              struct mptcp_pm_addr_entry *entry)
852 {
853         struct mptcp_pm_addr_entry *cur;
854         unsigned int addr_max;
855         int ret = -EINVAL;
856
857         spin_lock_bh(&pernet->lock);
858         /* to keep the code simple, don't do IDR-like allocation for address ID,
859          * just bail when we exceed limits
860          */
861         if (pernet->next_id == MPTCP_PM_MAX_ADDR_ID)
862                 pernet->next_id = 1;
863         if (pernet->addrs >= MPTCP_PM_ADDR_MAX)
864                 goto out;
865         if (test_bit(entry->addr.id, pernet->id_bitmap))
866                 goto out;
867
868         /* do not insert duplicate address, differentiate on port only
869          * singled addresses
870          */
871         list_for_each_entry(cur, &pernet->local_addr_list, list) {
872                 if (addresses_equal(&cur->addr, &entry->addr,
873                                     address_use_port(entry) &&
874                                     address_use_port(cur)))
875                         goto out;
876         }
877
878         if (!entry->addr.id) {
879 find_next:
880                 entry->addr.id = find_next_zero_bit(pernet->id_bitmap,
881                                                     MPTCP_PM_MAX_ADDR_ID + 1,
882                                                     pernet->next_id);
883                 if (!entry->addr.id && pernet->next_id != 1) {
884                         pernet->next_id = 1;
885                         goto find_next;
886                 }
887         }
888
889         if (!entry->addr.id)
890                 goto out;
891
892         __set_bit(entry->addr.id, pernet->id_bitmap);
893         if (entry->addr.id > pernet->next_id)
894                 pernet->next_id = entry->addr.id;
895
896         if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
897                 addr_max = pernet->add_addr_signal_max;
898                 WRITE_ONCE(pernet->add_addr_signal_max, addr_max + 1);
899         }
900         if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
901                 addr_max = pernet->local_addr_max;
902                 WRITE_ONCE(pernet->local_addr_max, addr_max + 1);
903         }
904
905         pernet->addrs++;
906         list_add_tail_rcu(&entry->list, &pernet->local_addr_list);
907         ret = entry->addr.id;
908
909 out:
910         spin_unlock_bh(&pernet->lock);
911         return ret;
912 }
913
914 static int mptcp_pm_nl_create_listen_socket(struct sock *sk,
915                                             struct mptcp_pm_addr_entry *entry)
916 {
917         struct sockaddr_storage addr;
918         struct mptcp_sock *msk;
919         struct socket *ssock;
920         int backlog = 1024;
921         int err;
922
923         err = sock_create_kern(sock_net(sk), entry->addr.family,
924                                SOCK_STREAM, IPPROTO_MPTCP, &entry->lsk);
925         if (err)
926                 return err;
927
928         msk = mptcp_sk(entry->lsk->sk);
929         if (!msk) {
930                 err = -EINVAL;
931                 goto out;
932         }
933
934         ssock = __mptcp_nmpc_socket(msk);
935         if (!ssock) {
936                 err = -EINVAL;
937                 goto out;
938         }
939
940         mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family);
941         err = kernel_bind(ssock, (struct sockaddr *)&addr,
942                           sizeof(struct sockaddr_in));
943         if (err) {
944                 pr_warn("kernel_bind error, err=%d", err);
945                 goto out;
946         }
947
948         err = kernel_listen(ssock, backlog);
949         if (err) {
950                 pr_warn("kernel_listen error, err=%d", err);
951                 goto out;
952         }
953
954         return 0;
955
956 out:
957         sock_release(entry->lsk);
958         return err;
959 }
960
961 int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
962 {
963         struct mptcp_pm_addr_entry *entry;
964         struct mptcp_addr_info skc_local;
965         struct mptcp_addr_info msk_local;
966         struct pm_nl_pernet *pernet;
967         int ret = -1;
968
969         if (WARN_ON_ONCE(!msk))
970                 return -1;
971
972         /* The 0 ID mapping is defined by the first subflow, copied into the msk
973          * addr
974          */
975         local_address((struct sock_common *)msk, &msk_local);
976         local_address((struct sock_common *)skc, &skc_local);
977         if (addresses_equal(&msk_local, &skc_local, false))
978                 return 0;
979
980         if (address_zero(&skc_local))
981                 return 0;
982
983         pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
984
985         rcu_read_lock();
986         list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
987                 if (addresses_equal(&entry->addr, &skc_local, entry->addr.port)) {
988                         ret = entry->addr.id;
989                         break;
990                 }
991         }
992         rcu_read_unlock();
993         if (ret >= 0)
994                 return ret;
995
996         /* address not found, add to local list */
997         entry = kmalloc(sizeof(*entry), GFP_ATOMIC);
998         if (!entry)
999                 return -ENOMEM;
1000
1001         entry->addr = skc_local;
1002         entry->addr.id = 0;
1003         entry->addr.port = 0;
1004         entry->ifindex = 0;
1005         entry->flags = 0;
1006         entry->lsk = NULL;
1007         ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
1008         if (ret < 0)
1009                 kfree(entry);
1010
1011         return ret;
1012 }
1013
1014 void mptcp_pm_nl_data_init(struct mptcp_sock *msk)
1015 {
1016         struct mptcp_pm_data *pm = &msk->pm;
1017         bool subflows;
1018
1019         subflows = !!mptcp_pm_get_subflows_max(msk);
1020         WRITE_ONCE(pm->work_pending, (!!mptcp_pm_get_local_addr_max(msk) && subflows) ||
1021                    !!mptcp_pm_get_add_addr_signal_max(msk));
1022         WRITE_ONCE(pm->accept_addr, !!mptcp_pm_get_add_addr_accept_max(msk) && subflows);
1023         WRITE_ONCE(pm->accept_subflow, subflows);
1024 }
1025
1026 #define MPTCP_PM_CMD_GRP_OFFSET       0
1027 #define MPTCP_PM_EV_GRP_OFFSET        1
1028
1029 static const struct genl_multicast_group mptcp_pm_mcgrps[] = {
1030         [MPTCP_PM_CMD_GRP_OFFSET]       = { .name = MPTCP_PM_CMD_GRP_NAME, },
1031         [MPTCP_PM_EV_GRP_OFFSET]        = { .name = MPTCP_PM_EV_GRP_NAME,
1032                                             .flags = GENL_UNS_ADMIN_PERM,
1033                                           },
1034 };
1035
1036 static const struct nla_policy
1037 mptcp_pm_addr_policy[MPTCP_PM_ADDR_ATTR_MAX + 1] = {
1038         [MPTCP_PM_ADDR_ATTR_FAMILY]     = { .type       = NLA_U16,      },
1039         [MPTCP_PM_ADDR_ATTR_ID]         = { .type       = NLA_U8,       },
1040         [MPTCP_PM_ADDR_ATTR_ADDR4]      = { .type       = NLA_U32,      },
1041         [MPTCP_PM_ADDR_ATTR_ADDR6]      =
1042                 NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)),
1043         [MPTCP_PM_ADDR_ATTR_PORT]       = { .type       = NLA_U16       },
1044         [MPTCP_PM_ADDR_ATTR_FLAGS]      = { .type       = NLA_U32       },
1045         [MPTCP_PM_ADDR_ATTR_IF_IDX]     = { .type       = NLA_S32       },
1046 };
1047
1048 static const struct nla_policy mptcp_pm_policy[MPTCP_PM_ATTR_MAX + 1] = {
1049         [MPTCP_PM_ATTR_ADDR]            =
1050                                         NLA_POLICY_NESTED(mptcp_pm_addr_policy),
1051         [MPTCP_PM_ATTR_RCV_ADD_ADDRS]   = { .type       = NLA_U32,      },
1052         [MPTCP_PM_ATTR_SUBFLOWS]        = { .type       = NLA_U32,      },
1053 };
1054
1055 void mptcp_pm_nl_subflow_chk_stale(const struct mptcp_sock *msk, struct sock *ssk)
1056 {
1057         struct mptcp_subflow_context *iter, *subflow = mptcp_subflow_ctx(ssk);
1058         struct sock *sk = (struct sock *)msk;
1059         unsigned int active_max_loss_cnt;
1060         struct net *net = sock_net(sk);
1061         unsigned int stale_loss_cnt;
1062         bool slow;
1063
1064         stale_loss_cnt = mptcp_stale_loss_cnt(net);
1065         if (subflow->stale || !stale_loss_cnt || subflow->stale_count <= stale_loss_cnt)
1066                 return;
1067
1068         /* look for another available subflow not in loss state */
1069         active_max_loss_cnt = max_t(int, stale_loss_cnt - 1, 1);
1070         mptcp_for_each_subflow(msk, iter) {
1071                 if (iter != subflow && mptcp_subflow_active(iter) &&
1072                     iter->stale_count < active_max_loss_cnt) {
1073                         /* we have some alternatives, try to mark this subflow as idle ...*/
1074                         slow = lock_sock_fast(ssk);
1075                         if (!tcp_rtx_and_write_queues_empty(ssk)) {
1076                                 subflow->stale = 1;
1077                                 __mptcp_retransmit_pending_data(sk);
1078                                 MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_SUBFLOWSTALE);
1079                         }
1080                         unlock_sock_fast(ssk, slow);
1081
1082                         /* always try to push the pending data regarless of re-injections:
1083                          * we can possibly use backup subflows now, and subflow selection
1084                          * is cheap under the msk socket lock
1085                          */
1086                         __mptcp_push_pending(sk, 0);
1087                         return;
1088                 }
1089         }
1090 }
1091
1092 static int mptcp_pm_family_to_addr(int family)
1093 {
1094 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1095         if (family == AF_INET6)
1096                 return MPTCP_PM_ADDR_ATTR_ADDR6;
1097 #endif
1098         return MPTCP_PM_ADDR_ATTR_ADDR4;
1099 }
1100
1101 static int mptcp_pm_parse_addr(struct nlattr *attr, struct genl_info *info,
1102                                bool require_family,
1103                                struct mptcp_pm_addr_entry *entry)
1104 {
1105         struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1];
1106         int err, addr_addr;
1107
1108         if (!attr) {
1109                 GENL_SET_ERR_MSG(info, "missing address info");
1110                 return -EINVAL;
1111         }
1112
1113         /* no validation needed - was already done via nested policy */
1114         err = nla_parse_nested_deprecated(tb, MPTCP_PM_ADDR_ATTR_MAX, attr,
1115                                           mptcp_pm_addr_policy, info->extack);
1116         if (err)
1117                 return err;
1118
1119         memset(entry, 0, sizeof(*entry));
1120         if (!tb[MPTCP_PM_ADDR_ATTR_FAMILY]) {
1121                 if (!require_family)
1122                         goto skip_family;
1123
1124                 NL_SET_ERR_MSG_ATTR(info->extack, attr,
1125                                     "missing family");
1126                 return -EINVAL;
1127         }
1128
1129         entry->addr.family = nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_FAMILY]);
1130         if (entry->addr.family != AF_INET
1131 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1132             && entry->addr.family != AF_INET6
1133 #endif
1134             ) {
1135                 NL_SET_ERR_MSG_ATTR(info->extack, attr,
1136                                     "unknown address family");
1137                 return -EINVAL;
1138         }
1139         addr_addr = mptcp_pm_family_to_addr(entry->addr.family);
1140         if (!tb[addr_addr]) {
1141                 NL_SET_ERR_MSG_ATTR(info->extack, attr,
1142                                     "missing address data");
1143                 return -EINVAL;
1144         }
1145
1146 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1147         if (entry->addr.family == AF_INET6)
1148                 entry->addr.addr6 = nla_get_in6_addr(tb[addr_addr]);
1149         else
1150 #endif
1151                 entry->addr.addr.s_addr = nla_get_in_addr(tb[addr_addr]);
1152
1153 skip_family:
1154         if (tb[MPTCP_PM_ADDR_ATTR_IF_IDX]) {
1155                 u32 val = nla_get_s32(tb[MPTCP_PM_ADDR_ATTR_IF_IDX]);
1156
1157                 entry->ifindex = val;
1158         }
1159
1160         if (tb[MPTCP_PM_ADDR_ATTR_ID])
1161                 entry->addr.id = nla_get_u8(tb[MPTCP_PM_ADDR_ATTR_ID]);
1162
1163         if (tb[MPTCP_PM_ADDR_ATTR_FLAGS])
1164                 entry->flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]);
1165
1166         if (tb[MPTCP_PM_ADDR_ATTR_PORT]) {
1167                 if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) {
1168                         NL_SET_ERR_MSG_ATTR(info->extack, attr,
1169                                             "flags must have signal when using port");
1170                         return -EINVAL;
1171                 }
1172                 entry->addr.port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT]));
1173         }
1174
1175         return 0;
1176 }
1177
1178 static struct pm_nl_pernet *genl_info_pm_nl(struct genl_info *info)
1179 {
1180         return net_generic(genl_info_net(info), pm_nl_pernet_id);
1181 }
1182
1183 static int mptcp_nl_add_subflow_or_signal_addr(struct net *net)
1184 {
1185         struct mptcp_sock *msk;
1186         long s_slot = 0, s_num = 0;
1187
1188         while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1189                 struct sock *sk = (struct sock *)msk;
1190
1191                 if (!READ_ONCE(msk->fully_established))
1192                         goto next;
1193
1194                 lock_sock(sk);
1195                 spin_lock_bh(&msk->pm.lock);
1196                 mptcp_pm_create_subflow_or_signal_addr(msk);
1197                 spin_unlock_bh(&msk->pm.lock);
1198                 release_sock(sk);
1199
1200 next:
1201                 sock_put(sk);
1202                 cond_resched();
1203         }
1204
1205         return 0;
1206 }
1207
1208 static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info)
1209 {
1210         struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1211         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1212         struct mptcp_pm_addr_entry addr, *entry;
1213         int ret;
1214
1215         ret = mptcp_pm_parse_addr(attr, info, true, &addr);
1216         if (ret < 0)
1217                 return ret;
1218
1219         entry = kmalloc(sizeof(*entry), GFP_KERNEL);
1220         if (!entry) {
1221                 GENL_SET_ERR_MSG(info, "can't allocate addr");
1222                 return -ENOMEM;
1223         }
1224
1225         *entry = addr;
1226         if (entry->addr.port) {
1227                 ret = mptcp_pm_nl_create_listen_socket(skb->sk, entry);
1228                 if (ret) {
1229                         GENL_SET_ERR_MSG(info, "create listen socket error");
1230                         kfree(entry);
1231                         return ret;
1232                 }
1233         }
1234         ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
1235         if (ret < 0) {
1236                 GENL_SET_ERR_MSG(info, "too many addresses or duplicate one");
1237                 if (entry->lsk)
1238                         sock_release(entry->lsk);
1239                 kfree(entry);
1240                 return ret;
1241         }
1242
1243         mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk));
1244
1245         return 0;
1246 }
1247
1248 int mptcp_pm_get_flags_and_ifindex_by_id(struct net *net, unsigned int id,
1249                                          u8 *flags, int *ifindex)
1250 {
1251         struct mptcp_pm_addr_entry *entry;
1252
1253         *flags = 0;
1254         *ifindex = 0;
1255
1256         if (id) {
1257                 rcu_read_lock();
1258                 entry = __lookup_addr_by_id(net_generic(net, pm_nl_pernet_id), id);
1259                 if (entry) {
1260                         *flags = entry->flags;
1261                         *ifindex = entry->ifindex;
1262                 }
1263                 rcu_read_unlock();
1264         }
1265
1266         return 0;
1267 }
1268
1269 static bool remove_anno_list_by_saddr(struct mptcp_sock *msk,
1270                                       struct mptcp_addr_info *addr)
1271 {
1272         struct mptcp_pm_add_entry *entry;
1273
1274         entry = mptcp_pm_del_add_timer(msk, addr, false);
1275         if (entry) {
1276                 list_del(&entry->list);
1277                 kfree(entry);
1278                 return true;
1279         }
1280
1281         return false;
1282 }
1283
1284 static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk,
1285                                       struct mptcp_addr_info *addr,
1286                                       bool force)
1287 {
1288         struct mptcp_rm_list list = { .nr = 0 };
1289         bool ret;
1290
1291         list.ids[list.nr++] = addr->id;
1292
1293         ret = remove_anno_list_by_saddr(msk, addr);
1294         if (ret || force) {
1295                 spin_lock_bh(&msk->pm.lock);
1296                 mptcp_pm_remove_addr(msk, &list);
1297                 spin_unlock_bh(&msk->pm.lock);
1298         }
1299         return ret;
1300 }
1301
1302 static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net,
1303                                                    struct mptcp_addr_info *addr)
1304 {
1305         struct mptcp_sock *msk;
1306         long s_slot = 0, s_num = 0;
1307         struct mptcp_rm_list list = { .nr = 0 };
1308
1309         pr_debug("remove_id=%d", addr->id);
1310
1311         list.ids[list.nr++] = addr->id;
1312
1313         while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1314                 struct sock *sk = (struct sock *)msk;
1315                 bool remove_subflow;
1316
1317                 if (list_empty(&msk->conn_list)) {
1318                         mptcp_pm_remove_anno_addr(msk, addr, false);
1319                         goto next;
1320                 }
1321
1322                 lock_sock(sk);
1323                 remove_subflow = lookup_subflow_by_saddr(&msk->conn_list, addr);
1324                 mptcp_pm_remove_anno_addr(msk, addr, remove_subflow);
1325                 if (remove_subflow)
1326                         mptcp_pm_remove_subflow(msk, &list);
1327                 release_sock(sk);
1328
1329 next:
1330                 sock_put(sk);
1331                 cond_resched();
1332         }
1333
1334         return 0;
1335 }
1336
1337 /* caller must ensure the RCU grace period is already elapsed */
1338 static void __mptcp_pm_release_addr_entry(struct mptcp_pm_addr_entry *entry)
1339 {
1340         if (entry->lsk)
1341                 sock_release(entry->lsk);
1342         kfree(entry);
1343 }
1344
1345 static int mptcp_nl_remove_id_zero_address(struct net *net,
1346                                            struct mptcp_addr_info *addr)
1347 {
1348         struct mptcp_rm_list list = { .nr = 0 };
1349         long s_slot = 0, s_num = 0;
1350         struct mptcp_sock *msk;
1351
1352         list.ids[list.nr++] = 0;
1353
1354         while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1355                 struct sock *sk = (struct sock *)msk;
1356                 struct mptcp_addr_info msk_local;
1357
1358                 if (list_empty(&msk->conn_list))
1359                         goto next;
1360
1361                 local_address((struct sock_common *)msk, &msk_local);
1362                 if (!addresses_equal(&msk_local, addr, addr->port))
1363                         goto next;
1364
1365                 lock_sock(sk);
1366                 spin_lock_bh(&msk->pm.lock);
1367                 mptcp_pm_remove_addr(msk, &list);
1368                 mptcp_pm_nl_rm_subflow_received(msk, &list);
1369                 spin_unlock_bh(&msk->pm.lock);
1370                 release_sock(sk);
1371
1372 next:
1373                 sock_put(sk);
1374                 cond_resched();
1375         }
1376
1377         return 0;
1378 }
1379
1380 static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
1381 {
1382         struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1383         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1384         struct mptcp_pm_addr_entry addr, *entry;
1385         unsigned int addr_max;
1386         int ret;
1387
1388         ret = mptcp_pm_parse_addr(attr, info, false, &addr);
1389         if (ret < 0)
1390                 return ret;
1391
1392         /* the zero id address is special: the first address used by the msk
1393          * always gets such an id, so different subflows can have different zero
1394          * id addresses. Additionally zero id is not accounted for in id_bitmap.
1395          * Let's use an 'mptcp_rm_list' instead of the common remove code.
1396          */
1397         if (addr.addr.id == 0)
1398                 return mptcp_nl_remove_id_zero_address(sock_net(skb->sk), &addr.addr);
1399
1400         spin_lock_bh(&pernet->lock);
1401         entry = __lookup_addr_by_id(pernet, addr.addr.id);
1402         if (!entry) {
1403                 GENL_SET_ERR_MSG(info, "address not found");
1404                 spin_unlock_bh(&pernet->lock);
1405                 return -EINVAL;
1406         }
1407         if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
1408                 addr_max = pernet->add_addr_signal_max;
1409                 WRITE_ONCE(pernet->add_addr_signal_max, addr_max - 1);
1410         }
1411         if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
1412                 addr_max = pernet->local_addr_max;
1413                 WRITE_ONCE(pernet->local_addr_max, addr_max - 1);
1414         }
1415
1416         pernet->addrs--;
1417         list_del_rcu(&entry->list);
1418         __clear_bit(entry->addr.id, pernet->id_bitmap);
1419         spin_unlock_bh(&pernet->lock);
1420
1421         mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), &entry->addr);
1422         synchronize_rcu();
1423         __mptcp_pm_release_addr_entry(entry);
1424
1425         return ret;
1426 }
1427
1428 static void mptcp_pm_remove_addrs_and_subflows(struct mptcp_sock *msk,
1429                                                struct list_head *rm_list)
1430 {
1431         struct mptcp_rm_list alist = { .nr = 0 }, slist = { .nr = 0 };
1432         struct mptcp_pm_addr_entry *entry;
1433
1434         list_for_each_entry(entry, rm_list, list) {
1435                 if (lookup_subflow_by_saddr(&msk->conn_list, &entry->addr) &&
1436                     alist.nr < MPTCP_RM_IDS_MAX &&
1437                     slist.nr < MPTCP_RM_IDS_MAX) {
1438                         alist.ids[alist.nr++] = entry->addr.id;
1439                         slist.ids[slist.nr++] = entry->addr.id;
1440                 } else if (remove_anno_list_by_saddr(msk, &entry->addr) &&
1441                          alist.nr < MPTCP_RM_IDS_MAX) {
1442                         alist.ids[alist.nr++] = entry->addr.id;
1443                 }
1444         }
1445
1446         if (alist.nr) {
1447                 spin_lock_bh(&msk->pm.lock);
1448                 mptcp_pm_remove_addr(msk, &alist);
1449                 spin_unlock_bh(&msk->pm.lock);
1450         }
1451         if (slist.nr)
1452                 mptcp_pm_remove_subflow(msk, &slist);
1453 }
1454
1455 static void mptcp_nl_remove_addrs_list(struct net *net,
1456                                        struct list_head *rm_list)
1457 {
1458         long s_slot = 0, s_num = 0;
1459         struct mptcp_sock *msk;
1460
1461         if (list_empty(rm_list))
1462                 return;
1463
1464         while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1465                 struct sock *sk = (struct sock *)msk;
1466
1467                 lock_sock(sk);
1468                 mptcp_pm_remove_addrs_and_subflows(msk, rm_list);
1469                 release_sock(sk);
1470
1471                 sock_put(sk);
1472                 cond_resched();
1473         }
1474 }
1475
1476 /* caller must ensure the RCU grace period is already elapsed */
1477 static void __flush_addrs(struct list_head *list)
1478 {
1479         while (!list_empty(list)) {
1480                 struct mptcp_pm_addr_entry *cur;
1481
1482                 cur = list_entry(list->next,
1483                                  struct mptcp_pm_addr_entry, list);
1484                 list_del_rcu(&cur->list);
1485                 __mptcp_pm_release_addr_entry(cur);
1486         }
1487 }
1488
1489 static void __reset_counters(struct pm_nl_pernet *pernet)
1490 {
1491         WRITE_ONCE(pernet->add_addr_signal_max, 0);
1492         WRITE_ONCE(pernet->add_addr_accept_max, 0);
1493         WRITE_ONCE(pernet->local_addr_max, 0);
1494         pernet->addrs = 0;
1495 }
1496
1497 static int mptcp_nl_cmd_flush_addrs(struct sk_buff *skb, struct genl_info *info)
1498 {
1499         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1500         LIST_HEAD(free_list);
1501
1502         spin_lock_bh(&pernet->lock);
1503         list_splice_init(&pernet->local_addr_list, &free_list);
1504         __reset_counters(pernet);
1505         pernet->next_id = 1;
1506         bitmap_zero(pernet->id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1);
1507         spin_unlock_bh(&pernet->lock);
1508         mptcp_nl_remove_addrs_list(sock_net(skb->sk), &free_list);
1509         synchronize_rcu();
1510         __flush_addrs(&free_list);
1511         return 0;
1512 }
1513
1514 static int mptcp_nl_fill_addr(struct sk_buff *skb,
1515                               struct mptcp_pm_addr_entry *entry)
1516 {
1517         struct mptcp_addr_info *addr = &entry->addr;
1518         struct nlattr *attr;
1519
1520         attr = nla_nest_start(skb, MPTCP_PM_ATTR_ADDR);
1521         if (!attr)
1522                 return -EMSGSIZE;
1523
1524         if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_FAMILY, addr->family))
1525                 goto nla_put_failure;
1526         if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_PORT, ntohs(addr->port)))
1527                 goto nla_put_failure;
1528         if (nla_put_u8(skb, MPTCP_PM_ADDR_ATTR_ID, addr->id))
1529                 goto nla_put_failure;
1530         if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->flags))
1531                 goto nla_put_failure;
1532         if (entry->ifindex &&
1533             nla_put_s32(skb, MPTCP_PM_ADDR_ATTR_IF_IDX, entry->ifindex))
1534                 goto nla_put_failure;
1535
1536         if (addr->family == AF_INET &&
1537             nla_put_in_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR4,
1538                             addr->addr.s_addr))
1539                 goto nla_put_failure;
1540 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1541         else if (addr->family == AF_INET6 &&
1542                  nla_put_in6_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR6, &addr->addr6))
1543                 goto nla_put_failure;
1544 #endif
1545         nla_nest_end(skb, attr);
1546         return 0;
1547
1548 nla_put_failure:
1549         nla_nest_cancel(skb, attr);
1550         return -EMSGSIZE;
1551 }
1552
1553 static int mptcp_nl_cmd_get_addr(struct sk_buff *skb, struct genl_info *info)
1554 {
1555         struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1556         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1557         struct mptcp_pm_addr_entry addr, *entry;
1558         struct sk_buff *msg;
1559         void *reply;
1560         int ret;
1561
1562         ret = mptcp_pm_parse_addr(attr, info, false, &addr);
1563         if (ret < 0)
1564                 return ret;
1565
1566         msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1567         if (!msg)
1568                 return -ENOMEM;
1569
1570         reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
1571                                   info->genlhdr->cmd);
1572         if (!reply) {
1573                 GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
1574                 ret = -EMSGSIZE;
1575                 goto fail;
1576         }
1577
1578         spin_lock_bh(&pernet->lock);
1579         entry = __lookup_addr_by_id(pernet, addr.addr.id);
1580         if (!entry) {
1581                 GENL_SET_ERR_MSG(info, "address not found");
1582                 ret = -EINVAL;
1583                 goto unlock_fail;
1584         }
1585
1586         ret = mptcp_nl_fill_addr(msg, entry);
1587         if (ret)
1588                 goto unlock_fail;
1589
1590         genlmsg_end(msg, reply);
1591         ret = genlmsg_reply(msg, info);
1592         spin_unlock_bh(&pernet->lock);
1593         return ret;
1594
1595 unlock_fail:
1596         spin_unlock_bh(&pernet->lock);
1597
1598 fail:
1599         nlmsg_free(msg);
1600         return ret;
1601 }
1602
1603 static int mptcp_nl_cmd_dump_addrs(struct sk_buff *msg,
1604                                    struct netlink_callback *cb)
1605 {
1606         struct net *net = sock_net(msg->sk);
1607         struct mptcp_pm_addr_entry *entry;
1608         struct pm_nl_pernet *pernet;
1609         int id = cb->args[0];
1610         void *hdr;
1611         int i;
1612
1613         pernet = net_generic(net, pm_nl_pernet_id);
1614
1615         spin_lock_bh(&pernet->lock);
1616         for (i = id; i < MPTCP_PM_MAX_ADDR_ID + 1; i++) {
1617                 if (test_bit(i, pernet->id_bitmap)) {
1618                         entry = __lookup_addr_by_id(pernet, i);
1619                         if (!entry)
1620                                 break;
1621
1622                         if (entry->addr.id <= id)
1623                                 continue;
1624
1625                         hdr = genlmsg_put(msg, NETLINK_CB(cb->skb).portid,
1626                                           cb->nlh->nlmsg_seq, &mptcp_genl_family,
1627                                           NLM_F_MULTI, MPTCP_PM_CMD_GET_ADDR);
1628                         if (!hdr)
1629                                 break;
1630
1631                         if (mptcp_nl_fill_addr(msg, entry) < 0) {
1632                                 genlmsg_cancel(msg, hdr);
1633                                 break;
1634                         }
1635
1636                         id = entry->addr.id;
1637                         genlmsg_end(msg, hdr);
1638                 }
1639         }
1640         spin_unlock_bh(&pernet->lock);
1641
1642         cb->args[0] = id;
1643         return msg->len;
1644 }
1645
1646 static int parse_limit(struct genl_info *info, int id, unsigned int *limit)
1647 {
1648         struct nlattr *attr = info->attrs[id];
1649
1650         if (!attr)
1651                 return 0;
1652
1653         *limit = nla_get_u32(attr);
1654         if (*limit > MPTCP_PM_ADDR_MAX) {
1655                 GENL_SET_ERR_MSG(info, "limit greater than maximum");
1656                 return -EINVAL;
1657         }
1658         return 0;
1659 }
1660
1661 static int
1662 mptcp_nl_cmd_set_limits(struct sk_buff *skb, struct genl_info *info)
1663 {
1664         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1665         unsigned int rcv_addrs, subflows;
1666         int ret;
1667
1668         spin_lock_bh(&pernet->lock);
1669         rcv_addrs = pernet->add_addr_accept_max;
1670         ret = parse_limit(info, MPTCP_PM_ATTR_RCV_ADD_ADDRS, &rcv_addrs);
1671         if (ret)
1672                 goto unlock;
1673
1674         subflows = pernet->subflows_max;
1675         ret = parse_limit(info, MPTCP_PM_ATTR_SUBFLOWS, &subflows);
1676         if (ret)
1677                 goto unlock;
1678
1679         WRITE_ONCE(pernet->add_addr_accept_max, rcv_addrs);
1680         WRITE_ONCE(pernet->subflows_max, subflows);
1681
1682 unlock:
1683         spin_unlock_bh(&pernet->lock);
1684         return ret;
1685 }
1686
1687 static int
1688 mptcp_nl_cmd_get_limits(struct sk_buff *skb, struct genl_info *info)
1689 {
1690         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1691         struct sk_buff *msg;
1692         void *reply;
1693
1694         msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1695         if (!msg)
1696                 return -ENOMEM;
1697
1698         reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
1699                                   MPTCP_PM_CMD_GET_LIMITS);
1700         if (!reply)
1701                 goto fail;
1702
1703         if (nla_put_u32(msg, MPTCP_PM_ATTR_RCV_ADD_ADDRS,
1704                         READ_ONCE(pernet->add_addr_accept_max)))
1705                 goto fail;
1706
1707         if (nla_put_u32(msg, MPTCP_PM_ATTR_SUBFLOWS,
1708                         READ_ONCE(pernet->subflows_max)))
1709                 goto fail;
1710
1711         genlmsg_end(msg, reply);
1712         return genlmsg_reply(msg, info);
1713
1714 fail:
1715         GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
1716         nlmsg_free(msg);
1717         return -EMSGSIZE;
1718 }
1719
1720 static int mptcp_nl_addr_backup(struct net *net,
1721                                 struct mptcp_addr_info *addr,
1722                                 u8 bkup)
1723 {
1724         long s_slot = 0, s_num = 0;
1725         struct mptcp_sock *msk;
1726         int ret = -EINVAL;
1727
1728         while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1729                 struct sock *sk = (struct sock *)msk;
1730
1731                 if (list_empty(&msk->conn_list))
1732                         goto next;
1733
1734                 lock_sock(sk);
1735                 spin_lock_bh(&msk->pm.lock);
1736                 ret = mptcp_pm_nl_mp_prio_send_ack(msk, addr, bkup);
1737                 spin_unlock_bh(&msk->pm.lock);
1738                 release_sock(sk);
1739
1740 next:
1741                 sock_put(sk);
1742                 cond_resched();
1743         }
1744
1745         return ret;
1746 }
1747
1748 static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info)
1749 {
1750         struct mptcp_pm_addr_entry addr = { .addr = { .family = AF_UNSPEC }, }, *entry;
1751         struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1752         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1753         struct net *net = sock_net(skb->sk);
1754         u8 bkup = 0, lookup_by_id = 0;
1755         int ret;
1756
1757         ret = mptcp_pm_parse_addr(attr, info, false, &addr);
1758         if (ret < 0)
1759                 return ret;
1760
1761         if (addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP)
1762                 bkup = 1;
1763         if (addr.addr.family == AF_UNSPEC) {
1764                 lookup_by_id = 1;
1765                 if (!addr.addr.id)
1766                         return -EOPNOTSUPP;
1767         }
1768
1769         list_for_each_entry(entry, &pernet->local_addr_list, list) {
1770                 if ((!lookup_by_id && addresses_equal(&entry->addr, &addr.addr, true)) ||
1771                     (lookup_by_id && entry->addr.id == addr.addr.id)) {
1772                         mptcp_nl_addr_backup(net, &entry->addr, bkup);
1773
1774                         if (bkup)
1775                                 entry->flags |= MPTCP_PM_ADDR_FLAG_BACKUP;
1776                         else
1777                                 entry->flags &= ~MPTCP_PM_ADDR_FLAG_BACKUP;
1778                 }
1779         }
1780
1781         return 0;
1782 }
1783
1784 static void mptcp_nl_mcast_send(struct net *net, struct sk_buff *nlskb, gfp_t gfp)
1785 {
1786         genlmsg_multicast_netns(&mptcp_genl_family, net,
1787                                 nlskb, 0, MPTCP_PM_EV_GRP_OFFSET, gfp);
1788 }
1789
1790 static int mptcp_event_add_subflow(struct sk_buff *skb, const struct sock *ssk)
1791 {
1792         const struct inet_sock *issk = inet_sk(ssk);
1793         const struct mptcp_subflow_context *sf;
1794
1795         if (nla_put_u16(skb, MPTCP_ATTR_FAMILY, ssk->sk_family))
1796                 return -EMSGSIZE;
1797
1798         switch (ssk->sk_family) {
1799         case AF_INET:
1800                 if (nla_put_in_addr(skb, MPTCP_ATTR_SADDR4, issk->inet_saddr))
1801                         return -EMSGSIZE;
1802                 if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, issk->inet_daddr))
1803                         return -EMSGSIZE;
1804                 break;
1805 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1806         case AF_INET6: {
1807                 const struct ipv6_pinfo *np = inet6_sk(ssk);
1808
1809                 if (nla_put_in6_addr(skb, MPTCP_ATTR_SADDR6, &np->saddr))
1810                         return -EMSGSIZE;
1811                 if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &ssk->sk_v6_daddr))
1812                         return -EMSGSIZE;
1813                 break;
1814         }
1815 #endif
1816         default:
1817                 WARN_ON_ONCE(1);
1818                 return -EMSGSIZE;
1819         }
1820
1821         if (nla_put_be16(skb, MPTCP_ATTR_SPORT, issk->inet_sport))
1822                 return -EMSGSIZE;
1823         if (nla_put_be16(skb, MPTCP_ATTR_DPORT, issk->inet_dport))
1824                 return -EMSGSIZE;
1825
1826         sf = mptcp_subflow_ctx(ssk);
1827         if (WARN_ON_ONCE(!sf))
1828                 return -EINVAL;
1829
1830         if (nla_put_u8(skb, MPTCP_ATTR_LOC_ID, sf->local_id))
1831                 return -EMSGSIZE;
1832
1833         if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, sf->remote_id))
1834                 return -EMSGSIZE;
1835
1836         return 0;
1837 }
1838
1839 static int mptcp_event_put_token_and_ssk(struct sk_buff *skb,
1840                                          const struct mptcp_sock *msk,
1841                                          const struct sock *ssk)
1842 {
1843         const struct sock *sk = (const struct sock *)msk;
1844         const struct mptcp_subflow_context *sf;
1845         u8 sk_err;
1846
1847         if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
1848                 return -EMSGSIZE;
1849
1850         if (mptcp_event_add_subflow(skb, ssk))
1851                 return -EMSGSIZE;
1852
1853         sf = mptcp_subflow_ctx(ssk);
1854         if (WARN_ON_ONCE(!sf))
1855                 return -EINVAL;
1856
1857         if (nla_put_u8(skb, MPTCP_ATTR_BACKUP, sf->backup))
1858                 return -EMSGSIZE;
1859
1860         if (ssk->sk_bound_dev_if &&
1861             nla_put_s32(skb, MPTCP_ATTR_IF_IDX, ssk->sk_bound_dev_if))
1862                 return -EMSGSIZE;
1863
1864         sk_err = ssk->sk_err;
1865         if (sk_err && sk->sk_state == TCP_ESTABLISHED &&
1866             nla_put_u8(skb, MPTCP_ATTR_ERROR, sk_err))
1867                 return -EMSGSIZE;
1868
1869         return 0;
1870 }
1871
1872 static int mptcp_event_sub_established(struct sk_buff *skb,
1873                                        const struct mptcp_sock *msk,
1874                                        const struct sock *ssk)
1875 {
1876         return mptcp_event_put_token_and_ssk(skb, msk, ssk);
1877 }
1878
1879 static int mptcp_event_sub_closed(struct sk_buff *skb,
1880                                   const struct mptcp_sock *msk,
1881                                   const struct sock *ssk)
1882 {
1883         const struct mptcp_subflow_context *sf;
1884
1885         if (mptcp_event_put_token_and_ssk(skb, msk, ssk))
1886                 return -EMSGSIZE;
1887
1888         sf = mptcp_subflow_ctx(ssk);
1889         if (!sf->reset_seen)
1890                 return 0;
1891
1892         if (nla_put_u32(skb, MPTCP_ATTR_RESET_REASON, sf->reset_reason))
1893                 return -EMSGSIZE;
1894
1895         if (nla_put_u32(skb, MPTCP_ATTR_RESET_FLAGS, sf->reset_transient))
1896                 return -EMSGSIZE;
1897
1898         return 0;
1899 }
1900
1901 static int mptcp_event_created(struct sk_buff *skb,
1902                                const struct mptcp_sock *msk,
1903                                const struct sock *ssk)
1904 {
1905         int err = nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token);
1906
1907         if (err)
1908                 return err;
1909
1910         return mptcp_event_add_subflow(skb, ssk);
1911 }
1912
1913 void mptcp_event_addr_removed(const struct mptcp_sock *msk, uint8_t id)
1914 {
1915         struct net *net = sock_net((const struct sock *)msk);
1916         struct nlmsghdr *nlh;
1917         struct sk_buff *skb;
1918
1919         if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
1920                 return;
1921
1922         skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC);
1923         if (!skb)
1924                 return;
1925
1926         nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, MPTCP_EVENT_REMOVED);
1927         if (!nlh)
1928                 goto nla_put_failure;
1929
1930         if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
1931                 goto nla_put_failure;
1932
1933         if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, id))
1934                 goto nla_put_failure;
1935
1936         genlmsg_end(skb, nlh);
1937         mptcp_nl_mcast_send(net, skb, GFP_ATOMIC);
1938         return;
1939
1940 nla_put_failure:
1941         kfree_skb(skb);
1942 }
1943
1944 void mptcp_event_addr_announced(const struct mptcp_sock *msk,
1945                                 const struct mptcp_addr_info *info)
1946 {
1947         struct net *net = sock_net((const struct sock *)msk);
1948         struct nlmsghdr *nlh;
1949         struct sk_buff *skb;
1950
1951         if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
1952                 return;
1953
1954         skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC);
1955         if (!skb)
1956                 return;
1957
1958         nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0,
1959                           MPTCP_EVENT_ANNOUNCED);
1960         if (!nlh)
1961                 goto nla_put_failure;
1962
1963         if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
1964                 goto nla_put_failure;
1965
1966         if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, info->id))
1967                 goto nla_put_failure;
1968
1969         if (nla_put_be16(skb, MPTCP_ATTR_DPORT, info->port))
1970                 goto nla_put_failure;
1971
1972         switch (info->family) {
1973         case AF_INET:
1974                 if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, info->addr.s_addr))
1975                         goto nla_put_failure;
1976                 break;
1977 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1978         case AF_INET6:
1979                 if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &info->addr6))
1980                         goto nla_put_failure;
1981                 break;
1982 #endif
1983         default:
1984                 WARN_ON_ONCE(1);
1985                 goto nla_put_failure;
1986         }
1987
1988         genlmsg_end(skb, nlh);
1989         mptcp_nl_mcast_send(net, skb, GFP_ATOMIC);
1990         return;
1991
1992 nla_put_failure:
1993         kfree_skb(skb);
1994 }
1995
1996 void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk,
1997                  const struct sock *ssk, gfp_t gfp)
1998 {
1999         struct net *net = sock_net((const struct sock *)msk);
2000         struct nlmsghdr *nlh;
2001         struct sk_buff *skb;
2002
2003         if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
2004                 return;
2005
2006         skb = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp);
2007         if (!skb)
2008                 return;
2009
2010         nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, type);
2011         if (!nlh)
2012                 goto nla_put_failure;
2013
2014         switch (type) {
2015         case MPTCP_EVENT_UNSPEC:
2016                 WARN_ON_ONCE(1);
2017                 break;
2018         case MPTCP_EVENT_CREATED:
2019         case MPTCP_EVENT_ESTABLISHED:
2020                 if (mptcp_event_created(skb, msk, ssk) < 0)
2021                         goto nla_put_failure;
2022                 break;
2023         case MPTCP_EVENT_CLOSED:
2024                 if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token) < 0)
2025                         goto nla_put_failure;
2026                 break;
2027         case MPTCP_EVENT_ANNOUNCED:
2028         case MPTCP_EVENT_REMOVED:
2029                 /* call mptcp_event_addr_announced()/removed instead */
2030                 WARN_ON_ONCE(1);
2031                 break;
2032         case MPTCP_EVENT_SUB_ESTABLISHED:
2033         case MPTCP_EVENT_SUB_PRIORITY:
2034                 if (mptcp_event_sub_established(skb, msk, ssk) < 0)
2035                         goto nla_put_failure;
2036                 break;
2037         case MPTCP_EVENT_SUB_CLOSED:
2038                 if (mptcp_event_sub_closed(skb, msk, ssk) < 0)
2039                         goto nla_put_failure;
2040                 break;
2041         }
2042
2043         genlmsg_end(skb, nlh);
2044         mptcp_nl_mcast_send(net, skb, gfp);
2045         return;
2046
2047 nla_put_failure:
2048         kfree_skb(skb);
2049 }
2050
2051 static const struct genl_small_ops mptcp_pm_ops[] = {
2052         {
2053                 .cmd    = MPTCP_PM_CMD_ADD_ADDR,
2054                 .doit   = mptcp_nl_cmd_add_addr,
2055                 .flags  = GENL_ADMIN_PERM,
2056         },
2057         {
2058                 .cmd    = MPTCP_PM_CMD_DEL_ADDR,
2059                 .doit   = mptcp_nl_cmd_del_addr,
2060                 .flags  = GENL_ADMIN_PERM,
2061         },
2062         {
2063                 .cmd    = MPTCP_PM_CMD_FLUSH_ADDRS,
2064                 .doit   = mptcp_nl_cmd_flush_addrs,
2065                 .flags  = GENL_ADMIN_PERM,
2066         },
2067         {
2068                 .cmd    = MPTCP_PM_CMD_GET_ADDR,
2069                 .doit   = mptcp_nl_cmd_get_addr,
2070                 .dumpit   = mptcp_nl_cmd_dump_addrs,
2071         },
2072         {
2073                 .cmd    = MPTCP_PM_CMD_SET_LIMITS,
2074                 .doit   = mptcp_nl_cmd_set_limits,
2075                 .flags  = GENL_ADMIN_PERM,
2076         },
2077         {
2078                 .cmd    = MPTCP_PM_CMD_GET_LIMITS,
2079                 .doit   = mptcp_nl_cmd_get_limits,
2080         },
2081         {
2082                 .cmd    = MPTCP_PM_CMD_SET_FLAGS,
2083                 .doit   = mptcp_nl_cmd_set_flags,
2084                 .flags  = GENL_ADMIN_PERM,
2085         },
2086 };
2087
2088 static struct genl_family mptcp_genl_family __ro_after_init = {
2089         .name           = MPTCP_PM_NAME,
2090         .version        = MPTCP_PM_VER,
2091         .maxattr        = MPTCP_PM_ATTR_MAX,
2092         .policy         = mptcp_pm_policy,
2093         .netnsok        = true,
2094         .module         = THIS_MODULE,
2095         .small_ops      = mptcp_pm_ops,
2096         .n_small_ops    = ARRAY_SIZE(mptcp_pm_ops),
2097         .mcgrps         = mptcp_pm_mcgrps,
2098         .n_mcgrps       = ARRAY_SIZE(mptcp_pm_mcgrps),
2099 };
2100
2101 static int __net_init pm_nl_init_net(struct net *net)
2102 {
2103         struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id);
2104
2105         INIT_LIST_HEAD_RCU(&pernet->local_addr_list);
2106
2107         /* Cit. 2 subflows ought to be enough for anybody. */
2108         pernet->subflows_max = 2;
2109         pernet->next_id = 1;
2110         pernet->stale_loss_cnt = 4;
2111         spin_lock_init(&pernet->lock);
2112
2113         /* No need to initialize other pernet fields, the struct is zeroed at
2114          * allocation time.
2115          */
2116
2117         return 0;
2118 }
2119
2120 static void __net_exit pm_nl_exit_net(struct list_head *net_list)
2121 {
2122         struct net *net;
2123
2124         list_for_each_entry(net, net_list, exit_list) {
2125                 struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id);
2126
2127                 /* net is removed from namespace list, can't race with
2128                  * other modifiers, also netns core already waited for a
2129                  * RCU grace period.
2130                  */
2131                 __flush_addrs(&pernet->local_addr_list);
2132         }
2133 }
2134
2135 static struct pernet_operations mptcp_pm_pernet_ops = {
2136         .init = pm_nl_init_net,
2137         .exit_batch = pm_nl_exit_net,
2138         .id = &pm_nl_pernet_id,
2139         .size = sizeof(struct pm_nl_pernet),
2140 };
2141
2142 void __init mptcp_pm_nl_init(void)
2143 {
2144         if (register_pernet_subsys(&mptcp_pm_pernet_ops) < 0)
2145                 panic("Failed to register MPTCP PM pernet subsystem.\n");
2146
2147         if (genl_register_family(&mptcp_genl_family))
2148                 panic("Failed to register MPTCP PM netlink family\n");
2149 }