Merge branch 'work.recursive_removal' of git://git.kernel.org/pub/scm/linux/kernel...
[linux-2.6-microblaze.git] / net / bridge / br_mdb.c
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/err.h>
3 #include <linux/igmp.h>
4 #include <linux/kernel.h>
5 #include <linux/netdevice.h>
6 #include <linux/rculist.h>
7 #include <linux/skbuff.h>
8 #include <linux/if_ether.h>
9 #include <net/ip.h>
10 #include <net/netlink.h>
11 #include <net/switchdev.h>
12 #if IS_ENABLED(CONFIG_IPV6)
13 #include <net/ipv6.h>
14 #include <net/addrconf.h>
15 #endif
16
17 #include "br_private.h"
18
19 static int br_rports_fill_info(struct sk_buff *skb, struct netlink_callback *cb,
20                                struct net_device *dev)
21 {
22         struct net_bridge *br = netdev_priv(dev);
23         struct net_bridge_port *p;
24         struct nlattr *nest, *port_nest;
25
26         if (!br->multicast_router || hlist_empty(&br->router_list))
27                 return 0;
28
29         nest = nla_nest_start_noflag(skb, MDBA_ROUTER);
30         if (nest == NULL)
31                 return -EMSGSIZE;
32
33         hlist_for_each_entry_rcu(p, &br->router_list, rlist) {
34                 if (!p)
35                         continue;
36                 port_nest = nla_nest_start_noflag(skb, MDBA_ROUTER_PORT);
37                 if (!port_nest)
38                         goto fail;
39                 if (nla_put_nohdr(skb, sizeof(u32), &p->dev->ifindex) ||
40                     nla_put_u32(skb, MDBA_ROUTER_PATTR_TIMER,
41                                 br_timer_value(&p->multicast_router_timer)) ||
42                     nla_put_u8(skb, MDBA_ROUTER_PATTR_TYPE,
43                                p->multicast_router)) {
44                         nla_nest_cancel(skb, port_nest);
45                         goto fail;
46                 }
47                 nla_nest_end(skb, port_nest);
48         }
49
50         nla_nest_end(skb, nest);
51         return 0;
52 fail:
53         nla_nest_cancel(skb, nest);
54         return -EMSGSIZE;
55 }
56
57 static void __mdb_entry_fill_flags(struct br_mdb_entry *e, unsigned char flags)
58 {
59         e->state = flags & MDB_PG_FLAGS_PERMANENT;
60         e->flags = 0;
61         if (flags & MDB_PG_FLAGS_OFFLOAD)
62                 e->flags |= MDB_FLAGS_OFFLOAD;
63         if (flags & MDB_PG_FLAGS_FAST_LEAVE)
64                 e->flags |= MDB_FLAGS_FAST_LEAVE;
65         if (flags & MDB_PG_FLAGS_STAR_EXCL)
66                 e->flags |= MDB_FLAGS_STAR_EXCL;
67         if (flags & MDB_PG_FLAGS_BLOCKED)
68                 e->flags |= MDB_FLAGS_BLOCKED;
69 }
70
71 static void __mdb_entry_to_br_ip(struct br_mdb_entry *entry, struct br_ip *ip,
72                                  struct nlattr **mdb_attrs)
73 {
74         memset(ip, 0, sizeof(struct br_ip));
75         ip->vid = entry->vid;
76         ip->proto = entry->addr.proto;
77         switch (ip->proto) {
78         case htons(ETH_P_IP):
79                 ip->dst.ip4 = entry->addr.u.ip4;
80                 if (mdb_attrs && mdb_attrs[MDBE_ATTR_SOURCE])
81                         ip->src.ip4 = nla_get_in_addr(mdb_attrs[MDBE_ATTR_SOURCE]);
82                 break;
83 #if IS_ENABLED(CONFIG_IPV6)
84         case htons(ETH_P_IPV6):
85                 ip->dst.ip6 = entry->addr.u.ip6;
86                 if (mdb_attrs && mdb_attrs[MDBE_ATTR_SOURCE])
87                         ip->src.ip6 = nla_get_in6_addr(mdb_attrs[MDBE_ATTR_SOURCE]);
88                 break;
89 #endif
90         default:
91                 ether_addr_copy(ip->dst.mac_addr, entry->addr.u.mac_addr);
92         }
93
94 }
95
96 static int __mdb_fill_srcs(struct sk_buff *skb,
97                            struct net_bridge_port_group *p)
98 {
99         struct net_bridge_group_src *ent;
100         struct nlattr *nest, *nest_ent;
101
102         if (hlist_empty(&p->src_list))
103                 return 0;
104
105         nest = nla_nest_start(skb, MDBA_MDB_EATTR_SRC_LIST);
106         if (!nest)
107                 return -EMSGSIZE;
108
109         hlist_for_each_entry_rcu(ent, &p->src_list, node,
110                                  lockdep_is_held(&p->key.port->br->multicast_lock)) {
111                 nest_ent = nla_nest_start(skb, MDBA_MDB_SRCLIST_ENTRY);
112                 if (!nest_ent)
113                         goto out_cancel_err;
114                 switch (ent->addr.proto) {
115                 case htons(ETH_P_IP):
116                         if (nla_put_in_addr(skb, MDBA_MDB_SRCATTR_ADDRESS,
117                                             ent->addr.src.ip4)) {
118                                 nla_nest_cancel(skb, nest_ent);
119                                 goto out_cancel_err;
120                         }
121                         break;
122 #if IS_ENABLED(CONFIG_IPV6)
123                 case htons(ETH_P_IPV6):
124                         if (nla_put_in6_addr(skb, MDBA_MDB_SRCATTR_ADDRESS,
125                                              &ent->addr.src.ip6)) {
126                                 nla_nest_cancel(skb, nest_ent);
127                                 goto out_cancel_err;
128                         }
129                         break;
130 #endif
131                 default:
132                         nla_nest_cancel(skb, nest_ent);
133                         continue;
134                 }
135                 if (nla_put_u32(skb, MDBA_MDB_SRCATTR_TIMER,
136                                 br_timer_value(&ent->timer))) {
137                         nla_nest_cancel(skb, nest_ent);
138                         goto out_cancel_err;
139                 }
140                 nla_nest_end(skb, nest_ent);
141         }
142
143         nla_nest_end(skb, nest);
144
145         return 0;
146
147 out_cancel_err:
148         nla_nest_cancel(skb, nest);
149         return -EMSGSIZE;
150 }
151
152 static int __mdb_fill_info(struct sk_buff *skb,
153                            struct net_bridge_mdb_entry *mp,
154                            struct net_bridge_port_group *p)
155 {
156         bool dump_srcs_mode = false;
157         struct timer_list *mtimer;
158         struct nlattr *nest_ent;
159         struct br_mdb_entry e;
160         u8 flags = 0;
161         int ifindex;
162
163         memset(&e, 0, sizeof(e));
164         if (p) {
165                 ifindex = p->key.port->dev->ifindex;
166                 mtimer = &p->timer;
167                 flags = p->flags;
168         } else {
169                 ifindex = mp->br->dev->ifindex;
170                 mtimer = &mp->timer;
171         }
172
173         __mdb_entry_fill_flags(&e, flags);
174         e.ifindex = ifindex;
175         e.vid = mp->addr.vid;
176         if (mp->addr.proto == htons(ETH_P_IP))
177                 e.addr.u.ip4 = mp->addr.dst.ip4;
178 #if IS_ENABLED(CONFIG_IPV6)
179         else if (mp->addr.proto == htons(ETH_P_IPV6))
180                 e.addr.u.ip6 = mp->addr.dst.ip6;
181 #endif
182         else
183                 ether_addr_copy(e.addr.u.mac_addr, mp->addr.dst.mac_addr);
184         e.addr.proto = mp->addr.proto;
185         nest_ent = nla_nest_start_noflag(skb,
186                                          MDBA_MDB_ENTRY_INFO);
187         if (!nest_ent)
188                 return -EMSGSIZE;
189
190         if (nla_put_nohdr(skb, sizeof(e), &e) ||
191             nla_put_u32(skb,
192                         MDBA_MDB_EATTR_TIMER,
193                         br_timer_value(mtimer)))
194                 goto nest_err;
195
196         switch (mp->addr.proto) {
197         case htons(ETH_P_IP):
198                 dump_srcs_mode = !!(mp->br->multicast_igmp_version == 3);
199                 if (mp->addr.src.ip4) {
200                         if (nla_put_in_addr(skb, MDBA_MDB_EATTR_SOURCE,
201                                             mp->addr.src.ip4))
202                                 goto nest_err;
203                         break;
204                 }
205                 break;
206 #if IS_ENABLED(CONFIG_IPV6)
207         case htons(ETH_P_IPV6):
208                 dump_srcs_mode = !!(mp->br->multicast_mld_version == 2);
209                 if (!ipv6_addr_any(&mp->addr.src.ip6)) {
210                         if (nla_put_in6_addr(skb, MDBA_MDB_EATTR_SOURCE,
211                                              &mp->addr.src.ip6))
212                                 goto nest_err;
213                         break;
214                 }
215                 break;
216 #endif
217         default:
218                 ether_addr_copy(e.addr.u.mac_addr, mp->addr.dst.mac_addr);
219         }
220         if (p) {
221                 if (nla_put_u8(skb, MDBA_MDB_EATTR_RTPROT, p->rt_protocol))
222                         goto nest_err;
223                 if (dump_srcs_mode &&
224                     (__mdb_fill_srcs(skb, p) ||
225                      nla_put_u8(skb, MDBA_MDB_EATTR_GROUP_MODE,
226                                 p->filter_mode)))
227                         goto nest_err;
228         }
229         nla_nest_end(skb, nest_ent);
230
231         return 0;
232
233 nest_err:
234         nla_nest_cancel(skb, nest_ent);
235         return -EMSGSIZE;
236 }
237
238 static int br_mdb_fill_info(struct sk_buff *skb, struct netlink_callback *cb,
239                             struct net_device *dev)
240 {
241         int idx = 0, s_idx = cb->args[1], err = 0, pidx = 0, s_pidx = cb->args[2];
242         struct net_bridge *br = netdev_priv(dev);
243         struct net_bridge_mdb_entry *mp;
244         struct nlattr *nest, *nest2;
245
246         if (!br_opt_get(br, BROPT_MULTICAST_ENABLED))
247                 return 0;
248
249         nest = nla_nest_start_noflag(skb, MDBA_MDB);
250         if (nest == NULL)
251                 return -EMSGSIZE;
252
253         hlist_for_each_entry_rcu(mp, &br->mdb_list, mdb_node) {
254                 struct net_bridge_port_group *p;
255                 struct net_bridge_port_group __rcu **pp;
256
257                 if (idx < s_idx)
258                         goto skip;
259
260                 nest2 = nla_nest_start_noflag(skb, MDBA_MDB_ENTRY);
261                 if (!nest2) {
262                         err = -EMSGSIZE;
263                         break;
264                 }
265
266                 if (!s_pidx && mp->host_joined) {
267                         err = __mdb_fill_info(skb, mp, NULL);
268                         if (err) {
269                                 nla_nest_cancel(skb, nest2);
270                                 break;
271                         }
272                 }
273
274                 for (pp = &mp->ports; (p = rcu_dereference(*pp)) != NULL;
275                       pp = &p->next) {
276                         if (!p->key.port)
277                                 continue;
278                         if (pidx < s_pidx)
279                                 goto skip_pg;
280
281                         err = __mdb_fill_info(skb, mp, p);
282                         if (err) {
283                                 nla_nest_end(skb, nest2);
284                                 goto out;
285                         }
286 skip_pg:
287                         pidx++;
288                 }
289                 pidx = 0;
290                 s_pidx = 0;
291                 nla_nest_end(skb, nest2);
292 skip:
293                 idx++;
294         }
295
296 out:
297         cb->args[1] = idx;
298         cb->args[2] = pidx;
299         nla_nest_end(skb, nest);
300         return err;
301 }
302
303 static int br_mdb_valid_dump_req(const struct nlmsghdr *nlh,
304                                  struct netlink_ext_ack *extack)
305 {
306         struct br_port_msg *bpm;
307
308         if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*bpm))) {
309                 NL_SET_ERR_MSG_MOD(extack, "Invalid header for mdb dump request");
310                 return -EINVAL;
311         }
312
313         bpm = nlmsg_data(nlh);
314         if (bpm->ifindex) {
315                 NL_SET_ERR_MSG_MOD(extack, "Filtering by device index is not supported for mdb dump request");
316                 return -EINVAL;
317         }
318         if (nlmsg_attrlen(nlh, sizeof(*bpm))) {
319                 NL_SET_ERR_MSG(extack, "Invalid data after header in mdb dump request");
320                 return -EINVAL;
321         }
322
323         return 0;
324 }
325
326 static int br_mdb_dump(struct sk_buff *skb, struct netlink_callback *cb)
327 {
328         struct net_device *dev;
329         struct net *net = sock_net(skb->sk);
330         struct nlmsghdr *nlh = NULL;
331         int idx = 0, s_idx;
332
333         if (cb->strict_check) {
334                 int err = br_mdb_valid_dump_req(cb->nlh, cb->extack);
335
336                 if (err < 0)
337                         return err;
338         }
339
340         s_idx = cb->args[0];
341
342         rcu_read_lock();
343
344         cb->seq = net->dev_base_seq;
345
346         for_each_netdev_rcu(net, dev) {
347                 if (dev->priv_flags & IFF_EBRIDGE) {
348                         struct br_port_msg *bpm;
349
350                         if (idx < s_idx)
351                                 goto skip;
352
353                         nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid,
354                                         cb->nlh->nlmsg_seq, RTM_GETMDB,
355                                         sizeof(*bpm), NLM_F_MULTI);
356                         if (nlh == NULL)
357                                 break;
358
359                         bpm = nlmsg_data(nlh);
360                         memset(bpm, 0, sizeof(*bpm));
361                         bpm->ifindex = dev->ifindex;
362                         if (br_mdb_fill_info(skb, cb, dev) < 0)
363                                 goto out;
364                         if (br_rports_fill_info(skb, cb, dev) < 0)
365                                 goto out;
366
367                         cb->args[1] = 0;
368                         nlmsg_end(skb, nlh);
369                 skip:
370                         idx++;
371                 }
372         }
373
374 out:
375         if (nlh)
376                 nlmsg_end(skb, nlh);
377         rcu_read_unlock();
378         cb->args[0] = idx;
379         return skb->len;
380 }
381
382 static int nlmsg_populate_mdb_fill(struct sk_buff *skb,
383                                    struct net_device *dev,
384                                    struct net_bridge_mdb_entry *mp,
385                                    struct net_bridge_port_group *pg,
386                                    int type)
387 {
388         struct nlmsghdr *nlh;
389         struct br_port_msg *bpm;
390         struct nlattr *nest, *nest2;
391
392         nlh = nlmsg_put(skb, 0, 0, type, sizeof(*bpm), 0);
393         if (!nlh)
394                 return -EMSGSIZE;
395
396         bpm = nlmsg_data(nlh);
397         memset(bpm, 0, sizeof(*bpm));
398         bpm->family  = AF_BRIDGE;
399         bpm->ifindex = dev->ifindex;
400         nest = nla_nest_start_noflag(skb, MDBA_MDB);
401         if (nest == NULL)
402                 goto cancel;
403         nest2 = nla_nest_start_noflag(skb, MDBA_MDB_ENTRY);
404         if (nest2 == NULL)
405                 goto end;
406
407         if (__mdb_fill_info(skb, mp, pg))
408                 goto end;
409
410         nla_nest_end(skb, nest2);
411         nla_nest_end(skb, nest);
412         nlmsg_end(skb, nlh);
413         return 0;
414
415 end:
416         nla_nest_end(skb, nest);
417 cancel:
418         nlmsg_cancel(skb, nlh);
419         return -EMSGSIZE;
420 }
421
422 static size_t rtnl_mdb_nlmsg_size(struct net_bridge_port_group *pg)
423 {
424         size_t nlmsg_size = NLMSG_ALIGN(sizeof(struct br_port_msg)) +
425                             nla_total_size(sizeof(struct br_mdb_entry)) +
426                             nla_total_size(sizeof(u32));
427         struct net_bridge_group_src *ent;
428         size_t addr_size = 0;
429
430         if (!pg)
431                 goto out;
432
433         /* MDBA_MDB_EATTR_RTPROT */
434         nlmsg_size += nla_total_size(sizeof(u8));
435
436         switch (pg->key.addr.proto) {
437         case htons(ETH_P_IP):
438                 /* MDBA_MDB_EATTR_SOURCE */
439                 if (pg->key.addr.src.ip4)
440                         nlmsg_size += nla_total_size(sizeof(__be32));
441                 if (pg->key.port->br->multicast_igmp_version == 2)
442                         goto out;
443                 addr_size = sizeof(__be32);
444                 break;
445 #if IS_ENABLED(CONFIG_IPV6)
446         case htons(ETH_P_IPV6):
447                 /* MDBA_MDB_EATTR_SOURCE */
448                 if (!ipv6_addr_any(&pg->key.addr.src.ip6))
449                         nlmsg_size += nla_total_size(sizeof(struct in6_addr));
450                 if (pg->key.port->br->multicast_mld_version == 1)
451                         goto out;
452                 addr_size = sizeof(struct in6_addr);
453                 break;
454 #endif
455         }
456
457         /* MDBA_MDB_EATTR_GROUP_MODE */
458         nlmsg_size += nla_total_size(sizeof(u8));
459
460         /* MDBA_MDB_EATTR_SRC_LIST nested attr */
461         if (!hlist_empty(&pg->src_list))
462                 nlmsg_size += nla_total_size(0);
463
464         hlist_for_each_entry(ent, &pg->src_list, node) {
465                 /* MDBA_MDB_SRCLIST_ENTRY nested attr +
466                  * MDBA_MDB_SRCATTR_ADDRESS + MDBA_MDB_SRCATTR_TIMER
467                  */
468                 nlmsg_size += nla_total_size(0) +
469                               nla_total_size(addr_size) +
470                               nla_total_size(sizeof(u32));
471         }
472 out:
473         return nlmsg_size;
474 }
475
476 struct br_mdb_complete_info {
477         struct net_bridge_port *port;
478         struct br_ip ip;
479 };
480
481 static void br_mdb_complete(struct net_device *dev, int err, void *priv)
482 {
483         struct br_mdb_complete_info *data = priv;
484         struct net_bridge_port_group __rcu **pp;
485         struct net_bridge_port_group *p;
486         struct net_bridge_mdb_entry *mp;
487         struct net_bridge_port *port = data->port;
488         struct net_bridge *br = port->br;
489
490         if (err)
491                 goto err;
492
493         spin_lock_bh(&br->multicast_lock);
494         mp = br_mdb_ip_get(br, &data->ip);
495         if (!mp)
496                 goto out;
497         for (pp = &mp->ports; (p = mlock_dereference(*pp, br)) != NULL;
498              pp = &p->next) {
499                 if (p->key.port != port)
500                         continue;
501                 p->flags |= MDB_PG_FLAGS_OFFLOAD;
502         }
503 out:
504         spin_unlock_bh(&br->multicast_lock);
505 err:
506         kfree(priv);
507 }
508
509 static void br_switchdev_mdb_populate(struct switchdev_obj_port_mdb *mdb,
510                                       const struct net_bridge_mdb_entry *mp)
511 {
512         if (mp->addr.proto == htons(ETH_P_IP))
513                 ip_eth_mc_map(mp->addr.dst.ip4, mdb->addr);
514 #if IS_ENABLED(CONFIG_IPV6)
515         else if (mp->addr.proto == htons(ETH_P_IPV6))
516                 ipv6_eth_mc_map(&mp->addr.dst.ip6, mdb->addr);
517 #endif
518         else
519                 ether_addr_copy(mdb->addr, mp->addr.dst.mac_addr);
520
521         mdb->vid = mp->addr.vid;
522 }
523
524 static int br_mdb_replay_one(struct notifier_block *nb, struct net_device *dev,
525                              struct switchdev_obj_port_mdb *mdb,
526                              struct netlink_ext_ack *extack)
527 {
528         struct switchdev_notifier_port_obj_info obj_info = {
529                 .info = {
530                         .dev = dev,
531                         .extack = extack,
532                 },
533                 .obj = &mdb->obj,
534         };
535         int err;
536
537         err = nb->notifier_call(nb, SWITCHDEV_PORT_OBJ_ADD, &obj_info);
538         return notifier_to_errno(err);
539 }
540
541 static int br_mdb_queue_one(struct list_head *mdb_list,
542                             enum switchdev_obj_id id,
543                             const struct net_bridge_mdb_entry *mp,
544                             struct net_device *orig_dev)
545 {
546         struct switchdev_obj_port_mdb *mdb;
547
548         mdb = kzalloc(sizeof(*mdb), GFP_ATOMIC);
549         if (!mdb)
550                 return -ENOMEM;
551
552         mdb->obj.id = id;
553         mdb->obj.orig_dev = orig_dev;
554         br_switchdev_mdb_populate(mdb, mp);
555         list_add_tail(&mdb->obj.list, mdb_list);
556
557         return 0;
558 }
559
560 int br_mdb_replay(struct net_device *br_dev, struct net_device *dev,
561                   struct notifier_block *nb, struct netlink_ext_ack *extack)
562 {
563         struct net_bridge_mdb_entry *mp;
564         struct switchdev_obj *obj, *tmp;
565         struct net_bridge *br;
566         LIST_HEAD(mdb_list);
567         int err = 0;
568
569         ASSERT_RTNL();
570
571         if (!netif_is_bridge_master(br_dev) || !netif_is_bridge_port(dev))
572                 return -EINVAL;
573
574         br = netdev_priv(br_dev);
575
576         if (!br_opt_get(br, BROPT_MULTICAST_ENABLED))
577                 return 0;
578
579         /* We cannot walk over br->mdb_list protected just by the rtnl_mutex,
580          * because the write-side protection is br->multicast_lock. But we
581          * need to emulate the [ blocking ] calling context of a regular
582          * switchdev event, so since both br->multicast_lock and RCU read side
583          * critical sections are atomic, we have no choice but to pick the RCU
584          * read side lock, queue up all our events, leave the critical section
585          * and notify switchdev from blocking context.
586          */
587         rcu_read_lock();
588
589         hlist_for_each_entry_rcu(mp, &br->mdb_list, mdb_node) {
590                 struct net_bridge_port_group __rcu **pp;
591                 struct net_bridge_port_group *p;
592
593                 if (mp->host_joined) {
594                         err = br_mdb_queue_one(&mdb_list,
595                                                SWITCHDEV_OBJ_ID_HOST_MDB,
596                                                mp, br_dev);
597                         if (err) {
598                                 rcu_read_unlock();
599                                 goto out_free_mdb;
600                         }
601                 }
602
603                 for (pp = &mp->ports; (p = rcu_dereference(*pp)) != NULL;
604                      pp = &p->next) {
605                         if (p->key.port->dev != dev)
606                                 continue;
607
608                         err = br_mdb_queue_one(&mdb_list,
609                                                SWITCHDEV_OBJ_ID_PORT_MDB,
610                                                mp, dev);
611                         if (err) {
612                                 rcu_read_unlock();
613                                 goto out_free_mdb;
614                         }
615                 }
616         }
617
618         rcu_read_unlock();
619
620         list_for_each_entry(obj, &mdb_list, list) {
621                 err = br_mdb_replay_one(nb, dev, SWITCHDEV_OBJ_PORT_MDB(obj),
622                                         extack);
623                 if (err)
624                         goto out_free_mdb;
625         }
626
627 out_free_mdb:
628         list_for_each_entry_safe(obj, tmp, &mdb_list, list) {
629                 list_del(&obj->list);
630                 kfree(SWITCHDEV_OBJ_PORT_MDB(obj));
631         }
632
633         return err;
634 }
635 EXPORT_SYMBOL_GPL(br_mdb_replay);
636
637 static void br_mdb_switchdev_host_port(struct net_device *dev,
638                                        struct net_device *lower_dev,
639                                        struct net_bridge_mdb_entry *mp,
640                                        int type)
641 {
642         struct switchdev_obj_port_mdb mdb = {
643                 .obj = {
644                         .id = SWITCHDEV_OBJ_ID_HOST_MDB,
645                         .flags = SWITCHDEV_F_DEFER,
646                         .orig_dev = dev,
647                 },
648         };
649
650         br_switchdev_mdb_populate(&mdb, mp);
651
652         switch (type) {
653         case RTM_NEWMDB:
654                 switchdev_port_obj_add(lower_dev, &mdb.obj, NULL);
655                 break;
656         case RTM_DELMDB:
657                 switchdev_port_obj_del(lower_dev, &mdb.obj);
658                 break;
659         }
660 }
661
662 static void br_mdb_switchdev_host(struct net_device *dev,
663                                   struct net_bridge_mdb_entry *mp, int type)
664 {
665         struct net_device *lower_dev;
666         struct list_head *iter;
667
668         netdev_for_each_lower_dev(dev, lower_dev, iter)
669                 br_mdb_switchdev_host_port(dev, lower_dev, mp, type);
670 }
671
672 void br_mdb_notify(struct net_device *dev,
673                    struct net_bridge_mdb_entry *mp,
674                    struct net_bridge_port_group *pg,
675                    int type)
676 {
677         struct br_mdb_complete_info *complete_info;
678         struct switchdev_obj_port_mdb mdb = {
679                 .obj = {
680                         .id = SWITCHDEV_OBJ_ID_PORT_MDB,
681                         .flags = SWITCHDEV_F_DEFER,
682                 },
683         };
684         struct net *net = dev_net(dev);
685         struct sk_buff *skb;
686         int err = -ENOBUFS;
687
688         if (pg) {
689                 br_switchdev_mdb_populate(&mdb, mp);
690
691                 mdb.obj.orig_dev = pg->key.port->dev;
692                 switch (type) {
693                 case RTM_NEWMDB:
694                         complete_info = kmalloc(sizeof(*complete_info), GFP_ATOMIC);
695                         if (!complete_info)
696                                 break;
697                         complete_info->port = pg->key.port;
698                         complete_info->ip = mp->addr;
699                         mdb.obj.complete_priv = complete_info;
700                         mdb.obj.complete = br_mdb_complete;
701                         if (switchdev_port_obj_add(pg->key.port->dev, &mdb.obj, NULL))
702                                 kfree(complete_info);
703                         break;
704                 case RTM_DELMDB:
705                         switchdev_port_obj_del(pg->key.port->dev, &mdb.obj);
706                         break;
707                 }
708         } else {
709                 br_mdb_switchdev_host(dev, mp, type);
710         }
711
712         skb = nlmsg_new(rtnl_mdb_nlmsg_size(pg), GFP_ATOMIC);
713         if (!skb)
714                 goto errout;
715
716         err = nlmsg_populate_mdb_fill(skb, dev, mp, pg, type);
717         if (err < 0) {
718                 kfree_skb(skb);
719                 goto errout;
720         }
721
722         rtnl_notify(skb, net, 0, RTNLGRP_MDB, NULL, GFP_ATOMIC);
723         return;
724 errout:
725         rtnl_set_sk_err(net, RTNLGRP_MDB, err);
726 }
727
728 static int nlmsg_populate_rtr_fill(struct sk_buff *skb,
729                                    struct net_device *dev,
730                                    int ifindex, u32 pid,
731                                    u32 seq, int type, unsigned int flags)
732 {
733         struct br_port_msg *bpm;
734         struct nlmsghdr *nlh;
735         struct nlattr *nest;
736
737         nlh = nlmsg_put(skb, pid, seq, type, sizeof(*bpm), 0);
738         if (!nlh)
739                 return -EMSGSIZE;
740
741         bpm = nlmsg_data(nlh);
742         memset(bpm, 0, sizeof(*bpm));
743         bpm->family = AF_BRIDGE;
744         bpm->ifindex = dev->ifindex;
745         nest = nla_nest_start_noflag(skb, MDBA_ROUTER);
746         if (!nest)
747                 goto cancel;
748
749         if (nla_put_u32(skb, MDBA_ROUTER_PORT, ifindex))
750                 goto end;
751
752         nla_nest_end(skb, nest);
753         nlmsg_end(skb, nlh);
754         return 0;
755
756 end:
757         nla_nest_end(skb, nest);
758 cancel:
759         nlmsg_cancel(skb, nlh);
760         return -EMSGSIZE;
761 }
762
763 static inline size_t rtnl_rtr_nlmsg_size(void)
764 {
765         return NLMSG_ALIGN(sizeof(struct br_port_msg))
766                 + nla_total_size(sizeof(__u32));
767 }
768
769 void br_rtr_notify(struct net_device *dev, struct net_bridge_port *port,
770                    int type)
771 {
772         struct net *net = dev_net(dev);
773         struct sk_buff *skb;
774         int err = -ENOBUFS;
775         int ifindex;
776
777         ifindex = port ? port->dev->ifindex : 0;
778         skb = nlmsg_new(rtnl_rtr_nlmsg_size(), GFP_ATOMIC);
779         if (!skb)
780                 goto errout;
781
782         err = nlmsg_populate_rtr_fill(skb, dev, ifindex, 0, 0, type, NTF_SELF);
783         if (err < 0) {
784                 kfree_skb(skb);
785                 goto errout;
786         }
787
788         rtnl_notify(skb, net, 0, RTNLGRP_MDB, NULL, GFP_ATOMIC);
789         return;
790
791 errout:
792         rtnl_set_sk_err(net, RTNLGRP_MDB, err);
793 }
794
795 static bool is_valid_mdb_entry(struct br_mdb_entry *entry,
796                                struct netlink_ext_ack *extack)
797 {
798         if (entry->ifindex == 0) {
799                 NL_SET_ERR_MSG_MOD(extack, "Zero entry ifindex is not allowed");
800                 return false;
801         }
802
803         if (entry->addr.proto == htons(ETH_P_IP)) {
804                 if (!ipv4_is_multicast(entry->addr.u.ip4)) {
805                         NL_SET_ERR_MSG_MOD(extack, "IPv4 entry group address is not multicast");
806                         return false;
807                 }
808                 if (ipv4_is_local_multicast(entry->addr.u.ip4)) {
809                         NL_SET_ERR_MSG_MOD(extack, "IPv4 entry group address is local multicast");
810                         return false;
811                 }
812 #if IS_ENABLED(CONFIG_IPV6)
813         } else if (entry->addr.proto == htons(ETH_P_IPV6)) {
814                 if (ipv6_addr_is_ll_all_nodes(&entry->addr.u.ip6)) {
815                         NL_SET_ERR_MSG_MOD(extack, "IPv6 entry group address is link-local all nodes");
816                         return false;
817                 }
818 #endif
819         } else if (entry->addr.proto == 0) {
820                 /* L2 mdb */
821                 if (!is_multicast_ether_addr(entry->addr.u.mac_addr)) {
822                         NL_SET_ERR_MSG_MOD(extack, "L2 entry group is not multicast");
823                         return false;
824                 }
825         } else {
826                 NL_SET_ERR_MSG_MOD(extack, "Unknown entry protocol");
827                 return false;
828         }
829
830         if (entry->state != MDB_PERMANENT && entry->state != MDB_TEMPORARY) {
831                 NL_SET_ERR_MSG_MOD(extack, "Unknown entry state");
832                 return false;
833         }
834         if (entry->vid >= VLAN_VID_MASK) {
835                 NL_SET_ERR_MSG_MOD(extack, "Invalid entry VLAN id");
836                 return false;
837         }
838
839         return true;
840 }
841
842 static bool is_valid_mdb_source(struct nlattr *attr, __be16 proto,
843                                 struct netlink_ext_ack *extack)
844 {
845         switch (proto) {
846         case htons(ETH_P_IP):
847                 if (nla_len(attr) != sizeof(struct in_addr)) {
848                         NL_SET_ERR_MSG_MOD(extack, "IPv4 invalid source address length");
849                         return false;
850                 }
851                 if (ipv4_is_multicast(nla_get_in_addr(attr))) {
852                         NL_SET_ERR_MSG_MOD(extack, "IPv4 multicast source address is not allowed");
853                         return false;
854                 }
855                 break;
856 #if IS_ENABLED(CONFIG_IPV6)
857         case htons(ETH_P_IPV6): {
858                 struct in6_addr src;
859
860                 if (nla_len(attr) != sizeof(struct in6_addr)) {
861                         NL_SET_ERR_MSG_MOD(extack, "IPv6 invalid source address length");
862                         return false;
863                 }
864                 src = nla_get_in6_addr(attr);
865                 if (ipv6_addr_is_multicast(&src)) {
866                         NL_SET_ERR_MSG_MOD(extack, "IPv6 multicast source address is not allowed");
867                         return false;
868                 }
869                 break;
870         }
871 #endif
872         default:
873                 NL_SET_ERR_MSG_MOD(extack, "Invalid protocol used with source address");
874                 return false;
875         }
876
877         return true;
878 }
879
880 static const struct nla_policy br_mdbe_attrs_pol[MDBE_ATTR_MAX + 1] = {
881         [MDBE_ATTR_SOURCE] = NLA_POLICY_RANGE(NLA_BINARY,
882                                               sizeof(struct in_addr),
883                                               sizeof(struct in6_addr)),
884 };
885
886 static int br_mdb_parse(struct sk_buff *skb, struct nlmsghdr *nlh,
887                         struct net_device **pdev, struct br_mdb_entry **pentry,
888                         struct nlattr **mdb_attrs, struct netlink_ext_ack *extack)
889 {
890         struct net *net = sock_net(skb->sk);
891         struct br_mdb_entry *entry;
892         struct br_port_msg *bpm;
893         struct nlattr *tb[MDBA_SET_ENTRY_MAX+1];
894         struct net_device *dev;
895         int err;
896
897         err = nlmsg_parse_deprecated(nlh, sizeof(*bpm), tb,
898                                      MDBA_SET_ENTRY_MAX, NULL, NULL);
899         if (err < 0)
900                 return err;
901
902         bpm = nlmsg_data(nlh);
903         if (bpm->ifindex == 0) {
904                 NL_SET_ERR_MSG_MOD(extack, "Invalid bridge ifindex");
905                 return -EINVAL;
906         }
907
908         dev = __dev_get_by_index(net, bpm->ifindex);
909         if (dev == NULL) {
910                 NL_SET_ERR_MSG_MOD(extack, "Bridge device doesn't exist");
911                 return -ENODEV;
912         }
913
914         if (!(dev->priv_flags & IFF_EBRIDGE)) {
915                 NL_SET_ERR_MSG_MOD(extack, "Device is not a bridge");
916                 return -EOPNOTSUPP;
917         }
918
919         *pdev = dev;
920
921         if (!tb[MDBA_SET_ENTRY]) {
922                 NL_SET_ERR_MSG_MOD(extack, "Missing MDBA_SET_ENTRY attribute");
923                 return -EINVAL;
924         }
925         if (nla_len(tb[MDBA_SET_ENTRY]) != sizeof(struct br_mdb_entry)) {
926                 NL_SET_ERR_MSG_MOD(extack, "Invalid MDBA_SET_ENTRY attribute length");
927                 return -EINVAL;
928         }
929
930         entry = nla_data(tb[MDBA_SET_ENTRY]);
931         if (!is_valid_mdb_entry(entry, extack))
932                 return -EINVAL;
933         *pentry = entry;
934
935         if (tb[MDBA_SET_ENTRY_ATTRS]) {
936                 err = nla_parse_nested(mdb_attrs, MDBE_ATTR_MAX,
937                                        tb[MDBA_SET_ENTRY_ATTRS],
938                                        br_mdbe_attrs_pol, extack);
939                 if (err)
940                         return err;
941                 if (mdb_attrs[MDBE_ATTR_SOURCE] &&
942                     !is_valid_mdb_source(mdb_attrs[MDBE_ATTR_SOURCE],
943                                          entry->addr.proto, extack))
944                         return -EINVAL;
945         } else {
946                 memset(mdb_attrs, 0,
947                        sizeof(struct nlattr *) * (MDBE_ATTR_MAX + 1));
948         }
949
950         return 0;
951 }
952
953 static int br_mdb_add_group(struct net_bridge *br, struct net_bridge_port *port,
954                             struct br_mdb_entry *entry,
955                             struct nlattr **mdb_attrs,
956                             struct netlink_ext_ack *extack)
957 {
958         struct net_bridge_mdb_entry *mp, *star_mp;
959         struct net_bridge_port_group *p;
960         struct net_bridge_port_group __rcu **pp;
961         struct br_ip group, star_group;
962         unsigned long now = jiffies;
963         unsigned char flags = 0;
964         u8 filter_mode;
965         int err;
966
967         __mdb_entry_to_br_ip(entry, &group, mdb_attrs);
968
969         /* host join errors which can happen before creating the group */
970         if (!port) {
971                 /* don't allow any flags for host-joined groups */
972                 if (entry->state) {
973                         NL_SET_ERR_MSG_MOD(extack, "Flags are not allowed for host groups");
974                         return -EINVAL;
975                 }
976                 if (!br_multicast_is_star_g(&group)) {
977                         NL_SET_ERR_MSG_MOD(extack, "Groups with sources cannot be manually host joined");
978                         return -EINVAL;
979                 }
980         }
981
982         if (br_group_is_l2(&group) && entry->state != MDB_PERMANENT) {
983                 NL_SET_ERR_MSG_MOD(extack, "Only permanent L2 entries allowed");
984                 return -EINVAL;
985         }
986
987         mp = br_mdb_ip_get(br, &group);
988         if (!mp) {
989                 mp = br_multicast_new_group(br, &group);
990                 err = PTR_ERR_OR_ZERO(mp);
991                 if (err)
992                         return err;
993         }
994
995         /* host join */
996         if (!port) {
997                 if (mp->host_joined) {
998                         NL_SET_ERR_MSG_MOD(extack, "Group is already joined by host");
999                         return -EEXIST;
1000                 }
1001
1002                 br_multicast_host_join(mp, false);
1003                 br_mdb_notify(br->dev, mp, NULL, RTM_NEWMDB);
1004
1005                 return 0;
1006         }
1007
1008         for (pp = &mp->ports;
1009              (p = mlock_dereference(*pp, br)) != NULL;
1010              pp = &p->next) {
1011                 if (p->key.port == port) {
1012                         NL_SET_ERR_MSG_MOD(extack, "Group is already joined by port");
1013                         return -EEXIST;
1014                 }
1015                 if ((unsigned long)p->key.port < (unsigned long)port)
1016                         break;
1017         }
1018
1019         filter_mode = br_multicast_is_star_g(&group) ? MCAST_EXCLUDE :
1020                                                        MCAST_INCLUDE;
1021
1022         if (entry->state == MDB_PERMANENT)
1023                 flags |= MDB_PG_FLAGS_PERMANENT;
1024
1025         p = br_multicast_new_port_group(port, &group, *pp, flags, NULL,
1026                                         filter_mode, RTPROT_STATIC);
1027         if (unlikely(!p)) {
1028                 NL_SET_ERR_MSG_MOD(extack, "Couldn't allocate new port group");
1029                 return -ENOMEM;
1030         }
1031         rcu_assign_pointer(*pp, p);
1032         if (entry->state == MDB_TEMPORARY)
1033                 mod_timer(&p->timer, now + br->multicast_membership_interval);
1034         br_mdb_notify(br->dev, mp, p, RTM_NEWMDB);
1035         /* if we are adding a new EXCLUDE port group (*,G) it needs to be also
1036          * added to all S,G entries for proper replication, if we are adding
1037          * a new INCLUDE port (S,G) then all of *,G EXCLUDE ports need to be
1038          * added to it for proper replication
1039          */
1040         if (br_multicast_should_handle_mode(br, group.proto)) {
1041                 switch (filter_mode) {
1042                 case MCAST_EXCLUDE:
1043                         br_multicast_star_g_handle_mode(p, MCAST_EXCLUDE);
1044                         break;
1045                 case MCAST_INCLUDE:
1046                         star_group = p->key.addr;
1047                         memset(&star_group.src, 0, sizeof(star_group.src));
1048                         star_mp = br_mdb_ip_get(br, &star_group);
1049                         if (star_mp)
1050                                 br_multicast_sg_add_exclude_ports(star_mp, p);
1051                         break;
1052                 }
1053         }
1054
1055         return 0;
1056 }
1057
1058 static int __br_mdb_add(struct net *net, struct net_bridge *br,
1059                         struct net_bridge_port *p,
1060                         struct br_mdb_entry *entry,
1061                         struct nlattr **mdb_attrs,
1062                         struct netlink_ext_ack *extack)
1063 {
1064         int ret;
1065
1066         spin_lock_bh(&br->multicast_lock);
1067         ret = br_mdb_add_group(br, p, entry, mdb_attrs, extack);
1068         spin_unlock_bh(&br->multicast_lock);
1069
1070         return ret;
1071 }
1072
1073 static int br_mdb_add(struct sk_buff *skb, struct nlmsghdr *nlh,
1074                       struct netlink_ext_ack *extack)
1075 {
1076         struct nlattr *mdb_attrs[MDBE_ATTR_MAX + 1];
1077         struct net *net = sock_net(skb->sk);
1078         struct net_bridge_vlan_group *vg;
1079         struct net_bridge_port *p = NULL;
1080         struct net_device *dev, *pdev;
1081         struct br_mdb_entry *entry;
1082         struct net_bridge_vlan *v;
1083         struct net_bridge *br;
1084         int err;
1085
1086         err = br_mdb_parse(skb, nlh, &dev, &entry, mdb_attrs, extack);
1087         if (err < 0)
1088                 return err;
1089
1090         br = netdev_priv(dev);
1091
1092         if (!netif_running(br->dev)) {
1093                 NL_SET_ERR_MSG_MOD(extack, "Bridge device is not running");
1094                 return -EINVAL;
1095         }
1096
1097         if (!br_opt_get(br, BROPT_MULTICAST_ENABLED)) {
1098                 NL_SET_ERR_MSG_MOD(extack, "Bridge's multicast processing is disabled");
1099                 return -EINVAL;
1100         }
1101
1102         if (entry->ifindex != br->dev->ifindex) {
1103                 pdev = __dev_get_by_index(net, entry->ifindex);
1104                 if (!pdev) {
1105                         NL_SET_ERR_MSG_MOD(extack, "Port net device doesn't exist");
1106                         return -ENODEV;
1107                 }
1108
1109                 p = br_port_get_rtnl(pdev);
1110                 if (!p) {
1111                         NL_SET_ERR_MSG_MOD(extack, "Net device is not a bridge port");
1112                         return -EINVAL;
1113                 }
1114
1115                 if (p->br != br) {
1116                         NL_SET_ERR_MSG_MOD(extack, "Port belongs to a different bridge device");
1117                         return -EINVAL;
1118                 }
1119                 if (p->state == BR_STATE_DISABLED) {
1120                         NL_SET_ERR_MSG_MOD(extack, "Port is in disabled state");
1121                         return -EINVAL;
1122                 }
1123                 vg = nbp_vlan_group(p);
1124         } else {
1125                 vg = br_vlan_group(br);
1126         }
1127
1128         /* If vlan filtering is enabled and VLAN is not specified
1129          * install mdb entry on all vlans configured on the port.
1130          */
1131         if (br_vlan_enabled(br->dev) && vg && entry->vid == 0) {
1132                 list_for_each_entry(v, &vg->vlan_list, vlist) {
1133                         entry->vid = v->vid;
1134                         err = __br_mdb_add(net, br, p, entry, mdb_attrs, extack);
1135                         if (err)
1136                                 break;
1137                 }
1138         } else {
1139                 err = __br_mdb_add(net, br, p, entry, mdb_attrs, extack);
1140         }
1141
1142         return err;
1143 }
1144
1145 static int __br_mdb_del(struct net_bridge *br, struct br_mdb_entry *entry,
1146                         struct nlattr **mdb_attrs)
1147 {
1148         struct net_bridge_mdb_entry *mp;
1149         struct net_bridge_port_group *p;
1150         struct net_bridge_port_group __rcu **pp;
1151         struct br_ip ip;
1152         int err = -EINVAL;
1153
1154         if (!netif_running(br->dev) || !br_opt_get(br, BROPT_MULTICAST_ENABLED))
1155                 return -EINVAL;
1156
1157         __mdb_entry_to_br_ip(entry, &ip, mdb_attrs);
1158
1159         spin_lock_bh(&br->multicast_lock);
1160         mp = br_mdb_ip_get(br, &ip);
1161         if (!mp)
1162                 goto unlock;
1163
1164         /* host leave */
1165         if (entry->ifindex == mp->br->dev->ifindex && mp->host_joined) {
1166                 br_multicast_host_leave(mp, false);
1167                 err = 0;
1168                 br_mdb_notify(br->dev, mp, NULL, RTM_DELMDB);
1169                 if (!mp->ports && netif_running(br->dev))
1170                         mod_timer(&mp->timer, jiffies);
1171                 goto unlock;
1172         }
1173
1174         for (pp = &mp->ports;
1175              (p = mlock_dereference(*pp, br)) != NULL;
1176              pp = &p->next) {
1177                 if (!p->key.port || p->key.port->dev->ifindex != entry->ifindex)
1178                         continue;
1179
1180                 if (p->key.port->state == BR_STATE_DISABLED)
1181                         goto unlock;
1182
1183                 br_multicast_del_pg(mp, p, pp);
1184                 err = 0;
1185                 break;
1186         }
1187
1188 unlock:
1189         spin_unlock_bh(&br->multicast_lock);
1190         return err;
1191 }
1192
1193 static int br_mdb_del(struct sk_buff *skb, struct nlmsghdr *nlh,
1194                       struct netlink_ext_ack *extack)
1195 {
1196         struct nlattr *mdb_attrs[MDBE_ATTR_MAX + 1];
1197         struct net *net = sock_net(skb->sk);
1198         struct net_bridge_vlan_group *vg;
1199         struct net_bridge_port *p = NULL;
1200         struct net_device *dev, *pdev;
1201         struct br_mdb_entry *entry;
1202         struct net_bridge_vlan *v;
1203         struct net_bridge *br;
1204         int err;
1205
1206         err = br_mdb_parse(skb, nlh, &dev, &entry, mdb_attrs, extack);
1207         if (err < 0)
1208                 return err;
1209
1210         br = netdev_priv(dev);
1211
1212         if (entry->ifindex != br->dev->ifindex) {
1213                 pdev = __dev_get_by_index(net, entry->ifindex);
1214                 if (!pdev)
1215                         return -ENODEV;
1216
1217                 p = br_port_get_rtnl(pdev);
1218                 if (!p || p->br != br || p->state == BR_STATE_DISABLED)
1219                         return -EINVAL;
1220                 vg = nbp_vlan_group(p);
1221         } else {
1222                 vg = br_vlan_group(br);
1223         }
1224
1225         /* If vlan filtering is enabled and VLAN is not specified
1226          * delete mdb entry on all vlans configured on the port.
1227          */
1228         if (br_vlan_enabled(br->dev) && vg && entry->vid == 0) {
1229                 list_for_each_entry(v, &vg->vlan_list, vlist) {
1230                         entry->vid = v->vid;
1231                         err = __br_mdb_del(br, entry, mdb_attrs);
1232                 }
1233         } else {
1234                 err = __br_mdb_del(br, entry, mdb_attrs);
1235         }
1236
1237         return err;
1238 }
1239
1240 void br_mdb_init(void)
1241 {
1242         rtnl_register_module(THIS_MODULE, PF_BRIDGE, RTM_GETMDB, NULL, br_mdb_dump, 0);
1243         rtnl_register_module(THIS_MODULE, PF_BRIDGE, RTM_NEWMDB, br_mdb_add, NULL, 0);
1244         rtnl_register_module(THIS_MODULE, PF_BRIDGE, RTM_DELMDB, br_mdb_del, NULL, 0);
1245 }
1246
1247 void br_mdb_uninit(void)
1248 {
1249         rtnl_unregister(PF_BRIDGE, RTM_GETMDB);
1250         rtnl_unregister(PF_BRIDGE, RTM_NEWMDB);
1251         rtnl_unregister(PF_BRIDGE, RTM_DELMDB);
1252 }