Merge tag 'x86-microcode-2022-06-05' of git://git.kernel.org/pub/scm/linux/kernel...
[linux-2.6-microblaze.git] / drivers / net / amt.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* Copyright (c) 2021 Taehee Yoo <ap420073@gmail.com> */
3
4 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
5
6 #include <linux/module.h>
7 #include <linux/skbuff.h>
8 #include <linux/udp.h>
9 #include <linux/jhash.h>
10 #include <linux/if_tunnel.h>
11 #include <linux/net.h>
12 #include <linux/igmp.h>
13 #include <linux/workqueue.h>
14 #include <net/sch_generic.h>
15 #include <net/net_namespace.h>
16 #include <net/ip.h>
17 #include <net/udp.h>
18 #include <net/udp_tunnel.h>
19 #include <net/icmp.h>
20 #include <net/mld.h>
21 #include <net/amt.h>
22 #include <uapi/linux/amt.h>
23 #include <linux/security.h>
24 #include <net/gro_cells.h>
25 #include <net/ipv6.h>
26 #include <net/if_inet6.h>
27 #include <net/ndisc.h>
28 #include <net/addrconf.h>
29 #include <net/ip6_route.h>
30 #include <net/inet_common.h>
31 #include <net/ip6_checksum.h>
32
33 static struct workqueue_struct *amt_wq;
34
35 static HLIST_HEAD(source_gc_list);
36 /* Lock for source_gc_list */
37 static spinlock_t source_gc_lock;
38 static struct delayed_work source_gc_wq;
39 static char *status_str[] = {
40         "AMT_STATUS_INIT",
41         "AMT_STATUS_SENT_DISCOVERY",
42         "AMT_STATUS_RECEIVED_DISCOVERY",
43         "AMT_STATUS_SENT_ADVERTISEMENT",
44         "AMT_STATUS_RECEIVED_ADVERTISEMENT",
45         "AMT_STATUS_SENT_REQUEST",
46         "AMT_STATUS_RECEIVED_REQUEST",
47         "AMT_STATUS_SENT_QUERY",
48         "AMT_STATUS_RECEIVED_QUERY",
49         "AMT_STATUS_SENT_UPDATE",
50         "AMT_STATUS_RECEIVED_UPDATE",
51 };
52
53 static char *type_str[] = {
54         "AMT_MSG_DISCOVERY",
55         "AMT_MSG_ADVERTISEMENT",
56         "AMT_MSG_REQUEST",
57         "AMT_MSG_MEMBERSHIP_QUERY",
58         "AMT_MSG_MEMBERSHIP_UPDATE",
59         "AMT_MSG_MULTICAST_DATA",
60         "AMT_MSG_TEARDOWN",
61 };
62
63 static char *action_str[] = {
64         "AMT_ACT_GMI",
65         "AMT_ACT_GMI_ZERO",
66         "AMT_ACT_GT",
67         "AMT_ACT_STATUS_FWD_NEW",
68         "AMT_ACT_STATUS_D_FWD_NEW",
69         "AMT_ACT_STATUS_NONE_NEW",
70 };
71
72 static struct igmpv3_grec igmpv3_zero_grec;
73
74 #if IS_ENABLED(CONFIG_IPV6)
75 #define MLD2_ALL_NODE_INIT { { { 0xff, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01 } } }
76 static struct in6_addr mld2_all_node = MLD2_ALL_NODE_INIT;
77 static struct mld2_grec mldv2_zero_grec;
78 #endif
79
80 static struct amt_skb_cb *amt_skb_cb(struct sk_buff *skb)
81 {
82         BUILD_BUG_ON(sizeof(struct amt_skb_cb) + sizeof(struct qdisc_skb_cb) >
83                      sizeof_field(struct sk_buff, cb));
84
85         return (struct amt_skb_cb *)((void *)skb->cb +
86                 sizeof(struct qdisc_skb_cb));
87 }
88
89 static void __amt_source_gc_work(void)
90 {
91         struct amt_source_node *snode;
92         struct hlist_head gc_list;
93         struct hlist_node *t;
94
95         spin_lock_bh(&source_gc_lock);
96         hlist_move_list(&source_gc_list, &gc_list);
97         spin_unlock_bh(&source_gc_lock);
98
99         hlist_for_each_entry_safe(snode, t, &gc_list, node) {
100                 hlist_del_rcu(&snode->node);
101                 kfree_rcu(snode, rcu);
102         }
103 }
104
105 static void amt_source_gc_work(struct work_struct *work)
106 {
107         __amt_source_gc_work();
108
109         spin_lock_bh(&source_gc_lock);
110         mod_delayed_work(amt_wq, &source_gc_wq,
111                          msecs_to_jiffies(AMT_GC_INTERVAL));
112         spin_unlock_bh(&source_gc_lock);
113 }
114
115 static bool amt_addr_equal(union amt_addr *a, union amt_addr *b)
116 {
117         return !memcmp(a, b, sizeof(union amt_addr));
118 }
119
120 static u32 amt_source_hash(struct amt_tunnel_list *tunnel, union amt_addr *src)
121 {
122         u32 hash = jhash(src, sizeof(*src), tunnel->amt->hash_seed);
123
124         return reciprocal_scale(hash, tunnel->amt->hash_buckets);
125 }
126
127 static bool amt_status_filter(struct amt_source_node *snode,
128                               enum amt_filter filter)
129 {
130         bool rc = false;
131
132         switch (filter) {
133         case AMT_FILTER_FWD:
134                 if (snode->status == AMT_SOURCE_STATUS_FWD &&
135                     snode->flags == AMT_SOURCE_OLD)
136                         rc = true;
137                 break;
138         case AMT_FILTER_D_FWD:
139                 if (snode->status == AMT_SOURCE_STATUS_D_FWD &&
140                     snode->flags == AMT_SOURCE_OLD)
141                         rc = true;
142                 break;
143         case AMT_FILTER_FWD_NEW:
144                 if (snode->status == AMT_SOURCE_STATUS_FWD &&
145                     snode->flags == AMT_SOURCE_NEW)
146                         rc = true;
147                 break;
148         case AMT_FILTER_D_FWD_NEW:
149                 if (snode->status == AMT_SOURCE_STATUS_D_FWD &&
150                     snode->flags == AMT_SOURCE_NEW)
151                         rc = true;
152                 break;
153         case AMT_FILTER_ALL:
154                 rc = true;
155                 break;
156         case AMT_FILTER_NONE_NEW:
157                 if (snode->status == AMT_SOURCE_STATUS_NONE &&
158                     snode->flags == AMT_SOURCE_NEW)
159                         rc = true;
160                 break;
161         case AMT_FILTER_BOTH:
162                 if ((snode->status == AMT_SOURCE_STATUS_D_FWD ||
163                      snode->status == AMT_SOURCE_STATUS_FWD) &&
164                     snode->flags == AMT_SOURCE_OLD)
165                         rc = true;
166                 break;
167         case AMT_FILTER_BOTH_NEW:
168                 if ((snode->status == AMT_SOURCE_STATUS_D_FWD ||
169                      snode->status == AMT_SOURCE_STATUS_FWD) &&
170                     snode->flags == AMT_SOURCE_NEW)
171                         rc = true;
172                 break;
173         default:
174                 WARN_ON_ONCE(1);
175                 break;
176         }
177
178         return rc;
179 }
180
181 static struct amt_source_node *amt_lookup_src(struct amt_tunnel_list *tunnel,
182                                               struct amt_group_node *gnode,
183                                               enum amt_filter filter,
184                                               union amt_addr *src)
185 {
186         u32 hash = amt_source_hash(tunnel, src);
187         struct amt_source_node *snode;
188
189         hlist_for_each_entry_rcu(snode, &gnode->sources[hash], node)
190                 if (amt_status_filter(snode, filter) &&
191                     amt_addr_equal(&snode->source_addr, src))
192                         return snode;
193
194         return NULL;
195 }
196
197 static u32 amt_group_hash(struct amt_tunnel_list *tunnel, union amt_addr *group)
198 {
199         u32 hash = jhash(group, sizeof(*group), tunnel->amt->hash_seed);
200
201         return reciprocal_scale(hash, tunnel->amt->hash_buckets);
202 }
203
204 static struct amt_group_node *amt_lookup_group(struct amt_tunnel_list *tunnel,
205                                                union amt_addr *group,
206                                                union amt_addr *host,
207                                                bool v6)
208 {
209         u32 hash = amt_group_hash(tunnel, group);
210         struct amt_group_node *gnode;
211
212         hlist_for_each_entry_rcu(gnode, &tunnel->groups[hash], node) {
213                 if (amt_addr_equal(&gnode->group_addr, group) &&
214                     amt_addr_equal(&gnode->host_addr, host) &&
215                     gnode->v6 == v6)
216                         return gnode;
217         }
218
219         return NULL;
220 }
221
222 static void amt_destroy_source(struct amt_source_node *snode)
223 {
224         struct amt_group_node *gnode = snode->gnode;
225         struct amt_tunnel_list *tunnel;
226
227         tunnel = gnode->tunnel_list;
228
229         if (!gnode->v6) {
230                 netdev_dbg(snode->gnode->amt->dev,
231                            "Delete source %pI4 from %pI4\n",
232                            &snode->source_addr.ip4,
233                            &gnode->group_addr.ip4);
234 #if IS_ENABLED(CONFIG_IPV6)
235         } else {
236                 netdev_dbg(snode->gnode->amt->dev,
237                            "Delete source %pI6 from %pI6\n",
238                            &snode->source_addr.ip6,
239                            &gnode->group_addr.ip6);
240 #endif
241         }
242
243         cancel_delayed_work(&snode->source_timer);
244         hlist_del_init_rcu(&snode->node);
245         tunnel->nr_sources--;
246         gnode->nr_sources--;
247         spin_lock_bh(&source_gc_lock);
248         hlist_add_head_rcu(&snode->node, &source_gc_list);
249         spin_unlock_bh(&source_gc_lock);
250 }
251
252 static void amt_del_group(struct amt_dev *amt, struct amt_group_node *gnode)
253 {
254         struct amt_source_node *snode;
255         struct hlist_node *t;
256         int i;
257
258         if (cancel_delayed_work(&gnode->group_timer))
259                 dev_put(amt->dev);
260         hlist_del_rcu(&gnode->node);
261         gnode->tunnel_list->nr_groups--;
262
263         if (!gnode->v6)
264                 netdev_dbg(amt->dev, "Leave group %pI4\n",
265                            &gnode->group_addr.ip4);
266 #if IS_ENABLED(CONFIG_IPV6)
267         else
268                 netdev_dbg(amt->dev, "Leave group %pI6\n",
269                            &gnode->group_addr.ip6);
270 #endif
271         for (i = 0; i < amt->hash_buckets; i++)
272                 hlist_for_each_entry_safe(snode, t, &gnode->sources[i], node)
273                         amt_destroy_source(snode);
274
275         /* tunnel->lock was acquired outside of amt_del_group()
276          * But rcu_read_lock() was acquired too so It's safe.
277          */
278         kfree_rcu(gnode, rcu);
279 }
280
281 /* If a source timer expires with a router filter-mode for the group of
282  * INCLUDE, the router concludes that traffic from this particular
283  * source is no longer desired on the attached network, and deletes the
284  * associated source record.
285  */
286 static void amt_source_work(struct work_struct *work)
287 {
288         struct amt_source_node *snode = container_of(to_delayed_work(work),
289                                                      struct amt_source_node,
290                                                      source_timer);
291         struct amt_group_node *gnode = snode->gnode;
292         struct amt_dev *amt = gnode->amt;
293         struct amt_tunnel_list *tunnel;
294
295         tunnel = gnode->tunnel_list;
296         spin_lock_bh(&tunnel->lock);
297         rcu_read_lock();
298         if (gnode->filter_mode == MCAST_INCLUDE) {
299                 amt_destroy_source(snode);
300                 if (!gnode->nr_sources)
301                         amt_del_group(amt, gnode);
302         } else {
303                 /* When a router filter-mode for a group is EXCLUDE,
304                  * source records are only deleted when the group timer expires
305                  */
306                 snode->status = AMT_SOURCE_STATUS_D_FWD;
307         }
308         rcu_read_unlock();
309         spin_unlock_bh(&tunnel->lock);
310 }
311
312 static void amt_act_src(struct amt_tunnel_list *tunnel,
313                         struct amt_group_node *gnode,
314                         struct amt_source_node *snode,
315                         enum amt_act act)
316 {
317         struct amt_dev *amt = tunnel->amt;
318
319         switch (act) {
320         case AMT_ACT_GMI:
321                 mod_delayed_work(amt_wq, &snode->source_timer,
322                                  msecs_to_jiffies(amt_gmi(amt)));
323                 break;
324         case AMT_ACT_GMI_ZERO:
325                 cancel_delayed_work(&snode->source_timer);
326                 break;
327         case AMT_ACT_GT:
328                 mod_delayed_work(amt_wq, &snode->source_timer,
329                                  gnode->group_timer.timer.expires);
330                 break;
331         case AMT_ACT_STATUS_FWD_NEW:
332                 snode->status = AMT_SOURCE_STATUS_FWD;
333                 snode->flags = AMT_SOURCE_NEW;
334                 break;
335         case AMT_ACT_STATUS_D_FWD_NEW:
336                 snode->status = AMT_SOURCE_STATUS_D_FWD;
337                 snode->flags = AMT_SOURCE_NEW;
338                 break;
339         case AMT_ACT_STATUS_NONE_NEW:
340                 cancel_delayed_work(&snode->source_timer);
341                 snode->status = AMT_SOURCE_STATUS_NONE;
342                 snode->flags = AMT_SOURCE_NEW;
343                 break;
344         default:
345                 WARN_ON_ONCE(1);
346                 return;
347         }
348
349         if (!gnode->v6)
350                 netdev_dbg(amt->dev, "Source %pI4 from %pI4 Acted %s\n",
351                            &snode->source_addr.ip4,
352                            &gnode->group_addr.ip4,
353                            action_str[act]);
354 #if IS_ENABLED(CONFIG_IPV6)
355         else
356                 netdev_dbg(amt->dev, "Source %pI6 from %pI6 Acted %s\n",
357                            &snode->source_addr.ip6,
358                            &gnode->group_addr.ip6,
359                            action_str[act]);
360 #endif
361 }
362
363 static struct amt_source_node *amt_alloc_snode(struct amt_group_node *gnode,
364                                                union amt_addr *src)
365 {
366         struct amt_source_node *snode;
367
368         snode = kzalloc(sizeof(*snode), GFP_ATOMIC);
369         if (!snode)
370                 return NULL;
371
372         memcpy(&snode->source_addr, src, sizeof(union amt_addr));
373         snode->gnode = gnode;
374         snode->status = AMT_SOURCE_STATUS_NONE;
375         snode->flags = AMT_SOURCE_NEW;
376         INIT_HLIST_NODE(&snode->node);
377         INIT_DELAYED_WORK(&snode->source_timer, amt_source_work);
378
379         return snode;
380 }
381
382 /* RFC 3810 - 7.2.2.  Definition of Filter Timers
383  *
384  *  Router Mode          Filter Timer         Actions/Comments
385  *  -----------       -----------------       ----------------
386  *
387  *    INCLUDE             Not Used            All listeners in
388  *                                            INCLUDE mode.
389  *
390  *    EXCLUDE             Timer > 0           At least one listener
391  *                                            in EXCLUDE mode.
392  *
393  *    EXCLUDE             Timer == 0          No more listeners in
394  *                                            EXCLUDE mode for the
395  *                                            multicast address.
396  *                                            If the Requested List
397  *                                            is empty, delete
398  *                                            Multicast Address
399  *                                            Record.  If not, switch
400  *                                            to INCLUDE filter mode;
401  *                                            the sources in the
402  *                                            Requested List are
403  *                                            moved to the Include
404  *                                            List, and the Exclude
405  *                                            List is deleted.
406  */
407 static void amt_group_work(struct work_struct *work)
408 {
409         struct amt_group_node *gnode = container_of(to_delayed_work(work),
410                                                     struct amt_group_node,
411                                                     group_timer);
412         struct amt_tunnel_list *tunnel = gnode->tunnel_list;
413         struct amt_dev *amt = gnode->amt;
414         struct amt_source_node *snode;
415         bool delete_group = true;
416         struct hlist_node *t;
417         int i, buckets;
418
419         buckets = amt->hash_buckets;
420
421         spin_lock_bh(&tunnel->lock);
422         if (gnode->filter_mode == MCAST_INCLUDE) {
423                 /* Not Used */
424                 spin_unlock_bh(&tunnel->lock);
425                 goto out;
426         }
427
428         rcu_read_lock();
429         for (i = 0; i < buckets; i++) {
430                 hlist_for_each_entry_safe(snode, t,
431                                           &gnode->sources[i], node) {
432                         if (!delayed_work_pending(&snode->source_timer) ||
433                             snode->status == AMT_SOURCE_STATUS_D_FWD) {
434                                 amt_destroy_source(snode);
435                         } else {
436                                 delete_group = false;
437                                 snode->status = AMT_SOURCE_STATUS_FWD;
438                         }
439                 }
440         }
441         if (delete_group)
442                 amt_del_group(amt, gnode);
443         else
444                 gnode->filter_mode = MCAST_INCLUDE;
445         rcu_read_unlock();
446         spin_unlock_bh(&tunnel->lock);
447 out:
448         dev_put(amt->dev);
449 }
450
451 /* Non-existant group is created as INCLUDE {empty}:
452  *
453  * RFC 3376 - 5.1. Action on Change of Interface State
454  *
455  * If no interface state existed for that multicast address before
456  * the change (i.e., the change consisted of creating a new
457  * per-interface record), or if no state exists after the change
458  * (i.e., the change consisted of deleting a per-interface record),
459  * then the "non-existent" state is considered to have a filter mode
460  * of INCLUDE and an empty source list.
461  */
462 static struct amt_group_node *amt_add_group(struct amt_dev *amt,
463                                             struct amt_tunnel_list *tunnel,
464                                             union amt_addr *group,
465                                             union amt_addr *host,
466                                             bool v6)
467 {
468         struct amt_group_node *gnode;
469         u32 hash;
470         int i;
471
472         if (tunnel->nr_groups >= amt->max_groups)
473                 return ERR_PTR(-ENOSPC);
474
475         gnode = kzalloc(sizeof(*gnode) +
476                         (sizeof(struct hlist_head) * amt->hash_buckets),
477                         GFP_ATOMIC);
478         if (unlikely(!gnode))
479                 return ERR_PTR(-ENOMEM);
480
481         gnode->amt = amt;
482         gnode->group_addr = *group;
483         gnode->host_addr = *host;
484         gnode->v6 = v6;
485         gnode->tunnel_list = tunnel;
486         gnode->filter_mode = MCAST_INCLUDE;
487         INIT_HLIST_NODE(&gnode->node);
488         INIT_DELAYED_WORK(&gnode->group_timer, amt_group_work);
489         for (i = 0; i < amt->hash_buckets; i++)
490                 INIT_HLIST_HEAD(&gnode->sources[i]);
491
492         hash = amt_group_hash(tunnel, group);
493         hlist_add_head_rcu(&gnode->node, &tunnel->groups[hash]);
494         tunnel->nr_groups++;
495
496         if (!gnode->v6)
497                 netdev_dbg(amt->dev, "Join group %pI4\n",
498                            &gnode->group_addr.ip4);
499 #if IS_ENABLED(CONFIG_IPV6)
500         else
501                 netdev_dbg(amt->dev, "Join group %pI6\n",
502                            &gnode->group_addr.ip6);
503 #endif
504
505         return gnode;
506 }
507
508 static struct sk_buff *amt_build_igmp_gq(struct amt_dev *amt)
509 {
510         u8 ra[AMT_IPHDR_OPTS] = { IPOPT_RA, 4, 0, 0 };
511         int hlen = LL_RESERVED_SPACE(amt->dev);
512         int tlen = amt->dev->needed_tailroom;
513         struct igmpv3_query *ihv3;
514         void *csum_start = NULL;
515         __sum16 *csum = NULL;
516         struct sk_buff *skb;
517         struct ethhdr *eth;
518         struct iphdr *iph;
519         unsigned int len;
520         int offset;
521
522         len = hlen + tlen + sizeof(*iph) + AMT_IPHDR_OPTS + sizeof(*ihv3);
523         skb = netdev_alloc_skb_ip_align(amt->dev, len);
524         if (!skb)
525                 return NULL;
526
527         skb_reserve(skb, hlen);
528         skb_push(skb, sizeof(*eth));
529         skb->protocol = htons(ETH_P_IP);
530         skb_reset_mac_header(skb);
531         skb->priority = TC_PRIO_CONTROL;
532         skb_put(skb, sizeof(*iph));
533         skb_put_data(skb, ra, sizeof(ra));
534         skb_put(skb, sizeof(*ihv3));
535         skb_pull(skb, sizeof(*eth));
536         skb_reset_network_header(skb);
537
538         iph             = ip_hdr(skb);
539         iph->version    = 4;
540         iph->ihl        = (sizeof(struct iphdr) + AMT_IPHDR_OPTS) >> 2;
541         iph->tos        = AMT_TOS;
542         iph->tot_len    = htons(sizeof(*iph) + AMT_IPHDR_OPTS + sizeof(*ihv3));
543         iph->frag_off   = htons(IP_DF);
544         iph->ttl        = 1;
545         iph->id         = 0;
546         iph->protocol   = IPPROTO_IGMP;
547         iph->daddr      = htonl(INADDR_ALLHOSTS_GROUP);
548         iph->saddr      = htonl(INADDR_ANY);
549         ip_send_check(iph);
550
551         eth = eth_hdr(skb);
552         ether_addr_copy(eth->h_source, amt->dev->dev_addr);
553         ip_eth_mc_map(htonl(INADDR_ALLHOSTS_GROUP), eth->h_dest);
554         eth->h_proto = htons(ETH_P_IP);
555
556         ihv3            = skb_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
557         skb_reset_transport_header(skb);
558         ihv3->type      = IGMP_HOST_MEMBERSHIP_QUERY;
559         ihv3->code      = 1;
560         ihv3->group     = 0;
561         ihv3->qqic      = amt->qi;
562         ihv3->nsrcs     = 0;
563         ihv3->resv      = 0;
564         ihv3->suppress  = false;
565         ihv3->qrv       = amt->net->ipv4.sysctl_igmp_qrv;
566         ihv3->csum      = 0;
567         csum            = &ihv3->csum;
568         csum_start      = (void *)ihv3;
569         *csum           = ip_compute_csum(csum_start, sizeof(*ihv3));
570         offset          = skb_transport_offset(skb);
571         skb->csum       = skb_checksum(skb, offset, skb->len - offset, 0);
572         skb->ip_summed  = CHECKSUM_NONE;
573
574         skb_push(skb, sizeof(*eth) + sizeof(*iph) + AMT_IPHDR_OPTS);
575
576         return skb;
577 }
578
579 static void __amt_update_gw_status(struct amt_dev *amt, enum amt_status status,
580                                    bool validate)
581 {
582         if (validate && amt->status >= status)
583                 return;
584         netdev_dbg(amt->dev, "Update GW status %s -> %s",
585                    status_str[amt->status], status_str[status]);
586         amt->status = status;
587 }
588
589 static void __amt_update_relay_status(struct amt_tunnel_list *tunnel,
590                                       enum amt_status status,
591                                       bool validate)
592 {
593         if (validate && tunnel->status >= status)
594                 return;
595         netdev_dbg(tunnel->amt->dev,
596                    "Update Tunnel(IP = %pI4, PORT = %u) status %s -> %s",
597                    &tunnel->ip4, ntohs(tunnel->source_port),
598                    status_str[tunnel->status], status_str[status]);
599         tunnel->status = status;
600 }
601
602 static void amt_update_gw_status(struct amt_dev *amt, enum amt_status status,
603                                  bool validate)
604 {
605         spin_lock_bh(&amt->lock);
606         __amt_update_gw_status(amt, status, validate);
607         spin_unlock_bh(&amt->lock);
608 }
609
610 static void amt_update_relay_status(struct amt_tunnel_list *tunnel,
611                                     enum amt_status status, bool validate)
612 {
613         spin_lock_bh(&tunnel->lock);
614         __amt_update_relay_status(tunnel, status, validate);
615         spin_unlock_bh(&tunnel->lock);
616 }
617
618 static void amt_send_discovery(struct amt_dev *amt)
619 {
620         struct amt_header_discovery *amtd;
621         int hlen, tlen, offset;
622         struct socket *sock;
623         struct udphdr *udph;
624         struct sk_buff *skb;
625         struct iphdr *iph;
626         struct rtable *rt;
627         struct flowi4 fl4;
628         u32 len;
629         int err;
630
631         rcu_read_lock();
632         sock = rcu_dereference(amt->sock);
633         if (!sock)
634                 goto out;
635
636         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
637                 goto out;
638
639         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
640                                    amt->discovery_ip, amt->local_ip,
641                                    amt->gw_port, amt->relay_port,
642                                    IPPROTO_UDP, 0,
643                                    amt->stream_dev->ifindex);
644         if (IS_ERR(rt)) {
645                 amt->dev->stats.tx_errors++;
646                 goto out;
647         }
648
649         hlen = LL_RESERVED_SPACE(amt->dev);
650         tlen = amt->dev->needed_tailroom;
651         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amtd);
652         skb = netdev_alloc_skb_ip_align(amt->dev, len);
653         if (!skb) {
654                 ip_rt_put(rt);
655                 amt->dev->stats.tx_errors++;
656                 goto out;
657         }
658
659         skb->priority = TC_PRIO_CONTROL;
660         skb_dst_set(skb, &rt->dst);
661
662         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amtd);
663         skb_reset_network_header(skb);
664         skb_put(skb, len);
665         amtd = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
666         amtd->version   = 0;
667         amtd->type      = AMT_MSG_DISCOVERY;
668         amtd->reserved  = 0;
669         amtd->nonce     = amt->nonce;
670         skb_push(skb, sizeof(*udph));
671         skb_reset_transport_header(skb);
672         udph            = udp_hdr(skb);
673         udph->source    = amt->gw_port;
674         udph->dest      = amt->relay_port;
675         udph->len       = htons(sizeof(*udph) + sizeof(*amtd));
676         udph->check     = 0;
677         offset = skb_transport_offset(skb);
678         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
679         udph->check = csum_tcpudp_magic(amt->local_ip, amt->discovery_ip,
680                                         sizeof(*udph) + sizeof(*amtd),
681                                         IPPROTO_UDP, skb->csum);
682
683         skb_push(skb, sizeof(*iph));
684         iph             = ip_hdr(skb);
685         iph->version    = 4;
686         iph->ihl        = (sizeof(struct iphdr)) >> 2;
687         iph->tos        = AMT_TOS;
688         iph->frag_off   = 0;
689         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
690         iph->daddr      = amt->discovery_ip;
691         iph->saddr      = amt->local_ip;
692         iph->protocol   = IPPROTO_UDP;
693         iph->tot_len    = htons(len);
694
695         skb->ip_summed = CHECKSUM_NONE;
696         ip_select_ident(amt->net, skb, NULL);
697         ip_send_check(iph);
698         err = ip_local_out(amt->net, sock->sk, skb);
699         if (unlikely(net_xmit_eval(err)))
700                 amt->dev->stats.tx_errors++;
701
702         spin_lock_bh(&amt->lock);
703         __amt_update_gw_status(amt, AMT_STATUS_SENT_DISCOVERY, true);
704         spin_unlock_bh(&amt->lock);
705 out:
706         rcu_read_unlock();
707 }
708
709 static void amt_send_request(struct amt_dev *amt, bool v6)
710 {
711         struct amt_header_request *amtrh;
712         int hlen, tlen, offset;
713         struct socket *sock;
714         struct udphdr *udph;
715         struct sk_buff *skb;
716         struct iphdr *iph;
717         struct rtable *rt;
718         struct flowi4 fl4;
719         u32 len;
720         int err;
721
722         rcu_read_lock();
723         sock = rcu_dereference(amt->sock);
724         if (!sock)
725                 goto out;
726
727         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
728                 goto out;
729
730         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
731                                    amt->remote_ip, amt->local_ip,
732                                    amt->gw_port, amt->relay_port,
733                                    IPPROTO_UDP, 0,
734                                    amt->stream_dev->ifindex);
735         if (IS_ERR(rt)) {
736                 amt->dev->stats.tx_errors++;
737                 goto out;
738         }
739
740         hlen = LL_RESERVED_SPACE(amt->dev);
741         tlen = amt->dev->needed_tailroom;
742         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amtrh);
743         skb = netdev_alloc_skb_ip_align(amt->dev, len);
744         if (!skb) {
745                 ip_rt_put(rt);
746                 amt->dev->stats.tx_errors++;
747                 goto out;
748         }
749
750         skb->priority = TC_PRIO_CONTROL;
751         skb_dst_set(skb, &rt->dst);
752
753         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amtrh);
754         skb_reset_network_header(skb);
755         skb_put(skb, len);
756         amtrh = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
757         amtrh->version   = 0;
758         amtrh->type      = AMT_MSG_REQUEST;
759         amtrh->reserved1 = 0;
760         amtrh->p         = v6;
761         amtrh->reserved2 = 0;
762         amtrh->nonce     = amt->nonce;
763         skb_push(skb, sizeof(*udph));
764         skb_reset_transport_header(skb);
765         udph            = udp_hdr(skb);
766         udph->source    = amt->gw_port;
767         udph->dest      = amt->relay_port;
768         udph->len       = htons(sizeof(*amtrh) + sizeof(*udph));
769         udph->check     = 0;
770         offset = skb_transport_offset(skb);
771         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
772         udph->check = csum_tcpudp_magic(amt->local_ip, amt->remote_ip,
773                                         sizeof(*udph) + sizeof(*amtrh),
774                                         IPPROTO_UDP, skb->csum);
775
776         skb_push(skb, sizeof(*iph));
777         iph             = ip_hdr(skb);
778         iph->version    = 4;
779         iph->ihl        = (sizeof(struct iphdr)) >> 2;
780         iph->tos        = AMT_TOS;
781         iph->frag_off   = 0;
782         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
783         iph->daddr      = amt->remote_ip;
784         iph->saddr      = amt->local_ip;
785         iph->protocol   = IPPROTO_UDP;
786         iph->tot_len    = htons(len);
787
788         skb->ip_summed = CHECKSUM_NONE;
789         ip_select_ident(amt->net, skb, NULL);
790         ip_send_check(iph);
791         err = ip_local_out(amt->net, sock->sk, skb);
792         if (unlikely(net_xmit_eval(err)))
793                 amt->dev->stats.tx_errors++;
794
795 out:
796         rcu_read_unlock();
797 }
798
799 static void amt_send_igmp_gq(struct amt_dev *amt,
800                              struct amt_tunnel_list *tunnel)
801 {
802         struct sk_buff *skb;
803
804         skb = amt_build_igmp_gq(amt);
805         if (!skb)
806                 return;
807
808         amt_skb_cb(skb)->tunnel = tunnel;
809         dev_queue_xmit(skb);
810 }
811
812 #if IS_ENABLED(CONFIG_IPV6)
813 static struct sk_buff *amt_build_mld_gq(struct amt_dev *amt)
814 {
815         u8 ra[AMT_IP6HDR_OPTS] = { IPPROTO_ICMPV6, 0, IPV6_TLV_ROUTERALERT,
816                                    2, 0, 0, IPV6_TLV_PAD1, IPV6_TLV_PAD1 };
817         int hlen = LL_RESERVED_SPACE(amt->dev);
818         int tlen = amt->dev->needed_tailroom;
819         struct mld2_query *mld2q;
820         void *csum_start = NULL;
821         struct ipv6hdr *ip6h;
822         struct sk_buff *skb;
823         struct ethhdr *eth;
824         u32 len;
825
826         len = hlen + tlen + sizeof(*ip6h) + sizeof(ra) + sizeof(*mld2q);
827         skb = netdev_alloc_skb_ip_align(amt->dev, len);
828         if (!skb)
829                 return NULL;
830
831         skb_reserve(skb, hlen);
832         skb_push(skb, sizeof(*eth));
833         skb_reset_mac_header(skb);
834         eth = eth_hdr(skb);
835         skb->priority = TC_PRIO_CONTROL;
836         skb->protocol = htons(ETH_P_IPV6);
837         skb_put_zero(skb, sizeof(*ip6h));
838         skb_put_data(skb, ra, sizeof(ra));
839         skb_put_zero(skb, sizeof(*mld2q));
840         skb_pull(skb, sizeof(*eth));
841         skb_reset_network_header(skb);
842         ip6h                    = ipv6_hdr(skb);
843         ip6h->payload_len       = htons(sizeof(ra) + sizeof(*mld2q));
844         ip6h->nexthdr           = NEXTHDR_HOP;
845         ip6h->hop_limit         = 1;
846         ip6h->daddr             = mld2_all_node;
847         ip6_flow_hdr(ip6h, 0, 0);
848
849         if (ipv6_dev_get_saddr(amt->net, amt->dev, &ip6h->daddr, 0,
850                                &ip6h->saddr)) {
851                 amt->dev->stats.tx_errors++;
852                 kfree_skb(skb);
853                 return NULL;
854         }
855
856         eth->h_proto = htons(ETH_P_IPV6);
857         ether_addr_copy(eth->h_source, amt->dev->dev_addr);
858         ipv6_eth_mc_map(&mld2_all_node, eth->h_dest);
859
860         skb_pull(skb, sizeof(*ip6h) + sizeof(ra));
861         skb_reset_transport_header(skb);
862         mld2q                   = (struct mld2_query *)icmp6_hdr(skb);
863         mld2q->mld2q_mrc        = htons(1);
864         mld2q->mld2q_type       = ICMPV6_MGM_QUERY;
865         mld2q->mld2q_code       = 0;
866         mld2q->mld2q_cksum      = 0;
867         mld2q->mld2q_resv1      = 0;
868         mld2q->mld2q_resv2      = 0;
869         mld2q->mld2q_suppress   = 0;
870         mld2q->mld2q_qrv        = amt->qrv;
871         mld2q->mld2q_nsrcs      = 0;
872         mld2q->mld2q_qqic       = amt->qi;
873         csum_start              = (void *)mld2q;
874         mld2q->mld2q_cksum = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
875                                              sizeof(*mld2q),
876                                              IPPROTO_ICMPV6,
877                                              csum_partial(csum_start,
878                                                           sizeof(*mld2q), 0));
879
880         skb->ip_summed = CHECKSUM_NONE;
881         skb_push(skb, sizeof(*eth) + sizeof(*ip6h) + sizeof(ra));
882         return skb;
883 }
884
885 static void amt_send_mld_gq(struct amt_dev *amt, struct amt_tunnel_list *tunnel)
886 {
887         struct sk_buff *skb;
888
889         skb = amt_build_mld_gq(amt);
890         if (!skb)
891                 return;
892
893         amt_skb_cb(skb)->tunnel = tunnel;
894         dev_queue_xmit(skb);
895 }
896 #else
897 static void amt_send_mld_gq(struct amt_dev *amt, struct amt_tunnel_list *tunnel)
898 {
899 }
900 #endif
901
902 static void amt_secret_work(struct work_struct *work)
903 {
904         struct amt_dev *amt = container_of(to_delayed_work(work),
905                                            struct amt_dev,
906                                            secret_wq);
907
908         spin_lock_bh(&amt->lock);
909         get_random_bytes(&amt->key, sizeof(siphash_key_t));
910         spin_unlock_bh(&amt->lock);
911         mod_delayed_work(amt_wq, &amt->secret_wq,
912                          msecs_to_jiffies(AMT_SECRET_TIMEOUT));
913 }
914
915 static void amt_discovery_work(struct work_struct *work)
916 {
917         struct amt_dev *amt = container_of(to_delayed_work(work),
918                                            struct amt_dev,
919                                            discovery_wq);
920
921         spin_lock_bh(&amt->lock);
922         if (amt->status > AMT_STATUS_SENT_DISCOVERY)
923                 goto out;
924         get_random_bytes(&amt->nonce, sizeof(__be32));
925         spin_unlock_bh(&amt->lock);
926
927         amt_send_discovery(amt);
928         spin_lock_bh(&amt->lock);
929 out:
930         mod_delayed_work(amt_wq, &amt->discovery_wq,
931                          msecs_to_jiffies(AMT_DISCOVERY_TIMEOUT));
932         spin_unlock_bh(&amt->lock);
933 }
934
935 static void amt_req_work(struct work_struct *work)
936 {
937         struct amt_dev *amt = container_of(to_delayed_work(work),
938                                            struct amt_dev,
939                                            req_wq);
940         u32 exp;
941
942         spin_lock_bh(&amt->lock);
943         if (amt->status < AMT_STATUS_RECEIVED_ADVERTISEMENT)
944                 goto out;
945
946         if (amt->req_cnt > AMT_MAX_REQ_COUNT) {
947                 netdev_dbg(amt->dev, "Gateway is not ready");
948                 amt->qi = AMT_INIT_REQ_TIMEOUT;
949                 amt->ready4 = false;
950                 amt->ready6 = false;
951                 amt->remote_ip = 0;
952                 __amt_update_gw_status(amt, AMT_STATUS_INIT, false);
953                 amt->req_cnt = 0;
954                 goto out;
955         }
956         spin_unlock_bh(&amt->lock);
957
958         amt_send_request(amt, false);
959         amt_send_request(amt, true);
960         spin_lock_bh(&amt->lock);
961         __amt_update_gw_status(amt, AMT_STATUS_SENT_REQUEST, true);
962         amt->req_cnt++;
963 out:
964         exp = min_t(u32, (1 * (1 << amt->req_cnt)), AMT_MAX_REQ_TIMEOUT);
965         mod_delayed_work(amt_wq, &amt->req_wq, msecs_to_jiffies(exp * 1000));
966         spin_unlock_bh(&amt->lock);
967 }
968
969 static bool amt_send_membership_update(struct amt_dev *amt,
970                                        struct sk_buff *skb,
971                                        bool v6)
972 {
973         struct amt_header_membership_update *amtmu;
974         struct socket *sock;
975         struct iphdr *iph;
976         struct flowi4 fl4;
977         struct rtable *rt;
978         int err;
979
980         sock = rcu_dereference_bh(amt->sock);
981         if (!sock)
982                 return true;
983
984         err = skb_cow_head(skb, LL_RESERVED_SPACE(amt->dev) + sizeof(*amtmu) +
985                            sizeof(*iph) + sizeof(struct udphdr));
986         if (err)
987                 return true;
988
989         skb_reset_inner_headers(skb);
990         memset(&fl4, 0, sizeof(struct flowi4));
991         fl4.flowi4_oif         = amt->stream_dev->ifindex;
992         fl4.daddr              = amt->remote_ip;
993         fl4.saddr              = amt->local_ip;
994         fl4.flowi4_tos         = AMT_TOS;
995         fl4.flowi4_proto       = IPPROTO_UDP;
996         rt = ip_route_output_key(amt->net, &fl4);
997         if (IS_ERR(rt)) {
998                 netdev_dbg(amt->dev, "no route to %pI4\n", &amt->remote_ip);
999                 return true;
1000         }
1001
1002         amtmu                   = skb_push(skb, sizeof(*amtmu));
1003         amtmu->version          = 0;
1004         amtmu->type             = AMT_MSG_MEMBERSHIP_UPDATE;
1005         amtmu->reserved         = 0;
1006         amtmu->nonce            = amt->nonce;
1007         amtmu->response_mac     = amt->mac;
1008
1009         if (!v6)
1010                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1011         else
1012                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1013         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1014                             fl4.saddr,
1015                             fl4.daddr,
1016                             AMT_TOS,
1017                             ip4_dst_hoplimit(&rt->dst),
1018                             0,
1019                             amt->gw_port,
1020                             amt->relay_port,
1021                             false,
1022                             false);
1023         amt_update_gw_status(amt, AMT_STATUS_SENT_UPDATE, true);
1024         return false;
1025 }
1026
1027 static void amt_send_multicast_data(struct amt_dev *amt,
1028                                     const struct sk_buff *oskb,
1029                                     struct amt_tunnel_list *tunnel,
1030                                     bool v6)
1031 {
1032         struct amt_header_mcast_data *amtmd;
1033         struct socket *sock;
1034         struct sk_buff *skb;
1035         struct iphdr *iph;
1036         struct flowi4 fl4;
1037         struct rtable *rt;
1038
1039         sock = rcu_dereference_bh(amt->sock);
1040         if (!sock)
1041                 return;
1042
1043         skb = skb_copy_expand(oskb, sizeof(*amtmd) + sizeof(*iph) +
1044                               sizeof(struct udphdr), 0, GFP_ATOMIC);
1045         if (!skb)
1046                 return;
1047
1048         skb_reset_inner_headers(skb);
1049         memset(&fl4, 0, sizeof(struct flowi4));
1050         fl4.flowi4_oif         = amt->stream_dev->ifindex;
1051         fl4.daddr              = tunnel->ip4;
1052         fl4.saddr              = amt->local_ip;
1053         fl4.flowi4_proto       = IPPROTO_UDP;
1054         rt = ip_route_output_key(amt->net, &fl4);
1055         if (IS_ERR(rt)) {
1056                 netdev_dbg(amt->dev, "no route to %pI4\n", &tunnel->ip4);
1057                 kfree_skb(skb);
1058                 return;
1059         }
1060
1061         amtmd = skb_push(skb, sizeof(*amtmd));
1062         amtmd->version = 0;
1063         amtmd->reserved = 0;
1064         amtmd->type = AMT_MSG_MULTICAST_DATA;
1065
1066         if (!v6)
1067                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1068         else
1069                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1070         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1071                             fl4.saddr,
1072                             fl4.daddr,
1073                             AMT_TOS,
1074                             ip4_dst_hoplimit(&rt->dst),
1075                             0,
1076                             amt->relay_port,
1077                             tunnel->source_port,
1078                             false,
1079                             false);
1080 }
1081
1082 static bool amt_send_membership_query(struct amt_dev *amt,
1083                                       struct sk_buff *skb,
1084                                       struct amt_tunnel_list *tunnel,
1085                                       bool v6)
1086 {
1087         struct amt_header_membership_query *amtmq;
1088         struct socket *sock;
1089         struct rtable *rt;
1090         struct flowi4 fl4;
1091         int err;
1092
1093         sock = rcu_dereference_bh(amt->sock);
1094         if (!sock)
1095                 return true;
1096
1097         err = skb_cow_head(skb, LL_RESERVED_SPACE(amt->dev) + sizeof(*amtmq) +
1098                            sizeof(struct iphdr) + sizeof(struct udphdr));
1099         if (err)
1100                 return true;
1101
1102         skb_reset_inner_headers(skb);
1103         memset(&fl4, 0, sizeof(struct flowi4));
1104         fl4.flowi4_oif         = amt->stream_dev->ifindex;
1105         fl4.daddr              = tunnel->ip4;
1106         fl4.saddr              = amt->local_ip;
1107         fl4.flowi4_tos         = AMT_TOS;
1108         fl4.flowi4_proto       = IPPROTO_UDP;
1109         rt = ip_route_output_key(amt->net, &fl4);
1110         if (IS_ERR(rt)) {
1111                 netdev_dbg(amt->dev, "no route to %pI4\n", &tunnel->ip4);
1112                 return true;
1113         }
1114
1115         amtmq           = skb_push(skb, sizeof(*amtmq));
1116         amtmq->version  = 0;
1117         amtmq->type     = AMT_MSG_MEMBERSHIP_QUERY;
1118         amtmq->reserved = 0;
1119         amtmq->l        = 0;
1120         amtmq->g        = 0;
1121         amtmq->nonce    = tunnel->nonce;
1122         amtmq->response_mac = tunnel->mac;
1123
1124         if (!v6)
1125                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1126         else
1127                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1128         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1129                             fl4.saddr,
1130                             fl4.daddr,
1131                             AMT_TOS,
1132                             ip4_dst_hoplimit(&rt->dst),
1133                             0,
1134                             amt->relay_port,
1135                             tunnel->source_port,
1136                             false,
1137                             false);
1138         amt_update_relay_status(tunnel, AMT_STATUS_SENT_QUERY, true);
1139         return false;
1140 }
1141
1142 static netdev_tx_t amt_dev_xmit(struct sk_buff *skb, struct net_device *dev)
1143 {
1144         struct amt_dev *amt = netdev_priv(dev);
1145         struct amt_tunnel_list *tunnel;
1146         struct amt_group_node *gnode;
1147         union amt_addr group = {0,};
1148 #if IS_ENABLED(CONFIG_IPV6)
1149         struct ipv6hdr *ip6h;
1150         struct mld_msg *mld;
1151 #endif
1152         bool report = false;
1153         struct igmphdr *ih;
1154         bool query = false;
1155         struct iphdr *iph;
1156         bool data = false;
1157         bool v6 = false;
1158         u32 hash;
1159
1160         iph = ip_hdr(skb);
1161         if (iph->version == 4) {
1162                 if (!ipv4_is_multicast(iph->daddr))
1163                         goto free;
1164
1165                 if (!ip_mc_check_igmp(skb)) {
1166                         ih = igmp_hdr(skb);
1167                         switch (ih->type) {
1168                         case IGMPV3_HOST_MEMBERSHIP_REPORT:
1169                         case IGMP_HOST_MEMBERSHIP_REPORT:
1170                                 report = true;
1171                                 break;
1172                         case IGMP_HOST_MEMBERSHIP_QUERY:
1173                                 query = true;
1174                                 break;
1175                         default:
1176                                 goto free;
1177                         }
1178                 } else {
1179                         data = true;
1180                 }
1181                 v6 = false;
1182                 group.ip4 = iph->daddr;
1183 #if IS_ENABLED(CONFIG_IPV6)
1184         } else if (iph->version == 6) {
1185                 ip6h = ipv6_hdr(skb);
1186                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
1187                         goto free;
1188
1189                 if (!ipv6_mc_check_mld(skb)) {
1190                         mld = (struct mld_msg *)skb_transport_header(skb);
1191                         switch (mld->mld_type) {
1192                         case ICMPV6_MGM_REPORT:
1193                         case ICMPV6_MLD2_REPORT:
1194                                 report = true;
1195                                 break;
1196                         case ICMPV6_MGM_QUERY:
1197                                 query = true;
1198                                 break;
1199                         default:
1200                                 goto free;
1201                         }
1202                 } else {
1203                         data = true;
1204                 }
1205                 v6 = true;
1206                 group.ip6 = ip6h->daddr;
1207 #endif
1208         } else {
1209                 dev->stats.tx_errors++;
1210                 goto free;
1211         }
1212
1213         if (!pskb_may_pull(skb, sizeof(struct ethhdr)))
1214                 goto free;
1215
1216         skb_pull(skb, sizeof(struct ethhdr));
1217
1218         if (amt->mode == AMT_MODE_GATEWAY) {
1219                 /* Gateway only passes IGMP/MLD packets */
1220                 if (!report)
1221                         goto free;
1222                 if ((!v6 && !amt->ready4) || (v6 && !amt->ready6))
1223                         goto free;
1224                 if (amt_send_membership_update(amt, skb,  v6))
1225                         goto free;
1226                 goto unlock;
1227         } else if (amt->mode == AMT_MODE_RELAY) {
1228                 if (query) {
1229                         tunnel = amt_skb_cb(skb)->tunnel;
1230                         if (!tunnel) {
1231                                 WARN_ON(1);
1232                                 goto free;
1233                         }
1234
1235                         /* Do not forward unexpected query */
1236                         if (amt_send_membership_query(amt, skb, tunnel, v6))
1237                                 goto free;
1238                         goto unlock;
1239                 }
1240
1241                 if (!data)
1242                         goto free;
1243                 list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list) {
1244                         hash = amt_group_hash(tunnel, &group);
1245                         hlist_for_each_entry_rcu(gnode, &tunnel->groups[hash],
1246                                                  node) {
1247                                 if (!v6) {
1248                                         if (gnode->group_addr.ip4 == iph->daddr)
1249                                                 goto found;
1250 #if IS_ENABLED(CONFIG_IPV6)
1251                                 } else {
1252                                         if (ipv6_addr_equal(&gnode->group_addr.ip6,
1253                                                             &ip6h->daddr))
1254                                                 goto found;
1255 #endif
1256                                 }
1257                         }
1258                         continue;
1259 found:
1260                         amt_send_multicast_data(amt, skb, tunnel, v6);
1261                 }
1262         }
1263
1264         dev_kfree_skb(skb);
1265         return NETDEV_TX_OK;
1266 free:
1267         dev_kfree_skb(skb);
1268 unlock:
1269         dev->stats.tx_dropped++;
1270         return NETDEV_TX_OK;
1271 }
1272
1273 static int amt_parse_type(struct sk_buff *skb)
1274 {
1275         struct amt_header *amth;
1276
1277         if (!pskb_may_pull(skb, sizeof(struct udphdr) +
1278                            sizeof(struct amt_header)))
1279                 return -1;
1280
1281         amth = (struct amt_header *)(udp_hdr(skb) + 1);
1282
1283         if (amth->version != 0)
1284                 return -1;
1285
1286         if (amth->type >= __AMT_MSG_MAX || !amth->type)
1287                 return -1;
1288         return amth->type;
1289 }
1290
1291 static void amt_clear_groups(struct amt_tunnel_list *tunnel)
1292 {
1293         struct amt_dev *amt = tunnel->amt;
1294         struct amt_group_node *gnode;
1295         struct hlist_node *t;
1296         int i;
1297
1298         spin_lock_bh(&tunnel->lock);
1299         rcu_read_lock();
1300         for (i = 0; i < amt->hash_buckets; i++)
1301                 hlist_for_each_entry_safe(gnode, t, &tunnel->groups[i], node)
1302                         amt_del_group(amt, gnode);
1303         rcu_read_unlock();
1304         spin_unlock_bh(&tunnel->lock);
1305 }
1306
1307 static void amt_tunnel_expire(struct work_struct *work)
1308 {
1309         struct amt_tunnel_list *tunnel = container_of(to_delayed_work(work),
1310                                                       struct amt_tunnel_list,
1311                                                       gc_wq);
1312         struct amt_dev *amt = tunnel->amt;
1313
1314         spin_lock_bh(&amt->lock);
1315         rcu_read_lock();
1316         list_del_rcu(&tunnel->list);
1317         amt->nr_tunnels--;
1318         amt_clear_groups(tunnel);
1319         rcu_read_unlock();
1320         spin_unlock_bh(&amt->lock);
1321         kfree_rcu(tunnel, rcu);
1322 }
1323
1324 static void amt_cleanup_srcs(struct amt_dev *amt,
1325                              struct amt_tunnel_list *tunnel,
1326                              struct amt_group_node *gnode)
1327 {
1328         struct amt_source_node *snode;
1329         struct hlist_node *t;
1330         int i;
1331
1332         /* Delete old sources */
1333         for (i = 0; i < amt->hash_buckets; i++) {
1334                 hlist_for_each_entry_safe(snode, t, &gnode->sources[i], node) {
1335                         if (snode->flags == AMT_SOURCE_OLD)
1336                                 amt_destroy_source(snode);
1337                 }
1338         }
1339
1340         /* switch from new to old */
1341         for (i = 0; i < amt->hash_buckets; i++)  {
1342                 hlist_for_each_entry_rcu(snode, &gnode->sources[i], node) {
1343                         snode->flags = AMT_SOURCE_OLD;
1344                         if (!gnode->v6)
1345                                 netdev_dbg(snode->gnode->amt->dev,
1346                                            "Add source as OLD %pI4 from %pI4\n",
1347                                            &snode->source_addr.ip4,
1348                                            &gnode->group_addr.ip4);
1349 #if IS_ENABLED(CONFIG_IPV6)
1350                         else
1351                                 netdev_dbg(snode->gnode->amt->dev,
1352                                            "Add source as OLD %pI6 from %pI6\n",
1353                                            &snode->source_addr.ip6,
1354                                            &gnode->group_addr.ip6);
1355 #endif
1356                 }
1357         }
1358 }
1359
1360 static void amt_add_srcs(struct amt_dev *amt, struct amt_tunnel_list *tunnel,
1361                          struct amt_group_node *gnode, void *grec,
1362                          bool v6)
1363 {
1364         struct igmpv3_grec *igmp_grec;
1365         struct amt_source_node *snode;
1366 #if IS_ENABLED(CONFIG_IPV6)
1367         struct mld2_grec *mld_grec;
1368 #endif
1369         union amt_addr src = {0,};
1370         u16 nsrcs;
1371         u32 hash;
1372         int i;
1373
1374         if (!v6) {
1375                 igmp_grec = (struct igmpv3_grec *)grec;
1376                 nsrcs = ntohs(igmp_grec->grec_nsrcs);
1377         } else {
1378 #if IS_ENABLED(CONFIG_IPV6)
1379                 mld_grec = (struct mld2_grec *)grec;
1380                 nsrcs = ntohs(mld_grec->grec_nsrcs);
1381 #else
1382         return;
1383 #endif
1384         }
1385         for (i = 0; i < nsrcs; i++) {
1386                 if (tunnel->nr_sources >= amt->max_sources)
1387                         return;
1388                 if (!v6)
1389                         src.ip4 = igmp_grec->grec_src[i];
1390 #if IS_ENABLED(CONFIG_IPV6)
1391                 else
1392                         memcpy(&src.ip6, &mld_grec->grec_src[i],
1393                                sizeof(struct in6_addr));
1394 #endif
1395                 if (amt_lookup_src(tunnel, gnode, AMT_FILTER_ALL, &src))
1396                         continue;
1397
1398                 snode = amt_alloc_snode(gnode, &src);
1399                 if (snode) {
1400                         hash = amt_source_hash(tunnel, &snode->source_addr);
1401                         hlist_add_head_rcu(&snode->node, &gnode->sources[hash]);
1402                         tunnel->nr_sources++;
1403                         gnode->nr_sources++;
1404
1405                         if (!gnode->v6)
1406                                 netdev_dbg(snode->gnode->amt->dev,
1407                                            "Add source as NEW %pI4 from %pI4\n",
1408                                            &snode->source_addr.ip4,
1409                                            &gnode->group_addr.ip4);
1410 #if IS_ENABLED(CONFIG_IPV6)
1411                         else
1412                                 netdev_dbg(snode->gnode->amt->dev,
1413                                            "Add source as NEW %pI6 from %pI6\n",
1414                                            &snode->source_addr.ip6,
1415                                            &gnode->group_addr.ip6);
1416 #endif
1417                 }
1418         }
1419 }
1420
1421 /* Router State   Report Rec'd New Router State
1422  * ------------   ------------ ----------------
1423  * EXCLUDE (X,Y)  IS_IN (A)    EXCLUDE (X+A,Y-A)
1424  *
1425  * -----------+-----------+-----------+
1426  *            |    OLD    |    NEW    |
1427  * -----------+-----------+-----------+
1428  *    FWD     |     X     |    X+A    |
1429  * -----------+-----------+-----------+
1430  *    D_FWD   |     Y     |    Y-A    |
1431  * -----------+-----------+-----------+
1432  *    NONE    |           |     A     |
1433  * -----------+-----------+-----------+
1434  *
1435  * a) Received sources are NONE/NEW
1436  * b) All NONE will be deleted by amt_cleanup_srcs().
1437  * c) All OLD will be deleted by amt_cleanup_srcs().
1438  * d) After delete, NEW source will be switched to OLD.
1439  */
1440 static void amt_lookup_act_srcs(struct amt_tunnel_list *tunnel,
1441                                 struct amt_group_node *gnode,
1442                                 void *grec,
1443                                 enum amt_ops ops,
1444                                 enum amt_filter filter,
1445                                 enum amt_act act,
1446                                 bool v6)
1447 {
1448         struct amt_dev *amt = tunnel->amt;
1449         struct amt_source_node *snode;
1450         struct igmpv3_grec *igmp_grec;
1451 #if IS_ENABLED(CONFIG_IPV6)
1452         struct mld2_grec *mld_grec;
1453 #endif
1454         union amt_addr src = {0,};
1455         struct hlist_node *t;
1456         u16 nsrcs;
1457         int i, j;
1458
1459         if (!v6) {
1460                 igmp_grec = (struct igmpv3_grec *)grec;
1461                 nsrcs = ntohs(igmp_grec->grec_nsrcs);
1462         } else {
1463 #if IS_ENABLED(CONFIG_IPV6)
1464                 mld_grec = (struct mld2_grec *)grec;
1465                 nsrcs = ntohs(mld_grec->grec_nsrcs);
1466 #else
1467         return;
1468 #endif
1469         }
1470
1471         memset(&src, 0, sizeof(union amt_addr));
1472         switch (ops) {
1473         case AMT_OPS_INT:
1474                 /* A*B */
1475                 for (i = 0; i < nsrcs; i++) {
1476                         if (!v6)
1477                                 src.ip4 = igmp_grec->grec_src[i];
1478 #if IS_ENABLED(CONFIG_IPV6)
1479                         else
1480                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1481                                        sizeof(struct in6_addr));
1482 #endif
1483                         snode = amt_lookup_src(tunnel, gnode, filter, &src);
1484                         if (!snode)
1485                                 continue;
1486                         amt_act_src(tunnel, gnode, snode, act);
1487                 }
1488                 break;
1489         case AMT_OPS_UNI:
1490                 /* A+B */
1491                 for (i = 0; i < amt->hash_buckets; i++) {
1492                         hlist_for_each_entry_safe(snode, t, &gnode->sources[i],
1493                                                   node) {
1494                                 if (amt_status_filter(snode, filter))
1495                                         amt_act_src(tunnel, gnode, snode, act);
1496                         }
1497                 }
1498                 for (i = 0; i < nsrcs; i++) {
1499                         if (!v6)
1500                                 src.ip4 = igmp_grec->grec_src[i];
1501 #if IS_ENABLED(CONFIG_IPV6)
1502                         else
1503                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1504                                        sizeof(struct in6_addr));
1505 #endif
1506                         snode = amt_lookup_src(tunnel, gnode, filter, &src);
1507                         if (!snode)
1508                                 continue;
1509                         amt_act_src(tunnel, gnode, snode, act);
1510                 }
1511                 break;
1512         case AMT_OPS_SUB:
1513                 /* A-B */
1514                 for (i = 0; i < amt->hash_buckets; i++) {
1515                         hlist_for_each_entry_safe(snode, t, &gnode->sources[i],
1516                                                   node) {
1517                                 if (!amt_status_filter(snode, filter))
1518                                         continue;
1519                                 for (j = 0; j < nsrcs; j++) {
1520                                         if (!v6)
1521                                                 src.ip4 = igmp_grec->grec_src[j];
1522 #if IS_ENABLED(CONFIG_IPV6)
1523                                         else
1524                                                 memcpy(&src.ip6,
1525                                                        &mld_grec->grec_src[j],
1526                                                        sizeof(struct in6_addr));
1527 #endif
1528                                         if (amt_addr_equal(&snode->source_addr,
1529                                                            &src))
1530                                                 goto out_sub;
1531                                 }
1532                                 amt_act_src(tunnel, gnode, snode, act);
1533                                 continue;
1534 out_sub:;
1535                         }
1536                 }
1537                 break;
1538         case AMT_OPS_SUB_REV:
1539                 /* B-A */
1540                 for (i = 0; i < nsrcs; i++) {
1541                         if (!v6)
1542                                 src.ip4 = igmp_grec->grec_src[i];
1543 #if IS_ENABLED(CONFIG_IPV6)
1544                         else
1545                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1546                                        sizeof(struct in6_addr));
1547 #endif
1548                         snode = amt_lookup_src(tunnel, gnode, AMT_FILTER_ALL,
1549                                                &src);
1550                         if (!snode) {
1551                                 snode = amt_lookup_src(tunnel, gnode,
1552                                                        filter, &src);
1553                                 if (snode)
1554                                         amt_act_src(tunnel, gnode, snode, act);
1555                         }
1556                 }
1557                 break;
1558         default:
1559                 netdev_dbg(amt->dev, "Invalid type\n");
1560                 return;
1561         }
1562 }
1563
1564 static void amt_mcast_is_in_handler(struct amt_dev *amt,
1565                                     struct amt_tunnel_list *tunnel,
1566                                     struct amt_group_node *gnode,
1567                                     void *grec, void *zero_grec, bool v6)
1568 {
1569         if (gnode->filter_mode == MCAST_INCLUDE) {
1570 /* Router State   Report Rec'd New Router State        Actions
1571  * ------------   ------------ ----------------        -------
1572  * INCLUDE (A)    IS_IN (B)    INCLUDE (A+B)           (B)=GMI
1573  */
1574                 /* Update IS_IN (B) as FWD/NEW */
1575                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1576                                     AMT_FILTER_NONE_NEW,
1577                                     AMT_ACT_STATUS_FWD_NEW,
1578                                     v6);
1579                 /* Update INCLUDE (A) as NEW */
1580                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1581                                     AMT_FILTER_FWD,
1582                                     AMT_ACT_STATUS_FWD_NEW,
1583                                     v6);
1584                 /* (B)=GMI */
1585                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1586                                     AMT_FILTER_FWD_NEW,
1587                                     AMT_ACT_GMI,
1588                                     v6);
1589         } else {
1590 /* State        Actions
1591  * ------------   ------------ ----------------        -------
1592  * EXCLUDE (X,Y)  IS_IN (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1593  */
1594                 /* Update (A) in (X, Y) as NONE/NEW */
1595                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1596                                     AMT_FILTER_BOTH,
1597                                     AMT_ACT_STATUS_NONE_NEW,
1598                                     v6);
1599                 /* Update FWD/OLD as FWD/NEW */
1600                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1601                                     AMT_FILTER_FWD,
1602                                     AMT_ACT_STATUS_FWD_NEW,
1603                                     v6);
1604                 /* Update IS_IN (A) as FWD/NEW */
1605                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1606                                     AMT_FILTER_NONE_NEW,
1607                                     AMT_ACT_STATUS_FWD_NEW,
1608                                     v6);
1609                 /* Update EXCLUDE (, Y-A) as D_FWD_NEW */
1610                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB,
1611                                     AMT_FILTER_D_FWD,
1612                                     AMT_ACT_STATUS_D_FWD_NEW,
1613                                     v6);
1614         }
1615 }
1616
1617 static void amt_mcast_is_ex_handler(struct amt_dev *amt,
1618                                     struct amt_tunnel_list *tunnel,
1619                                     struct amt_group_node *gnode,
1620                                     void *grec, void *zero_grec, bool v6)
1621 {
1622         if (gnode->filter_mode == MCAST_INCLUDE) {
1623 /* Router State   Report Rec'd  New Router State         Actions
1624  * ------------   ------------  ----------------         -------
1625  * INCLUDE (A)    IS_EX (B)     EXCLUDE (A*B,B-A)        (B-A)=0
1626  *                                                       Delete (A-B)
1627  *                                                       Group Timer=GMI
1628  */
1629                 /* EXCLUDE(A*B, ) */
1630                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1631                                     AMT_FILTER_FWD,
1632                                     AMT_ACT_STATUS_FWD_NEW,
1633                                     v6);
1634                 /* EXCLUDE(, B-A) */
1635                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1636                                     AMT_FILTER_FWD,
1637                                     AMT_ACT_STATUS_D_FWD_NEW,
1638                                     v6);
1639                 /* (B-A)=0 */
1640                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1641                                     AMT_FILTER_D_FWD_NEW,
1642                                     AMT_ACT_GMI_ZERO,
1643                                     v6);
1644                 /* Group Timer=GMI */
1645                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1646                                       msecs_to_jiffies(amt_gmi(amt))))
1647                         dev_hold(amt->dev);
1648                 gnode->filter_mode = MCAST_EXCLUDE;
1649                 /* Delete (A-B) will be worked by amt_cleanup_srcs(). */
1650         } else {
1651 /* Router State   Report Rec'd  New Router State        Actions
1652  * ------------   ------------  ----------------        -------
1653  * EXCLUDE (X,Y)  IS_EX (A)     EXCLUDE (A-Y,Y*A)       (A-X-Y)=GMI
1654  *                                                      Delete (X-A)
1655  *                                                      Delete (Y-A)
1656  *                                                      Group Timer=GMI
1657  */
1658                 /* EXCLUDE (A-Y, ) */
1659                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1660                                     AMT_FILTER_D_FWD,
1661                                     AMT_ACT_STATUS_FWD_NEW,
1662                                     v6);
1663                 /* EXCLUDE (, Y*A ) */
1664                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1665                                     AMT_FILTER_D_FWD,
1666                                     AMT_ACT_STATUS_D_FWD_NEW,
1667                                     v6);
1668                 /* (A-X-Y)=GMI */
1669                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1670                                     AMT_FILTER_BOTH_NEW,
1671                                     AMT_ACT_GMI,
1672                                     v6);
1673                 /* Group Timer=GMI */
1674                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1675                                       msecs_to_jiffies(amt_gmi(amt))))
1676                         dev_hold(amt->dev);
1677                 /* Delete (X-A), (Y-A) will be worked by amt_cleanup_srcs(). */
1678         }
1679 }
1680
1681 static void amt_mcast_to_in_handler(struct amt_dev *amt,
1682                                     struct amt_tunnel_list *tunnel,
1683                                     struct amt_group_node *gnode,
1684                                     void *grec, void *zero_grec, bool v6)
1685 {
1686         if (gnode->filter_mode == MCAST_INCLUDE) {
1687 /* Router State   Report Rec'd New Router State        Actions
1688  * ------------   ------------ ----------------        -------
1689  * INCLUDE (A)    TO_IN (B)    INCLUDE (A+B)           (B)=GMI
1690  *                                                     Send Q(G,A-B)
1691  */
1692                 /* Update TO_IN (B) sources as FWD/NEW */
1693                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1694                                     AMT_FILTER_NONE_NEW,
1695                                     AMT_ACT_STATUS_FWD_NEW,
1696                                     v6);
1697                 /* Update INCLUDE (A) sources as NEW */
1698                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1699                                     AMT_FILTER_FWD,
1700                                     AMT_ACT_STATUS_FWD_NEW,
1701                                     v6);
1702                 /* (B)=GMI */
1703                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1704                                     AMT_FILTER_FWD_NEW,
1705                                     AMT_ACT_GMI,
1706                                     v6);
1707         } else {
1708 /* Router State   Report Rec'd New Router State        Actions
1709  * ------------   ------------ ----------------        -------
1710  * EXCLUDE (X,Y)  TO_IN (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1711  *                                                     Send Q(G,X-A)
1712  *                                                     Send Q(G)
1713  */
1714                 /* Update TO_IN (A) sources as FWD/NEW */
1715                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1716                                     AMT_FILTER_NONE_NEW,
1717                                     AMT_ACT_STATUS_FWD_NEW,
1718                                     v6);
1719                 /* Update EXCLUDE(X,) sources as FWD/NEW */
1720                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1721                                     AMT_FILTER_FWD,
1722                                     AMT_ACT_STATUS_FWD_NEW,
1723                                     v6);
1724                 /* EXCLUDE (, Y-A)
1725                  * (A) are already switched to FWD_NEW.
1726                  * So, D_FWD/OLD -> D_FWD/NEW is okay.
1727                  */
1728                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1729                                     AMT_FILTER_D_FWD,
1730                                     AMT_ACT_STATUS_D_FWD_NEW,
1731                                     v6);
1732                 /* (A)=GMI
1733                  * Only FWD_NEW will have (A) sources.
1734                  */
1735                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1736                                     AMT_FILTER_FWD_NEW,
1737                                     AMT_ACT_GMI,
1738                                     v6);
1739         }
1740 }
1741
1742 static void amt_mcast_to_ex_handler(struct amt_dev *amt,
1743                                     struct amt_tunnel_list *tunnel,
1744                                     struct amt_group_node *gnode,
1745                                     void *grec, void *zero_grec, bool v6)
1746 {
1747         if (gnode->filter_mode == MCAST_INCLUDE) {
1748 /* Router State   Report Rec'd New Router State        Actions
1749  * ------------   ------------ ----------------        -------
1750  * INCLUDE (A)    TO_EX (B)    EXCLUDE (A*B,B-A)       (B-A)=0
1751  *                                                     Delete (A-B)
1752  *                                                     Send Q(G,A*B)
1753  *                                                     Group Timer=GMI
1754  */
1755                 /* EXCLUDE (A*B, ) */
1756                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1757                                     AMT_FILTER_FWD,
1758                                     AMT_ACT_STATUS_FWD_NEW,
1759                                     v6);
1760                 /* EXCLUDE (, B-A) */
1761                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1762                                     AMT_FILTER_FWD,
1763                                     AMT_ACT_STATUS_D_FWD_NEW,
1764                                     v6);
1765                 /* (B-A)=0 */
1766                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1767                                     AMT_FILTER_D_FWD_NEW,
1768                                     AMT_ACT_GMI_ZERO,
1769                                     v6);
1770                 /* Group Timer=GMI */
1771                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1772                                       msecs_to_jiffies(amt_gmi(amt))))
1773                         dev_hold(amt->dev);
1774                 gnode->filter_mode = MCAST_EXCLUDE;
1775                 /* Delete (A-B) will be worked by amt_cleanup_srcs(). */
1776         } else {
1777 /* Router State   Report Rec'd New Router State        Actions
1778  * ------------   ------------ ----------------        -------
1779  * EXCLUDE (X,Y)  TO_EX (A)    EXCLUDE (A-Y,Y*A)       (A-X-Y)=Group Timer
1780  *                                                     Delete (X-A)
1781  *                                                     Delete (Y-A)
1782  *                                                     Send Q(G,A-Y)
1783  *                                                     Group Timer=GMI
1784  */
1785                 /* Update (A-X-Y) as NONE/OLD */
1786                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1787                                     AMT_FILTER_BOTH,
1788                                     AMT_ACT_GT,
1789                                     v6);
1790                 /* EXCLUDE (A-Y, ) */
1791                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1792                                     AMT_FILTER_D_FWD,
1793                                     AMT_ACT_STATUS_FWD_NEW,
1794                                     v6);
1795                 /* EXCLUDE (, Y*A) */
1796                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1797                                     AMT_FILTER_D_FWD,
1798                                     AMT_ACT_STATUS_D_FWD_NEW,
1799                                     v6);
1800                 /* Group Timer=GMI */
1801                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1802                                       msecs_to_jiffies(amt_gmi(amt))))
1803                         dev_hold(amt->dev);
1804                 /* Delete (X-A), (Y-A) will be worked by amt_cleanup_srcs(). */
1805         }
1806 }
1807
1808 static void amt_mcast_allow_handler(struct amt_dev *amt,
1809                                     struct amt_tunnel_list *tunnel,
1810                                     struct amt_group_node *gnode,
1811                                     void *grec, void *zero_grec, bool v6)
1812 {
1813         if (gnode->filter_mode == MCAST_INCLUDE) {
1814 /* Router State   Report Rec'd New Router State        Actions
1815  * ------------   ------------ ----------------        -------
1816  * INCLUDE (A)    ALLOW (B)    INCLUDE (A+B)           (B)=GMI
1817  */
1818                 /* INCLUDE (A+B) */
1819                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1820                                     AMT_FILTER_FWD,
1821                                     AMT_ACT_STATUS_FWD_NEW,
1822                                     v6);
1823                 /* (B)=GMI */
1824                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1825                                     AMT_FILTER_FWD_NEW,
1826                                     AMT_ACT_GMI,
1827                                     v6);
1828         } else {
1829 /* Router State   Report Rec'd New Router State        Actions
1830  * ------------   ------------ ----------------        -------
1831  * EXCLUDE (X,Y)  ALLOW (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1832  */
1833                 /* EXCLUDE (X+A, ) */
1834                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1835                                     AMT_FILTER_FWD,
1836                                     AMT_ACT_STATUS_FWD_NEW,
1837                                     v6);
1838                 /* EXCLUDE (, Y-A) */
1839                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB,
1840                                     AMT_FILTER_D_FWD,
1841                                     AMT_ACT_STATUS_D_FWD_NEW,
1842                                     v6);
1843                 /* (A)=GMI
1844                  * All (A) source are now FWD/NEW status.
1845                  */
1846                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1847                                     AMT_FILTER_FWD_NEW,
1848                                     AMT_ACT_GMI,
1849                                     v6);
1850         }
1851 }
1852
1853 static void amt_mcast_block_handler(struct amt_dev *amt,
1854                                     struct amt_tunnel_list *tunnel,
1855                                     struct amt_group_node *gnode,
1856                                     void *grec, void *zero_grec, bool v6)
1857 {
1858         if (gnode->filter_mode == MCAST_INCLUDE) {
1859 /* Router State   Report Rec'd New Router State        Actions
1860  * ------------   ------------ ----------------        -------
1861  * INCLUDE (A)    BLOCK (B)    INCLUDE (A)             Send Q(G,A*B)
1862  */
1863                 /* INCLUDE (A) */
1864                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1865                                     AMT_FILTER_FWD,
1866                                     AMT_ACT_STATUS_FWD_NEW,
1867                                     v6);
1868         } else {
1869 /* Router State   Report Rec'd New Router State        Actions
1870  * ------------   ------------ ----------------        -------
1871  * EXCLUDE (X,Y)  BLOCK (A)    EXCLUDE (X+(A-Y),Y)     (A-X-Y)=Group Timer
1872  *                                                     Send Q(G,A-Y)
1873  */
1874                 /* (A-X-Y)=Group Timer */
1875                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1876                                     AMT_FILTER_BOTH,
1877                                     AMT_ACT_GT,
1878                                     v6);
1879                 /* EXCLUDE (X, ) */
1880                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1881                                     AMT_FILTER_FWD,
1882                                     AMT_ACT_STATUS_FWD_NEW,
1883                                     v6);
1884                 /* EXCLUDE (X+(A-Y) */
1885                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1886                                     AMT_FILTER_D_FWD,
1887                                     AMT_ACT_STATUS_FWD_NEW,
1888                                     v6);
1889                 /* EXCLUDE (, Y) */
1890                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1891                                     AMT_FILTER_D_FWD,
1892                                     AMT_ACT_STATUS_D_FWD_NEW,
1893                                     v6);
1894         }
1895 }
1896
1897 /* RFC 3376
1898  * 7.3.2. In the Presence of Older Version Group Members
1899  *
1900  * When Group Compatibility Mode is IGMPv2, a router internally
1901  * translates the following IGMPv2 messages for that group to their
1902  * IGMPv3 equivalents:
1903  *
1904  * IGMPv2 Message                IGMPv3 Equivalent
1905  * --------------                -----------------
1906  * Report                        IS_EX( {} )
1907  * Leave                         TO_IN( {} )
1908  */
1909 static void amt_igmpv2_report_handler(struct amt_dev *amt, struct sk_buff *skb,
1910                                       struct amt_tunnel_list *tunnel)
1911 {
1912         struct igmphdr *ih = igmp_hdr(skb);
1913         struct iphdr *iph = ip_hdr(skb);
1914         struct amt_group_node *gnode;
1915         union amt_addr group, host;
1916
1917         memset(&group, 0, sizeof(union amt_addr));
1918         group.ip4 = ih->group;
1919         memset(&host, 0, sizeof(union amt_addr));
1920         host.ip4 = iph->saddr;
1921
1922         gnode = amt_lookup_group(tunnel, &group, &host, false);
1923         if (!gnode) {
1924                 gnode = amt_add_group(amt, tunnel, &group, &host, false);
1925                 if (!IS_ERR(gnode)) {
1926                         gnode->filter_mode = MCAST_EXCLUDE;
1927                         if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1928                                               msecs_to_jiffies(amt_gmi(amt))))
1929                                 dev_hold(amt->dev);
1930                 }
1931         }
1932 }
1933
1934 /* RFC 3376
1935  * 7.3.2. In the Presence of Older Version Group Members
1936  *
1937  * When Group Compatibility Mode is IGMPv2, a router internally
1938  * translates the following IGMPv2 messages for that group to their
1939  * IGMPv3 equivalents:
1940  *
1941  * IGMPv2 Message                IGMPv3 Equivalent
1942  * --------------                -----------------
1943  * Report                        IS_EX( {} )
1944  * Leave                         TO_IN( {} )
1945  */
1946 static void amt_igmpv2_leave_handler(struct amt_dev *amt, struct sk_buff *skb,
1947                                      struct amt_tunnel_list *tunnel)
1948 {
1949         struct igmphdr *ih = igmp_hdr(skb);
1950         struct iphdr *iph = ip_hdr(skb);
1951         struct amt_group_node *gnode;
1952         union amt_addr group, host;
1953
1954         memset(&group, 0, sizeof(union amt_addr));
1955         group.ip4 = ih->group;
1956         memset(&host, 0, sizeof(union amt_addr));
1957         host.ip4 = iph->saddr;
1958
1959         gnode = amt_lookup_group(tunnel, &group, &host, false);
1960         if (gnode)
1961                 amt_del_group(amt, gnode);
1962 }
1963
1964 static void amt_igmpv3_report_handler(struct amt_dev *amt, struct sk_buff *skb,
1965                                       struct amt_tunnel_list *tunnel)
1966 {
1967         struct igmpv3_report *ihrv3 = igmpv3_report_hdr(skb);
1968         int len = skb_transport_offset(skb) + sizeof(*ihrv3);
1969         void *zero_grec = (void *)&igmpv3_zero_grec;
1970         struct iphdr *iph = ip_hdr(skb);
1971         struct amt_group_node *gnode;
1972         union amt_addr group, host;
1973         struct igmpv3_grec *grec;
1974         u16 nsrcs;
1975         int i;
1976
1977         for (i = 0; i < ntohs(ihrv3->ngrec); i++) {
1978                 len += sizeof(*grec);
1979                 if (!ip_mc_may_pull(skb, len))
1980                         break;
1981
1982                 grec = (void *)(skb->data + len - sizeof(*grec));
1983                 nsrcs = ntohs(grec->grec_nsrcs);
1984
1985                 len += nsrcs * sizeof(__be32);
1986                 if (!ip_mc_may_pull(skb, len))
1987                         break;
1988
1989                 memset(&group, 0, sizeof(union amt_addr));
1990                 group.ip4 = grec->grec_mca;
1991                 memset(&host, 0, sizeof(union amt_addr));
1992                 host.ip4 = iph->saddr;
1993                 gnode = amt_lookup_group(tunnel, &group, &host, false);
1994                 if (!gnode) {
1995                         gnode = amt_add_group(amt, tunnel, &group, &host,
1996                                               false);
1997                         if (IS_ERR(gnode))
1998                                 continue;
1999                 }
2000
2001                 amt_add_srcs(amt, tunnel, gnode, grec, false);
2002                 switch (grec->grec_type) {
2003                 case IGMPV3_MODE_IS_INCLUDE:
2004                         amt_mcast_is_in_handler(amt, tunnel, gnode, grec,
2005                                                 zero_grec, false);
2006                         break;
2007                 case IGMPV3_MODE_IS_EXCLUDE:
2008                         amt_mcast_is_ex_handler(amt, tunnel, gnode, grec,
2009                                                 zero_grec, false);
2010                         break;
2011                 case IGMPV3_CHANGE_TO_INCLUDE:
2012                         amt_mcast_to_in_handler(amt, tunnel, gnode, grec,
2013                                                 zero_grec, false);
2014                         break;
2015                 case IGMPV3_CHANGE_TO_EXCLUDE:
2016                         amt_mcast_to_ex_handler(amt, tunnel, gnode, grec,
2017                                                 zero_grec, false);
2018                         break;
2019                 case IGMPV3_ALLOW_NEW_SOURCES:
2020                         amt_mcast_allow_handler(amt, tunnel, gnode, grec,
2021                                                 zero_grec, false);
2022                         break;
2023                 case IGMPV3_BLOCK_OLD_SOURCES:
2024                         amt_mcast_block_handler(amt, tunnel, gnode, grec,
2025                                                 zero_grec, false);
2026                         break;
2027                 default:
2028                         break;
2029                 }
2030                 amt_cleanup_srcs(amt, tunnel, gnode);
2031         }
2032 }
2033
2034 /* caller held tunnel->lock */
2035 static void amt_igmp_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2036                                     struct amt_tunnel_list *tunnel)
2037 {
2038         struct igmphdr *ih = igmp_hdr(skb);
2039
2040         switch (ih->type) {
2041         case IGMPV3_HOST_MEMBERSHIP_REPORT:
2042                 amt_igmpv3_report_handler(amt, skb, tunnel);
2043                 break;
2044         case IGMPV2_HOST_MEMBERSHIP_REPORT:
2045                 amt_igmpv2_report_handler(amt, skb, tunnel);
2046                 break;
2047         case IGMP_HOST_LEAVE_MESSAGE:
2048                 amt_igmpv2_leave_handler(amt, skb, tunnel);
2049                 break;
2050         default:
2051                 break;
2052         }
2053 }
2054
2055 #if IS_ENABLED(CONFIG_IPV6)
2056 /* RFC 3810
2057  * 8.3.2. In the Presence of MLDv1 Multicast Address Listeners
2058  *
2059  * When Multicast Address Compatibility Mode is MLDv2, a router acts
2060  * using the MLDv2 protocol for that multicast address.  When Multicast
2061  * Address Compatibility Mode is MLDv1, a router internally translates
2062  * the following MLDv1 messages for that multicast address to their
2063  * MLDv2 equivalents:
2064  *
2065  * MLDv1 Message                 MLDv2 Equivalent
2066  * --------------                -----------------
2067  * Report                        IS_EX( {} )
2068  * Done                          TO_IN( {} )
2069  */
2070 static void amt_mldv1_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2071                                      struct amt_tunnel_list *tunnel)
2072 {
2073         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2074         struct ipv6hdr *ip6h = ipv6_hdr(skb);
2075         struct amt_group_node *gnode;
2076         union amt_addr group, host;
2077
2078         memcpy(&group.ip6, &mld->mld_mca, sizeof(struct in6_addr));
2079         memcpy(&host.ip6, &ip6h->saddr, sizeof(struct in6_addr));
2080
2081         gnode = amt_lookup_group(tunnel, &group, &host, true);
2082         if (!gnode) {
2083                 gnode = amt_add_group(amt, tunnel, &group, &host, true);
2084                 if (!IS_ERR(gnode)) {
2085                         gnode->filter_mode = MCAST_EXCLUDE;
2086                         if (!mod_delayed_work(amt_wq, &gnode->group_timer,
2087                                               msecs_to_jiffies(amt_gmi(amt))))
2088                                 dev_hold(amt->dev);
2089                 }
2090         }
2091 }
2092
2093 /* RFC 3810
2094  * 8.3.2. In the Presence of MLDv1 Multicast Address Listeners
2095  *
2096  * When Multicast Address Compatibility Mode is MLDv2, a router acts
2097  * using the MLDv2 protocol for that multicast address.  When Multicast
2098  * Address Compatibility Mode is MLDv1, a router internally translates
2099  * the following MLDv1 messages for that multicast address to their
2100  * MLDv2 equivalents:
2101  *
2102  * MLDv1 Message                 MLDv2 Equivalent
2103  * --------------                -----------------
2104  * Report                        IS_EX( {} )
2105  * Done                          TO_IN( {} )
2106  */
2107 static void amt_mldv1_leave_handler(struct amt_dev *amt, struct sk_buff *skb,
2108                                     struct amt_tunnel_list *tunnel)
2109 {
2110         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2111         struct iphdr *iph = ip_hdr(skb);
2112         struct amt_group_node *gnode;
2113         union amt_addr group, host;
2114
2115         memcpy(&group.ip6, &mld->mld_mca, sizeof(struct in6_addr));
2116         memset(&host, 0, sizeof(union amt_addr));
2117         host.ip4 = iph->saddr;
2118
2119         gnode = amt_lookup_group(tunnel, &group, &host, true);
2120         if (gnode) {
2121                 amt_del_group(amt, gnode);
2122                 return;
2123         }
2124 }
2125
2126 static void amt_mldv2_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2127                                      struct amt_tunnel_list *tunnel)
2128 {
2129         struct mld2_report *mld2r = (struct mld2_report *)icmp6_hdr(skb);
2130         int len = skb_transport_offset(skb) + sizeof(*mld2r);
2131         void *zero_grec = (void *)&mldv2_zero_grec;
2132         struct ipv6hdr *ip6h = ipv6_hdr(skb);
2133         struct amt_group_node *gnode;
2134         union amt_addr group, host;
2135         struct mld2_grec *grec;
2136         u16 nsrcs;
2137         int i;
2138
2139         for (i = 0; i < ntohs(mld2r->mld2r_ngrec); i++) {
2140                 len += sizeof(*grec);
2141                 if (!ipv6_mc_may_pull(skb, len))
2142                         break;
2143
2144                 grec = (void *)(skb->data + len - sizeof(*grec));
2145                 nsrcs = ntohs(grec->grec_nsrcs);
2146
2147                 len += nsrcs * sizeof(struct in6_addr);
2148                 if (!ipv6_mc_may_pull(skb, len))
2149                         break;
2150
2151                 memset(&group, 0, sizeof(union amt_addr));
2152                 group.ip6 = grec->grec_mca;
2153                 memset(&host, 0, sizeof(union amt_addr));
2154                 host.ip6 = ip6h->saddr;
2155                 gnode = amt_lookup_group(tunnel, &group, &host, true);
2156                 if (!gnode) {
2157                         gnode = amt_add_group(amt, tunnel, &group, &host,
2158                                               ETH_P_IPV6);
2159                         if (IS_ERR(gnode))
2160                                 continue;
2161                 }
2162
2163                 amt_add_srcs(amt, tunnel, gnode, grec, true);
2164                 switch (grec->grec_type) {
2165                 case MLD2_MODE_IS_INCLUDE:
2166                         amt_mcast_is_in_handler(amt, tunnel, gnode, grec,
2167                                                 zero_grec, true);
2168                         break;
2169                 case MLD2_MODE_IS_EXCLUDE:
2170                         amt_mcast_is_ex_handler(amt, tunnel, gnode, grec,
2171                                                 zero_grec, true);
2172                         break;
2173                 case MLD2_CHANGE_TO_INCLUDE:
2174                         amt_mcast_to_in_handler(amt, tunnel, gnode, grec,
2175                                                 zero_grec, true);
2176                         break;
2177                 case MLD2_CHANGE_TO_EXCLUDE:
2178                         amt_mcast_to_ex_handler(amt, tunnel, gnode, grec,
2179                                                 zero_grec, true);
2180                         break;
2181                 case MLD2_ALLOW_NEW_SOURCES:
2182                         amt_mcast_allow_handler(amt, tunnel, gnode, grec,
2183                                                 zero_grec, true);
2184                         break;
2185                 case MLD2_BLOCK_OLD_SOURCES:
2186                         amt_mcast_block_handler(amt, tunnel, gnode, grec,
2187                                                 zero_grec, true);
2188                         break;
2189                 default:
2190                         break;
2191                 }
2192                 amt_cleanup_srcs(amt, tunnel, gnode);
2193         }
2194 }
2195
2196 /* caller held tunnel->lock */
2197 static void amt_mld_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2198                                    struct amt_tunnel_list *tunnel)
2199 {
2200         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2201
2202         switch (mld->mld_type) {
2203         case ICMPV6_MGM_REPORT:
2204                 amt_mldv1_report_handler(amt, skb, tunnel);
2205                 break;
2206         case ICMPV6_MLD2_REPORT:
2207                 amt_mldv2_report_handler(amt, skb, tunnel);
2208                 break;
2209         case ICMPV6_MGM_REDUCTION:
2210                 amt_mldv1_leave_handler(amt, skb, tunnel);
2211                 break;
2212         default:
2213                 break;
2214         }
2215 }
2216 #endif
2217
2218 static bool amt_advertisement_handler(struct amt_dev *amt, struct sk_buff *skb)
2219 {
2220         struct amt_header_advertisement *amta;
2221         int hdr_size;
2222
2223         hdr_size = sizeof(*amta) - sizeof(struct amt_header);
2224
2225         if (!pskb_may_pull(skb, hdr_size))
2226                 return true;
2227
2228         amta = (struct amt_header_advertisement *)(udp_hdr(skb) + 1);
2229         if (!amta->ip4)
2230                 return true;
2231
2232         if (amta->reserved || amta->version)
2233                 return true;
2234
2235         if (ipv4_is_loopback(amta->ip4) || ipv4_is_multicast(amta->ip4) ||
2236             ipv4_is_zeronet(amta->ip4))
2237                 return true;
2238
2239         amt->remote_ip = amta->ip4;
2240         netdev_dbg(amt->dev, "advertised remote ip = %pI4\n", &amt->remote_ip);
2241         mod_delayed_work(amt_wq, &amt->req_wq, 0);
2242
2243         amt_update_gw_status(amt, AMT_STATUS_RECEIVED_ADVERTISEMENT, true);
2244         return false;
2245 }
2246
2247 static bool amt_multicast_data_handler(struct amt_dev *amt, struct sk_buff *skb)
2248 {
2249         struct amt_header_mcast_data *amtmd;
2250         int hdr_size, len, err;
2251         struct ethhdr *eth;
2252         struct iphdr *iph;
2253
2254         amtmd = (struct amt_header_mcast_data *)(udp_hdr(skb) + 1);
2255         if (amtmd->reserved || amtmd->version)
2256                 return true;
2257
2258         hdr_size = sizeof(*amtmd) + sizeof(struct udphdr);
2259         if (iptunnel_pull_header(skb, hdr_size, htons(ETH_P_IP), false))
2260                 return true;
2261         skb_reset_network_header(skb);
2262         skb_push(skb, sizeof(*eth));
2263         skb_reset_mac_header(skb);
2264         skb_pull(skb, sizeof(*eth));
2265         eth = eth_hdr(skb);
2266         iph = ip_hdr(skb);
2267         if (iph->version == 4) {
2268                 if (!ipv4_is_multicast(iph->daddr))
2269                         return true;
2270                 skb->protocol = htons(ETH_P_IP);
2271                 eth->h_proto = htons(ETH_P_IP);
2272                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2273 #if IS_ENABLED(CONFIG_IPV6)
2274         } else if (iph->version == 6) {
2275                 struct ipv6hdr *ip6h;
2276
2277                 ip6h = ipv6_hdr(skb);
2278                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
2279                         return true;
2280                 skb->protocol = htons(ETH_P_IPV6);
2281                 eth->h_proto = htons(ETH_P_IPV6);
2282                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2283 #endif
2284         } else {
2285                 return true;
2286         }
2287
2288         skb->pkt_type = PACKET_MULTICAST;
2289         skb->ip_summed = CHECKSUM_NONE;
2290         len = skb->len;
2291         err = gro_cells_receive(&amt->gro_cells, skb);
2292         if (likely(err == NET_RX_SUCCESS))
2293                 dev_sw_netstats_rx_add(amt->dev, len);
2294         else
2295                 amt->dev->stats.rx_dropped++;
2296
2297         return false;
2298 }
2299
2300 static bool amt_membership_query_handler(struct amt_dev *amt,
2301                                          struct sk_buff *skb)
2302 {
2303         struct amt_header_membership_query *amtmq;
2304         struct igmpv3_query *ihv3;
2305         struct ethhdr *eth, *oeth;
2306         struct iphdr *iph;
2307         int hdr_size, len;
2308
2309         hdr_size = sizeof(*amtmq) - sizeof(struct amt_header);
2310
2311         if (!pskb_may_pull(skb, hdr_size))
2312                 return true;
2313
2314         amtmq = (struct amt_header_membership_query *)(udp_hdr(skb) + 1);
2315         if (amtmq->reserved || amtmq->version)
2316                 return true;
2317
2318         hdr_size = sizeof(*amtmq) + sizeof(struct udphdr) - sizeof(*eth);
2319         if (iptunnel_pull_header(skb, hdr_size, htons(ETH_P_TEB), false))
2320                 return true;
2321         oeth = eth_hdr(skb);
2322         skb_reset_mac_header(skb);
2323         skb_pull(skb, sizeof(*eth));
2324         skb_reset_network_header(skb);
2325         eth = eth_hdr(skb);
2326         iph = ip_hdr(skb);
2327         if (iph->version == 4) {
2328                 if (!ipv4_is_multicast(iph->daddr))
2329                         return true;
2330                 if (!pskb_may_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS +
2331                                    sizeof(*ihv3)))
2332                         return true;
2333
2334                 ihv3 = skb_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
2335                 skb_reset_transport_header(skb);
2336                 skb_push(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
2337                 spin_lock_bh(&amt->lock);
2338                 amt->ready4 = true;
2339                 amt->mac = amtmq->response_mac;
2340                 amt->req_cnt = 0;
2341                 amt->qi = ihv3->qqic;
2342                 spin_unlock_bh(&amt->lock);
2343                 skb->protocol = htons(ETH_P_IP);
2344                 eth->h_proto = htons(ETH_P_IP);
2345                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2346 #if IS_ENABLED(CONFIG_IPV6)
2347         } else if (iph->version == 6) {
2348                 struct ipv6hdr *ip6h = ipv6_hdr(skb);
2349                 struct mld2_query *mld2q;
2350
2351                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
2352                         return true;
2353                 if (!pskb_may_pull(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS +
2354                                    sizeof(*mld2q)))
2355                         return true;
2356
2357                 mld2q = skb_pull(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS);
2358                 skb_reset_transport_header(skb);
2359                 skb_push(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS);
2360                 spin_lock_bh(&amt->lock);
2361                 amt->ready6 = true;
2362                 amt->mac = amtmq->response_mac;
2363                 amt->req_cnt = 0;
2364                 amt->qi = mld2q->mld2q_qqic;
2365                 spin_unlock_bh(&amt->lock);
2366                 skb->protocol = htons(ETH_P_IPV6);
2367                 eth->h_proto = htons(ETH_P_IPV6);
2368                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2369 #endif
2370         } else {
2371                 return true;
2372         }
2373
2374         ether_addr_copy(eth->h_source, oeth->h_source);
2375         skb->pkt_type = PACKET_MULTICAST;
2376         skb->ip_summed = CHECKSUM_NONE;
2377         len = skb->len;
2378         if (__netif_rx(skb) == NET_RX_SUCCESS) {
2379                 amt_update_gw_status(amt, AMT_STATUS_RECEIVED_QUERY, true);
2380                 dev_sw_netstats_rx_add(amt->dev, len);
2381         } else {
2382                 amt->dev->stats.rx_dropped++;
2383         }
2384
2385         return false;
2386 }
2387
2388 static bool amt_update_handler(struct amt_dev *amt, struct sk_buff *skb)
2389 {
2390         struct amt_header_membership_update *amtmu;
2391         struct amt_tunnel_list *tunnel;
2392         struct udphdr *udph;
2393         struct ethhdr *eth;
2394         struct iphdr *iph;
2395         int len;
2396
2397         iph = ip_hdr(skb);
2398         udph = udp_hdr(skb);
2399
2400         if (__iptunnel_pull_header(skb, sizeof(*udph), skb->protocol,
2401                                    false, false))
2402                 return true;
2403
2404         amtmu = (struct amt_header_membership_update *)skb->data;
2405         if (amtmu->reserved || amtmu->version)
2406                 return true;
2407
2408         skb_pull(skb, sizeof(*amtmu));
2409         skb_reset_network_header(skb);
2410
2411         list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list) {
2412                 if (tunnel->ip4 == iph->saddr) {
2413                         if ((amtmu->nonce == tunnel->nonce &&
2414                              amtmu->response_mac == tunnel->mac)) {
2415                                 mod_delayed_work(amt_wq, &tunnel->gc_wq,
2416                                                  msecs_to_jiffies(amt_gmi(amt))
2417                                                                   * 3);
2418                                 goto report;
2419                         } else {
2420                                 netdev_dbg(amt->dev, "Invalid MAC\n");
2421                                 return true;
2422                         }
2423                 }
2424         }
2425
2426         return true;
2427
2428 report:
2429         iph = ip_hdr(skb);
2430         if (iph->version == 4) {
2431                 if (ip_mc_check_igmp(skb)) {
2432                         netdev_dbg(amt->dev, "Invalid IGMP\n");
2433                         return true;
2434                 }
2435
2436                 spin_lock_bh(&tunnel->lock);
2437                 amt_igmp_report_handler(amt, skb, tunnel);
2438                 spin_unlock_bh(&tunnel->lock);
2439
2440                 skb_push(skb, sizeof(struct ethhdr));
2441                 skb_reset_mac_header(skb);
2442                 eth = eth_hdr(skb);
2443                 skb->protocol = htons(ETH_P_IP);
2444                 eth->h_proto = htons(ETH_P_IP);
2445                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2446 #if IS_ENABLED(CONFIG_IPV6)
2447         } else if (iph->version == 6) {
2448                 struct ipv6hdr *ip6h = ipv6_hdr(skb);
2449
2450                 if (ipv6_mc_check_mld(skb)) {
2451                         netdev_dbg(amt->dev, "Invalid MLD\n");
2452                         return true;
2453                 }
2454
2455                 spin_lock_bh(&tunnel->lock);
2456                 amt_mld_report_handler(amt, skb, tunnel);
2457                 spin_unlock_bh(&tunnel->lock);
2458
2459                 skb_push(skb, sizeof(struct ethhdr));
2460                 skb_reset_mac_header(skb);
2461                 eth = eth_hdr(skb);
2462                 skb->protocol = htons(ETH_P_IPV6);
2463                 eth->h_proto = htons(ETH_P_IPV6);
2464                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2465 #endif
2466         } else {
2467                 netdev_dbg(amt->dev, "Unsupported Protocol\n");
2468                 return true;
2469         }
2470
2471         skb_pull(skb, sizeof(struct ethhdr));
2472         skb->pkt_type = PACKET_MULTICAST;
2473         skb->ip_summed = CHECKSUM_NONE;
2474         len = skb->len;
2475         if (__netif_rx(skb) == NET_RX_SUCCESS) {
2476                 amt_update_relay_status(tunnel, AMT_STATUS_RECEIVED_UPDATE,
2477                                         true);
2478                 dev_sw_netstats_rx_add(amt->dev, len);
2479         } else {
2480                 amt->dev->stats.rx_dropped++;
2481         }
2482
2483         return false;
2484 }
2485
2486 static void amt_send_advertisement(struct amt_dev *amt, __be32 nonce,
2487                                    __be32 daddr, __be16 dport)
2488 {
2489         struct amt_header_advertisement *amta;
2490         int hlen, tlen, offset;
2491         struct socket *sock;
2492         struct udphdr *udph;
2493         struct sk_buff *skb;
2494         struct iphdr *iph;
2495         struct rtable *rt;
2496         struct flowi4 fl4;
2497         u32 len;
2498         int err;
2499
2500         rcu_read_lock();
2501         sock = rcu_dereference(amt->sock);
2502         if (!sock)
2503                 goto out;
2504
2505         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
2506                 goto out;
2507
2508         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
2509                                    daddr, amt->local_ip,
2510                                    dport, amt->relay_port,
2511                                    IPPROTO_UDP, 0,
2512                                    amt->stream_dev->ifindex);
2513         if (IS_ERR(rt)) {
2514                 amt->dev->stats.tx_errors++;
2515                 goto out;
2516         }
2517
2518         hlen = LL_RESERVED_SPACE(amt->dev);
2519         tlen = amt->dev->needed_tailroom;
2520         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amta);
2521         skb = netdev_alloc_skb_ip_align(amt->dev, len);
2522         if (!skb) {
2523                 ip_rt_put(rt);
2524                 amt->dev->stats.tx_errors++;
2525                 goto out;
2526         }
2527
2528         skb->priority = TC_PRIO_CONTROL;
2529         skb_dst_set(skb, &rt->dst);
2530
2531         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amta);
2532         skb_reset_network_header(skb);
2533         skb_put(skb, len);
2534         amta = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
2535         amta->version   = 0;
2536         amta->type      = AMT_MSG_ADVERTISEMENT;
2537         amta->reserved  = 0;
2538         amta->nonce     = nonce;
2539         amta->ip4       = amt->local_ip;
2540         skb_push(skb, sizeof(*udph));
2541         skb_reset_transport_header(skb);
2542         udph            = udp_hdr(skb);
2543         udph->source    = amt->relay_port;
2544         udph->dest      = dport;
2545         udph->len       = htons(sizeof(*amta) + sizeof(*udph));
2546         udph->check     = 0;
2547         offset = skb_transport_offset(skb);
2548         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
2549         udph->check = csum_tcpudp_magic(amt->local_ip, daddr,
2550                                         sizeof(*udph) + sizeof(*amta),
2551                                         IPPROTO_UDP, skb->csum);
2552
2553         skb_push(skb, sizeof(*iph));
2554         iph             = ip_hdr(skb);
2555         iph->version    = 4;
2556         iph->ihl        = (sizeof(struct iphdr)) >> 2;
2557         iph->tos        = AMT_TOS;
2558         iph->frag_off   = 0;
2559         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
2560         iph->daddr      = daddr;
2561         iph->saddr      = amt->local_ip;
2562         iph->protocol   = IPPROTO_UDP;
2563         iph->tot_len    = htons(len);
2564
2565         skb->ip_summed = CHECKSUM_NONE;
2566         ip_select_ident(amt->net, skb, NULL);
2567         ip_send_check(iph);
2568         err = ip_local_out(amt->net, sock->sk, skb);
2569         if (unlikely(net_xmit_eval(err)))
2570                 amt->dev->stats.tx_errors++;
2571
2572 out:
2573         rcu_read_unlock();
2574 }
2575
2576 static bool amt_discovery_handler(struct amt_dev *amt, struct sk_buff *skb)
2577 {
2578         struct amt_header_discovery *amtd;
2579         struct udphdr *udph;
2580         struct iphdr *iph;
2581
2582         if (!pskb_may_pull(skb, sizeof(*udph) + sizeof(*amtd)))
2583                 return true;
2584
2585         iph = ip_hdr(skb);
2586         udph = udp_hdr(skb);
2587         amtd = (struct amt_header_discovery *)(udp_hdr(skb) + 1);
2588
2589         if (amtd->reserved || amtd->version)
2590                 return true;
2591
2592         amt_send_advertisement(amt, amtd->nonce, iph->saddr, udph->source);
2593
2594         return false;
2595 }
2596
2597 static bool amt_request_handler(struct amt_dev *amt, struct sk_buff *skb)
2598 {
2599         struct amt_header_request *amtrh;
2600         struct amt_tunnel_list *tunnel;
2601         unsigned long long key;
2602         struct udphdr *udph;
2603         struct iphdr *iph;
2604         u64 mac;
2605         int i;
2606
2607         if (!pskb_may_pull(skb, sizeof(*udph) + sizeof(*amtrh)))
2608                 return true;
2609
2610         iph = ip_hdr(skb);
2611         udph = udp_hdr(skb);
2612         amtrh = (struct amt_header_request *)(udp_hdr(skb) + 1);
2613
2614         if (amtrh->reserved1 || amtrh->reserved2 || amtrh->version)
2615                 return true;
2616
2617         list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list)
2618                 if (tunnel->ip4 == iph->saddr)
2619                         goto send;
2620
2621         if (amt->nr_tunnels >= amt->max_tunnels) {
2622                 icmp_ndo_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_UNREACH, 0);
2623                 return true;
2624         }
2625
2626         tunnel = kzalloc(sizeof(*tunnel) +
2627                          (sizeof(struct hlist_head) * amt->hash_buckets),
2628                          GFP_ATOMIC);
2629         if (!tunnel)
2630                 return true;
2631
2632         tunnel->source_port = udph->source;
2633         tunnel->ip4 = iph->saddr;
2634
2635         memcpy(&key, &tunnel->key, sizeof(unsigned long long));
2636         tunnel->amt = amt;
2637         spin_lock_init(&tunnel->lock);
2638         for (i = 0; i < amt->hash_buckets; i++)
2639                 INIT_HLIST_HEAD(&tunnel->groups[i]);
2640
2641         INIT_DELAYED_WORK(&tunnel->gc_wq, amt_tunnel_expire);
2642
2643         spin_lock_bh(&amt->lock);
2644         list_add_tail_rcu(&tunnel->list, &amt->tunnel_list);
2645         tunnel->key = amt->key;
2646         amt_update_relay_status(tunnel, AMT_STATUS_RECEIVED_REQUEST, true);
2647         amt->nr_tunnels++;
2648         mod_delayed_work(amt_wq, &tunnel->gc_wq,
2649                          msecs_to_jiffies(amt_gmi(amt)));
2650         spin_unlock_bh(&amt->lock);
2651
2652 send:
2653         tunnel->nonce = amtrh->nonce;
2654         mac = siphash_3u32((__force u32)tunnel->ip4,
2655                            (__force u32)tunnel->source_port,
2656                            (__force u32)tunnel->nonce,
2657                            &tunnel->key);
2658         tunnel->mac = mac >> 16;
2659
2660         if (!netif_running(amt->dev) || !netif_running(amt->stream_dev))
2661                 return true;
2662
2663         if (!amtrh->p)
2664                 amt_send_igmp_gq(amt, tunnel);
2665         else
2666                 amt_send_mld_gq(amt, tunnel);
2667
2668         return false;
2669 }
2670
2671 static int amt_rcv(struct sock *sk, struct sk_buff *skb)
2672 {
2673         struct amt_dev *amt;
2674         struct iphdr *iph;
2675         int type;
2676         bool err;
2677
2678         rcu_read_lock_bh();
2679         amt = rcu_dereference_sk_user_data(sk);
2680         if (!amt) {
2681                 err = true;
2682                 goto drop;
2683         }
2684
2685         skb->dev = amt->dev;
2686         iph = ip_hdr(skb);
2687         type = amt_parse_type(skb);
2688         if (type == -1) {
2689                 err = true;
2690                 goto drop;
2691         }
2692
2693         if (amt->mode == AMT_MODE_GATEWAY) {
2694                 switch (type) {
2695                 case AMT_MSG_ADVERTISEMENT:
2696                         if (iph->saddr != amt->discovery_ip) {
2697                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2698                                 err = true;
2699                                 goto drop;
2700                         }
2701                         err = amt_advertisement_handler(amt, skb);
2702                         break;
2703                 case AMT_MSG_MULTICAST_DATA:
2704                         if (iph->saddr != amt->remote_ip) {
2705                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2706                                 err = true;
2707                                 goto drop;
2708                         }
2709                         err = amt_multicast_data_handler(amt, skb);
2710                         if (err)
2711                                 goto drop;
2712                         else
2713                                 goto out;
2714                 case AMT_MSG_MEMBERSHIP_QUERY:
2715                         if (iph->saddr != amt->remote_ip) {
2716                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2717                                 err = true;
2718                                 goto drop;
2719                         }
2720                         err = amt_membership_query_handler(amt, skb);
2721                         if (err)
2722                                 goto drop;
2723                         else
2724                                 goto out;
2725                 default:
2726                         err = true;
2727                         netdev_dbg(amt->dev, "Invalid type of Gateway\n");
2728                         break;
2729                 }
2730         } else {
2731                 switch (type) {
2732                 case AMT_MSG_DISCOVERY:
2733                         err = amt_discovery_handler(amt, skb);
2734                         break;
2735                 case AMT_MSG_REQUEST:
2736                         err = amt_request_handler(amt, skb);
2737                         break;
2738                 case AMT_MSG_MEMBERSHIP_UPDATE:
2739                         err = amt_update_handler(amt, skb);
2740                         if (err)
2741                                 goto drop;
2742                         else
2743                                 goto out;
2744                 default:
2745                         err = true;
2746                         netdev_dbg(amt->dev, "Invalid type of relay\n");
2747                         break;
2748                 }
2749         }
2750 drop:
2751         if (err) {
2752                 amt->dev->stats.rx_dropped++;
2753                 kfree_skb(skb);
2754         } else {
2755                 consume_skb(skb);
2756         }
2757 out:
2758         rcu_read_unlock_bh();
2759         return 0;
2760 }
2761
2762 static int amt_err_lookup(struct sock *sk, struct sk_buff *skb)
2763 {
2764         struct amt_dev *amt;
2765         int type;
2766
2767         rcu_read_lock_bh();
2768         amt = rcu_dereference_sk_user_data(sk);
2769         if (!amt)
2770                 goto out;
2771
2772         if (amt->mode != AMT_MODE_GATEWAY)
2773                 goto drop;
2774
2775         type = amt_parse_type(skb);
2776         if (type == -1)
2777                 goto drop;
2778
2779         netdev_dbg(amt->dev, "Received IGMP Unreachable of %s\n",
2780                    type_str[type]);
2781         switch (type) {
2782         case AMT_MSG_DISCOVERY:
2783                 break;
2784         case AMT_MSG_REQUEST:
2785         case AMT_MSG_MEMBERSHIP_UPDATE:
2786                 if (amt->status >= AMT_STATUS_RECEIVED_ADVERTISEMENT)
2787                         mod_delayed_work(amt_wq, &amt->req_wq, 0);
2788                 break;
2789         default:
2790                 goto drop;
2791         }
2792 out:
2793         rcu_read_unlock_bh();
2794         return 0;
2795 drop:
2796         rcu_read_unlock_bh();
2797         amt->dev->stats.rx_dropped++;
2798         return 0;
2799 }
2800
2801 static struct socket *amt_create_sock(struct net *net, __be16 port)
2802 {
2803         struct udp_port_cfg udp_conf;
2804         struct socket *sock;
2805         int err;
2806
2807         memset(&udp_conf, 0, sizeof(udp_conf));
2808         udp_conf.family = AF_INET;
2809         udp_conf.local_ip.s_addr = htonl(INADDR_ANY);
2810
2811         udp_conf.local_udp_port = port;
2812
2813         err = udp_sock_create(net, &udp_conf, &sock);
2814         if (err < 0)
2815                 return ERR_PTR(err);
2816
2817         return sock;
2818 }
2819
2820 static int amt_socket_create(struct amt_dev *amt)
2821 {
2822         struct udp_tunnel_sock_cfg tunnel_cfg;
2823         struct socket *sock;
2824
2825         sock = amt_create_sock(amt->net, amt->relay_port);
2826         if (IS_ERR(sock))
2827                 return PTR_ERR(sock);
2828
2829         /* Mark socket as an encapsulation socket */
2830         memset(&tunnel_cfg, 0, sizeof(tunnel_cfg));
2831         tunnel_cfg.sk_user_data = amt;
2832         tunnel_cfg.encap_type = 1;
2833         tunnel_cfg.encap_rcv = amt_rcv;
2834         tunnel_cfg.encap_err_lookup = amt_err_lookup;
2835         tunnel_cfg.encap_destroy = NULL;
2836         setup_udp_tunnel_sock(amt->net, sock, &tunnel_cfg);
2837
2838         rcu_assign_pointer(amt->sock, sock);
2839         return 0;
2840 }
2841
2842 static int amt_dev_open(struct net_device *dev)
2843 {
2844         struct amt_dev *amt = netdev_priv(dev);
2845         int err;
2846
2847         amt->ready4 = false;
2848         amt->ready6 = false;
2849
2850         err = amt_socket_create(amt);
2851         if (err)
2852                 return err;
2853
2854         amt->req_cnt = 0;
2855         amt->remote_ip = 0;
2856         get_random_bytes(&amt->key, sizeof(siphash_key_t));
2857
2858         amt->status = AMT_STATUS_INIT;
2859         if (amt->mode == AMT_MODE_GATEWAY) {
2860                 mod_delayed_work(amt_wq, &amt->discovery_wq, 0);
2861                 mod_delayed_work(amt_wq, &amt->req_wq, 0);
2862         } else if (amt->mode == AMT_MODE_RELAY) {
2863                 mod_delayed_work(amt_wq, &amt->secret_wq,
2864                                  msecs_to_jiffies(AMT_SECRET_TIMEOUT));
2865         }
2866         return err;
2867 }
2868
2869 static int amt_dev_stop(struct net_device *dev)
2870 {
2871         struct amt_dev *amt = netdev_priv(dev);
2872         struct amt_tunnel_list *tunnel, *tmp;
2873         struct socket *sock;
2874
2875         cancel_delayed_work_sync(&amt->req_wq);
2876         cancel_delayed_work_sync(&amt->discovery_wq);
2877         cancel_delayed_work_sync(&amt->secret_wq);
2878
2879         /* shutdown */
2880         sock = rtnl_dereference(amt->sock);
2881         RCU_INIT_POINTER(amt->sock, NULL);
2882         synchronize_net();
2883         if (sock)
2884                 udp_tunnel_sock_release(sock);
2885
2886         amt->ready4 = false;
2887         amt->ready6 = false;
2888         amt->req_cnt = 0;
2889         amt->remote_ip = 0;
2890
2891         list_for_each_entry_safe(tunnel, tmp, &amt->tunnel_list, list) {
2892                 list_del_rcu(&tunnel->list);
2893                 amt->nr_tunnels--;
2894                 cancel_delayed_work_sync(&tunnel->gc_wq);
2895                 amt_clear_groups(tunnel);
2896                 kfree_rcu(tunnel, rcu);
2897         }
2898
2899         return 0;
2900 }
2901
2902 static const struct device_type amt_type = {
2903         .name = "amt",
2904 };
2905
2906 static int amt_dev_init(struct net_device *dev)
2907 {
2908         struct amt_dev *amt = netdev_priv(dev);
2909         int err;
2910
2911         amt->dev = dev;
2912         dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats);
2913         if (!dev->tstats)
2914                 return -ENOMEM;
2915
2916         err = gro_cells_init(&amt->gro_cells, dev);
2917         if (err) {
2918                 free_percpu(dev->tstats);
2919                 return err;
2920         }
2921
2922         return 0;
2923 }
2924
2925 static void amt_dev_uninit(struct net_device *dev)
2926 {
2927         struct amt_dev *amt = netdev_priv(dev);
2928
2929         gro_cells_destroy(&amt->gro_cells);
2930         free_percpu(dev->tstats);
2931 }
2932
2933 static const struct net_device_ops amt_netdev_ops = {
2934         .ndo_init               = amt_dev_init,
2935         .ndo_uninit             = amt_dev_uninit,
2936         .ndo_open               = amt_dev_open,
2937         .ndo_stop               = amt_dev_stop,
2938         .ndo_start_xmit         = amt_dev_xmit,
2939         .ndo_get_stats64        = dev_get_tstats64,
2940 };
2941
2942 static void amt_link_setup(struct net_device *dev)
2943 {
2944         dev->netdev_ops         = &amt_netdev_ops;
2945         dev->needs_free_netdev  = true;
2946         SET_NETDEV_DEVTYPE(dev, &amt_type);
2947         dev->min_mtu            = ETH_MIN_MTU;
2948         dev->max_mtu            = ETH_MAX_MTU;
2949         dev->type               = ARPHRD_NONE;
2950         dev->flags              = IFF_POINTOPOINT | IFF_NOARP | IFF_MULTICAST;
2951         dev->hard_header_len    = 0;
2952         dev->addr_len           = 0;
2953         dev->priv_flags         |= IFF_NO_QUEUE;
2954         dev->features           |= NETIF_F_LLTX;
2955         dev->features           |= NETIF_F_GSO_SOFTWARE;
2956         dev->features           |= NETIF_F_NETNS_LOCAL;
2957         dev->hw_features        |= NETIF_F_SG | NETIF_F_HW_CSUM;
2958         dev->hw_features        |= NETIF_F_FRAGLIST | NETIF_F_RXCSUM;
2959         dev->hw_features        |= NETIF_F_GSO_SOFTWARE;
2960         eth_hw_addr_random(dev);
2961         eth_zero_addr(dev->broadcast);
2962         ether_setup(dev);
2963 }
2964
2965 static const struct nla_policy amt_policy[IFLA_AMT_MAX + 1] = {
2966         [IFLA_AMT_MODE]         = { .type = NLA_U32 },
2967         [IFLA_AMT_RELAY_PORT]   = { .type = NLA_U16 },
2968         [IFLA_AMT_GATEWAY_PORT] = { .type = NLA_U16 },
2969         [IFLA_AMT_LINK]         = { .type = NLA_U32 },
2970         [IFLA_AMT_LOCAL_IP]     = { .len = sizeof_field(struct iphdr, daddr) },
2971         [IFLA_AMT_REMOTE_IP]    = { .len = sizeof_field(struct iphdr, daddr) },
2972         [IFLA_AMT_DISCOVERY_IP] = { .len = sizeof_field(struct iphdr, daddr) },
2973         [IFLA_AMT_MAX_TUNNELS]  = { .type = NLA_U32 },
2974 };
2975
2976 static int amt_validate(struct nlattr *tb[], struct nlattr *data[],
2977                         struct netlink_ext_ack *extack)
2978 {
2979         if (!data)
2980                 return -EINVAL;
2981
2982         if (!data[IFLA_AMT_LINK]) {
2983                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_LINK],
2984                                     "Link attribute is required");
2985                 return -EINVAL;
2986         }
2987
2988         if (!data[IFLA_AMT_MODE]) {
2989                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_MODE],
2990                                     "Mode attribute is required");
2991                 return -EINVAL;
2992         }
2993
2994         if (nla_get_u32(data[IFLA_AMT_MODE]) > AMT_MODE_MAX) {
2995                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_MODE],
2996                                     "Mode attribute is not valid");
2997                 return -EINVAL;
2998         }
2999
3000         if (!data[IFLA_AMT_LOCAL_IP]) {
3001                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_DISCOVERY_IP],
3002                                     "Local attribute is required");
3003                 return -EINVAL;
3004         }
3005
3006         if (!data[IFLA_AMT_DISCOVERY_IP] &&
3007             nla_get_u32(data[IFLA_AMT_MODE]) == AMT_MODE_GATEWAY) {
3008                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_LOCAL_IP],
3009                                     "Discovery attribute is required");
3010                 return -EINVAL;
3011         }
3012
3013         return 0;
3014 }
3015
3016 static int amt_newlink(struct net *net, struct net_device *dev,
3017                        struct nlattr *tb[], struct nlattr *data[],
3018                        struct netlink_ext_ack *extack)
3019 {
3020         struct amt_dev *amt = netdev_priv(dev);
3021         int err = -EINVAL;
3022
3023         amt->net = net;
3024         amt->mode = nla_get_u32(data[IFLA_AMT_MODE]);
3025
3026         if (data[IFLA_AMT_MAX_TUNNELS] &&
3027             nla_get_u32(data[IFLA_AMT_MAX_TUNNELS]))
3028                 amt->max_tunnels = nla_get_u32(data[IFLA_AMT_MAX_TUNNELS]);
3029         else
3030                 amt->max_tunnels = AMT_MAX_TUNNELS;
3031
3032         spin_lock_init(&amt->lock);
3033         amt->max_groups = AMT_MAX_GROUP;
3034         amt->max_sources = AMT_MAX_SOURCE;
3035         amt->hash_buckets = AMT_HSIZE;
3036         amt->nr_tunnels = 0;
3037         get_random_bytes(&amt->hash_seed, sizeof(amt->hash_seed));
3038         amt->stream_dev = dev_get_by_index(net,
3039                                            nla_get_u32(data[IFLA_AMT_LINK]));
3040         if (!amt->stream_dev) {
3041                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LINK],
3042                                     "Can't find stream device");
3043                 return -ENODEV;
3044         }
3045
3046         if (amt->stream_dev->type != ARPHRD_ETHER) {
3047                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LINK],
3048                                     "Invalid stream device type");
3049                 goto err;
3050         }
3051
3052         amt->local_ip = nla_get_in_addr(data[IFLA_AMT_LOCAL_IP]);
3053         if (ipv4_is_loopback(amt->local_ip) ||
3054             ipv4_is_zeronet(amt->local_ip) ||
3055             ipv4_is_multicast(amt->local_ip)) {
3056                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LOCAL_IP],
3057                                     "Invalid Local address");
3058                 goto err;
3059         }
3060
3061         if (data[IFLA_AMT_RELAY_PORT])
3062                 amt->relay_port = nla_get_be16(data[IFLA_AMT_RELAY_PORT]);
3063         else
3064                 amt->relay_port = htons(IANA_AMT_UDP_PORT);
3065
3066         if (data[IFLA_AMT_GATEWAY_PORT])
3067                 amt->gw_port = nla_get_be16(data[IFLA_AMT_GATEWAY_PORT]);
3068         else
3069                 amt->gw_port = htons(IANA_AMT_UDP_PORT);
3070
3071         if (!amt->relay_port) {
3072                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3073                                     "relay port must not be 0");
3074                 goto err;
3075         }
3076         if (amt->mode == AMT_MODE_RELAY) {
3077                 amt->qrv = amt->net->ipv4.sysctl_igmp_qrv;
3078                 amt->qri = 10;
3079                 dev->needed_headroom = amt->stream_dev->needed_headroom +
3080                                        AMT_RELAY_HLEN;
3081                 dev->mtu = amt->stream_dev->mtu - AMT_RELAY_HLEN;
3082                 dev->max_mtu = dev->mtu;
3083                 dev->min_mtu = ETH_MIN_MTU + AMT_RELAY_HLEN;
3084         } else {
3085                 if (!data[IFLA_AMT_DISCOVERY_IP]) {
3086                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3087                                             "discovery must be set in gateway mode");
3088                         goto err;
3089                 }
3090                 if (!amt->gw_port) {
3091                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3092                                             "gateway port must not be 0");
3093                         goto err;
3094                 }
3095                 amt->remote_ip = 0;
3096                 amt->discovery_ip = nla_get_in_addr(data[IFLA_AMT_DISCOVERY_IP]);
3097                 if (ipv4_is_loopback(amt->discovery_ip) ||
3098                     ipv4_is_zeronet(amt->discovery_ip) ||
3099                     ipv4_is_multicast(amt->discovery_ip)) {
3100                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3101                                             "discovery must be unicast");
3102                         goto err;
3103                 }
3104
3105                 dev->needed_headroom = amt->stream_dev->needed_headroom +
3106                                        AMT_GW_HLEN;
3107                 dev->mtu = amt->stream_dev->mtu - AMT_GW_HLEN;
3108                 dev->max_mtu = dev->mtu;
3109                 dev->min_mtu = ETH_MIN_MTU + AMT_GW_HLEN;
3110         }
3111         amt->qi = AMT_INIT_QUERY_INTERVAL;
3112
3113         err = register_netdevice(dev);
3114         if (err < 0) {
3115                 netdev_dbg(dev, "failed to register new netdev %d\n", err);
3116                 goto err;
3117         }
3118
3119         err = netdev_upper_dev_link(amt->stream_dev, dev, extack);
3120         if (err < 0) {
3121                 unregister_netdevice(dev);
3122                 goto err;
3123         }
3124
3125         INIT_DELAYED_WORK(&amt->discovery_wq, amt_discovery_work);
3126         INIT_DELAYED_WORK(&amt->req_wq, amt_req_work);
3127         INIT_DELAYED_WORK(&amt->secret_wq, amt_secret_work);
3128         INIT_LIST_HEAD(&amt->tunnel_list);
3129
3130         return 0;
3131 err:
3132         dev_put(amt->stream_dev);
3133         return err;
3134 }
3135
3136 static void amt_dellink(struct net_device *dev, struct list_head *head)
3137 {
3138         struct amt_dev *amt = netdev_priv(dev);
3139
3140         unregister_netdevice_queue(dev, head);
3141         netdev_upper_dev_unlink(amt->stream_dev, dev);
3142         dev_put(amt->stream_dev);
3143 }
3144
3145 static size_t amt_get_size(const struct net_device *dev)
3146 {
3147         return nla_total_size(sizeof(__u32)) + /* IFLA_AMT_MODE */
3148                nla_total_size(sizeof(__u16)) + /* IFLA_AMT_RELAY_PORT */
3149                nla_total_size(sizeof(__u16)) + /* IFLA_AMT_GATEWAY_PORT */
3150                nla_total_size(sizeof(__u32)) + /* IFLA_AMT_LINK */
3151                nla_total_size(sizeof(__u32)) + /* IFLA_MAX_TUNNELS */
3152                nla_total_size(sizeof(struct iphdr)) + /* IFLA_AMT_DISCOVERY_IP */
3153                nla_total_size(sizeof(struct iphdr)) + /* IFLA_AMT_REMOTE_IP */
3154                nla_total_size(sizeof(struct iphdr)); /* IFLA_AMT_LOCAL_IP */
3155 }
3156
3157 static int amt_fill_info(struct sk_buff *skb, const struct net_device *dev)
3158 {
3159         struct amt_dev *amt = netdev_priv(dev);
3160
3161         if (nla_put_u32(skb, IFLA_AMT_MODE, amt->mode))
3162                 goto nla_put_failure;
3163         if (nla_put_be16(skb, IFLA_AMT_RELAY_PORT, amt->relay_port))
3164                 goto nla_put_failure;
3165         if (nla_put_be16(skb, IFLA_AMT_GATEWAY_PORT, amt->gw_port))
3166                 goto nla_put_failure;
3167         if (nla_put_u32(skb, IFLA_AMT_LINK, amt->stream_dev->ifindex))
3168                 goto nla_put_failure;
3169         if (nla_put_in_addr(skb, IFLA_AMT_LOCAL_IP, amt->local_ip))
3170                 goto nla_put_failure;
3171         if (nla_put_in_addr(skb, IFLA_AMT_DISCOVERY_IP, amt->discovery_ip))
3172                 goto nla_put_failure;
3173         if (amt->remote_ip)
3174                 if (nla_put_in_addr(skb, IFLA_AMT_REMOTE_IP, amt->remote_ip))
3175                         goto nla_put_failure;
3176         if (nla_put_u32(skb, IFLA_AMT_MAX_TUNNELS, amt->max_tunnels))
3177                 goto nla_put_failure;
3178
3179         return 0;
3180
3181 nla_put_failure:
3182         return -EMSGSIZE;
3183 }
3184
3185 static struct rtnl_link_ops amt_link_ops __read_mostly = {
3186         .kind           = "amt",
3187         .maxtype        = IFLA_AMT_MAX,
3188         .policy         = amt_policy,
3189         .priv_size      = sizeof(struct amt_dev),
3190         .setup          = amt_link_setup,
3191         .validate       = amt_validate,
3192         .newlink        = amt_newlink,
3193         .dellink        = amt_dellink,
3194         .get_size       = amt_get_size,
3195         .fill_info      = amt_fill_info,
3196 };
3197
3198 static struct net_device *amt_lookup_upper_dev(struct net_device *dev)
3199 {
3200         struct net_device *upper_dev;
3201         struct amt_dev *amt;
3202
3203         for_each_netdev(dev_net(dev), upper_dev) {
3204                 if (netif_is_amt(upper_dev)) {
3205                         amt = netdev_priv(upper_dev);
3206                         if (amt->stream_dev == dev)
3207                                 return upper_dev;
3208                 }
3209         }
3210
3211         return NULL;
3212 }
3213
3214 static int amt_device_event(struct notifier_block *unused,
3215                             unsigned long event, void *ptr)
3216 {
3217         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
3218         struct net_device *upper_dev;
3219         struct amt_dev *amt;
3220         LIST_HEAD(list);
3221         int new_mtu;
3222
3223         upper_dev = amt_lookup_upper_dev(dev);
3224         if (!upper_dev)
3225                 return NOTIFY_DONE;
3226         amt = netdev_priv(upper_dev);
3227
3228         switch (event) {
3229         case NETDEV_UNREGISTER:
3230                 amt_dellink(amt->dev, &list);
3231                 unregister_netdevice_many(&list);
3232                 break;
3233         case NETDEV_CHANGEMTU:
3234                 if (amt->mode == AMT_MODE_RELAY)
3235                         new_mtu = dev->mtu - AMT_RELAY_HLEN;
3236                 else
3237                         new_mtu = dev->mtu - AMT_GW_HLEN;
3238
3239                 dev_set_mtu(amt->dev, new_mtu);
3240                 break;
3241         }
3242
3243         return NOTIFY_DONE;
3244 }
3245
3246 static struct notifier_block amt_notifier_block __read_mostly = {
3247         .notifier_call = amt_device_event,
3248 };
3249
3250 static int __init amt_init(void)
3251 {
3252         int err;
3253
3254         err = register_netdevice_notifier(&amt_notifier_block);
3255         if (err < 0)
3256                 goto err;
3257
3258         err = rtnl_link_register(&amt_link_ops);
3259         if (err < 0)
3260                 goto unregister_notifier;
3261
3262         amt_wq = alloc_workqueue("amt", WQ_UNBOUND, 1);
3263         if (!amt_wq) {
3264                 err = -ENOMEM;
3265                 goto rtnl_unregister;
3266         }
3267
3268         spin_lock_init(&source_gc_lock);
3269         spin_lock_bh(&source_gc_lock);
3270         INIT_DELAYED_WORK(&source_gc_wq, amt_source_gc_work);
3271         mod_delayed_work(amt_wq, &source_gc_wq,
3272                          msecs_to_jiffies(AMT_GC_INTERVAL));
3273         spin_unlock_bh(&source_gc_lock);
3274
3275         return 0;
3276
3277 rtnl_unregister:
3278         rtnl_link_unregister(&amt_link_ops);
3279 unregister_notifier:
3280         unregister_netdevice_notifier(&amt_notifier_block);
3281 err:
3282         pr_err("error loading AMT module loaded\n");
3283         return err;
3284 }
3285 late_initcall(amt_init);
3286
3287 static void __exit amt_fini(void)
3288 {
3289         rtnl_link_unregister(&amt_link_ops);
3290         unregister_netdevice_notifier(&amt_notifier_block);
3291         cancel_delayed_work_sync(&source_gc_wq);
3292         __amt_source_gc_work();
3293         destroy_workqueue(amt_wq);
3294 }
3295 module_exit(amt_fini);
3296
3297 MODULE_LICENSE("GPL");
3298 MODULE_AUTHOR("Taehee Yoo <ap420073@gmail.com>");
3299 MODULE_ALIAS_RTNL_LINK("amt");