Merge tag 'linux-kselftest-fixes-5.10-rc1' of git://git.kernel.org/pub/scm/linux...
[linux-2.6-microblaze.git] / net / mptcp / pm_netlink.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Multipath TCP
3  *
4  * Copyright (c) 2020, Red Hat, Inc.
5  */
6
7 #define pr_fmt(fmt) "MPTCP: " fmt
8
9 #include <linux/inet.h>
10 #include <linux/kernel.h>
11 #include <net/tcp.h>
12 #include <net/netns/generic.h>
13 #include <net/mptcp.h>
14 #include <net/genetlink.h>
15 #include <uapi/linux/mptcp.h>
16
17 #include "protocol.h"
18
19 /* forward declaration */
20 static struct genl_family mptcp_genl_family;
21
22 static int pm_nl_pernet_id;
23
24 struct mptcp_pm_addr_entry {
25         struct list_head        list;
26         unsigned int            flags;
27         int                     ifindex;
28         struct mptcp_addr_info  addr;
29         struct rcu_head         rcu;
30 };
31
32 struct pm_nl_pernet {
33         /* protects pernet updates */
34         spinlock_t              lock;
35         struct list_head        local_addr_list;
36         unsigned int            addrs;
37         unsigned int            add_addr_signal_max;
38         unsigned int            add_addr_accept_max;
39         unsigned int            local_addr_max;
40         unsigned int            subflows_max;
41         unsigned int            next_id;
42 };
43
44 #define MPTCP_PM_ADDR_MAX       8
45
46 static bool addresses_equal(const struct mptcp_addr_info *a,
47                             struct mptcp_addr_info *b, bool use_port)
48 {
49         bool addr_equals = false;
50
51         if (a->family != b->family)
52                 return false;
53
54         if (a->family == AF_INET)
55                 addr_equals = a->addr.s_addr == b->addr.s_addr;
56 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
57         else
58                 addr_equals = !ipv6_addr_cmp(&a->addr6, &b->addr6);
59 #endif
60
61         if (!addr_equals)
62                 return false;
63         if (!use_port)
64                 return true;
65
66         return a->port == b->port;
67 }
68
69 static bool address_zero(const struct mptcp_addr_info *addr)
70 {
71         struct mptcp_addr_info zero;
72
73         memset(&zero, 0, sizeof(zero));
74         zero.family = addr->family;
75
76         return addresses_equal(addr, &zero, false);
77 }
78
79 static void local_address(const struct sock_common *skc,
80                           struct mptcp_addr_info *addr)
81 {
82         addr->port = 0;
83         addr->family = skc->skc_family;
84         if (addr->family == AF_INET)
85                 addr->addr.s_addr = skc->skc_rcv_saddr;
86 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
87         else if (addr->family == AF_INET6)
88                 addr->addr6 = skc->skc_v6_rcv_saddr;
89 #endif
90 }
91
92 static void remote_address(const struct sock_common *skc,
93                            struct mptcp_addr_info *addr)
94 {
95         addr->family = skc->skc_family;
96         addr->port = skc->skc_dport;
97         if (addr->family == AF_INET)
98                 addr->addr.s_addr = skc->skc_daddr;
99 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
100         else if (addr->family == AF_INET6)
101                 addr->addr6 = skc->skc_v6_daddr;
102 #endif
103 }
104
105 static bool lookup_subflow_by_saddr(const struct list_head *list,
106                                     struct mptcp_addr_info *saddr)
107 {
108         struct mptcp_subflow_context *subflow;
109         struct mptcp_addr_info cur;
110         struct sock_common *skc;
111
112         list_for_each_entry(subflow, list, node) {
113                 skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
114
115                 local_address(skc, &cur);
116                 if (addresses_equal(&cur, saddr, false))
117                         return true;
118         }
119
120         return false;
121 }
122
123 static struct mptcp_pm_addr_entry *
124 select_local_address(const struct pm_nl_pernet *pernet,
125                      struct mptcp_sock *msk)
126 {
127         struct mptcp_pm_addr_entry *entry, *ret = NULL;
128
129         rcu_read_lock();
130         spin_lock_bh(&msk->join_list_lock);
131         list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
132                 if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW))
133                         continue;
134
135                 /* avoid any address already in use by subflows and
136                  * pending join
137                  */
138                 if (entry->addr.family == ((struct sock *)msk)->sk_family &&
139                     !lookup_subflow_by_saddr(&msk->conn_list, &entry->addr) &&
140                     !lookup_subflow_by_saddr(&msk->join_list, &entry->addr)) {
141                         ret = entry;
142                         break;
143                 }
144         }
145         spin_unlock_bh(&msk->join_list_lock);
146         rcu_read_unlock();
147         return ret;
148 }
149
150 static struct mptcp_pm_addr_entry *
151 select_signal_address(struct pm_nl_pernet *pernet, unsigned int pos)
152 {
153         struct mptcp_pm_addr_entry *entry, *ret = NULL;
154         int i = 0;
155
156         rcu_read_lock();
157         /* do not keep any additional per socket state, just signal
158          * the address list in order.
159          * Note: removal from the local address list during the msk life-cycle
160          * can lead to additional addresses not being announced.
161          */
162         list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
163                 if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL))
164                         continue;
165                 if (i++ == pos) {
166                         ret = entry;
167                         break;
168                 }
169         }
170         rcu_read_unlock();
171         return ret;
172 }
173
174 static void check_work_pending(struct mptcp_sock *msk)
175 {
176         if (msk->pm.add_addr_signaled == msk->pm.add_addr_signal_max &&
177             (msk->pm.local_addr_used == msk->pm.local_addr_max ||
178              msk->pm.subflows == msk->pm.subflows_max))
179                 WRITE_ONCE(msk->pm.work_pending, false);
180 }
181
182 static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
183 {
184         struct mptcp_addr_info remote = { 0 };
185         struct sock *sk = (struct sock *)msk;
186         struct mptcp_pm_addr_entry *local;
187         struct pm_nl_pernet *pernet;
188
189         pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
190
191         pr_debug("local %d:%d signal %d:%d subflows %d:%d\n",
192                  msk->pm.local_addr_used, msk->pm.local_addr_max,
193                  msk->pm.add_addr_signaled, msk->pm.add_addr_signal_max,
194                  msk->pm.subflows, msk->pm.subflows_max);
195
196         /* check first for announce */
197         if (msk->pm.add_addr_signaled < msk->pm.add_addr_signal_max) {
198                 local = select_signal_address(pernet,
199                                               msk->pm.add_addr_signaled);
200
201                 if (local) {
202                         msk->pm.add_addr_signaled++;
203                         mptcp_pm_announce_addr(msk, &local->addr);
204                 } else {
205                         /* pick failed, avoid fourther attempts later */
206                         msk->pm.local_addr_used = msk->pm.add_addr_signal_max;
207                 }
208
209                 check_work_pending(msk);
210         }
211
212         /* check if should create a new subflow */
213         if (msk->pm.local_addr_used < msk->pm.local_addr_max &&
214             msk->pm.subflows < msk->pm.subflows_max) {
215                 remote_address((struct sock_common *)sk, &remote);
216
217                 local = select_local_address(pernet, msk);
218                 if (local) {
219                         msk->pm.local_addr_used++;
220                         msk->pm.subflows++;
221                         check_work_pending(msk);
222                         spin_unlock_bh(&msk->pm.lock);
223                         __mptcp_subflow_connect(sk, local->ifindex,
224                                                 &local->addr, &remote);
225                         spin_lock_bh(&msk->pm.lock);
226                         return;
227                 }
228
229                 /* lookup failed, avoid fourther attempts later */
230                 msk->pm.local_addr_used = msk->pm.local_addr_max;
231                 check_work_pending(msk);
232         }
233 }
234
235 void mptcp_pm_nl_fully_established(struct mptcp_sock *msk)
236 {
237         mptcp_pm_create_subflow_or_signal_addr(msk);
238 }
239
240 void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk)
241 {
242         mptcp_pm_create_subflow_or_signal_addr(msk);
243 }
244
245 void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk)
246 {
247         struct sock *sk = (struct sock *)msk;
248         struct mptcp_addr_info remote;
249         struct mptcp_addr_info local;
250
251         pr_debug("accepted %d:%d remote family %d",
252                  msk->pm.add_addr_accepted, msk->pm.add_addr_accept_max,
253                  msk->pm.remote.family);
254         msk->pm.add_addr_accepted++;
255         msk->pm.subflows++;
256         if (msk->pm.add_addr_accepted >= msk->pm.add_addr_accept_max ||
257             msk->pm.subflows >= msk->pm.subflows_max)
258                 WRITE_ONCE(msk->pm.accept_addr, false);
259
260         /* connect to the specified remote address, using whatever
261          * local address the routing configuration will pick.
262          */
263         remote = msk->pm.remote;
264         if (!remote.port)
265                 remote.port = sk->sk_dport;
266         memset(&local, 0, sizeof(local));
267         local.family = remote.family;
268
269         spin_unlock_bh(&msk->pm.lock);
270         __mptcp_subflow_connect((struct sock *)msk, 0, &local, &remote);
271         spin_lock_bh(&msk->pm.lock);
272 }
273
274 static bool address_use_port(struct mptcp_pm_addr_entry *entry)
275 {
276         return (entry->flags &
277                 (MPTCP_PM_ADDR_FLAG_SIGNAL | MPTCP_PM_ADDR_FLAG_SUBFLOW)) ==
278                 MPTCP_PM_ADDR_FLAG_SIGNAL;
279 }
280
281 static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet,
282                                              struct mptcp_pm_addr_entry *entry)
283 {
284         struct mptcp_pm_addr_entry *cur;
285         int ret = -EINVAL;
286
287         spin_lock_bh(&pernet->lock);
288         /* to keep the code simple, don't do IDR-like allocation for address ID,
289          * just bail when we exceed limits
290          */
291         if (pernet->next_id > 255)
292                 goto out;
293         if (pernet->addrs >= MPTCP_PM_ADDR_MAX)
294                 goto out;
295
296         /* do not insert duplicate address, differentiate on port only
297          * singled addresses
298          */
299         list_for_each_entry(cur, &pernet->local_addr_list, list) {
300                 if (addresses_equal(&cur->addr, &entry->addr,
301                                     address_use_port(entry) &&
302                                     address_use_port(cur)))
303                         goto out;
304         }
305
306         if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)
307                 pernet->add_addr_signal_max++;
308         if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW)
309                 pernet->local_addr_max++;
310
311         entry->addr.id = pernet->next_id++;
312         pernet->addrs++;
313         list_add_tail_rcu(&entry->list, &pernet->local_addr_list);
314         ret = entry->addr.id;
315
316 out:
317         spin_unlock_bh(&pernet->lock);
318         return ret;
319 }
320
321 int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
322 {
323         struct mptcp_pm_addr_entry *entry;
324         struct mptcp_addr_info skc_local;
325         struct mptcp_addr_info msk_local;
326         struct pm_nl_pernet *pernet;
327         int ret = -1;
328
329         if (WARN_ON_ONCE(!msk))
330                 return -1;
331
332         /* The 0 ID mapping is defined by the first subflow, copied into the msk
333          * addr
334          */
335         local_address((struct sock_common *)msk, &msk_local);
336         local_address((struct sock_common *)skc, &skc_local);
337         if (addresses_equal(&msk_local, &skc_local, false))
338                 return 0;
339
340         if (address_zero(&skc_local))
341                 return 0;
342
343         pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
344
345         rcu_read_lock();
346         list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
347                 if (addresses_equal(&entry->addr, &skc_local, false)) {
348                         ret = entry->addr.id;
349                         break;
350                 }
351         }
352         rcu_read_unlock();
353         if (ret >= 0)
354                 return ret;
355
356         /* address not found, add to local list */
357         entry = kmalloc(sizeof(*entry), GFP_ATOMIC);
358         if (!entry)
359                 return -ENOMEM;
360
361         entry->flags = 0;
362         entry->addr = skc_local;
363         ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
364         if (ret < 0)
365                 kfree(entry);
366
367         return ret;
368 }
369
370 void mptcp_pm_nl_data_init(struct mptcp_sock *msk)
371 {
372         struct mptcp_pm_data *pm = &msk->pm;
373         struct pm_nl_pernet *pernet;
374         bool subflows;
375
376         pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
377
378         pm->add_addr_signal_max = READ_ONCE(pernet->add_addr_signal_max);
379         pm->add_addr_accept_max = READ_ONCE(pernet->add_addr_accept_max);
380         pm->local_addr_max = READ_ONCE(pernet->local_addr_max);
381         pm->subflows_max = READ_ONCE(pernet->subflows_max);
382         subflows = !!pm->subflows_max;
383         WRITE_ONCE(pm->work_pending, (!!pm->local_addr_max && subflows) ||
384                    !!pm->add_addr_signal_max);
385         WRITE_ONCE(pm->accept_addr, !!pm->add_addr_accept_max && subflows);
386         WRITE_ONCE(pm->accept_subflow, subflows);
387 }
388
389 #define MPTCP_PM_CMD_GRP_OFFSET 0
390
391 static const struct genl_multicast_group mptcp_pm_mcgrps[] = {
392         [MPTCP_PM_CMD_GRP_OFFSET]       = { .name = MPTCP_PM_CMD_GRP_NAME, },
393 };
394
395 static const struct nla_policy
396 mptcp_pm_addr_policy[MPTCP_PM_ADDR_ATTR_MAX + 1] = {
397         [MPTCP_PM_ADDR_ATTR_FAMILY]     = { .type       = NLA_U16,      },
398         [MPTCP_PM_ADDR_ATTR_ID]         = { .type       = NLA_U8,       },
399         [MPTCP_PM_ADDR_ATTR_ADDR4]      = { .type       = NLA_U32,      },
400         [MPTCP_PM_ADDR_ATTR_ADDR6]      = { .type       = NLA_EXACT_LEN,
401                                             .len   = sizeof(struct in6_addr), },
402         [MPTCP_PM_ADDR_ATTR_PORT]       = { .type       = NLA_U16       },
403         [MPTCP_PM_ADDR_ATTR_FLAGS]      = { .type       = NLA_U32       },
404         [MPTCP_PM_ADDR_ATTR_IF_IDX]     = { .type       = NLA_S32       },
405 };
406
407 static const struct nla_policy mptcp_pm_policy[MPTCP_PM_ATTR_MAX + 1] = {
408         [MPTCP_PM_ATTR_ADDR]            =
409                                         NLA_POLICY_NESTED(mptcp_pm_addr_policy),
410         [MPTCP_PM_ATTR_RCV_ADD_ADDRS]   = { .type       = NLA_U32,      },
411         [MPTCP_PM_ATTR_SUBFLOWS]        = { .type       = NLA_U32,      },
412 };
413
414 static int mptcp_pm_family_to_addr(int family)
415 {
416 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
417         if (family == AF_INET6)
418                 return MPTCP_PM_ADDR_ATTR_ADDR6;
419 #endif
420         return MPTCP_PM_ADDR_ATTR_ADDR4;
421 }
422
423 static int mptcp_pm_parse_addr(struct nlattr *attr, struct genl_info *info,
424                                bool require_family,
425                                struct mptcp_pm_addr_entry *entry)
426 {
427         struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1];
428         int err, addr_addr;
429
430         if (!attr) {
431                 GENL_SET_ERR_MSG(info, "missing address info");
432                 return -EINVAL;
433         }
434
435         /* no validation needed - was already done via nested policy */
436         err = nla_parse_nested_deprecated(tb, MPTCP_PM_ADDR_ATTR_MAX, attr,
437                                           mptcp_pm_addr_policy, info->extack);
438         if (err)
439                 return err;
440
441         memset(entry, 0, sizeof(*entry));
442         if (!tb[MPTCP_PM_ADDR_ATTR_FAMILY]) {
443                 if (!require_family)
444                         goto skip_family;
445
446                 NL_SET_ERR_MSG_ATTR(info->extack, attr,
447                                     "missing family");
448                 return -EINVAL;
449         }
450
451         entry->addr.family = nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_FAMILY]);
452         if (entry->addr.family != AF_INET
453 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
454             && entry->addr.family != AF_INET6
455 #endif
456             ) {
457                 NL_SET_ERR_MSG_ATTR(info->extack, attr,
458                                     "unknown address family");
459                 return -EINVAL;
460         }
461         addr_addr = mptcp_pm_family_to_addr(entry->addr.family);
462         if (!tb[addr_addr]) {
463                 NL_SET_ERR_MSG_ATTR(info->extack, attr,
464                                     "missing address data");
465                 return -EINVAL;
466         }
467
468 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
469         if (entry->addr.family == AF_INET6)
470                 entry->addr.addr6 = nla_get_in6_addr(tb[addr_addr]);
471         else
472 #endif
473                 entry->addr.addr.s_addr = nla_get_in_addr(tb[addr_addr]);
474
475 skip_family:
476         if (tb[MPTCP_PM_ADDR_ATTR_IF_IDX])
477                 entry->ifindex = nla_get_s32(tb[MPTCP_PM_ADDR_ATTR_IF_IDX]);
478
479         if (tb[MPTCP_PM_ADDR_ATTR_ID])
480                 entry->addr.id = nla_get_u8(tb[MPTCP_PM_ADDR_ATTR_ID]);
481
482         if (tb[MPTCP_PM_ADDR_ATTR_FLAGS])
483                 entry->flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]);
484
485         return 0;
486 }
487
488 static struct pm_nl_pernet *genl_info_pm_nl(struct genl_info *info)
489 {
490         return net_generic(genl_info_net(info), pm_nl_pernet_id);
491 }
492
493 static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info)
494 {
495         struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
496         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
497         struct mptcp_pm_addr_entry addr, *entry;
498         int ret;
499
500         ret = mptcp_pm_parse_addr(attr, info, true, &addr);
501         if (ret < 0)
502                 return ret;
503
504         entry = kmalloc(sizeof(*entry), GFP_KERNEL);
505         if (!entry) {
506                 GENL_SET_ERR_MSG(info, "can't allocate addr");
507                 return -ENOMEM;
508         }
509
510         *entry = addr;
511         ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
512         if (ret < 0) {
513                 GENL_SET_ERR_MSG(info, "too many addresses or duplicate one");
514                 kfree(entry);
515                 return ret;
516         }
517
518         return 0;
519 }
520
521 static struct mptcp_pm_addr_entry *
522 __lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id)
523 {
524         struct mptcp_pm_addr_entry *entry;
525
526         list_for_each_entry(entry, &pernet->local_addr_list, list) {
527                 if (entry->addr.id == id)
528                         return entry;
529         }
530         return NULL;
531 }
532
533 static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
534 {
535         struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
536         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
537         struct mptcp_pm_addr_entry addr, *entry;
538         int ret;
539
540         ret = mptcp_pm_parse_addr(attr, info, false, &addr);
541         if (ret < 0)
542                 return ret;
543
544         spin_lock_bh(&pernet->lock);
545         entry = __lookup_addr_by_id(pernet, addr.addr.id);
546         if (!entry) {
547                 GENL_SET_ERR_MSG(info, "address not found");
548                 ret = -EINVAL;
549                 goto out;
550         }
551         if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)
552                 pernet->add_addr_signal_max--;
553         if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW)
554                 pernet->local_addr_max--;
555
556         pernet->addrs--;
557         list_del_rcu(&entry->list);
558         kfree_rcu(entry, rcu);
559 out:
560         spin_unlock_bh(&pernet->lock);
561         return ret;
562 }
563
564 static void __flush_addrs(struct pm_nl_pernet *pernet)
565 {
566         while (!list_empty(&pernet->local_addr_list)) {
567                 struct mptcp_pm_addr_entry *cur;
568
569                 cur = list_entry(pernet->local_addr_list.next,
570                                  struct mptcp_pm_addr_entry, list);
571                 list_del_rcu(&cur->list);
572                 kfree_rcu(cur, rcu);
573         }
574 }
575
576 static void __reset_counters(struct pm_nl_pernet *pernet)
577 {
578         pernet->add_addr_signal_max = 0;
579         pernet->add_addr_accept_max = 0;
580         pernet->local_addr_max = 0;
581         pernet->addrs = 0;
582 }
583
584 static int mptcp_nl_cmd_flush_addrs(struct sk_buff *skb, struct genl_info *info)
585 {
586         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
587
588         spin_lock_bh(&pernet->lock);
589         __flush_addrs(pernet);
590         __reset_counters(pernet);
591         spin_unlock_bh(&pernet->lock);
592         return 0;
593 }
594
595 static int mptcp_nl_fill_addr(struct sk_buff *skb,
596                               struct mptcp_pm_addr_entry *entry)
597 {
598         struct mptcp_addr_info *addr = &entry->addr;
599         struct nlattr *attr;
600
601         attr = nla_nest_start(skb, MPTCP_PM_ATTR_ADDR);
602         if (!attr)
603                 return -EMSGSIZE;
604
605         if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_FAMILY, addr->family))
606                 goto nla_put_failure;
607         if (nla_put_u8(skb, MPTCP_PM_ADDR_ATTR_ID, addr->id))
608                 goto nla_put_failure;
609         if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->flags))
610                 goto nla_put_failure;
611         if (entry->ifindex &&
612             nla_put_s32(skb, MPTCP_PM_ADDR_ATTR_IF_IDX, entry->ifindex))
613                 goto nla_put_failure;
614
615         if (addr->family == AF_INET &&
616             nla_put_in_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR4,
617                             addr->addr.s_addr))
618                 goto nla_put_failure;
619 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
620         else if (addr->family == AF_INET6 &&
621                  nla_put_in6_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR6, &addr->addr6))
622                 goto nla_put_failure;
623 #endif
624         nla_nest_end(skb, attr);
625         return 0;
626
627 nla_put_failure:
628         nla_nest_cancel(skb, attr);
629         return -EMSGSIZE;
630 }
631
632 static int mptcp_nl_cmd_get_addr(struct sk_buff *skb, struct genl_info *info)
633 {
634         struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
635         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
636         struct mptcp_pm_addr_entry addr, *entry;
637         struct sk_buff *msg;
638         void *reply;
639         int ret;
640
641         ret = mptcp_pm_parse_addr(attr, info, false, &addr);
642         if (ret < 0)
643                 return ret;
644
645         msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
646         if (!msg)
647                 return -ENOMEM;
648
649         reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
650                                   info->genlhdr->cmd);
651         if (!reply) {
652                 GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
653                 ret = -EMSGSIZE;
654                 goto fail;
655         }
656
657         spin_lock_bh(&pernet->lock);
658         entry = __lookup_addr_by_id(pernet, addr.addr.id);
659         if (!entry) {
660                 GENL_SET_ERR_MSG(info, "address not found");
661                 ret = -EINVAL;
662                 goto unlock_fail;
663         }
664
665         ret = mptcp_nl_fill_addr(msg, entry);
666         if (ret)
667                 goto unlock_fail;
668
669         genlmsg_end(msg, reply);
670         ret = genlmsg_reply(msg, info);
671         spin_unlock_bh(&pernet->lock);
672         return ret;
673
674 unlock_fail:
675         spin_unlock_bh(&pernet->lock);
676
677 fail:
678         nlmsg_free(msg);
679         return ret;
680 }
681
682 static int mptcp_nl_cmd_dump_addrs(struct sk_buff *msg,
683                                    struct netlink_callback *cb)
684 {
685         struct net *net = sock_net(msg->sk);
686         struct mptcp_pm_addr_entry *entry;
687         struct pm_nl_pernet *pernet;
688         int id = cb->args[0];
689         void *hdr;
690
691         pernet = net_generic(net, pm_nl_pernet_id);
692
693         spin_lock_bh(&pernet->lock);
694         list_for_each_entry(entry, &pernet->local_addr_list, list) {
695                 if (entry->addr.id <= id)
696                         continue;
697
698                 hdr = genlmsg_put(msg, NETLINK_CB(cb->skb).portid,
699                                   cb->nlh->nlmsg_seq, &mptcp_genl_family,
700                                   NLM_F_MULTI, MPTCP_PM_CMD_GET_ADDR);
701                 if (!hdr)
702                         break;
703
704                 if (mptcp_nl_fill_addr(msg, entry) < 0) {
705                         genlmsg_cancel(msg, hdr);
706                         break;
707                 }
708
709                 id = entry->addr.id;
710                 genlmsg_end(msg, hdr);
711         }
712         spin_unlock_bh(&pernet->lock);
713
714         cb->args[0] = id;
715         return msg->len;
716 }
717
718 static int parse_limit(struct genl_info *info, int id, unsigned int *limit)
719 {
720         struct nlattr *attr = info->attrs[id];
721
722         if (!attr)
723                 return 0;
724
725         *limit = nla_get_u32(attr);
726         if (*limit > MPTCP_PM_ADDR_MAX) {
727                 GENL_SET_ERR_MSG(info, "limit greater than maximum");
728                 return -EINVAL;
729         }
730         return 0;
731 }
732
733 static int
734 mptcp_nl_cmd_set_limits(struct sk_buff *skb, struct genl_info *info)
735 {
736         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
737         unsigned int rcv_addrs, subflows;
738         int ret;
739
740         spin_lock_bh(&pernet->lock);
741         rcv_addrs = pernet->add_addr_accept_max;
742         ret = parse_limit(info, MPTCP_PM_ATTR_RCV_ADD_ADDRS, &rcv_addrs);
743         if (ret)
744                 goto unlock;
745
746         subflows = pernet->subflows_max;
747         ret = parse_limit(info, MPTCP_PM_ATTR_SUBFLOWS, &subflows);
748         if (ret)
749                 goto unlock;
750
751         WRITE_ONCE(pernet->add_addr_accept_max, rcv_addrs);
752         WRITE_ONCE(pernet->subflows_max, subflows);
753
754 unlock:
755         spin_unlock_bh(&pernet->lock);
756         return ret;
757 }
758
759 static int
760 mptcp_nl_cmd_get_limits(struct sk_buff *skb, struct genl_info *info)
761 {
762         struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
763         struct sk_buff *msg;
764         void *reply;
765
766         msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
767         if (!msg)
768                 return -ENOMEM;
769
770         reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
771                                   MPTCP_PM_CMD_GET_LIMITS);
772         if (!reply)
773                 goto fail;
774
775         if (nla_put_u32(msg, MPTCP_PM_ATTR_RCV_ADD_ADDRS,
776                         READ_ONCE(pernet->add_addr_accept_max)))
777                 goto fail;
778
779         if (nla_put_u32(msg, MPTCP_PM_ATTR_SUBFLOWS,
780                         READ_ONCE(pernet->subflows_max)))
781                 goto fail;
782
783         genlmsg_end(msg, reply);
784         return genlmsg_reply(msg, info);
785
786 fail:
787         GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
788         nlmsg_free(msg);
789         return -EMSGSIZE;
790 }
791
792 static struct genl_ops mptcp_pm_ops[] = {
793         {
794                 .cmd    = MPTCP_PM_CMD_ADD_ADDR,
795                 .doit   = mptcp_nl_cmd_add_addr,
796                 .flags  = GENL_ADMIN_PERM,
797         },
798         {
799                 .cmd    = MPTCP_PM_CMD_DEL_ADDR,
800                 .doit   = mptcp_nl_cmd_del_addr,
801                 .flags  = GENL_ADMIN_PERM,
802         },
803         {
804                 .cmd    = MPTCP_PM_CMD_FLUSH_ADDRS,
805                 .doit   = mptcp_nl_cmd_flush_addrs,
806                 .flags  = GENL_ADMIN_PERM,
807         },
808         {
809                 .cmd    = MPTCP_PM_CMD_GET_ADDR,
810                 .doit   = mptcp_nl_cmd_get_addr,
811                 .dumpit   = mptcp_nl_cmd_dump_addrs,
812         },
813         {
814                 .cmd    = MPTCP_PM_CMD_SET_LIMITS,
815                 .doit   = mptcp_nl_cmd_set_limits,
816                 .flags  = GENL_ADMIN_PERM,
817         },
818         {
819                 .cmd    = MPTCP_PM_CMD_GET_LIMITS,
820                 .doit   = mptcp_nl_cmd_get_limits,
821         },
822 };
823
824 static struct genl_family mptcp_genl_family __ro_after_init = {
825         .name           = MPTCP_PM_NAME,
826         .version        = MPTCP_PM_VER,
827         .maxattr        = MPTCP_PM_ATTR_MAX,
828         .policy         = mptcp_pm_policy,
829         .netnsok        = true,
830         .module         = THIS_MODULE,
831         .ops            = mptcp_pm_ops,
832         .n_ops          = ARRAY_SIZE(mptcp_pm_ops),
833         .mcgrps         = mptcp_pm_mcgrps,
834         .n_mcgrps       = ARRAY_SIZE(mptcp_pm_mcgrps),
835 };
836
837 static int __net_init pm_nl_init_net(struct net *net)
838 {
839         struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id);
840
841         INIT_LIST_HEAD_RCU(&pernet->local_addr_list);
842         __reset_counters(pernet);
843         pernet->next_id = 1;
844         spin_lock_init(&pernet->lock);
845         return 0;
846 }
847
848 static void __net_exit pm_nl_exit_net(struct list_head *net_list)
849 {
850         struct net *net;
851
852         list_for_each_entry(net, net_list, exit_list) {
853                 /* net is removed from namespace list, can't race with
854                  * other modifiers
855                  */
856                 __flush_addrs(net_generic(net, pm_nl_pernet_id));
857         }
858 }
859
860 static struct pernet_operations mptcp_pm_pernet_ops = {
861         .init = pm_nl_init_net,
862         .exit_batch = pm_nl_exit_net,
863         .id = &pm_nl_pernet_id,
864         .size = sizeof(struct pm_nl_pernet),
865 };
866
867 void __init mptcp_pm_nl_init(void)
868 {
869         if (register_pernet_subsys(&mptcp_pm_pernet_ops) < 0)
870                 panic("Failed to register MPTCP PM pernet subsystem.\n");
871
872         if (genl_register_family(&mptcp_genl_family))
873                 panic("Failed to register MPTCP PM netlink family\n");
874 }