vxlan: vni filtering support on collect metadata device
[linux-2.6-microblaze.git] / drivers / net / vxlan / vxlan_vnifilter.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  *      Vxlan vni filter for collect metadata mode
4  *
5  *      Authors: Roopa Prabhu <roopa@nvidia.com>
6  *
7  */
8
9 #include <linux/kernel.h>
10 #include <linux/slab.h>
11 #include <linux/etherdevice.h>
12 #include <linux/rhashtable.h>
13 #include <net/rtnetlink.h>
14 #include <net/net_namespace.h>
15 #include <net/sock.h>
16 #include <net/vxlan.h>
17
18 #include "vxlan_private.h"
19
20 static inline int vxlan_vni_cmp(struct rhashtable_compare_arg *arg,
21                                 const void *ptr)
22 {
23         const struct vxlan_vni_node *vnode = ptr;
24         __be32 vni = *(__be32 *)arg->key;
25
26         return vnode->vni != vni;
27 }
28
29 const struct rhashtable_params vxlan_vni_rht_params = {
30         .head_offset = offsetof(struct vxlan_vni_node, vnode),
31         .key_offset = offsetof(struct vxlan_vni_node, vni),
32         .key_len = sizeof(__be32),
33         .nelem_hint = 3,
34         .max_size = VXLAN_N_VID,
35         .obj_cmpfn = vxlan_vni_cmp,
36         .automatic_shrinking = true,
37 };
38
39 static void vxlan_vs_add_del_vninode(struct vxlan_dev *vxlan,
40                                      struct vxlan_vni_node *v,
41                                      bool del)
42 {
43         struct vxlan_net *vn = net_generic(vxlan->net, vxlan_net_id);
44         struct vxlan_dev_node *node;
45         struct vxlan_sock *vs;
46
47         spin_lock(&vn->sock_lock);
48         if (del) {
49                 if (!hlist_unhashed(&v->hlist4.hlist))
50                         hlist_del_init_rcu(&v->hlist4.hlist);
51 #if IS_ENABLED(CONFIG_IPV6)
52                 if (!hlist_unhashed(&v->hlist6.hlist))
53                         hlist_del_init_rcu(&v->hlist6.hlist);
54 #endif
55                 goto out;
56         }
57
58 #if IS_ENABLED(CONFIG_IPV6)
59         vs = rtnl_dereference(vxlan->vn6_sock);
60         if (vs && v) {
61                 node = &v->hlist6;
62                 hlist_add_head_rcu(&node->hlist, vni_head(vs, v->vni));
63         }
64 #endif
65         vs = rtnl_dereference(vxlan->vn4_sock);
66         if (vs && v) {
67                 node = &v->hlist4;
68                 hlist_add_head_rcu(&node->hlist, vni_head(vs, v->vni));
69         }
70 out:
71         spin_unlock(&vn->sock_lock);
72 }
73
74 void vxlan_vs_add_vnigrp(struct vxlan_dev *vxlan,
75                          struct vxlan_sock *vs,
76                          bool ipv6)
77 {
78         struct vxlan_net *vn = net_generic(vxlan->net, vxlan_net_id);
79         struct vxlan_vni_group *vg = rtnl_dereference(vxlan->vnigrp);
80         struct vxlan_vni_node *v, *tmp;
81         struct vxlan_dev_node *node;
82
83         if (!vg)
84                 return;
85
86         spin_lock(&vn->sock_lock);
87         list_for_each_entry_safe(v, tmp, &vg->vni_list, vlist) {
88 #if IS_ENABLED(CONFIG_IPV6)
89                 if (ipv6)
90                         node = &v->hlist6;
91                 else
92 #endif
93                         node = &v->hlist4;
94                 node->vxlan = vxlan;
95                 hlist_add_head_rcu(&node->hlist, vni_head(vs, v->vni));
96         }
97         spin_unlock(&vn->sock_lock);
98 }
99
100 void vxlan_vs_del_vnigrp(struct vxlan_dev *vxlan)
101 {
102         struct vxlan_vni_group *vg = rtnl_dereference(vxlan->vnigrp);
103         struct vxlan_net *vn = net_generic(vxlan->net, vxlan_net_id);
104         struct vxlan_vni_node *v, *tmp;
105
106         if (!vg)
107                 return;
108
109         spin_lock(&vn->sock_lock);
110         list_for_each_entry_safe(v, tmp, &vg->vni_list, vlist) {
111                 hlist_del_init_rcu(&v->hlist4.hlist);
112 #if IS_ENABLED(CONFIG_IPV6)
113                 hlist_del_init_rcu(&v->hlist6.hlist);
114 #endif
115         }
116         spin_unlock(&vn->sock_lock);
117 }
118
119 static u32 vnirange(struct vxlan_vni_node *vbegin,
120                     struct vxlan_vni_node *vend)
121 {
122         return (be32_to_cpu(vend->vni) - be32_to_cpu(vbegin->vni));
123 }
124
125 static size_t vxlan_vnifilter_entry_nlmsg_size(void)
126 {
127         return NLMSG_ALIGN(sizeof(struct tunnel_msg))
128                 + nla_total_size(0) /* VXLAN_VNIFILTER_ENTRY */
129                 + nla_total_size(sizeof(u32)) /* VXLAN_VNIFILTER_ENTRY_START */
130                 + nla_total_size(sizeof(u32)) /* VXLAN_VNIFILTER_ENTRY_END */
131                 + nla_total_size(sizeof(struct in6_addr));/* VXLAN_VNIFILTER_ENTRY_GROUP{6} */
132 }
133
134 static bool vxlan_fill_vni_filter_entry(struct sk_buff *skb,
135                                         struct vxlan_vni_node *vbegin,
136                                         struct vxlan_vni_node *vend)
137 {
138         struct nlattr *ventry;
139         u32 vs = be32_to_cpu(vbegin->vni);
140         u32 ve = 0;
141
142         if (vbegin != vend)
143                 ve = be32_to_cpu(vend->vni);
144
145         ventry = nla_nest_start(skb, VXLAN_VNIFILTER_ENTRY);
146         if (!ventry)
147                 return false;
148
149         if (nla_put_u32(skb, VXLAN_VNIFILTER_ENTRY_START, vs))
150                 goto out_err;
151
152         if (ve && nla_put_u32(skb, VXLAN_VNIFILTER_ENTRY_END, ve))
153                 goto out_err;
154
155         if (!vxlan_addr_any(&vbegin->remote_ip)) {
156                 if (vbegin->remote_ip.sa.sa_family == AF_INET) {
157                         if (nla_put_in_addr(skb, VXLAN_VNIFILTER_ENTRY_GROUP,
158                                             vbegin->remote_ip.sin.sin_addr.s_addr))
159                                 goto out_err;
160 #if IS_ENABLED(CONFIG_IPV6)
161                 } else {
162                         if (nla_put_in6_addr(skb, VXLAN_VNIFILTER_ENTRY_GROUP6,
163                                              &vbegin->remote_ip.sin6.sin6_addr))
164                                 goto out_err;
165 #endif
166                 }
167         }
168
169         nla_nest_end(skb, ventry);
170
171         return true;
172
173 out_err:
174         nla_nest_cancel(skb, ventry);
175
176         return false;
177 }
178
179 static void vxlan_vnifilter_notify(const struct vxlan_dev *vxlan,
180                                    struct vxlan_vni_node *vninode, int cmd)
181 {
182         struct tunnel_msg *tmsg;
183         struct sk_buff *skb;
184         struct nlmsghdr *nlh;
185         struct net *net = dev_net(vxlan->dev);
186         int err = -ENOBUFS;
187
188         skb = nlmsg_new(vxlan_vnifilter_entry_nlmsg_size(), GFP_KERNEL);
189         if (!skb)
190                 goto out_err;
191
192         err = -EMSGSIZE;
193         nlh = nlmsg_put(skb, 0, 0, cmd, sizeof(*tmsg), 0);
194         if (!nlh)
195                 goto out_err;
196         tmsg = nlmsg_data(nlh);
197         memset(tmsg, 0, sizeof(*tmsg));
198         tmsg->family = AF_BRIDGE;
199         tmsg->ifindex = vxlan->dev->ifindex;
200
201         if (!vxlan_fill_vni_filter_entry(skb, vninode, vninode))
202                 goto out_err;
203
204         nlmsg_end(skb, nlh);
205         rtnl_notify(skb, net, 0, RTNLGRP_TUNNEL, NULL, GFP_KERNEL);
206
207         return;
208
209 out_err:
210         rtnl_set_sk_err(net, RTNLGRP_TUNNEL, err);
211
212         kfree_skb(skb);
213 }
214
215 static int vxlan_vnifilter_dump_dev(const struct net_device *dev,
216                                     struct sk_buff *skb,
217                                     struct netlink_callback *cb)
218 {
219         struct vxlan_vni_node *tmp, *v, *vbegin = NULL, *vend = NULL;
220         struct vxlan_dev *vxlan = netdev_priv(dev);
221         struct tunnel_msg *new_tmsg;
222         int idx = 0, s_idx = cb->args[1];
223         struct vxlan_vni_group *vg;
224         struct nlmsghdr *nlh;
225         int err = 0;
226
227         if (!(vxlan->cfg.flags & VXLAN_F_VNIFILTER))
228                 return -EINVAL;
229
230         /* RCU needed because of the vni locking rules (rcu || rtnl) */
231         vg = rcu_dereference(vxlan->vnigrp);
232         if (!vg || !vg->num_vnis)
233                 return 0;
234
235         nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
236                         RTM_NEWTUNNEL, sizeof(*new_tmsg), NLM_F_MULTI);
237         if (!nlh)
238                 return -EMSGSIZE;
239         new_tmsg = nlmsg_data(nlh);
240         memset(new_tmsg, 0, sizeof(*new_tmsg));
241         new_tmsg->family = PF_BRIDGE;
242         new_tmsg->ifindex = dev->ifindex;
243
244         list_for_each_entry_safe(v, tmp, &vg->vni_list, vlist) {
245                 if (idx < s_idx) {
246                         idx++;
247                         continue;
248                 }
249                 if (!vbegin) {
250                         vbegin = v;
251                         vend = v;
252                         continue;
253                 }
254                 if (vnirange(vend, v) == 1 &&
255                     vxlan_addr_equal(&v->remote_ip, &vend->remote_ip)) {
256                         goto update_end;
257                 } else {
258                         if (!vxlan_fill_vni_filter_entry(skb, vbegin, vend)) {
259                                 err = -EMSGSIZE;
260                                 break;
261                         }
262                         idx += vnirange(vbegin, vend) + 1;
263                         vbegin = v;
264                 }
265 update_end:
266                 vend = v;
267         }
268
269         if (!err && vbegin) {
270                 if (!vxlan_fill_vni_filter_entry(skb, vbegin, vend))
271                         err = -EMSGSIZE;
272         }
273
274         cb->args[1] = err ? idx : 0;
275
276         nlmsg_end(skb, nlh);
277
278         return err;
279 }
280
281 static int vxlan_vnifilter_dump(struct sk_buff *skb, struct netlink_callback *cb)
282 {
283         int idx = 0, err = 0, s_idx = cb->args[0];
284         struct net *net = sock_net(skb->sk);
285         struct tunnel_msg *tmsg;
286         struct net_device *dev;
287
288         tmsg = nlmsg_data(cb->nlh);
289
290         rcu_read_lock();
291         if (tmsg->ifindex) {
292                 dev = dev_get_by_index_rcu(net, tmsg->ifindex);
293                 if (!dev) {
294                         err = -ENODEV;
295                         goto out_err;
296                 }
297                 err = vxlan_vnifilter_dump_dev(dev, skb, cb);
298                 /* if the dump completed without an error we return 0 here */
299                 if (err != -EMSGSIZE)
300                         goto out_err;
301         } else {
302                 for_each_netdev_rcu(net, dev) {
303                         if (!netif_is_vxlan(dev))
304                                 continue;
305                         if (idx < s_idx)
306                                 goto skip;
307                         err = vxlan_vnifilter_dump_dev(dev, skb, cb);
308                         if (err == -EMSGSIZE)
309                                 break;
310 skip:
311                         idx++;
312                 }
313         }
314         cb->args[0] = idx;
315         rcu_read_unlock();
316
317         return skb->len;
318
319 out_err:
320         rcu_read_unlock();
321
322         return err;
323 }
324
325 static const struct nla_policy vni_filter_entry_policy[VXLAN_VNIFILTER_ENTRY_MAX + 1] = {
326         [VXLAN_VNIFILTER_ENTRY_START] = { .type = NLA_U32 },
327         [VXLAN_VNIFILTER_ENTRY_END] = { .type = NLA_U32 },
328         [VXLAN_VNIFILTER_ENTRY_GROUP]   = { .type = NLA_BINARY,
329                                             .len = sizeof_field(struct iphdr, daddr) },
330         [VXLAN_VNIFILTER_ENTRY_GROUP6]  = { .type = NLA_BINARY,
331                                             .len = sizeof(struct in6_addr) },
332 };
333
334 static const struct nla_policy vni_filter_policy[VXLAN_VNIFILTER_MAX + 1] = {
335         [VXLAN_VNIFILTER_ENTRY] = { .type = NLA_NESTED },
336 };
337
338 static int vxlan_update_default_fdb_entry(struct vxlan_dev *vxlan, __be32 vni,
339                                           union vxlan_addr *old_remote_ip,
340                                           union vxlan_addr *remote_ip,
341                                           struct netlink_ext_ack *extack)
342 {
343         struct vxlan_rdst *dst = &vxlan->default_dst;
344         u32 hash_index;
345         int err = 0;
346
347         hash_index = fdb_head_index(vxlan, all_zeros_mac, vni);
348         spin_lock_bh(&vxlan->hash_lock[hash_index]);
349         if (remote_ip && !vxlan_addr_any(remote_ip)) {
350                 err = vxlan_fdb_update(vxlan, all_zeros_mac,
351                                        remote_ip,
352                                        NUD_REACHABLE | NUD_PERMANENT,
353                                        NLM_F_APPEND | NLM_F_CREATE,
354                                        vxlan->cfg.dst_port,
355                                        vni,
356                                        vni,
357                                        dst->remote_ifindex,
358                                        NTF_SELF, 0, true, extack);
359                 if (err) {
360                         spin_unlock_bh(&vxlan->hash_lock[hash_index]);
361                         return err;
362                 }
363         }
364
365         if (old_remote_ip && !vxlan_addr_any(old_remote_ip)) {
366                 __vxlan_fdb_delete(vxlan, all_zeros_mac,
367                                    *old_remote_ip,
368                                    vxlan->cfg.dst_port,
369                                    vni, vni,
370                                    dst->remote_ifindex,
371                                    true);
372         }
373         spin_unlock_bh(&vxlan->hash_lock[hash_index]);
374
375         return err;
376 }
377
378 static int vxlan_vni_update_group(struct vxlan_dev *vxlan,
379                                   struct vxlan_vni_node *vninode,
380                                   union vxlan_addr *group,
381                                   bool create, bool *changed,
382                                   struct netlink_ext_ack *extack)
383 {
384         struct vxlan_net *vn = net_generic(vxlan->net, vxlan_net_id);
385         struct vxlan_rdst *dst = &vxlan->default_dst;
386         union vxlan_addr *newrip = NULL, *oldrip = NULL;
387         union vxlan_addr old_remote_ip;
388         int ret = 0;
389
390         memcpy(&old_remote_ip, &vninode->remote_ip, sizeof(old_remote_ip));
391
392         /* if per vni remote ip is not present use vxlan dev
393          * default dst remote ip for fdb entry
394          */
395         if (group && !vxlan_addr_any(group)) {
396                 newrip = group;
397         } else {
398                 if (!vxlan_addr_any(&dst->remote_ip))
399                         newrip = &dst->remote_ip;
400         }
401
402         /* if old rip exists, and no newrip,
403          * explicitly delete old rip
404          */
405         if (!newrip && !vxlan_addr_any(&old_remote_ip))
406                 oldrip = &old_remote_ip;
407
408         if (!newrip && !oldrip)
409                 return 0;
410
411         if (!create && oldrip && newrip && vxlan_addr_equal(oldrip, newrip))
412                 return 0;
413
414         ret = vxlan_update_default_fdb_entry(vxlan, vninode->vni,
415                                              oldrip, newrip,
416                                              extack);
417         if (ret)
418                 goto out;
419
420         if (group)
421                 memcpy(&vninode->remote_ip, group, sizeof(vninode->remote_ip));
422
423         if (vxlan->dev->flags & IFF_UP) {
424                 if (vxlan_addr_multicast(&old_remote_ip) &&
425                     !vxlan_group_used(vn, vxlan, vninode->vni,
426                                       &old_remote_ip,
427                                       vxlan->default_dst.remote_ifindex)) {
428                         ret = vxlan_igmp_leave(vxlan, &old_remote_ip,
429                                                0);
430                         if (ret)
431                                 goto out;
432                 }
433
434                 if (vxlan_addr_multicast(&vninode->remote_ip)) {
435                         ret = vxlan_igmp_join(vxlan, &vninode->remote_ip, 0);
436                         if (ret == -EADDRINUSE)
437                                 ret = 0;
438                         if (ret)
439                                 goto out;
440                 }
441         }
442
443         *changed = true;
444
445         return 0;
446 out:
447         return ret;
448 }
449
450 int vxlan_vnilist_update_group(struct vxlan_dev *vxlan,
451                                union vxlan_addr *old_remote_ip,
452                                union vxlan_addr *new_remote_ip,
453                                struct netlink_ext_ack *extack)
454 {
455         struct list_head *headp, *hpos;
456         struct vxlan_vni_group *vg;
457         struct vxlan_vni_node *vent;
458         int ret;
459
460         vg = rtnl_dereference(vxlan->vnigrp);
461
462         headp = &vg->vni_list;
463         list_for_each_prev(hpos, headp) {
464                 vent = list_entry(hpos, struct vxlan_vni_node, vlist);
465                 if (vxlan_addr_any(&vent->remote_ip)) {
466                         ret = vxlan_update_default_fdb_entry(vxlan, vent->vni,
467                                                              old_remote_ip,
468                                                              new_remote_ip,
469                                                              extack);
470                         if (ret)
471                                 return ret;
472                 }
473         }
474
475         return 0;
476 }
477
478 static void vxlan_vni_delete_group(struct vxlan_dev *vxlan,
479                                    struct vxlan_vni_node *vninode)
480 {
481         struct vxlan_net *vn = net_generic(vxlan->net, vxlan_net_id);
482         struct vxlan_rdst *dst = &vxlan->default_dst;
483
484         /* if per vni remote_ip not present, delete the
485          * default dst remote_ip previously added for this vni
486          */
487         if (!vxlan_addr_any(&vninode->remote_ip) ||
488             !vxlan_addr_any(&dst->remote_ip))
489                 __vxlan_fdb_delete(vxlan, all_zeros_mac,
490                                    (vxlan_addr_any(&vninode->remote_ip) ?
491                                    dst->remote_ip : vninode->remote_ip),
492                                    vxlan->cfg.dst_port,
493                                    vninode->vni, vninode->vni,
494                                    dst->remote_ifindex,
495                                    true);
496
497         if (vxlan->dev->flags & IFF_UP) {
498                 if (vxlan_addr_multicast(&vninode->remote_ip) &&
499                     !vxlan_group_used(vn, vxlan, vninode->vni,
500                                       &vninode->remote_ip,
501                                       dst->remote_ifindex)) {
502                         vxlan_igmp_leave(vxlan, &vninode->remote_ip, 0);
503                 }
504         }
505 }
506
507 static int vxlan_vni_update(struct vxlan_dev *vxlan,
508                             struct vxlan_vni_group *vg,
509                             __be32 vni, union vxlan_addr *group,
510                             bool *changed,
511                             struct netlink_ext_ack *extack)
512 {
513         struct vxlan_vni_node *vninode;
514         int ret;
515
516         vninode = rhashtable_lookup_fast(&vg->vni_hash, &vni,
517                                          vxlan_vni_rht_params);
518         if (!vninode)
519                 return 0;
520
521         ret = vxlan_vni_update_group(vxlan, vninode, group, false, changed,
522                                      extack);
523         if (ret)
524                 return ret;
525
526         if (changed)
527                 vxlan_vnifilter_notify(vxlan, vninode, RTM_NEWTUNNEL);
528
529         return 0;
530 }
531
532 static void __vxlan_vni_add_list(struct vxlan_vni_group *vg,
533                                  struct vxlan_vni_node *v)
534 {
535         struct list_head *headp, *hpos;
536         struct vxlan_vni_node *vent;
537
538         headp = &vg->vni_list;
539         list_for_each_prev(hpos, headp) {
540                 vent = list_entry(hpos, struct vxlan_vni_node, vlist);
541                 if (be32_to_cpu(v->vni) < be32_to_cpu(vent->vni))
542                         continue;
543                 else
544                         break;
545         }
546         list_add_rcu(&v->vlist, hpos);
547         vg->num_vnis++;
548 }
549
550 static void __vxlan_vni_del_list(struct vxlan_vni_group *vg,
551                                  struct vxlan_vni_node *v)
552 {
553         list_del_rcu(&v->vlist);
554         vg->num_vnis--;
555 }
556
557 static struct vxlan_vni_node *vxlan_vni_alloc(struct vxlan_dev *vxlan,
558                                               __be32 vni)
559 {
560         struct vxlan_vni_node *vninode;
561
562         vninode = kzalloc(sizeof(*vninode), GFP_ATOMIC);
563         if (!vninode)
564                 return NULL;
565         vninode->vni = vni;
566         vninode->hlist4.vxlan = vxlan;
567 #if IS_ENABLED(CONFIG_IPV6)
568         vninode->hlist6.vxlan = vxlan;
569 #endif
570
571         return vninode;
572 }
573
574 static int vxlan_vni_add(struct vxlan_dev *vxlan,
575                          struct vxlan_vni_group *vg,
576                          u32 vni, union vxlan_addr *group,
577                          struct netlink_ext_ack *extack)
578 {
579         struct vxlan_vni_node *vninode;
580         __be32 v = cpu_to_be32(vni);
581         bool changed = false;
582         int err = 0;
583
584         if (vxlan_vnifilter_lookup(vxlan, v))
585                 return vxlan_vni_update(vxlan, vg, v, group, &changed, extack);
586
587         err = vxlan_vni_in_use(vxlan->net, vxlan, &vxlan->cfg, v);
588         if (err) {
589                 NL_SET_ERR_MSG(extack, "VNI in use");
590                 return err;
591         }
592
593         vninode = vxlan_vni_alloc(vxlan, v);
594         if (!vninode)
595                 return -ENOMEM;
596
597         err = rhashtable_lookup_insert_fast(&vg->vni_hash,
598                                             &vninode->vnode,
599                                             vxlan_vni_rht_params);
600         if (err) {
601                 kfree(vninode);
602                 return err;
603         }
604
605         __vxlan_vni_add_list(vg, vninode);
606
607         if (vxlan->dev->flags & IFF_UP)
608                 vxlan_vs_add_del_vninode(vxlan, vninode, false);
609
610         err = vxlan_vni_update_group(vxlan, vninode, group, true, &changed,
611                                      extack);
612
613         if (changed)
614                 vxlan_vnifilter_notify(vxlan, vninode, RTM_NEWTUNNEL);
615
616         return err;
617 }
618
619 static void vxlan_vni_node_rcu_free(struct rcu_head *rcu)
620 {
621         struct vxlan_vni_node *v;
622
623         v = container_of(rcu, struct vxlan_vni_node, rcu);
624         kfree(v);
625 }
626
627 static int vxlan_vni_del(struct vxlan_dev *vxlan,
628                          struct vxlan_vni_group *vg,
629                          u32 vni, struct netlink_ext_ack *extack)
630 {
631         struct vxlan_vni_node *vninode;
632         __be32 v = cpu_to_be32(vni);
633         int err = 0;
634
635         vg = rtnl_dereference(vxlan->vnigrp);
636
637         vninode = rhashtable_lookup_fast(&vg->vni_hash, &v,
638                                          vxlan_vni_rht_params);
639         if (!vninode) {
640                 err = -ENOENT;
641                 goto out;
642         }
643
644         vxlan_vni_delete_group(vxlan, vninode);
645
646         err = rhashtable_remove_fast(&vg->vni_hash,
647                                      &vninode->vnode,
648                                      vxlan_vni_rht_params);
649         if (err)
650                 goto out;
651
652         __vxlan_vni_del_list(vg, vninode);
653
654         vxlan_vnifilter_notify(vxlan, vninode, RTM_DELTUNNEL);
655
656         if (vxlan->dev->flags & IFF_UP)
657                 vxlan_vs_add_del_vninode(vxlan, vninode, true);
658
659         call_rcu(&vninode->rcu, vxlan_vni_node_rcu_free);
660
661         return 0;
662 out:
663         return err;
664 }
665
666 static int vxlan_vni_add_del(struct vxlan_dev *vxlan, __u32 start_vni,
667                              __u32 end_vni, union vxlan_addr *group,
668                              int cmd, struct netlink_ext_ack *extack)
669 {
670         struct vxlan_vni_group *vg;
671         int v, err = 0;
672
673         vg = rtnl_dereference(vxlan->vnigrp);
674
675         for (v = start_vni; v <= end_vni; v++) {
676                 switch (cmd) {
677                 case RTM_NEWTUNNEL:
678                         err = vxlan_vni_add(vxlan, vg, v, group, extack);
679                         break;
680                 case RTM_DELTUNNEL:
681                         err = vxlan_vni_del(vxlan, vg, v, extack);
682                         break;
683                 default:
684                         err = -EOPNOTSUPP;
685                         break;
686                 }
687                 if (err)
688                         goto out;
689         }
690
691         return 0;
692 out:
693         return err;
694 }
695
696 static int vxlan_process_vni_filter(struct vxlan_dev *vxlan,
697                                     struct nlattr *nlvnifilter,
698                                     int cmd, struct netlink_ext_ack *extack)
699 {
700         struct nlattr *vattrs[VXLAN_VNIFILTER_ENTRY_MAX + 1];
701         u32 vni_start = 0, vni_end = 0;
702         union vxlan_addr group;
703         int err;
704
705         err = nla_parse_nested(vattrs,
706                                VXLAN_VNIFILTER_ENTRY_MAX,
707                                nlvnifilter, vni_filter_entry_policy,
708                                extack);
709         if (err)
710                 return err;
711
712         if (vattrs[VXLAN_VNIFILTER_ENTRY_START]) {
713                 vni_start = nla_get_u32(vattrs[VXLAN_VNIFILTER_ENTRY_START]);
714                 vni_end = vni_start;
715         }
716
717         if (vattrs[VXLAN_VNIFILTER_ENTRY_END])
718                 vni_end = nla_get_u32(vattrs[VXLAN_VNIFILTER_ENTRY_END]);
719
720         if (!vni_start && !vni_end) {
721                 NL_SET_ERR_MSG_ATTR(extack, nlvnifilter,
722                                     "vni start nor end found in vni entry");
723                 return -EINVAL;
724         }
725
726         if (vattrs[VXLAN_VNIFILTER_ENTRY_GROUP]) {
727                 group.sin.sin_addr.s_addr =
728                         nla_get_in_addr(vattrs[VXLAN_VNIFILTER_ENTRY_GROUP]);
729                 group.sa.sa_family = AF_INET;
730         } else if (vattrs[VXLAN_VNIFILTER_ENTRY_GROUP6]) {
731                 group.sin6.sin6_addr =
732                         nla_get_in6_addr(vattrs[VXLAN_VNIFILTER_ENTRY_GROUP6]);
733                 group.sa.sa_family = AF_INET6;
734         } else {
735                 memset(&group, 0, sizeof(group));
736         }
737
738         if (vxlan_addr_multicast(&group) && !vxlan->default_dst.remote_ifindex) {
739                 NL_SET_ERR_MSG(extack,
740                                "Local interface required for multicast remote group");
741
742                 return -EINVAL;
743         }
744
745         err = vxlan_vni_add_del(vxlan, vni_start, vni_end, &group, cmd,
746                                 extack);
747         if (err)
748                 return err;
749
750         return 0;
751 }
752
753 void vxlan_vnigroup_uninit(struct vxlan_dev *vxlan)
754 {
755         struct vxlan_vni_node *v, *tmp;
756         struct vxlan_vni_group *vg;
757
758         vg = rtnl_dereference(vxlan->vnigrp);
759         list_for_each_entry_safe(v, tmp, &vg->vni_list, vlist) {
760                 rhashtable_remove_fast(&vg->vni_hash, &v->vnode,
761                                        vxlan_vni_rht_params);
762                 hlist_del_init_rcu(&v->hlist4.hlist);
763 #if IS_ENABLED(CONFIG_IPV6)
764                 hlist_del_init_rcu(&v->hlist6.hlist);
765 #endif
766                 __vxlan_vni_del_list(vg, v);
767                 vxlan_vnifilter_notify(vxlan, v, RTM_DELTUNNEL);
768                 call_rcu(&v->rcu, vxlan_vni_node_rcu_free);
769         }
770         rhashtable_destroy(&vg->vni_hash);
771         kfree(vg);
772 }
773
774 int vxlan_vnigroup_init(struct vxlan_dev *vxlan)
775 {
776         struct vxlan_vni_group *vg;
777         int ret;
778
779         vg = kzalloc(sizeof(*vg), GFP_KERNEL);
780         if (!vg)
781                 return -ENOMEM;
782         ret = rhashtable_init(&vg->vni_hash, &vxlan_vni_rht_params);
783         if (ret) {
784                 kfree(vg);
785                 return ret;
786         }
787         INIT_LIST_HEAD(&vg->vni_list);
788         rcu_assign_pointer(vxlan->vnigrp, vg);
789
790         return 0;
791 }
792
793 static int vxlan_vnifilter_process(struct sk_buff *skb, struct nlmsghdr *nlh,
794                                    struct netlink_ext_ack *extack)
795 {
796         struct net *net = sock_net(skb->sk);
797         struct tunnel_msg *tmsg;
798         struct vxlan_dev *vxlan;
799         struct net_device *dev;
800         struct nlattr *attr;
801         int err, vnis = 0;
802         int rem;
803
804         /* this should validate the header and check for remaining bytes */
805         err = nlmsg_parse(nlh, sizeof(*tmsg), NULL, VXLAN_VNIFILTER_MAX,
806                           vni_filter_policy, extack);
807         if (err < 0)
808                 return err;
809
810         tmsg = nlmsg_data(nlh);
811         dev = __dev_get_by_index(net, tmsg->ifindex);
812         if (!dev)
813                 return -ENODEV;
814
815         if (!netif_is_vxlan(dev)) {
816                 NL_SET_ERR_MSG_MOD(extack, "The device is not a vxlan device");
817                 return -EINVAL;
818         }
819
820         vxlan = netdev_priv(dev);
821
822         if (!(vxlan->cfg.flags & VXLAN_F_VNIFILTER))
823                 return -EOPNOTSUPP;
824
825         nlmsg_for_each_attr(attr, nlh, sizeof(*tmsg), rem) {
826                 switch (nla_type(attr)) {
827                 case VXLAN_VNIFILTER_ENTRY:
828                         err = vxlan_process_vni_filter(vxlan, attr,
829                                                        nlh->nlmsg_type, extack);
830                         break;
831                 default:
832                         continue;
833                 }
834                 vnis++;
835                 if (err)
836                         break;
837         }
838
839         if (!vnis) {
840                 NL_SET_ERR_MSG_MOD(extack, "No vnis found to process");
841                 err = -EINVAL;
842         }
843
844         return err;
845 }
846
847 void vxlan_vnifilter_init(void)
848 {
849         rtnl_register_module(THIS_MODULE, PF_BRIDGE, RTM_GETTUNNEL, NULL,
850                              vxlan_vnifilter_dump, 0);
851         rtnl_register_module(THIS_MODULE, PF_BRIDGE, RTM_NEWTUNNEL,
852                              vxlan_vnifilter_process, NULL, 0);
853         rtnl_register_module(THIS_MODULE, PF_BRIDGE, RTM_DELTUNNEL,
854                              vxlan_vnifilter_process, NULL, 0);
855 }
856
857 void vxlan_vnifilter_uninit(void)
858 {
859         rtnl_unregister(PF_BRIDGE, RTM_GETTUNNEL);
860         rtnl_unregister(PF_BRIDGE, RTM_NEWTUNNEL);
861         rtnl_unregister(PF_BRIDGE, RTM_DELTUNNEL);
862 }