Merge branch 'pm-cpuidle'
[linux-2.6-microblaze.git] / net / mptcp / pm_netlink.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Multipath TCP
3  *
4  * Copyright (c) 2020, Red Hat, Inc.
5  */
6
7 #include <linux/inet.h>
8 #include <linux/kernel.h>
9 #include <net/tcp.h>
10 #include <net/netns/generic.h>
11 #include <net/mptcp.h>
12 #include <net/genetlink.h>
13 #include <uapi/linux/mptcp.h>
14
15 #include "protocol.h"
16
17 /* forward declaration */
18 static struct genl_family mptcp_genl_family;
19
20 static int pm_nl_pernet_id;
21
22 struct mptcp_pm_addr_entry {
23         struct list_head        list;
24         unsigned int            flags;
25         int                     ifindex;
26         struct mptcp_addr_info  addr;
27         struct rcu_head         rcu;
28 };
29
30 struct pm_nl_pernet {
31         /* protects pernet updates */
32         spinlock_t              lock;
33         struct list_head        local_addr_list;
34         unsigned int            addrs;
35         unsigned int            add_addr_signal_max;
36         unsigned int            add_addr_accept_max;
37         unsigned int            local_addr_max;
38         unsigned int            subflows_max;
39         unsigned int            next_id;
40 };
41
42 #define MPTCP_PM_ADDR_MAX       8
43
44 static bool addresses_equal(const struct mptcp_addr_info *a,
45                             struct mptcp_addr_info *b, bool use_port)
46 {
47         bool addr_equals = false;
48
49         if (a->family != b->family)
50                 return false;
51
52         if (a->family == AF_INET)
53                 addr_equals = a->addr.s_addr == b->addr.s_addr;
54 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
55         else
56                 addr_equals = !ipv6_addr_cmp(&a->addr6, &b->addr6);
57 #endif
58
59         if (!addr_equals)
60                 return false;
61         if (!use_port)
62                 return true;
63
64         return a->port == b->port;
65 }
66
67 static void local_address(const struct sock_common *skc,
68                           struct mptcp_addr_info *addr)
69 {
70         addr->port = 0;
71         addr->family = skc->skc_family;
72         if (addr->family == AF_INET)
73                 addr->addr.s_addr = skc->skc_rcv_saddr;
74 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
75         else if (addr->family == AF_INET6)
76                 addr->addr6 = skc->skc_v6_rcv_saddr;
77 #endif
78 }
79
80 static void remote_address(const struct sock_common *skc,
81                            struct mptcp_addr_info *addr)
82 {
83         addr->family = skc->skc_family;
84         addr->port = skc->skc_dport;
85         if (addr->family == AF_INET)
86                 addr->addr.s_addr = skc->skc_daddr;
87 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
88         else if (addr->family == AF_INET6)
89                 addr->addr6 = skc->skc_v6_daddr;
90 #endif
91 }
92
93 static bool lookup_subflow_by_saddr(const struct list_head *list,
94                                     struct mptcp_addr_info *saddr)
95 {
96         struct mptcp_subflow_context *subflow;
97         struct mptcp_addr_info cur;
98         struct sock_common *skc;
99
100         list_for_each_entry(subflow, list, node) {
101                 skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
102
103                 local_address(skc, &cur);
104                 if (addresses_equal(&cur, saddr, false))
105                         return true;
106         }
107
108         return false;
109 }
110
111 static struct mptcp_pm_addr_entry *
112 select_local_address(const struct pm_nl_pernet *pernet,
113                      struct mptcp_sock *msk)
114 {
115         struct mptcp_pm_addr_entry *entry, *ret = NULL;
116
117         rcu_read_lock();
118         spin_lock_bh(&msk->join_list_lock);
119         list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
120                 if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW))
121                         continue;
122
123                 /* avoid any address already in use by subflows and
124                  * pending join
125                  */
126                 if (entry->addr.family == ((struct sock *)msk)->sk_family &&
127                     !lookup_subflow_by_saddr(&msk->conn_list, &entry->addr) &&
128                     !lookup_subflow_by_saddr(&msk->join_list, &entry->addr)) {
129                         ret = entry;
130                         break;
131                 }
132         }
133         spin_unlock_bh(&msk->join_list_lock);
134         rcu_read_unlock();
135         return ret;
136 }
137
138 static struct mptcp_pm_addr_entry *
139 select_signal_address(struct pm_nl_pernet *pernet, unsigned int pos)
140 {
141         struct mptcp_pm_addr_entry *entry, *ret = NULL;
142         int i = 0;
143
144         rcu_read_lock();
145         /* do not keep any additional per socket state, just signal
146          * the address list in order.
147          * Note: removal from the local address list during the msk life-cycle
148          * can lead to additional addresses not being announced.
149          */
150         list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
151                 if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL))
152                         continue;
153                 if (i++ == pos) {
154                         ret = entry;
155                         break;
156                 }
157         }
158         rcu_read_unlock();
159         return ret;
160 }
161
162 static void check_work_pending(struct mptcp_sock *msk)
163 {
164         if (msk->pm.add_addr_signaled == msk->pm.add_addr_signal_max &&
165             (msk->pm.local_addr_used == msk->pm.local_addr_max ||
166              msk->pm.subflows == msk->pm.subflows_max))
167                 WRITE_ONCE(msk->pm.work_pending, false);
168 }
169
170 static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
171 {
172         struct sock *sk = (struct sock *)msk;
173         struct mptcp_pm_addr_entry *local;
174         struct mptcp_addr_info remote;
175         struct pm_nl_pernet *pernet;
176
177         pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
178
179         pr_debug("local %d:%d signal %d:%d subflows %d:%d\n",
180                  msk->pm.local_addr_used, msk->pm.local_addr_max,
181                  msk->pm.add_addr_signaled, msk->pm.add_addr_signal_max,
182                  msk->pm.subflows, msk->pm.subflows_max);
183
184         /* check first for announce */
185         if (msk->pm.add_addr_signaled < msk->pm.add_addr_signal_max) {
186                 local = select_signal_address(pernet,
187                                               msk->pm.add_addr_signaled);
188
189                 if (local) {
190                         msk->pm.add_addr_signaled++;
191                         mptcp_pm_announce_addr(msk, &local->addr);
192                 } else {
193                         /* pick failed, avoid fourther attempts later */
194                         msk->pm.local_addr_used = msk->pm.add_addr_signal_max;
195                 }
196
197                 check_work_pending(msk);
198         }
199
200         /* check if should create a new subflow */
201         if (msk->pm.local_addr_used < msk->pm.local_addr_max &&
202             msk->pm.subflows < msk->pm.subflows_max) {
203                 remote_address((struct sock_common *)sk, &remote);
204
205                 local = select_local_address(pernet, msk);
206                 if (local) {
207                         msk->pm.local_addr_used++;
208                         msk->pm.subflows++;
209                         check_work_pending(msk);
210                         spin_unlock_bh(&msk->pm.lock);
211                         __mptcp_subflow_connect(sk, local->ifindex,
212                                                 &local->addr, &remote);
213                         spin_lock_bh(&msk->pm.lock);
214                         return;
215                 }
216
217                 /* lookup failed, avoid fourther attempts later */
218                 msk->pm.local_addr_used = msk->pm.local_addr_max;
219                 check_work_pending(msk);
220         }
221 }
222
223 void mptcp_pm_nl_fully_established(struct mptcp_sock *msk)
224 {
225         mptcp_pm_create_subflow_or_signal_addr(msk);
226 }
227
228 void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk)
229 {
230         mptcp_pm_create_subflow_or_signal_addr(msk);
231 }
232
233 void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk)
234 {
235         struct sock *sk = (struct sock *)msk;
236         struct mptcp_addr_info remote;
237         struct mptcp_addr_info local;
238
239         pr_debug("accepted %d:%d remote family %d",
240                  msk->pm.add_addr_accepted, msk->pm.add_addr_accept_max,
241                  msk->pm.remote.family);
242         msk->pm.add_addr_accepted++;
243         msk->pm.subflows++;
244         if (msk->pm.add_addr_accepted >= msk->pm.add_addr_accept_max ||
245             msk->pm.subflows >= msk->pm.subflows_max)
246                 WRITE_ONCE(msk->pm.accept_addr, false);
247
248         /* connect to the specified remote address, using whatever
249          * local address the routing configuration will pick.
250          */
251         remote = msk->pm.remote;
252         if (!remote.port)
253                 remote.port = sk->sk_dport;
254         memset(&local, 0, sizeof(local));
255         local.family = remote.family;
256
257         spin_unlock_bh(&msk->pm.lock);
258         __mptcp_subflow_connect((struct sock *)msk, 0, &local, &remote);
259         spin_lock_bh(&msk->pm.lock);
260 }
261
262 static bool address_use_port(struct mptcp_pm_addr_entry *entry)
263 {
264         return (entry->flags &
265                 (MPTCP_PM_ADDR_FLAG_SIGNAL | MPTCP_PM_ADDR_FLAG_SUBFLOW)) ==
266                 MPTCP_PM_ADDR_FLAG_SIGNAL;
267 }
268
269 static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet,
270                                              struct mptcp_pm_addr_entry *entry)
271 {
272         struct mptcp_pm_addr_entry *cur;
273         int ret = -EINVAL;
274
275         spin_lock_bh(&pernet->lock);
276         /* to keep the code simple, don't do IDR-like allocation for address ID,
277          * just bail when we exceed limits
278          */
279         if (pernet->next_id > 255)
280                 goto out;
281         if (pernet->addrs >= MPTCP_PM_ADDR_MAX)
282                 goto out;
283
284         /* do not insert duplicate address, differentiate on port only
285          * singled addresses
286          */
287         list_for_each_entry(cur, &pernet->local_addr_list, list) {
288                 if (addresses_equal(&cur->addr, &entry->addr,
289                                     address_use_port(entry) &&
290                                     address_use_port(cur)))
291                         goto out;
292         }
293
294         if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)
295                 pernet->add_addr_signal_max++;
296         if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW)
297                 pernet->local_addr_max++;
298
299         entry->addr.id = pernet->next_id++;
300         pernet->addrs++;
301         list_add_tail_rcu(&entry->list, &pernet->local_addr_list);
302         ret = entry->addr.id;
303
304 out:
305         spin_unlock_bh(&pernet->lock);
306         return ret;
307 }
308
309 int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
310 {
311         struct mptcp_pm_addr_entry *entry;
312         struct mptcp_addr_info skc_local;
313         struct mptcp_addr_info msk_local;
314         struct pm_nl_pernet *pernet;
315         int ret = -1;
316
317         if (WARN_ON_ONCE(!msk))
318                 return -1;
319
320         /* The 0 ID mapping is defined by the first subflow, copied into the msk
321          * addr
322          */
323         local_address((struct sock_common *)msk, &msk_local);
324         local_address((struct sock_common *)msk, &skc_local);
325         if (addresses_equal(&msk_local, &skc_local, false))
326                 return 0;
327
328         pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
329
330         rcu_read_lock();
331         list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
332                 if (addresses_equal(&entry->addr, &skc_local, false)) {
333                         ret = entry->addr.id;
334                         break;
335                 }
336         }
337         rcu_read_unlock();
338         if (ret >= 0)
339                 return ret;
340
341         /* address not found, add to local list */
342         entry = kmalloc(sizeof(*entry), GFP_KERNEL);
343         if (!entry)
344                 return -ENOMEM;
345
346         entry->flags = 0;
347         entry->addr = skc_local;
348         ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
349         if (ret < 0)
350                 kfree(entry);
351
352         return ret;
353 }
354
355 void mptcp_pm_nl_data_init(struct mptcp_sock *msk)
356 {
357         struct mptcp_pm_data *pm = &msk->pm;
358         struct pm_nl_pernet *pernet;
359         bool subflows;
360
361         pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
362
363         pm->add_addr_signal_max = READ_ONCE(pernet->add_addr_signal_max);
364         pm->add_addr_accept_max = READ_ONCE(pernet->add_addr_accept_max);
365         pm->local_addr_max = READ_ONCE(pernet->local_addr_max);
366         pm->subflows_max = READ_ONCE(pernet->subflows_max);
367         subflows = !!pm->subflows_max;
368         WRITE_ONCE(pm->work_pending, (!!pm->local_addr_max && subflows) ||
369                    !!pm->add_addr_signal_max);
370         WRITE_ONCE(pm->accept_addr, !!pm->add_addr_accept_max && subflows);
371         WRITE_ONCE(pm->accept_subflow, subflows);
372 }
373
374 #define MPTCP_PM_CMD_GRP_OFFSET 0
375
376 static const struct genl_multicast_group mptcp_pm_mcgrps[] = {
377         [MPTCP_PM_CMD_GRP_OFFSET]       = { .name = MPTCP_PM_CMD_GRP_NAME, },
378 };
379
380 static const struct nla_policy
381 mptcp_pm_addr_policy[MPTCP_PM_ADDR_ATTR_MAX + 1] = {
382         [MPTCP_PM_ADDR_ATTR_FAMILY]     = { .type       = NLA_U16,      },
383         [MPTCP_PM_ADDR_ATTR_ID]         = { .type       = NLA_U8,       },
384         [MPTCP_PM_ADDR_ATTR_ADDR4]      = { .type       = NLA_U32,      },
385         [MPTCP_PM_ADDR_ATTR_ADDR6]      = { .type       = NLA_EXACT_LEN,
386                                             .len   = sizeof(struct in6_addr), },
387         [MPTCP_PM_ADDR_ATTR_PORT]       = { .type       = NLA_U16       },
388         [MPTCP_PM_ADDR_ATTR_FLAGS]      = { .type       = NLA_U32       },
389         [MPTCP_PM_ADDR_ATTR_IF_IDX]     = { .type       = NLA_S32       },
390 };
391
392 static const struct nla_policy mptcp_pm_policy[MPTCP_PM_ATTR_MAX + 1] = {
393         [MPTCP_PM_ATTR_ADDR]            =
394                                         NLA_POLICY_NESTED(mptcp_pm_addr_policy),
395         [MPTCP_PM_ATTR_RCV_ADD_ADDRS]   = { .type       = NLA_U32,      },
396         [MPTCP_PM_ATTR_SUBFLOWS]        = { .type       = NLA_U32,      },
397 };
398
399 static int mptcp_pm_family_to_addr(int family)
400 {
401 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
402         if (family == AF_INET6)
403                 return MPTCP_PM_ADDR_ATTR_ADDR6;
404 #endif
405         return MPTCP_PM_ADDR_ATTR_ADDR4;
406 }
407
408 static int mptcp_pm_parse_addr(struct nlattr *attr, struct genl_info *info,
409                                bool require_family,
410                                struct mptcp_pm_addr_entry *entry)
411 {
412         struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1];
413         int err, addr_addr;
414
415         if (!attr) {
416                 GENL_SET_ERR_MSG(info, "missing address info");
417                 return -EINVAL;
418         }
419
420         /* no validation needed - was already done via nested policy */
421         err = nla_parse_nested_deprecated(tb, MPTCP_PM_ADDR_ATTR_MAX, attr,
422                                           mptcp_pm_addr_policy, info->extack);
423         if (err)
424                 return err;
425
426         memset(entry, 0, sizeof(*entry));
427         if (!tb[MPTCP_PM_ADDR_ATTR_FAMILY]) {
428                 if (!require_family)
429                         goto skip_family;
430
431                 NL_SET_ERR_MSG_ATTR(info->extack, attr,
432                                     "missing family");
433                 return -EINVAL;
434         }
435
436         entry->addr.family = nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_FAMILY]);
437         if (entry->addr.family != AF_INET
438 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
439             && entry->addr.family != AF_INET6
440 #endif
441             ) {
442                 NL_SET_ERR_MSG_ATTR(info->extack, attr,
443                                     "unknown address family");
444                 return -EINVAL;
445         }
446         addr_addr = mptcp_pm_family_to_addr(entry->addr.family);
447         if (!tb[addr_addr]) {
448                 NL_SET_ERR_MSG_ATTR(info->extack, attr,
449                                     "missing address data");
450                 return -EINVAL;
451         }
452
453 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
454         if (entry->addr.family == AF_INET6)
455                 entry->addr.addr6 = nla_get_in6_addr(tb[addr_addr]);
456         else
457 #endif
458                 entry->addr.addr.s_addr = nla_get_in_addr(tb[addr_addr]);
459
460 skip_family:
461         if (tb[MPTCP_PM_ADDR_ATTR_IF_IDX])
462                 entry->ifindex = nla_get_s32(tb[MPTCP_PM_ADDR_ATTR_IF_IDX]);
463
464         if (tb[MPTCP_PM_ADDR_ATTR_ID])
465                 entry->addr.id = nla_get_u8(tb[MPTCP_PM_ADDR_ATTR_ID]);
466
467         if (tb[MPTCP_PM_ADDR_ATTR_FLAGS])
468                 entry->flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]);
469
470         return 0;
471 }
472
473 static struct pm_nl_pernet *genl_info_pm_nl(struct genl_info *info)
474 {
475         return net_generic(genl_info_net(info), pm_nl_pernet_id);
476 }
477
478 static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info)
479 {
480         struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
481         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
482         struct mptcp_pm_addr_entry addr, *entry;
483         int ret;
484
485         ret = mptcp_pm_parse_addr(attr, info, true, &addr);
486         if (ret < 0)
487                 return ret;
488
489         entry = kmalloc(sizeof(*entry), GFP_KERNEL);
490         if (!entry) {
491                 GENL_SET_ERR_MSG(info, "can't allocate addr");
492                 return -ENOMEM;
493         }
494
495         *entry = addr;
496         ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
497         if (ret < 0) {
498                 GENL_SET_ERR_MSG(info, "too many addresses or duplicate one");
499                 kfree(entry);
500                 return ret;
501         }
502
503         return 0;
504 }
505
506 static struct mptcp_pm_addr_entry *
507 __lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id)
508 {
509         struct mptcp_pm_addr_entry *entry;
510
511         list_for_each_entry(entry, &pernet->local_addr_list, list) {
512                 if (entry->addr.id == id)
513                         return entry;
514         }
515         return NULL;
516 }
517
518 static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
519 {
520         struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
521         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
522         struct mptcp_pm_addr_entry addr, *entry;
523         int ret;
524
525         ret = mptcp_pm_parse_addr(attr, info, false, &addr);
526         if (ret < 0)
527                 return ret;
528
529         spin_lock_bh(&pernet->lock);
530         entry = __lookup_addr_by_id(pernet, addr.addr.id);
531         if (!entry) {
532                 GENL_SET_ERR_MSG(info, "address not found");
533                 ret = -EINVAL;
534                 goto out;
535         }
536         if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)
537                 pernet->add_addr_signal_max--;
538         if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW)
539                 pernet->local_addr_max--;
540
541         pernet->addrs--;
542         list_del_rcu(&entry->list);
543         kfree_rcu(entry, rcu);
544 out:
545         spin_unlock_bh(&pernet->lock);
546         return ret;
547 }
548
549 static void __flush_addrs(struct pm_nl_pernet *pernet)
550 {
551         while (!list_empty(&pernet->local_addr_list)) {
552                 struct mptcp_pm_addr_entry *cur;
553
554                 cur = list_entry(pernet->local_addr_list.next,
555                                  struct mptcp_pm_addr_entry, list);
556                 list_del_rcu(&cur->list);
557                 kfree_rcu(cur, rcu);
558         }
559 }
560
561 static void __reset_counters(struct pm_nl_pernet *pernet)
562 {
563         pernet->add_addr_signal_max = 0;
564         pernet->add_addr_accept_max = 0;
565         pernet->local_addr_max = 0;
566         pernet->addrs = 0;
567 }
568
569 static int mptcp_nl_cmd_flush_addrs(struct sk_buff *skb, struct genl_info *info)
570 {
571         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
572
573         spin_lock_bh(&pernet->lock);
574         __flush_addrs(pernet);
575         __reset_counters(pernet);
576         spin_unlock_bh(&pernet->lock);
577         return 0;
578 }
579
580 static int mptcp_nl_fill_addr(struct sk_buff *skb,
581                               struct mptcp_pm_addr_entry *entry)
582 {
583         struct mptcp_addr_info *addr = &entry->addr;
584         struct nlattr *attr;
585
586         attr = nla_nest_start(skb, MPTCP_PM_ATTR_ADDR);
587         if (!attr)
588                 return -EMSGSIZE;
589
590         if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_FAMILY, addr->family))
591                 goto nla_put_failure;
592         if (nla_put_u8(skb, MPTCP_PM_ADDR_ATTR_ID, addr->id))
593                 goto nla_put_failure;
594         if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->flags))
595                 goto nla_put_failure;
596         if (entry->ifindex &&
597             nla_put_s32(skb, MPTCP_PM_ADDR_ATTR_IF_IDX, entry->ifindex))
598                 goto nla_put_failure;
599
600         if (addr->family == AF_INET)
601                 nla_put_in_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR4,
602                                 addr->addr.s_addr);
603 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
604         else if (addr->family == AF_INET6)
605                 nla_put_in6_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR6, &addr->addr6);
606 #endif
607         nla_nest_end(skb, attr);
608         return 0;
609
610 nla_put_failure:
611         nla_nest_cancel(skb, attr);
612         return -EMSGSIZE;
613 }
614
615 static int mptcp_nl_cmd_get_addr(struct sk_buff *skb, struct genl_info *info)
616 {
617         struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
618         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
619         struct mptcp_pm_addr_entry addr, *entry;
620         struct sk_buff *msg;
621         void *reply;
622         int ret;
623
624         ret = mptcp_pm_parse_addr(attr, info, false, &addr);
625         if (ret < 0)
626                 return ret;
627
628         msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
629         if (!msg)
630                 return -ENOMEM;
631
632         reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
633                                   info->genlhdr->cmd);
634         if (!reply) {
635                 GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
636                 ret = -EMSGSIZE;
637                 goto fail;
638         }
639
640         spin_lock_bh(&pernet->lock);
641         entry = __lookup_addr_by_id(pernet, addr.addr.id);
642         if (!entry) {
643                 GENL_SET_ERR_MSG(info, "address not found");
644                 ret = -EINVAL;
645                 goto unlock_fail;
646         }
647
648         ret = mptcp_nl_fill_addr(msg, entry);
649         if (ret)
650                 goto unlock_fail;
651
652         genlmsg_end(msg, reply);
653         ret = genlmsg_reply(msg, info);
654         spin_unlock_bh(&pernet->lock);
655         return ret;
656
657 unlock_fail:
658         spin_unlock_bh(&pernet->lock);
659
660 fail:
661         nlmsg_free(msg);
662         return ret;
663 }
664
665 static int mptcp_nl_cmd_dump_addrs(struct sk_buff *msg,
666                                    struct netlink_callback *cb)
667 {
668         struct net *net = sock_net(msg->sk);
669         struct mptcp_pm_addr_entry *entry;
670         struct pm_nl_pernet *pernet;
671         int id = cb->args[0];
672         void *hdr;
673
674         pernet = net_generic(net, pm_nl_pernet_id);
675
676         spin_lock_bh(&pernet->lock);
677         list_for_each_entry(entry, &pernet->local_addr_list, list) {
678                 if (entry->addr.id <= id)
679                         continue;
680
681                 hdr = genlmsg_put(msg, NETLINK_CB(cb->skb).portid,
682                                   cb->nlh->nlmsg_seq, &mptcp_genl_family,
683                                   NLM_F_MULTI, MPTCP_PM_CMD_GET_ADDR);
684                 if (!hdr)
685                         break;
686
687                 if (mptcp_nl_fill_addr(msg, entry) < 0) {
688                         genlmsg_cancel(msg, hdr);
689                         break;
690                 }
691
692                 id = entry->addr.id;
693                 genlmsg_end(msg, hdr);
694         }
695         spin_unlock_bh(&pernet->lock);
696
697         cb->args[0] = id;
698         return msg->len;
699 }
700
701 static int parse_limit(struct genl_info *info, int id, unsigned int *limit)
702 {
703         struct nlattr *attr = info->attrs[id];
704
705         if (!attr)
706                 return 0;
707
708         *limit = nla_get_u32(attr);
709         if (*limit > MPTCP_PM_ADDR_MAX) {
710                 GENL_SET_ERR_MSG(info, "limit greater than maximum");
711                 return -EINVAL;
712         }
713         return 0;
714 }
715
716 static int
717 mptcp_nl_cmd_set_limits(struct sk_buff *skb, struct genl_info *info)
718 {
719         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
720         unsigned int rcv_addrs, subflows;
721         int ret;
722
723         spin_lock_bh(&pernet->lock);
724         rcv_addrs = pernet->add_addr_accept_max;
725         ret = parse_limit(info, MPTCP_PM_ATTR_RCV_ADD_ADDRS, &rcv_addrs);
726         if (ret)
727                 goto unlock;
728
729         subflows = pernet->subflows_max;
730         ret = parse_limit(info, MPTCP_PM_ATTR_SUBFLOWS, &subflows);
731         if (ret)
732                 goto unlock;
733
734         WRITE_ONCE(pernet->add_addr_accept_max, rcv_addrs);
735         WRITE_ONCE(pernet->subflows_max, subflows);
736
737 unlock:
738         spin_unlock_bh(&pernet->lock);
739         return ret;
740 }
741
742 static int
743 mptcp_nl_cmd_get_limits(struct sk_buff *skb, struct genl_info *info)
744 {
745         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
746         struct sk_buff *msg;
747         void *reply;
748
749         msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
750         if (!msg)
751                 return -ENOMEM;
752
753         reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
754                                   MPTCP_PM_CMD_GET_LIMITS);
755         if (!reply)
756                 goto fail;
757
758         if (nla_put_u32(msg, MPTCP_PM_ATTR_RCV_ADD_ADDRS,
759                         READ_ONCE(pernet->add_addr_accept_max)))
760                 goto fail;
761
762         if (nla_put_u32(msg, MPTCP_PM_ATTR_SUBFLOWS,
763                         READ_ONCE(pernet->subflows_max)))
764                 goto fail;
765
766         genlmsg_end(msg, reply);
767         return genlmsg_reply(msg, info);
768
769 fail:
770         GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
771         nlmsg_free(msg);
772         return -EMSGSIZE;
773 }
774
775 static struct genl_ops mptcp_pm_ops[] = {
776         {
777                 .cmd    = MPTCP_PM_CMD_ADD_ADDR,
778                 .doit   = mptcp_nl_cmd_add_addr,
779                 .flags  = GENL_ADMIN_PERM,
780         },
781         {
782                 .cmd    = MPTCP_PM_CMD_DEL_ADDR,
783                 .doit   = mptcp_nl_cmd_del_addr,
784                 .flags  = GENL_ADMIN_PERM,
785         },
786         {
787                 .cmd    = MPTCP_PM_CMD_FLUSH_ADDRS,
788                 .doit   = mptcp_nl_cmd_flush_addrs,
789                 .flags  = GENL_ADMIN_PERM,
790         },
791         {
792                 .cmd    = MPTCP_PM_CMD_GET_ADDR,
793                 .doit   = mptcp_nl_cmd_get_addr,
794                 .dumpit   = mptcp_nl_cmd_dump_addrs,
795         },
796         {
797                 .cmd    = MPTCP_PM_CMD_SET_LIMITS,
798                 .doit   = mptcp_nl_cmd_set_limits,
799                 .flags  = GENL_ADMIN_PERM,
800         },
801         {
802                 .cmd    = MPTCP_PM_CMD_GET_LIMITS,
803                 .doit   = mptcp_nl_cmd_get_limits,
804         },
805 };
806
807 static struct genl_family mptcp_genl_family __ro_after_init = {
808         .name           = MPTCP_PM_NAME,
809         .version        = MPTCP_PM_VER,
810         .maxattr        = MPTCP_PM_ATTR_MAX,
811         .policy         = mptcp_pm_policy,
812         .netnsok        = true,
813         .module         = THIS_MODULE,
814         .ops            = mptcp_pm_ops,
815         .n_ops          = ARRAY_SIZE(mptcp_pm_ops),
816         .mcgrps         = mptcp_pm_mcgrps,
817         .n_mcgrps       = ARRAY_SIZE(mptcp_pm_mcgrps),
818 };
819
820 static int __net_init pm_nl_init_net(struct net *net)
821 {
822         struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id);
823
824         INIT_LIST_HEAD_RCU(&pernet->local_addr_list);
825         __reset_counters(pernet);
826         pernet->next_id = 1;
827         spin_lock_init(&pernet->lock);
828         return 0;
829 }
830
831 static void __net_exit pm_nl_exit_net(struct list_head *net_list)
832 {
833         struct net *net;
834
835         list_for_each_entry(net, net_list, exit_list) {
836                 /* net is removed from namespace list, can't race with
837                  * other modifiers
838                  */
839                 __flush_addrs(net_generic(net, pm_nl_pernet_id));
840         }
841 }
842
843 static struct pernet_operations mptcp_pm_pernet_ops = {
844         .init = pm_nl_init_net,
845         .exit_batch = pm_nl_exit_net,
846         .id = &pm_nl_pernet_id,
847         .size = sizeof(struct pm_nl_pernet),
848 };
849
850 void mptcp_pm_nl_init(void)
851 {
852         if (register_pernet_subsys(&mptcp_pm_pernet_ops) < 0)
853                 panic("Failed to register MPTCP PM pernet subsystem.\n");
854
855         if (genl_register_family(&mptcp_genl_family))
856                 panic("Failed to register MPTCP PM netlink family\n");
857 }