amt: drop unexpected advertisement message
[linux-2.6-microblaze.git] / drivers / net / amt.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* Copyright (c) 2021 Taehee Yoo <ap420073@gmail.com> */
3
4 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
5
6 #include <linux/module.h>
7 #include <linux/skbuff.h>
8 #include <linux/udp.h>
9 #include <linux/jhash.h>
10 #include <linux/if_tunnel.h>
11 #include <linux/net.h>
12 #include <linux/igmp.h>
13 #include <linux/workqueue.h>
14 #include <net/sch_generic.h>
15 #include <net/net_namespace.h>
16 #include <net/ip.h>
17 #include <net/udp.h>
18 #include <net/udp_tunnel.h>
19 #include <net/icmp.h>
20 #include <net/mld.h>
21 #include <net/amt.h>
22 #include <uapi/linux/amt.h>
23 #include <linux/security.h>
24 #include <net/gro_cells.h>
25 #include <net/ipv6.h>
26 #include <net/if_inet6.h>
27 #include <net/ndisc.h>
28 #include <net/addrconf.h>
29 #include <net/ip6_route.h>
30 #include <net/inet_common.h>
31 #include <net/ip6_checksum.h>
32
33 static struct workqueue_struct *amt_wq;
34
35 static HLIST_HEAD(source_gc_list);
36 /* Lock for source_gc_list */
37 static spinlock_t source_gc_lock;
38 static struct delayed_work source_gc_wq;
39 static char *status_str[] = {
40         "AMT_STATUS_INIT",
41         "AMT_STATUS_SENT_DISCOVERY",
42         "AMT_STATUS_RECEIVED_DISCOVERY",
43         "AMT_STATUS_SENT_ADVERTISEMENT",
44         "AMT_STATUS_RECEIVED_ADVERTISEMENT",
45         "AMT_STATUS_SENT_REQUEST",
46         "AMT_STATUS_RECEIVED_REQUEST",
47         "AMT_STATUS_SENT_QUERY",
48         "AMT_STATUS_RECEIVED_QUERY",
49         "AMT_STATUS_SENT_UPDATE",
50         "AMT_STATUS_RECEIVED_UPDATE",
51 };
52
53 static char *type_str[] = {
54         "", /* Type 0 is not defined */
55         "AMT_MSG_DISCOVERY",
56         "AMT_MSG_ADVERTISEMENT",
57         "AMT_MSG_REQUEST",
58         "AMT_MSG_MEMBERSHIP_QUERY",
59         "AMT_MSG_MEMBERSHIP_UPDATE",
60         "AMT_MSG_MULTICAST_DATA",
61         "AMT_MSG_TEARDOWN",
62 };
63
64 static char *action_str[] = {
65         "AMT_ACT_GMI",
66         "AMT_ACT_GMI_ZERO",
67         "AMT_ACT_GT",
68         "AMT_ACT_STATUS_FWD_NEW",
69         "AMT_ACT_STATUS_D_FWD_NEW",
70         "AMT_ACT_STATUS_NONE_NEW",
71 };
72
73 static struct igmpv3_grec igmpv3_zero_grec;
74
75 #if IS_ENABLED(CONFIG_IPV6)
76 #define MLD2_ALL_NODE_INIT { { { 0xff, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01 } } }
77 static struct in6_addr mld2_all_node = MLD2_ALL_NODE_INIT;
78 static struct mld2_grec mldv2_zero_grec;
79 #endif
80
81 static struct amt_skb_cb *amt_skb_cb(struct sk_buff *skb)
82 {
83         BUILD_BUG_ON(sizeof(struct amt_skb_cb) + sizeof(struct qdisc_skb_cb) >
84                      sizeof_field(struct sk_buff, cb));
85
86         return (struct amt_skb_cb *)((void *)skb->cb +
87                 sizeof(struct qdisc_skb_cb));
88 }
89
90 static void __amt_source_gc_work(void)
91 {
92         struct amt_source_node *snode;
93         struct hlist_head gc_list;
94         struct hlist_node *t;
95
96         spin_lock_bh(&source_gc_lock);
97         hlist_move_list(&source_gc_list, &gc_list);
98         spin_unlock_bh(&source_gc_lock);
99
100         hlist_for_each_entry_safe(snode, t, &gc_list, node) {
101                 hlist_del_rcu(&snode->node);
102                 kfree_rcu(snode, rcu);
103         }
104 }
105
106 static void amt_source_gc_work(struct work_struct *work)
107 {
108         __amt_source_gc_work();
109
110         spin_lock_bh(&source_gc_lock);
111         mod_delayed_work(amt_wq, &source_gc_wq,
112                          msecs_to_jiffies(AMT_GC_INTERVAL));
113         spin_unlock_bh(&source_gc_lock);
114 }
115
116 static bool amt_addr_equal(union amt_addr *a, union amt_addr *b)
117 {
118         return !memcmp(a, b, sizeof(union amt_addr));
119 }
120
121 static u32 amt_source_hash(struct amt_tunnel_list *tunnel, union amt_addr *src)
122 {
123         u32 hash = jhash(src, sizeof(*src), tunnel->amt->hash_seed);
124
125         return reciprocal_scale(hash, tunnel->amt->hash_buckets);
126 }
127
128 static bool amt_status_filter(struct amt_source_node *snode,
129                               enum amt_filter filter)
130 {
131         bool rc = false;
132
133         switch (filter) {
134         case AMT_FILTER_FWD:
135                 if (snode->status == AMT_SOURCE_STATUS_FWD &&
136                     snode->flags == AMT_SOURCE_OLD)
137                         rc = true;
138                 break;
139         case AMT_FILTER_D_FWD:
140                 if (snode->status == AMT_SOURCE_STATUS_D_FWD &&
141                     snode->flags == AMT_SOURCE_OLD)
142                         rc = true;
143                 break;
144         case AMT_FILTER_FWD_NEW:
145                 if (snode->status == AMT_SOURCE_STATUS_FWD &&
146                     snode->flags == AMT_SOURCE_NEW)
147                         rc = true;
148                 break;
149         case AMT_FILTER_D_FWD_NEW:
150                 if (snode->status == AMT_SOURCE_STATUS_D_FWD &&
151                     snode->flags == AMT_SOURCE_NEW)
152                         rc = true;
153                 break;
154         case AMT_FILTER_ALL:
155                 rc = true;
156                 break;
157         case AMT_FILTER_NONE_NEW:
158                 if (snode->status == AMT_SOURCE_STATUS_NONE &&
159                     snode->flags == AMT_SOURCE_NEW)
160                         rc = true;
161                 break;
162         case AMT_FILTER_BOTH:
163                 if ((snode->status == AMT_SOURCE_STATUS_D_FWD ||
164                      snode->status == AMT_SOURCE_STATUS_FWD) &&
165                     snode->flags == AMT_SOURCE_OLD)
166                         rc = true;
167                 break;
168         case AMT_FILTER_BOTH_NEW:
169                 if ((snode->status == AMT_SOURCE_STATUS_D_FWD ||
170                      snode->status == AMT_SOURCE_STATUS_FWD) &&
171                     snode->flags == AMT_SOURCE_NEW)
172                         rc = true;
173                 break;
174         default:
175                 WARN_ON_ONCE(1);
176                 break;
177         }
178
179         return rc;
180 }
181
182 static struct amt_source_node *amt_lookup_src(struct amt_tunnel_list *tunnel,
183                                               struct amt_group_node *gnode,
184                                               enum amt_filter filter,
185                                               union amt_addr *src)
186 {
187         u32 hash = amt_source_hash(tunnel, src);
188         struct amt_source_node *snode;
189
190         hlist_for_each_entry_rcu(snode, &gnode->sources[hash], node)
191                 if (amt_status_filter(snode, filter) &&
192                     amt_addr_equal(&snode->source_addr, src))
193                         return snode;
194
195         return NULL;
196 }
197
198 static u32 amt_group_hash(struct amt_tunnel_list *tunnel, union amt_addr *group)
199 {
200         u32 hash = jhash(group, sizeof(*group), tunnel->amt->hash_seed);
201
202         return reciprocal_scale(hash, tunnel->amt->hash_buckets);
203 }
204
205 static struct amt_group_node *amt_lookup_group(struct amt_tunnel_list *tunnel,
206                                                union amt_addr *group,
207                                                union amt_addr *host,
208                                                bool v6)
209 {
210         u32 hash = amt_group_hash(tunnel, group);
211         struct amt_group_node *gnode;
212
213         hlist_for_each_entry_rcu(gnode, &tunnel->groups[hash], node) {
214                 if (amt_addr_equal(&gnode->group_addr, group) &&
215                     amt_addr_equal(&gnode->host_addr, host) &&
216                     gnode->v6 == v6)
217                         return gnode;
218         }
219
220         return NULL;
221 }
222
223 static void amt_destroy_source(struct amt_source_node *snode)
224 {
225         struct amt_group_node *gnode = snode->gnode;
226         struct amt_tunnel_list *tunnel;
227
228         tunnel = gnode->tunnel_list;
229
230         if (!gnode->v6) {
231                 netdev_dbg(snode->gnode->amt->dev,
232                            "Delete source %pI4 from %pI4\n",
233                            &snode->source_addr.ip4,
234                            &gnode->group_addr.ip4);
235 #if IS_ENABLED(CONFIG_IPV6)
236         } else {
237                 netdev_dbg(snode->gnode->amt->dev,
238                            "Delete source %pI6 from %pI6\n",
239                            &snode->source_addr.ip6,
240                            &gnode->group_addr.ip6);
241 #endif
242         }
243
244         cancel_delayed_work(&snode->source_timer);
245         hlist_del_init_rcu(&snode->node);
246         tunnel->nr_sources--;
247         gnode->nr_sources--;
248         spin_lock_bh(&source_gc_lock);
249         hlist_add_head_rcu(&snode->node, &source_gc_list);
250         spin_unlock_bh(&source_gc_lock);
251 }
252
253 static void amt_del_group(struct amt_dev *amt, struct amt_group_node *gnode)
254 {
255         struct amt_source_node *snode;
256         struct hlist_node *t;
257         int i;
258
259         if (cancel_delayed_work(&gnode->group_timer))
260                 dev_put(amt->dev);
261         hlist_del_rcu(&gnode->node);
262         gnode->tunnel_list->nr_groups--;
263
264         if (!gnode->v6)
265                 netdev_dbg(amt->dev, "Leave group %pI4\n",
266                            &gnode->group_addr.ip4);
267 #if IS_ENABLED(CONFIG_IPV6)
268         else
269                 netdev_dbg(amt->dev, "Leave group %pI6\n",
270                            &gnode->group_addr.ip6);
271 #endif
272         for (i = 0; i < amt->hash_buckets; i++)
273                 hlist_for_each_entry_safe(snode, t, &gnode->sources[i], node)
274                         amt_destroy_source(snode);
275
276         /* tunnel->lock was acquired outside of amt_del_group()
277          * But rcu_read_lock() was acquired too so It's safe.
278          */
279         kfree_rcu(gnode, rcu);
280 }
281
282 /* If a source timer expires with a router filter-mode for the group of
283  * INCLUDE, the router concludes that traffic from this particular
284  * source is no longer desired on the attached network, and deletes the
285  * associated source record.
286  */
287 static void amt_source_work(struct work_struct *work)
288 {
289         struct amt_source_node *snode = container_of(to_delayed_work(work),
290                                                      struct amt_source_node,
291                                                      source_timer);
292         struct amt_group_node *gnode = snode->gnode;
293         struct amt_dev *amt = gnode->amt;
294         struct amt_tunnel_list *tunnel;
295
296         tunnel = gnode->tunnel_list;
297         spin_lock_bh(&tunnel->lock);
298         rcu_read_lock();
299         if (gnode->filter_mode == MCAST_INCLUDE) {
300                 amt_destroy_source(snode);
301                 if (!gnode->nr_sources)
302                         amt_del_group(amt, gnode);
303         } else {
304                 /* When a router filter-mode for a group is EXCLUDE,
305                  * source records are only deleted when the group timer expires
306                  */
307                 snode->status = AMT_SOURCE_STATUS_D_FWD;
308         }
309         rcu_read_unlock();
310         spin_unlock_bh(&tunnel->lock);
311 }
312
313 static void amt_act_src(struct amt_tunnel_list *tunnel,
314                         struct amt_group_node *gnode,
315                         struct amt_source_node *snode,
316                         enum amt_act act)
317 {
318         struct amt_dev *amt = tunnel->amt;
319
320         switch (act) {
321         case AMT_ACT_GMI:
322                 mod_delayed_work(amt_wq, &snode->source_timer,
323                                  msecs_to_jiffies(amt_gmi(amt)));
324                 break;
325         case AMT_ACT_GMI_ZERO:
326                 cancel_delayed_work(&snode->source_timer);
327                 break;
328         case AMT_ACT_GT:
329                 mod_delayed_work(amt_wq, &snode->source_timer,
330                                  gnode->group_timer.timer.expires);
331                 break;
332         case AMT_ACT_STATUS_FWD_NEW:
333                 snode->status = AMT_SOURCE_STATUS_FWD;
334                 snode->flags = AMT_SOURCE_NEW;
335                 break;
336         case AMT_ACT_STATUS_D_FWD_NEW:
337                 snode->status = AMT_SOURCE_STATUS_D_FWD;
338                 snode->flags = AMT_SOURCE_NEW;
339                 break;
340         case AMT_ACT_STATUS_NONE_NEW:
341                 cancel_delayed_work(&snode->source_timer);
342                 snode->status = AMT_SOURCE_STATUS_NONE;
343                 snode->flags = AMT_SOURCE_NEW;
344                 break;
345         default:
346                 WARN_ON_ONCE(1);
347                 return;
348         }
349
350         if (!gnode->v6)
351                 netdev_dbg(amt->dev, "Source %pI4 from %pI4 Acted %s\n",
352                            &snode->source_addr.ip4,
353                            &gnode->group_addr.ip4,
354                            action_str[act]);
355 #if IS_ENABLED(CONFIG_IPV6)
356         else
357                 netdev_dbg(amt->dev, "Source %pI6 from %pI6 Acted %s\n",
358                            &snode->source_addr.ip6,
359                            &gnode->group_addr.ip6,
360                            action_str[act]);
361 #endif
362 }
363
364 static struct amt_source_node *amt_alloc_snode(struct amt_group_node *gnode,
365                                                union amt_addr *src)
366 {
367         struct amt_source_node *snode;
368
369         snode = kzalloc(sizeof(*snode), GFP_ATOMIC);
370         if (!snode)
371                 return NULL;
372
373         memcpy(&snode->source_addr, src, sizeof(union amt_addr));
374         snode->gnode = gnode;
375         snode->status = AMT_SOURCE_STATUS_NONE;
376         snode->flags = AMT_SOURCE_NEW;
377         INIT_HLIST_NODE(&snode->node);
378         INIT_DELAYED_WORK(&snode->source_timer, amt_source_work);
379
380         return snode;
381 }
382
383 /* RFC 3810 - 7.2.2.  Definition of Filter Timers
384  *
385  *  Router Mode          Filter Timer         Actions/Comments
386  *  -----------       -----------------       ----------------
387  *
388  *    INCLUDE             Not Used            All listeners in
389  *                                            INCLUDE mode.
390  *
391  *    EXCLUDE             Timer > 0           At least one listener
392  *                                            in EXCLUDE mode.
393  *
394  *    EXCLUDE             Timer == 0          No more listeners in
395  *                                            EXCLUDE mode for the
396  *                                            multicast address.
397  *                                            If the Requested List
398  *                                            is empty, delete
399  *                                            Multicast Address
400  *                                            Record.  If not, switch
401  *                                            to INCLUDE filter mode;
402  *                                            the sources in the
403  *                                            Requested List are
404  *                                            moved to the Include
405  *                                            List, and the Exclude
406  *                                            List is deleted.
407  */
408 static void amt_group_work(struct work_struct *work)
409 {
410         struct amt_group_node *gnode = container_of(to_delayed_work(work),
411                                                     struct amt_group_node,
412                                                     group_timer);
413         struct amt_tunnel_list *tunnel = gnode->tunnel_list;
414         struct amt_dev *amt = gnode->amt;
415         struct amt_source_node *snode;
416         bool delete_group = true;
417         struct hlist_node *t;
418         int i, buckets;
419
420         buckets = amt->hash_buckets;
421
422         spin_lock_bh(&tunnel->lock);
423         if (gnode->filter_mode == MCAST_INCLUDE) {
424                 /* Not Used */
425                 spin_unlock_bh(&tunnel->lock);
426                 goto out;
427         }
428
429         rcu_read_lock();
430         for (i = 0; i < buckets; i++) {
431                 hlist_for_each_entry_safe(snode, t,
432                                           &gnode->sources[i], node) {
433                         if (!delayed_work_pending(&snode->source_timer) ||
434                             snode->status == AMT_SOURCE_STATUS_D_FWD) {
435                                 amt_destroy_source(snode);
436                         } else {
437                                 delete_group = false;
438                                 snode->status = AMT_SOURCE_STATUS_FWD;
439                         }
440                 }
441         }
442         if (delete_group)
443                 amt_del_group(amt, gnode);
444         else
445                 gnode->filter_mode = MCAST_INCLUDE;
446         rcu_read_unlock();
447         spin_unlock_bh(&tunnel->lock);
448 out:
449         dev_put(amt->dev);
450 }
451
452 /* Non-existant group is created as INCLUDE {empty}:
453  *
454  * RFC 3376 - 5.1. Action on Change of Interface State
455  *
456  * If no interface state existed for that multicast address before
457  * the change (i.e., the change consisted of creating a new
458  * per-interface record), or if no state exists after the change
459  * (i.e., the change consisted of deleting a per-interface record),
460  * then the "non-existent" state is considered to have a filter mode
461  * of INCLUDE and an empty source list.
462  */
463 static struct amt_group_node *amt_add_group(struct amt_dev *amt,
464                                             struct amt_tunnel_list *tunnel,
465                                             union amt_addr *group,
466                                             union amt_addr *host,
467                                             bool v6)
468 {
469         struct amt_group_node *gnode;
470         u32 hash;
471         int i;
472
473         if (tunnel->nr_groups >= amt->max_groups)
474                 return ERR_PTR(-ENOSPC);
475
476         gnode = kzalloc(sizeof(*gnode) +
477                         (sizeof(struct hlist_head) * amt->hash_buckets),
478                         GFP_ATOMIC);
479         if (unlikely(!gnode))
480                 return ERR_PTR(-ENOMEM);
481
482         gnode->amt = amt;
483         gnode->group_addr = *group;
484         gnode->host_addr = *host;
485         gnode->v6 = v6;
486         gnode->tunnel_list = tunnel;
487         gnode->filter_mode = MCAST_INCLUDE;
488         INIT_HLIST_NODE(&gnode->node);
489         INIT_DELAYED_WORK(&gnode->group_timer, amt_group_work);
490         for (i = 0; i < amt->hash_buckets; i++)
491                 INIT_HLIST_HEAD(&gnode->sources[i]);
492
493         hash = amt_group_hash(tunnel, group);
494         hlist_add_head_rcu(&gnode->node, &tunnel->groups[hash]);
495         tunnel->nr_groups++;
496
497         if (!gnode->v6)
498                 netdev_dbg(amt->dev, "Join group %pI4\n",
499                            &gnode->group_addr.ip4);
500 #if IS_ENABLED(CONFIG_IPV6)
501         else
502                 netdev_dbg(amt->dev, "Join group %pI6\n",
503                            &gnode->group_addr.ip6);
504 #endif
505
506         return gnode;
507 }
508
509 static struct sk_buff *amt_build_igmp_gq(struct amt_dev *amt)
510 {
511         u8 ra[AMT_IPHDR_OPTS] = { IPOPT_RA, 4, 0, 0 };
512         int hlen = LL_RESERVED_SPACE(amt->dev);
513         int tlen = amt->dev->needed_tailroom;
514         struct igmpv3_query *ihv3;
515         void *csum_start = NULL;
516         __sum16 *csum = NULL;
517         struct sk_buff *skb;
518         struct ethhdr *eth;
519         struct iphdr *iph;
520         unsigned int len;
521         int offset;
522
523         len = hlen + tlen + sizeof(*iph) + AMT_IPHDR_OPTS + sizeof(*ihv3);
524         skb = netdev_alloc_skb_ip_align(amt->dev, len);
525         if (!skb)
526                 return NULL;
527
528         skb_reserve(skb, hlen);
529         skb_push(skb, sizeof(*eth));
530         skb->protocol = htons(ETH_P_IP);
531         skb_reset_mac_header(skb);
532         skb->priority = TC_PRIO_CONTROL;
533         skb_put(skb, sizeof(*iph));
534         skb_put_data(skb, ra, sizeof(ra));
535         skb_put(skb, sizeof(*ihv3));
536         skb_pull(skb, sizeof(*eth));
537         skb_reset_network_header(skb);
538
539         iph             = ip_hdr(skb);
540         iph->version    = 4;
541         iph->ihl        = (sizeof(struct iphdr) + AMT_IPHDR_OPTS) >> 2;
542         iph->tos        = AMT_TOS;
543         iph->tot_len    = htons(sizeof(*iph) + AMT_IPHDR_OPTS + sizeof(*ihv3));
544         iph->frag_off   = htons(IP_DF);
545         iph->ttl        = 1;
546         iph->id         = 0;
547         iph->protocol   = IPPROTO_IGMP;
548         iph->daddr      = htonl(INADDR_ALLHOSTS_GROUP);
549         iph->saddr      = htonl(INADDR_ANY);
550         ip_send_check(iph);
551
552         eth = eth_hdr(skb);
553         ether_addr_copy(eth->h_source, amt->dev->dev_addr);
554         ip_eth_mc_map(htonl(INADDR_ALLHOSTS_GROUP), eth->h_dest);
555         eth->h_proto = htons(ETH_P_IP);
556
557         ihv3            = skb_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
558         skb_reset_transport_header(skb);
559         ihv3->type      = IGMP_HOST_MEMBERSHIP_QUERY;
560         ihv3->code      = 1;
561         ihv3->group     = 0;
562         ihv3->qqic      = amt->qi;
563         ihv3->nsrcs     = 0;
564         ihv3->resv      = 0;
565         ihv3->suppress  = false;
566         ihv3->qrv       = READ_ONCE(amt->net->ipv4.sysctl_igmp_qrv);
567         ihv3->csum      = 0;
568         csum            = &ihv3->csum;
569         csum_start      = (void *)ihv3;
570         *csum           = ip_compute_csum(csum_start, sizeof(*ihv3));
571         offset          = skb_transport_offset(skb);
572         skb->csum       = skb_checksum(skb, offset, skb->len - offset, 0);
573         skb->ip_summed  = CHECKSUM_NONE;
574
575         skb_push(skb, sizeof(*eth) + sizeof(*iph) + AMT_IPHDR_OPTS);
576
577         return skb;
578 }
579
580 static void amt_update_gw_status(struct amt_dev *amt, enum amt_status status,
581                                  bool validate)
582 {
583         if (validate && amt->status >= status)
584                 return;
585         netdev_dbg(amt->dev, "Update GW status %s -> %s",
586                    status_str[amt->status], status_str[status]);
587         WRITE_ONCE(amt->status, status);
588 }
589
590 static void __amt_update_relay_status(struct amt_tunnel_list *tunnel,
591                                       enum amt_status status,
592                                       bool validate)
593 {
594         if (validate && tunnel->status >= status)
595                 return;
596         netdev_dbg(tunnel->amt->dev,
597                    "Update Tunnel(IP = %pI4, PORT = %u) status %s -> %s",
598                    &tunnel->ip4, ntohs(tunnel->source_port),
599                    status_str[tunnel->status], status_str[status]);
600         tunnel->status = status;
601 }
602
603 static void amt_update_relay_status(struct amt_tunnel_list *tunnel,
604                                     enum amt_status status, bool validate)
605 {
606         spin_lock_bh(&tunnel->lock);
607         __amt_update_relay_status(tunnel, status, validate);
608         spin_unlock_bh(&tunnel->lock);
609 }
610
611 static void amt_send_discovery(struct amt_dev *amt)
612 {
613         struct amt_header_discovery *amtd;
614         int hlen, tlen, offset;
615         struct socket *sock;
616         struct udphdr *udph;
617         struct sk_buff *skb;
618         struct iphdr *iph;
619         struct rtable *rt;
620         struct flowi4 fl4;
621         u32 len;
622         int err;
623
624         rcu_read_lock();
625         sock = rcu_dereference(amt->sock);
626         if (!sock)
627                 goto out;
628
629         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
630                 goto out;
631
632         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
633                                    amt->discovery_ip, amt->local_ip,
634                                    amt->gw_port, amt->relay_port,
635                                    IPPROTO_UDP, 0,
636                                    amt->stream_dev->ifindex);
637         if (IS_ERR(rt)) {
638                 amt->dev->stats.tx_errors++;
639                 goto out;
640         }
641
642         hlen = LL_RESERVED_SPACE(amt->dev);
643         tlen = amt->dev->needed_tailroom;
644         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amtd);
645         skb = netdev_alloc_skb_ip_align(amt->dev, len);
646         if (!skb) {
647                 ip_rt_put(rt);
648                 amt->dev->stats.tx_errors++;
649                 goto out;
650         }
651
652         skb->priority = TC_PRIO_CONTROL;
653         skb_dst_set(skb, &rt->dst);
654
655         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amtd);
656         skb_reset_network_header(skb);
657         skb_put(skb, len);
658         amtd = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
659         amtd->version   = 0;
660         amtd->type      = AMT_MSG_DISCOVERY;
661         amtd->reserved  = 0;
662         amtd->nonce     = amt->nonce;
663         skb_push(skb, sizeof(*udph));
664         skb_reset_transport_header(skb);
665         udph            = udp_hdr(skb);
666         udph->source    = amt->gw_port;
667         udph->dest      = amt->relay_port;
668         udph->len       = htons(sizeof(*udph) + sizeof(*amtd));
669         udph->check     = 0;
670         offset = skb_transport_offset(skb);
671         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
672         udph->check = csum_tcpudp_magic(amt->local_ip, amt->discovery_ip,
673                                         sizeof(*udph) + sizeof(*amtd),
674                                         IPPROTO_UDP, skb->csum);
675
676         skb_push(skb, sizeof(*iph));
677         iph             = ip_hdr(skb);
678         iph->version    = 4;
679         iph->ihl        = (sizeof(struct iphdr)) >> 2;
680         iph->tos        = AMT_TOS;
681         iph->frag_off   = 0;
682         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
683         iph->daddr      = amt->discovery_ip;
684         iph->saddr      = amt->local_ip;
685         iph->protocol   = IPPROTO_UDP;
686         iph->tot_len    = htons(len);
687
688         skb->ip_summed = CHECKSUM_NONE;
689         ip_select_ident(amt->net, skb, NULL);
690         ip_send_check(iph);
691         err = ip_local_out(amt->net, sock->sk, skb);
692         if (unlikely(net_xmit_eval(err)))
693                 amt->dev->stats.tx_errors++;
694
695         amt_update_gw_status(amt, AMT_STATUS_SENT_DISCOVERY, true);
696 out:
697         rcu_read_unlock();
698 }
699
700 static void amt_send_request(struct amt_dev *amt, bool v6)
701 {
702         struct amt_header_request *amtrh;
703         int hlen, tlen, offset;
704         struct socket *sock;
705         struct udphdr *udph;
706         struct sk_buff *skb;
707         struct iphdr *iph;
708         struct rtable *rt;
709         struct flowi4 fl4;
710         u32 len;
711         int err;
712
713         rcu_read_lock();
714         sock = rcu_dereference(amt->sock);
715         if (!sock)
716                 goto out;
717
718         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
719                 goto out;
720
721         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
722                                    amt->remote_ip, amt->local_ip,
723                                    amt->gw_port, amt->relay_port,
724                                    IPPROTO_UDP, 0,
725                                    amt->stream_dev->ifindex);
726         if (IS_ERR(rt)) {
727                 amt->dev->stats.tx_errors++;
728                 goto out;
729         }
730
731         hlen = LL_RESERVED_SPACE(amt->dev);
732         tlen = amt->dev->needed_tailroom;
733         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amtrh);
734         skb = netdev_alloc_skb_ip_align(amt->dev, len);
735         if (!skb) {
736                 ip_rt_put(rt);
737                 amt->dev->stats.tx_errors++;
738                 goto out;
739         }
740
741         skb->priority = TC_PRIO_CONTROL;
742         skb_dst_set(skb, &rt->dst);
743
744         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amtrh);
745         skb_reset_network_header(skb);
746         skb_put(skb, len);
747         amtrh = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
748         amtrh->version   = 0;
749         amtrh->type      = AMT_MSG_REQUEST;
750         amtrh->reserved1 = 0;
751         amtrh->p         = v6;
752         amtrh->reserved2 = 0;
753         amtrh->nonce     = amt->nonce;
754         skb_push(skb, sizeof(*udph));
755         skb_reset_transport_header(skb);
756         udph            = udp_hdr(skb);
757         udph->source    = amt->gw_port;
758         udph->dest      = amt->relay_port;
759         udph->len       = htons(sizeof(*amtrh) + sizeof(*udph));
760         udph->check     = 0;
761         offset = skb_transport_offset(skb);
762         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
763         udph->check = csum_tcpudp_magic(amt->local_ip, amt->remote_ip,
764                                         sizeof(*udph) + sizeof(*amtrh),
765                                         IPPROTO_UDP, skb->csum);
766
767         skb_push(skb, sizeof(*iph));
768         iph             = ip_hdr(skb);
769         iph->version    = 4;
770         iph->ihl        = (sizeof(struct iphdr)) >> 2;
771         iph->tos        = AMT_TOS;
772         iph->frag_off   = 0;
773         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
774         iph->daddr      = amt->remote_ip;
775         iph->saddr      = amt->local_ip;
776         iph->protocol   = IPPROTO_UDP;
777         iph->tot_len    = htons(len);
778
779         skb->ip_summed = CHECKSUM_NONE;
780         ip_select_ident(amt->net, skb, NULL);
781         ip_send_check(iph);
782         err = ip_local_out(amt->net, sock->sk, skb);
783         if (unlikely(net_xmit_eval(err)))
784                 amt->dev->stats.tx_errors++;
785
786 out:
787         rcu_read_unlock();
788 }
789
790 static void amt_send_igmp_gq(struct amt_dev *amt,
791                              struct amt_tunnel_list *tunnel)
792 {
793         struct sk_buff *skb;
794
795         skb = amt_build_igmp_gq(amt);
796         if (!skb)
797                 return;
798
799         amt_skb_cb(skb)->tunnel = tunnel;
800         dev_queue_xmit(skb);
801 }
802
803 #if IS_ENABLED(CONFIG_IPV6)
804 static struct sk_buff *amt_build_mld_gq(struct amt_dev *amt)
805 {
806         u8 ra[AMT_IP6HDR_OPTS] = { IPPROTO_ICMPV6, 0, IPV6_TLV_ROUTERALERT,
807                                    2, 0, 0, IPV6_TLV_PAD1, IPV6_TLV_PAD1 };
808         int hlen = LL_RESERVED_SPACE(amt->dev);
809         int tlen = amt->dev->needed_tailroom;
810         struct mld2_query *mld2q;
811         void *csum_start = NULL;
812         struct ipv6hdr *ip6h;
813         struct sk_buff *skb;
814         struct ethhdr *eth;
815         u32 len;
816
817         len = hlen + tlen + sizeof(*ip6h) + sizeof(ra) + sizeof(*mld2q);
818         skb = netdev_alloc_skb_ip_align(amt->dev, len);
819         if (!skb)
820                 return NULL;
821
822         skb_reserve(skb, hlen);
823         skb_push(skb, sizeof(*eth));
824         skb_reset_mac_header(skb);
825         eth = eth_hdr(skb);
826         skb->priority = TC_PRIO_CONTROL;
827         skb->protocol = htons(ETH_P_IPV6);
828         skb_put_zero(skb, sizeof(*ip6h));
829         skb_put_data(skb, ra, sizeof(ra));
830         skb_put_zero(skb, sizeof(*mld2q));
831         skb_pull(skb, sizeof(*eth));
832         skb_reset_network_header(skb);
833         ip6h                    = ipv6_hdr(skb);
834         ip6h->payload_len       = htons(sizeof(ra) + sizeof(*mld2q));
835         ip6h->nexthdr           = NEXTHDR_HOP;
836         ip6h->hop_limit         = 1;
837         ip6h->daddr             = mld2_all_node;
838         ip6_flow_hdr(ip6h, 0, 0);
839
840         if (ipv6_dev_get_saddr(amt->net, amt->dev, &ip6h->daddr, 0,
841                                &ip6h->saddr)) {
842                 amt->dev->stats.tx_errors++;
843                 kfree_skb(skb);
844                 return NULL;
845         }
846
847         eth->h_proto = htons(ETH_P_IPV6);
848         ether_addr_copy(eth->h_source, amt->dev->dev_addr);
849         ipv6_eth_mc_map(&mld2_all_node, eth->h_dest);
850
851         skb_pull(skb, sizeof(*ip6h) + sizeof(ra));
852         skb_reset_transport_header(skb);
853         mld2q                   = (struct mld2_query *)icmp6_hdr(skb);
854         mld2q->mld2q_mrc        = htons(1);
855         mld2q->mld2q_type       = ICMPV6_MGM_QUERY;
856         mld2q->mld2q_code       = 0;
857         mld2q->mld2q_cksum      = 0;
858         mld2q->mld2q_resv1      = 0;
859         mld2q->mld2q_resv2      = 0;
860         mld2q->mld2q_suppress   = 0;
861         mld2q->mld2q_qrv        = amt->qrv;
862         mld2q->mld2q_nsrcs      = 0;
863         mld2q->mld2q_qqic       = amt->qi;
864         csum_start              = (void *)mld2q;
865         mld2q->mld2q_cksum = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
866                                              sizeof(*mld2q),
867                                              IPPROTO_ICMPV6,
868                                              csum_partial(csum_start,
869                                                           sizeof(*mld2q), 0));
870
871         skb->ip_summed = CHECKSUM_NONE;
872         skb_push(skb, sizeof(*eth) + sizeof(*ip6h) + sizeof(ra));
873         return skb;
874 }
875
876 static void amt_send_mld_gq(struct amt_dev *amt, struct amt_tunnel_list *tunnel)
877 {
878         struct sk_buff *skb;
879
880         skb = amt_build_mld_gq(amt);
881         if (!skb)
882                 return;
883
884         amt_skb_cb(skb)->tunnel = tunnel;
885         dev_queue_xmit(skb);
886 }
887 #else
888 static void amt_send_mld_gq(struct amt_dev *amt, struct amt_tunnel_list *tunnel)
889 {
890 }
891 #endif
892
893 static bool amt_queue_event(struct amt_dev *amt, enum amt_event event,
894                             struct sk_buff *skb)
895 {
896         int index;
897
898         spin_lock_bh(&amt->lock);
899         if (amt->nr_events >= AMT_MAX_EVENTS) {
900                 spin_unlock_bh(&amt->lock);
901                 return 1;
902         }
903
904         index = (amt->event_idx + amt->nr_events) % AMT_MAX_EVENTS;
905         amt->events[index].event = event;
906         amt->events[index].skb = skb;
907         amt->nr_events++;
908         amt->event_idx %= AMT_MAX_EVENTS;
909         queue_work(amt_wq, &amt->event_wq);
910         spin_unlock_bh(&amt->lock);
911
912         return 0;
913 }
914
915 static void amt_secret_work(struct work_struct *work)
916 {
917         struct amt_dev *amt = container_of(to_delayed_work(work),
918                                            struct amt_dev,
919                                            secret_wq);
920
921         spin_lock_bh(&amt->lock);
922         get_random_bytes(&amt->key, sizeof(siphash_key_t));
923         spin_unlock_bh(&amt->lock);
924         mod_delayed_work(amt_wq, &amt->secret_wq,
925                          msecs_to_jiffies(AMT_SECRET_TIMEOUT));
926 }
927
928 static void amt_event_send_discovery(struct amt_dev *amt)
929 {
930         if (amt->status > AMT_STATUS_SENT_DISCOVERY)
931                 goto out;
932         get_random_bytes(&amt->nonce, sizeof(__be32));
933
934         amt_send_discovery(amt);
935 out:
936         mod_delayed_work(amt_wq, &amt->discovery_wq,
937                          msecs_to_jiffies(AMT_DISCOVERY_TIMEOUT));
938 }
939
940 static void amt_discovery_work(struct work_struct *work)
941 {
942         struct amt_dev *amt = container_of(to_delayed_work(work),
943                                            struct amt_dev,
944                                            discovery_wq);
945
946         if (amt_queue_event(amt, AMT_EVENT_SEND_DISCOVERY, NULL))
947                 mod_delayed_work(amt_wq, &amt->discovery_wq,
948                                  msecs_to_jiffies(AMT_DISCOVERY_TIMEOUT));
949 }
950
951 static void amt_event_send_request(struct amt_dev *amt)
952 {
953         u32 exp;
954
955         if (amt->status < AMT_STATUS_RECEIVED_ADVERTISEMENT)
956                 goto out;
957
958         if (amt->req_cnt > AMT_MAX_REQ_COUNT) {
959                 netdev_dbg(amt->dev, "Gateway is not ready");
960                 amt->qi = AMT_INIT_REQ_TIMEOUT;
961                 WRITE_ONCE(amt->ready4, false);
962                 WRITE_ONCE(amt->ready6, false);
963                 amt->remote_ip = 0;
964                 amt_update_gw_status(amt, AMT_STATUS_INIT, false);
965                 amt->req_cnt = 0;
966                 amt->nonce = 0;
967                 goto out;
968         }
969
970         if (!amt->req_cnt)
971                 get_random_bytes(&amt->nonce, sizeof(__be32));
972
973         amt_send_request(amt, false);
974         amt_send_request(amt, true);
975         amt_update_gw_status(amt, AMT_STATUS_SENT_REQUEST, true);
976         amt->req_cnt++;
977 out:
978         exp = min_t(u32, (1 * (1 << amt->req_cnt)), AMT_MAX_REQ_TIMEOUT);
979         mod_delayed_work(amt_wq, &amt->req_wq, msecs_to_jiffies(exp * 1000));
980 }
981
982 static void amt_req_work(struct work_struct *work)
983 {
984         struct amt_dev *amt = container_of(to_delayed_work(work),
985                                            struct amt_dev,
986                                            req_wq);
987
988         if (amt_queue_event(amt, AMT_EVENT_SEND_REQUEST, NULL))
989                 mod_delayed_work(amt_wq, &amt->req_wq,
990                                  msecs_to_jiffies(100));
991 }
992
993 static bool amt_send_membership_update(struct amt_dev *amt,
994                                        struct sk_buff *skb,
995                                        bool v6)
996 {
997         struct amt_header_membership_update *amtmu;
998         struct socket *sock;
999         struct iphdr *iph;
1000         struct flowi4 fl4;
1001         struct rtable *rt;
1002         int err;
1003
1004         sock = rcu_dereference_bh(amt->sock);
1005         if (!sock)
1006                 return true;
1007
1008         err = skb_cow_head(skb, LL_RESERVED_SPACE(amt->dev) + sizeof(*amtmu) +
1009                            sizeof(*iph) + sizeof(struct udphdr));
1010         if (err)
1011                 return true;
1012
1013         skb_reset_inner_headers(skb);
1014         memset(&fl4, 0, sizeof(struct flowi4));
1015         fl4.flowi4_oif         = amt->stream_dev->ifindex;
1016         fl4.daddr              = amt->remote_ip;
1017         fl4.saddr              = amt->local_ip;
1018         fl4.flowi4_tos         = AMT_TOS;
1019         fl4.flowi4_proto       = IPPROTO_UDP;
1020         rt = ip_route_output_key(amt->net, &fl4);
1021         if (IS_ERR(rt)) {
1022                 netdev_dbg(amt->dev, "no route to %pI4\n", &amt->remote_ip);
1023                 return true;
1024         }
1025
1026         amtmu                   = skb_push(skb, sizeof(*amtmu));
1027         amtmu->version          = 0;
1028         amtmu->type             = AMT_MSG_MEMBERSHIP_UPDATE;
1029         amtmu->reserved         = 0;
1030         amtmu->nonce            = amt->nonce;
1031         amtmu->response_mac     = amt->mac;
1032
1033         if (!v6)
1034                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1035         else
1036                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1037         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1038                             fl4.saddr,
1039                             fl4.daddr,
1040                             AMT_TOS,
1041                             ip4_dst_hoplimit(&rt->dst),
1042                             0,
1043                             amt->gw_port,
1044                             amt->relay_port,
1045                             false,
1046                             false);
1047         amt_update_gw_status(amt, AMT_STATUS_SENT_UPDATE, true);
1048         return false;
1049 }
1050
1051 static void amt_send_multicast_data(struct amt_dev *amt,
1052                                     const struct sk_buff *oskb,
1053                                     struct amt_tunnel_list *tunnel,
1054                                     bool v6)
1055 {
1056         struct amt_header_mcast_data *amtmd;
1057         struct socket *sock;
1058         struct sk_buff *skb;
1059         struct iphdr *iph;
1060         struct flowi4 fl4;
1061         struct rtable *rt;
1062
1063         sock = rcu_dereference_bh(amt->sock);
1064         if (!sock)
1065                 return;
1066
1067         skb = skb_copy_expand(oskb, sizeof(*amtmd) + sizeof(*iph) +
1068                               sizeof(struct udphdr), 0, GFP_ATOMIC);
1069         if (!skb)
1070                 return;
1071
1072         skb_reset_inner_headers(skb);
1073         memset(&fl4, 0, sizeof(struct flowi4));
1074         fl4.flowi4_oif         = amt->stream_dev->ifindex;
1075         fl4.daddr              = tunnel->ip4;
1076         fl4.saddr              = amt->local_ip;
1077         fl4.flowi4_proto       = IPPROTO_UDP;
1078         rt = ip_route_output_key(amt->net, &fl4);
1079         if (IS_ERR(rt)) {
1080                 netdev_dbg(amt->dev, "no route to %pI4\n", &tunnel->ip4);
1081                 kfree_skb(skb);
1082                 return;
1083         }
1084
1085         amtmd = skb_push(skb, sizeof(*amtmd));
1086         amtmd->version = 0;
1087         amtmd->reserved = 0;
1088         amtmd->type = AMT_MSG_MULTICAST_DATA;
1089
1090         if (!v6)
1091                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1092         else
1093                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1094         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1095                             fl4.saddr,
1096                             fl4.daddr,
1097                             AMT_TOS,
1098                             ip4_dst_hoplimit(&rt->dst),
1099                             0,
1100                             amt->relay_port,
1101                             tunnel->source_port,
1102                             false,
1103                             false);
1104 }
1105
1106 static bool amt_send_membership_query(struct amt_dev *amt,
1107                                       struct sk_buff *skb,
1108                                       struct amt_tunnel_list *tunnel,
1109                                       bool v6)
1110 {
1111         struct amt_header_membership_query *amtmq;
1112         struct socket *sock;
1113         struct rtable *rt;
1114         struct flowi4 fl4;
1115         int err;
1116
1117         sock = rcu_dereference_bh(amt->sock);
1118         if (!sock)
1119                 return true;
1120
1121         err = skb_cow_head(skb, LL_RESERVED_SPACE(amt->dev) + sizeof(*amtmq) +
1122                            sizeof(struct iphdr) + sizeof(struct udphdr));
1123         if (err)
1124                 return true;
1125
1126         skb_reset_inner_headers(skb);
1127         memset(&fl4, 0, sizeof(struct flowi4));
1128         fl4.flowi4_oif         = amt->stream_dev->ifindex;
1129         fl4.daddr              = tunnel->ip4;
1130         fl4.saddr              = amt->local_ip;
1131         fl4.flowi4_tos         = AMT_TOS;
1132         fl4.flowi4_proto       = IPPROTO_UDP;
1133         rt = ip_route_output_key(amt->net, &fl4);
1134         if (IS_ERR(rt)) {
1135                 netdev_dbg(amt->dev, "no route to %pI4\n", &tunnel->ip4);
1136                 return true;
1137         }
1138
1139         amtmq           = skb_push(skb, sizeof(*amtmq));
1140         amtmq->version  = 0;
1141         amtmq->type     = AMT_MSG_MEMBERSHIP_QUERY;
1142         amtmq->reserved = 0;
1143         amtmq->l        = 0;
1144         amtmq->g        = 0;
1145         amtmq->nonce    = tunnel->nonce;
1146         amtmq->response_mac = tunnel->mac;
1147
1148         if (!v6)
1149                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1150         else
1151                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1152         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1153                             fl4.saddr,
1154                             fl4.daddr,
1155                             AMT_TOS,
1156                             ip4_dst_hoplimit(&rt->dst),
1157                             0,
1158                             amt->relay_port,
1159                             tunnel->source_port,
1160                             false,
1161                             false);
1162         amt_update_relay_status(tunnel, AMT_STATUS_SENT_QUERY, true);
1163         return false;
1164 }
1165
1166 static netdev_tx_t amt_dev_xmit(struct sk_buff *skb, struct net_device *dev)
1167 {
1168         struct amt_dev *amt = netdev_priv(dev);
1169         struct amt_tunnel_list *tunnel;
1170         struct amt_group_node *gnode;
1171         union amt_addr group = {0,};
1172 #if IS_ENABLED(CONFIG_IPV6)
1173         struct ipv6hdr *ip6h;
1174         struct mld_msg *mld;
1175 #endif
1176         bool report = false;
1177         struct igmphdr *ih;
1178         bool query = false;
1179         struct iphdr *iph;
1180         bool data = false;
1181         bool v6 = false;
1182         u32 hash;
1183
1184         iph = ip_hdr(skb);
1185         if (iph->version == 4) {
1186                 if (!ipv4_is_multicast(iph->daddr))
1187                         goto free;
1188
1189                 if (!ip_mc_check_igmp(skb)) {
1190                         ih = igmp_hdr(skb);
1191                         switch (ih->type) {
1192                         case IGMPV3_HOST_MEMBERSHIP_REPORT:
1193                         case IGMP_HOST_MEMBERSHIP_REPORT:
1194                                 report = true;
1195                                 break;
1196                         case IGMP_HOST_MEMBERSHIP_QUERY:
1197                                 query = true;
1198                                 break;
1199                         default:
1200                                 goto free;
1201                         }
1202                 } else {
1203                         data = true;
1204                 }
1205                 v6 = false;
1206                 group.ip4 = iph->daddr;
1207 #if IS_ENABLED(CONFIG_IPV6)
1208         } else if (iph->version == 6) {
1209                 ip6h = ipv6_hdr(skb);
1210                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
1211                         goto free;
1212
1213                 if (!ipv6_mc_check_mld(skb)) {
1214                         mld = (struct mld_msg *)skb_transport_header(skb);
1215                         switch (mld->mld_type) {
1216                         case ICMPV6_MGM_REPORT:
1217                         case ICMPV6_MLD2_REPORT:
1218                                 report = true;
1219                                 break;
1220                         case ICMPV6_MGM_QUERY:
1221                                 query = true;
1222                                 break;
1223                         default:
1224                                 goto free;
1225                         }
1226                 } else {
1227                         data = true;
1228                 }
1229                 v6 = true;
1230                 group.ip6 = ip6h->daddr;
1231 #endif
1232         } else {
1233                 dev->stats.tx_errors++;
1234                 goto free;
1235         }
1236
1237         if (!pskb_may_pull(skb, sizeof(struct ethhdr)))
1238                 goto free;
1239
1240         skb_pull(skb, sizeof(struct ethhdr));
1241
1242         if (amt->mode == AMT_MODE_GATEWAY) {
1243                 /* Gateway only passes IGMP/MLD packets */
1244                 if (!report)
1245                         goto free;
1246                 if ((!v6 && !READ_ONCE(amt->ready4)) ||
1247                     (v6 && !READ_ONCE(amt->ready6)))
1248                         goto free;
1249                 if (amt_send_membership_update(amt, skb,  v6))
1250                         goto free;
1251                 goto unlock;
1252         } else if (amt->mode == AMT_MODE_RELAY) {
1253                 if (query) {
1254                         tunnel = amt_skb_cb(skb)->tunnel;
1255                         if (!tunnel) {
1256                                 WARN_ON(1);
1257                                 goto free;
1258                         }
1259
1260                         /* Do not forward unexpected query */
1261                         if (amt_send_membership_query(amt, skb, tunnel, v6))
1262                                 goto free;
1263                         goto unlock;
1264                 }
1265
1266                 if (!data)
1267                         goto free;
1268                 list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list) {
1269                         hash = amt_group_hash(tunnel, &group);
1270                         hlist_for_each_entry_rcu(gnode, &tunnel->groups[hash],
1271                                                  node) {
1272                                 if (!v6) {
1273                                         if (gnode->group_addr.ip4 == iph->daddr)
1274                                                 goto found;
1275 #if IS_ENABLED(CONFIG_IPV6)
1276                                 } else {
1277                                         if (ipv6_addr_equal(&gnode->group_addr.ip6,
1278                                                             &ip6h->daddr))
1279                                                 goto found;
1280 #endif
1281                                 }
1282                         }
1283                         continue;
1284 found:
1285                         amt_send_multicast_data(amt, skb, tunnel, v6);
1286                 }
1287         }
1288
1289         dev_kfree_skb(skb);
1290         return NETDEV_TX_OK;
1291 free:
1292         dev_kfree_skb(skb);
1293 unlock:
1294         dev->stats.tx_dropped++;
1295         return NETDEV_TX_OK;
1296 }
1297
1298 static int amt_parse_type(struct sk_buff *skb)
1299 {
1300         struct amt_header *amth;
1301
1302         if (!pskb_may_pull(skb, sizeof(struct udphdr) +
1303                            sizeof(struct amt_header)))
1304                 return -1;
1305
1306         amth = (struct amt_header *)(udp_hdr(skb) + 1);
1307
1308         if (amth->version != 0)
1309                 return -1;
1310
1311         if (amth->type >= __AMT_MSG_MAX || !amth->type)
1312                 return -1;
1313         return amth->type;
1314 }
1315
1316 static void amt_clear_groups(struct amt_tunnel_list *tunnel)
1317 {
1318         struct amt_dev *amt = tunnel->amt;
1319         struct amt_group_node *gnode;
1320         struct hlist_node *t;
1321         int i;
1322
1323         spin_lock_bh(&tunnel->lock);
1324         rcu_read_lock();
1325         for (i = 0; i < amt->hash_buckets; i++)
1326                 hlist_for_each_entry_safe(gnode, t, &tunnel->groups[i], node)
1327                         amt_del_group(amt, gnode);
1328         rcu_read_unlock();
1329         spin_unlock_bh(&tunnel->lock);
1330 }
1331
1332 static void amt_tunnel_expire(struct work_struct *work)
1333 {
1334         struct amt_tunnel_list *tunnel = container_of(to_delayed_work(work),
1335                                                       struct amt_tunnel_list,
1336                                                       gc_wq);
1337         struct amt_dev *amt = tunnel->amt;
1338
1339         spin_lock_bh(&amt->lock);
1340         rcu_read_lock();
1341         list_del_rcu(&tunnel->list);
1342         amt->nr_tunnels--;
1343         amt_clear_groups(tunnel);
1344         rcu_read_unlock();
1345         spin_unlock_bh(&amt->lock);
1346         kfree_rcu(tunnel, rcu);
1347 }
1348
1349 static void amt_cleanup_srcs(struct amt_dev *amt,
1350                              struct amt_tunnel_list *tunnel,
1351                              struct amt_group_node *gnode)
1352 {
1353         struct amt_source_node *snode;
1354         struct hlist_node *t;
1355         int i;
1356
1357         /* Delete old sources */
1358         for (i = 0; i < amt->hash_buckets; i++) {
1359                 hlist_for_each_entry_safe(snode, t, &gnode->sources[i], node) {
1360                         if (snode->flags == AMT_SOURCE_OLD)
1361                                 amt_destroy_source(snode);
1362                 }
1363         }
1364
1365         /* switch from new to old */
1366         for (i = 0; i < amt->hash_buckets; i++)  {
1367                 hlist_for_each_entry_rcu(snode, &gnode->sources[i], node) {
1368                         snode->flags = AMT_SOURCE_OLD;
1369                         if (!gnode->v6)
1370                                 netdev_dbg(snode->gnode->amt->dev,
1371                                            "Add source as OLD %pI4 from %pI4\n",
1372                                            &snode->source_addr.ip4,
1373                                            &gnode->group_addr.ip4);
1374 #if IS_ENABLED(CONFIG_IPV6)
1375                         else
1376                                 netdev_dbg(snode->gnode->amt->dev,
1377                                            "Add source as OLD %pI6 from %pI6\n",
1378                                            &snode->source_addr.ip6,
1379                                            &gnode->group_addr.ip6);
1380 #endif
1381                 }
1382         }
1383 }
1384
1385 static void amt_add_srcs(struct amt_dev *amt, struct amt_tunnel_list *tunnel,
1386                          struct amt_group_node *gnode, void *grec,
1387                          bool v6)
1388 {
1389         struct igmpv3_grec *igmp_grec;
1390         struct amt_source_node *snode;
1391 #if IS_ENABLED(CONFIG_IPV6)
1392         struct mld2_grec *mld_grec;
1393 #endif
1394         union amt_addr src = {0,};
1395         u16 nsrcs;
1396         u32 hash;
1397         int i;
1398
1399         if (!v6) {
1400                 igmp_grec = (struct igmpv3_grec *)grec;
1401                 nsrcs = ntohs(igmp_grec->grec_nsrcs);
1402         } else {
1403 #if IS_ENABLED(CONFIG_IPV6)
1404                 mld_grec = (struct mld2_grec *)grec;
1405                 nsrcs = ntohs(mld_grec->grec_nsrcs);
1406 #else
1407         return;
1408 #endif
1409         }
1410         for (i = 0; i < nsrcs; i++) {
1411                 if (tunnel->nr_sources >= amt->max_sources)
1412                         return;
1413                 if (!v6)
1414                         src.ip4 = igmp_grec->grec_src[i];
1415 #if IS_ENABLED(CONFIG_IPV6)
1416                 else
1417                         memcpy(&src.ip6, &mld_grec->grec_src[i],
1418                                sizeof(struct in6_addr));
1419 #endif
1420                 if (amt_lookup_src(tunnel, gnode, AMT_FILTER_ALL, &src))
1421                         continue;
1422
1423                 snode = amt_alloc_snode(gnode, &src);
1424                 if (snode) {
1425                         hash = amt_source_hash(tunnel, &snode->source_addr);
1426                         hlist_add_head_rcu(&snode->node, &gnode->sources[hash]);
1427                         tunnel->nr_sources++;
1428                         gnode->nr_sources++;
1429
1430                         if (!gnode->v6)
1431                                 netdev_dbg(snode->gnode->amt->dev,
1432                                            "Add source as NEW %pI4 from %pI4\n",
1433                                            &snode->source_addr.ip4,
1434                                            &gnode->group_addr.ip4);
1435 #if IS_ENABLED(CONFIG_IPV6)
1436                         else
1437                                 netdev_dbg(snode->gnode->amt->dev,
1438                                            "Add source as NEW %pI6 from %pI6\n",
1439                                            &snode->source_addr.ip6,
1440                                            &gnode->group_addr.ip6);
1441 #endif
1442                 }
1443         }
1444 }
1445
1446 /* Router State   Report Rec'd New Router State
1447  * ------------   ------------ ----------------
1448  * EXCLUDE (X,Y)  IS_IN (A)    EXCLUDE (X+A,Y-A)
1449  *
1450  * -----------+-----------+-----------+
1451  *            |    OLD    |    NEW    |
1452  * -----------+-----------+-----------+
1453  *    FWD     |     X     |    X+A    |
1454  * -----------+-----------+-----------+
1455  *    D_FWD   |     Y     |    Y-A    |
1456  * -----------+-----------+-----------+
1457  *    NONE    |           |     A     |
1458  * -----------+-----------+-----------+
1459  *
1460  * a) Received sources are NONE/NEW
1461  * b) All NONE will be deleted by amt_cleanup_srcs().
1462  * c) All OLD will be deleted by amt_cleanup_srcs().
1463  * d) After delete, NEW source will be switched to OLD.
1464  */
1465 static void amt_lookup_act_srcs(struct amt_tunnel_list *tunnel,
1466                                 struct amt_group_node *gnode,
1467                                 void *grec,
1468                                 enum amt_ops ops,
1469                                 enum amt_filter filter,
1470                                 enum amt_act act,
1471                                 bool v6)
1472 {
1473         struct amt_dev *amt = tunnel->amt;
1474         struct amt_source_node *snode;
1475         struct igmpv3_grec *igmp_grec;
1476 #if IS_ENABLED(CONFIG_IPV6)
1477         struct mld2_grec *mld_grec;
1478 #endif
1479         union amt_addr src = {0,};
1480         struct hlist_node *t;
1481         u16 nsrcs;
1482         int i, j;
1483
1484         if (!v6) {
1485                 igmp_grec = (struct igmpv3_grec *)grec;
1486                 nsrcs = ntohs(igmp_grec->grec_nsrcs);
1487         } else {
1488 #if IS_ENABLED(CONFIG_IPV6)
1489                 mld_grec = (struct mld2_grec *)grec;
1490                 nsrcs = ntohs(mld_grec->grec_nsrcs);
1491 #else
1492         return;
1493 #endif
1494         }
1495
1496         memset(&src, 0, sizeof(union amt_addr));
1497         switch (ops) {
1498         case AMT_OPS_INT:
1499                 /* A*B */
1500                 for (i = 0; i < nsrcs; i++) {
1501                         if (!v6)
1502                                 src.ip4 = igmp_grec->grec_src[i];
1503 #if IS_ENABLED(CONFIG_IPV6)
1504                         else
1505                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1506                                        sizeof(struct in6_addr));
1507 #endif
1508                         snode = amt_lookup_src(tunnel, gnode, filter, &src);
1509                         if (!snode)
1510                                 continue;
1511                         amt_act_src(tunnel, gnode, snode, act);
1512                 }
1513                 break;
1514         case AMT_OPS_UNI:
1515                 /* A+B */
1516                 for (i = 0; i < amt->hash_buckets; i++) {
1517                         hlist_for_each_entry_safe(snode, t, &gnode->sources[i],
1518                                                   node) {
1519                                 if (amt_status_filter(snode, filter))
1520                                         amt_act_src(tunnel, gnode, snode, act);
1521                         }
1522                 }
1523                 for (i = 0; i < nsrcs; i++) {
1524                         if (!v6)
1525                                 src.ip4 = igmp_grec->grec_src[i];
1526 #if IS_ENABLED(CONFIG_IPV6)
1527                         else
1528                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1529                                        sizeof(struct in6_addr));
1530 #endif
1531                         snode = amt_lookup_src(tunnel, gnode, filter, &src);
1532                         if (!snode)
1533                                 continue;
1534                         amt_act_src(tunnel, gnode, snode, act);
1535                 }
1536                 break;
1537         case AMT_OPS_SUB:
1538                 /* A-B */
1539                 for (i = 0; i < amt->hash_buckets; i++) {
1540                         hlist_for_each_entry_safe(snode, t, &gnode->sources[i],
1541                                                   node) {
1542                                 if (!amt_status_filter(snode, filter))
1543                                         continue;
1544                                 for (j = 0; j < nsrcs; j++) {
1545                                         if (!v6)
1546                                                 src.ip4 = igmp_grec->grec_src[j];
1547 #if IS_ENABLED(CONFIG_IPV6)
1548                                         else
1549                                                 memcpy(&src.ip6,
1550                                                        &mld_grec->grec_src[j],
1551                                                        sizeof(struct in6_addr));
1552 #endif
1553                                         if (amt_addr_equal(&snode->source_addr,
1554                                                            &src))
1555                                                 goto out_sub;
1556                                 }
1557                                 amt_act_src(tunnel, gnode, snode, act);
1558                                 continue;
1559 out_sub:;
1560                         }
1561                 }
1562                 break;
1563         case AMT_OPS_SUB_REV:
1564                 /* B-A */
1565                 for (i = 0; i < nsrcs; i++) {
1566                         if (!v6)
1567                                 src.ip4 = igmp_grec->grec_src[i];
1568 #if IS_ENABLED(CONFIG_IPV6)
1569                         else
1570                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1571                                        sizeof(struct in6_addr));
1572 #endif
1573                         snode = amt_lookup_src(tunnel, gnode, AMT_FILTER_ALL,
1574                                                &src);
1575                         if (!snode) {
1576                                 snode = amt_lookup_src(tunnel, gnode,
1577                                                        filter, &src);
1578                                 if (snode)
1579                                         amt_act_src(tunnel, gnode, snode, act);
1580                         }
1581                 }
1582                 break;
1583         default:
1584                 netdev_dbg(amt->dev, "Invalid type\n");
1585                 return;
1586         }
1587 }
1588
1589 static void amt_mcast_is_in_handler(struct amt_dev *amt,
1590                                     struct amt_tunnel_list *tunnel,
1591                                     struct amt_group_node *gnode,
1592                                     void *grec, void *zero_grec, bool v6)
1593 {
1594         if (gnode->filter_mode == MCAST_INCLUDE) {
1595 /* Router State   Report Rec'd New Router State        Actions
1596  * ------------   ------------ ----------------        -------
1597  * INCLUDE (A)    IS_IN (B)    INCLUDE (A+B)           (B)=GMI
1598  */
1599                 /* Update IS_IN (B) as FWD/NEW */
1600                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1601                                     AMT_FILTER_NONE_NEW,
1602                                     AMT_ACT_STATUS_FWD_NEW,
1603                                     v6);
1604                 /* Update INCLUDE (A) as NEW */
1605                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1606                                     AMT_FILTER_FWD,
1607                                     AMT_ACT_STATUS_FWD_NEW,
1608                                     v6);
1609                 /* (B)=GMI */
1610                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1611                                     AMT_FILTER_FWD_NEW,
1612                                     AMT_ACT_GMI,
1613                                     v6);
1614         } else {
1615 /* State        Actions
1616  * ------------   ------------ ----------------        -------
1617  * EXCLUDE (X,Y)  IS_IN (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1618  */
1619                 /* Update (A) in (X, Y) as NONE/NEW */
1620                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1621                                     AMT_FILTER_BOTH,
1622                                     AMT_ACT_STATUS_NONE_NEW,
1623                                     v6);
1624                 /* Update FWD/OLD as FWD/NEW */
1625                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1626                                     AMT_FILTER_FWD,
1627                                     AMT_ACT_STATUS_FWD_NEW,
1628                                     v6);
1629                 /* Update IS_IN (A) as FWD/NEW */
1630                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1631                                     AMT_FILTER_NONE_NEW,
1632                                     AMT_ACT_STATUS_FWD_NEW,
1633                                     v6);
1634                 /* Update EXCLUDE (, Y-A) as D_FWD_NEW */
1635                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB,
1636                                     AMT_FILTER_D_FWD,
1637                                     AMT_ACT_STATUS_D_FWD_NEW,
1638                                     v6);
1639         }
1640 }
1641
1642 static void amt_mcast_is_ex_handler(struct amt_dev *amt,
1643                                     struct amt_tunnel_list *tunnel,
1644                                     struct amt_group_node *gnode,
1645                                     void *grec, void *zero_grec, bool v6)
1646 {
1647         if (gnode->filter_mode == MCAST_INCLUDE) {
1648 /* Router State   Report Rec'd  New Router State         Actions
1649  * ------------   ------------  ----------------         -------
1650  * INCLUDE (A)    IS_EX (B)     EXCLUDE (A*B,B-A)        (B-A)=0
1651  *                                                       Delete (A-B)
1652  *                                                       Group Timer=GMI
1653  */
1654                 /* EXCLUDE(A*B, ) */
1655                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1656                                     AMT_FILTER_FWD,
1657                                     AMT_ACT_STATUS_FWD_NEW,
1658                                     v6);
1659                 /* EXCLUDE(, B-A) */
1660                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1661                                     AMT_FILTER_FWD,
1662                                     AMT_ACT_STATUS_D_FWD_NEW,
1663                                     v6);
1664                 /* (B-A)=0 */
1665                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1666                                     AMT_FILTER_D_FWD_NEW,
1667                                     AMT_ACT_GMI_ZERO,
1668                                     v6);
1669                 /* Group Timer=GMI */
1670                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1671                                       msecs_to_jiffies(amt_gmi(amt))))
1672                         dev_hold(amt->dev);
1673                 gnode->filter_mode = MCAST_EXCLUDE;
1674                 /* Delete (A-B) will be worked by amt_cleanup_srcs(). */
1675         } else {
1676 /* Router State   Report Rec'd  New Router State        Actions
1677  * ------------   ------------  ----------------        -------
1678  * EXCLUDE (X,Y)  IS_EX (A)     EXCLUDE (A-Y,Y*A)       (A-X-Y)=GMI
1679  *                                                      Delete (X-A)
1680  *                                                      Delete (Y-A)
1681  *                                                      Group Timer=GMI
1682  */
1683                 /* EXCLUDE (A-Y, ) */
1684                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1685                                     AMT_FILTER_D_FWD,
1686                                     AMT_ACT_STATUS_FWD_NEW,
1687                                     v6);
1688                 /* EXCLUDE (, Y*A ) */
1689                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1690                                     AMT_FILTER_D_FWD,
1691                                     AMT_ACT_STATUS_D_FWD_NEW,
1692                                     v6);
1693                 /* (A-X-Y)=GMI */
1694                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1695                                     AMT_FILTER_BOTH_NEW,
1696                                     AMT_ACT_GMI,
1697                                     v6);
1698                 /* Group Timer=GMI */
1699                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1700                                       msecs_to_jiffies(amt_gmi(amt))))
1701                         dev_hold(amt->dev);
1702                 /* Delete (X-A), (Y-A) will be worked by amt_cleanup_srcs(). */
1703         }
1704 }
1705
1706 static void amt_mcast_to_in_handler(struct amt_dev *amt,
1707                                     struct amt_tunnel_list *tunnel,
1708                                     struct amt_group_node *gnode,
1709                                     void *grec, void *zero_grec, bool v6)
1710 {
1711         if (gnode->filter_mode == MCAST_INCLUDE) {
1712 /* Router State   Report Rec'd New Router State        Actions
1713  * ------------   ------------ ----------------        -------
1714  * INCLUDE (A)    TO_IN (B)    INCLUDE (A+B)           (B)=GMI
1715  *                                                     Send Q(G,A-B)
1716  */
1717                 /* Update TO_IN (B) sources as FWD/NEW */
1718                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1719                                     AMT_FILTER_NONE_NEW,
1720                                     AMT_ACT_STATUS_FWD_NEW,
1721                                     v6);
1722                 /* Update INCLUDE (A) sources as NEW */
1723                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1724                                     AMT_FILTER_FWD,
1725                                     AMT_ACT_STATUS_FWD_NEW,
1726                                     v6);
1727                 /* (B)=GMI */
1728                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1729                                     AMT_FILTER_FWD_NEW,
1730                                     AMT_ACT_GMI,
1731                                     v6);
1732         } else {
1733 /* Router State   Report Rec'd New Router State        Actions
1734  * ------------   ------------ ----------------        -------
1735  * EXCLUDE (X,Y)  TO_IN (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1736  *                                                     Send Q(G,X-A)
1737  *                                                     Send Q(G)
1738  */
1739                 /* Update TO_IN (A) sources as FWD/NEW */
1740                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1741                                     AMT_FILTER_NONE_NEW,
1742                                     AMT_ACT_STATUS_FWD_NEW,
1743                                     v6);
1744                 /* Update EXCLUDE(X,) sources as FWD/NEW */
1745                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1746                                     AMT_FILTER_FWD,
1747                                     AMT_ACT_STATUS_FWD_NEW,
1748                                     v6);
1749                 /* EXCLUDE (, Y-A)
1750                  * (A) are already switched to FWD_NEW.
1751                  * So, D_FWD/OLD -> D_FWD/NEW is okay.
1752                  */
1753                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1754                                     AMT_FILTER_D_FWD,
1755                                     AMT_ACT_STATUS_D_FWD_NEW,
1756                                     v6);
1757                 /* (A)=GMI
1758                  * Only FWD_NEW will have (A) sources.
1759                  */
1760                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1761                                     AMT_FILTER_FWD_NEW,
1762                                     AMT_ACT_GMI,
1763                                     v6);
1764         }
1765 }
1766
1767 static void amt_mcast_to_ex_handler(struct amt_dev *amt,
1768                                     struct amt_tunnel_list *tunnel,
1769                                     struct amt_group_node *gnode,
1770                                     void *grec, void *zero_grec, bool v6)
1771 {
1772         if (gnode->filter_mode == MCAST_INCLUDE) {
1773 /* Router State   Report Rec'd New Router State        Actions
1774  * ------------   ------------ ----------------        -------
1775  * INCLUDE (A)    TO_EX (B)    EXCLUDE (A*B,B-A)       (B-A)=0
1776  *                                                     Delete (A-B)
1777  *                                                     Send Q(G,A*B)
1778  *                                                     Group Timer=GMI
1779  */
1780                 /* EXCLUDE (A*B, ) */
1781                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1782                                     AMT_FILTER_FWD,
1783                                     AMT_ACT_STATUS_FWD_NEW,
1784                                     v6);
1785                 /* EXCLUDE (, B-A) */
1786                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1787                                     AMT_FILTER_FWD,
1788                                     AMT_ACT_STATUS_D_FWD_NEW,
1789                                     v6);
1790                 /* (B-A)=0 */
1791                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1792                                     AMT_FILTER_D_FWD_NEW,
1793                                     AMT_ACT_GMI_ZERO,
1794                                     v6);
1795                 /* Group Timer=GMI */
1796                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1797                                       msecs_to_jiffies(amt_gmi(amt))))
1798                         dev_hold(amt->dev);
1799                 gnode->filter_mode = MCAST_EXCLUDE;
1800                 /* Delete (A-B) will be worked by amt_cleanup_srcs(). */
1801         } else {
1802 /* Router State   Report Rec'd New Router State        Actions
1803  * ------------   ------------ ----------------        -------
1804  * EXCLUDE (X,Y)  TO_EX (A)    EXCLUDE (A-Y,Y*A)       (A-X-Y)=Group Timer
1805  *                                                     Delete (X-A)
1806  *                                                     Delete (Y-A)
1807  *                                                     Send Q(G,A-Y)
1808  *                                                     Group Timer=GMI
1809  */
1810                 /* Update (A-X-Y) as NONE/OLD */
1811                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1812                                     AMT_FILTER_BOTH,
1813                                     AMT_ACT_GT,
1814                                     v6);
1815                 /* EXCLUDE (A-Y, ) */
1816                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1817                                     AMT_FILTER_D_FWD,
1818                                     AMT_ACT_STATUS_FWD_NEW,
1819                                     v6);
1820                 /* EXCLUDE (, Y*A) */
1821                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1822                                     AMT_FILTER_D_FWD,
1823                                     AMT_ACT_STATUS_D_FWD_NEW,
1824                                     v6);
1825                 /* Group Timer=GMI */
1826                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1827                                       msecs_to_jiffies(amt_gmi(amt))))
1828                         dev_hold(amt->dev);
1829                 /* Delete (X-A), (Y-A) will be worked by amt_cleanup_srcs(). */
1830         }
1831 }
1832
1833 static void amt_mcast_allow_handler(struct amt_dev *amt,
1834                                     struct amt_tunnel_list *tunnel,
1835                                     struct amt_group_node *gnode,
1836                                     void *grec, void *zero_grec, bool v6)
1837 {
1838         if (gnode->filter_mode == MCAST_INCLUDE) {
1839 /* Router State   Report Rec'd New Router State        Actions
1840  * ------------   ------------ ----------------        -------
1841  * INCLUDE (A)    ALLOW (B)    INCLUDE (A+B)           (B)=GMI
1842  */
1843                 /* INCLUDE (A+B) */
1844                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1845                                     AMT_FILTER_FWD,
1846                                     AMT_ACT_STATUS_FWD_NEW,
1847                                     v6);
1848                 /* (B)=GMI */
1849                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1850                                     AMT_FILTER_FWD_NEW,
1851                                     AMT_ACT_GMI,
1852                                     v6);
1853         } else {
1854 /* Router State   Report Rec'd New Router State        Actions
1855  * ------------   ------------ ----------------        -------
1856  * EXCLUDE (X,Y)  ALLOW (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1857  */
1858                 /* EXCLUDE (X+A, ) */
1859                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1860                                     AMT_FILTER_FWD,
1861                                     AMT_ACT_STATUS_FWD_NEW,
1862                                     v6);
1863                 /* EXCLUDE (, Y-A) */
1864                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB,
1865                                     AMT_FILTER_D_FWD,
1866                                     AMT_ACT_STATUS_D_FWD_NEW,
1867                                     v6);
1868                 /* (A)=GMI
1869                  * All (A) source are now FWD/NEW status.
1870                  */
1871                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1872                                     AMT_FILTER_FWD_NEW,
1873                                     AMT_ACT_GMI,
1874                                     v6);
1875         }
1876 }
1877
1878 static void amt_mcast_block_handler(struct amt_dev *amt,
1879                                     struct amt_tunnel_list *tunnel,
1880                                     struct amt_group_node *gnode,
1881                                     void *grec, void *zero_grec, bool v6)
1882 {
1883         if (gnode->filter_mode == MCAST_INCLUDE) {
1884 /* Router State   Report Rec'd New Router State        Actions
1885  * ------------   ------------ ----------------        -------
1886  * INCLUDE (A)    BLOCK (B)    INCLUDE (A)             Send Q(G,A*B)
1887  */
1888                 /* INCLUDE (A) */
1889                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1890                                     AMT_FILTER_FWD,
1891                                     AMT_ACT_STATUS_FWD_NEW,
1892                                     v6);
1893         } else {
1894 /* Router State   Report Rec'd New Router State        Actions
1895  * ------------   ------------ ----------------        -------
1896  * EXCLUDE (X,Y)  BLOCK (A)    EXCLUDE (X+(A-Y),Y)     (A-X-Y)=Group Timer
1897  *                                                     Send Q(G,A-Y)
1898  */
1899                 /* (A-X-Y)=Group Timer */
1900                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1901                                     AMT_FILTER_BOTH,
1902                                     AMT_ACT_GT,
1903                                     v6);
1904                 /* EXCLUDE (X, ) */
1905                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1906                                     AMT_FILTER_FWD,
1907                                     AMT_ACT_STATUS_FWD_NEW,
1908                                     v6);
1909                 /* EXCLUDE (X+(A-Y) */
1910                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1911                                     AMT_FILTER_D_FWD,
1912                                     AMT_ACT_STATUS_FWD_NEW,
1913                                     v6);
1914                 /* EXCLUDE (, Y) */
1915                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1916                                     AMT_FILTER_D_FWD,
1917                                     AMT_ACT_STATUS_D_FWD_NEW,
1918                                     v6);
1919         }
1920 }
1921
1922 /* RFC 3376
1923  * 7.3.2. In the Presence of Older Version Group Members
1924  *
1925  * When Group Compatibility Mode is IGMPv2, a router internally
1926  * translates the following IGMPv2 messages for that group to their
1927  * IGMPv3 equivalents:
1928  *
1929  * IGMPv2 Message                IGMPv3 Equivalent
1930  * --------------                -----------------
1931  * Report                        IS_EX( {} )
1932  * Leave                         TO_IN( {} )
1933  */
1934 static void amt_igmpv2_report_handler(struct amt_dev *amt, struct sk_buff *skb,
1935                                       struct amt_tunnel_list *tunnel)
1936 {
1937         struct igmphdr *ih = igmp_hdr(skb);
1938         struct iphdr *iph = ip_hdr(skb);
1939         struct amt_group_node *gnode;
1940         union amt_addr group, host;
1941
1942         memset(&group, 0, sizeof(union amt_addr));
1943         group.ip4 = ih->group;
1944         memset(&host, 0, sizeof(union amt_addr));
1945         host.ip4 = iph->saddr;
1946
1947         gnode = amt_lookup_group(tunnel, &group, &host, false);
1948         if (!gnode) {
1949                 gnode = amt_add_group(amt, tunnel, &group, &host, false);
1950                 if (!IS_ERR(gnode)) {
1951                         gnode->filter_mode = MCAST_EXCLUDE;
1952                         if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1953                                               msecs_to_jiffies(amt_gmi(amt))))
1954                                 dev_hold(amt->dev);
1955                 }
1956         }
1957 }
1958
1959 /* RFC 3376
1960  * 7.3.2. In the Presence of Older Version Group Members
1961  *
1962  * When Group Compatibility Mode is IGMPv2, a router internally
1963  * translates the following IGMPv2 messages for that group to their
1964  * IGMPv3 equivalents:
1965  *
1966  * IGMPv2 Message                IGMPv3 Equivalent
1967  * --------------                -----------------
1968  * Report                        IS_EX( {} )
1969  * Leave                         TO_IN( {} )
1970  */
1971 static void amt_igmpv2_leave_handler(struct amt_dev *amt, struct sk_buff *skb,
1972                                      struct amt_tunnel_list *tunnel)
1973 {
1974         struct igmphdr *ih = igmp_hdr(skb);
1975         struct iphdr *iph = ip_hdr(skb);
1976         struct amt_group_node *gnode;
1977         union amt_addr group, host;
1978
1979         memset(&group, 0, sizeof(union amt_addr));
1980         group.ip4 = ih->group;
1981         memset(&host, 0, sizeof(union amt_addr));
1982         host.ip4 = iph->saddr;
1983
1984         gnode = amt_lookup_group(tunnel, &group, &host, false);
1985         if (gnode)
1986                 amt_del_group(amt, gnode);
1987 }
1988
1989 static void amt_igmpv3_report_handler(struct amt_dev *amt, struct sk_buff *skb,
1990                                       struct amt_tunnel_list *tunnel)
1991 {
1992         struct igmpv3_report *ihrv3 = igmpv3_report_hdr(skb);
1993         int len = skb_transport_offset(skb) + sizeof(*ihrv3);
1994         void *zero_grec = (void *)&igmpv3_zero_grec;
1995         struct iphdr *iph = ip_hdr(skb);
1996         struct amt_group_node *gnode;
1997         union amt_addr group, host;
1998         struct igmpv3_grec *grec;
1999         u16 nsrcs;
2000         int i;
2001
2002         for (i = 0; i < ntohs(ihrv3->ngrec); i++) {
2003                 len += sizeof(*grec);
2004                 if (!ip_mc_may_pull(skb, len))
2005                         break;
2006
2007                 grec = (void *)(skb->data + len - sizeof(*grec));
2008                 nsrcs = ntohs(grec->grec_nsrcs);
2009
2010                 len += nsrcs * sizeof(__be32);
2011                 if (!ip_mc_may_pull(skb, len))
2012                         break;
2013
2014                 memset(&group, 0, sizeof(union amt_addr));
2015                 group.ip4 = grec->grec_mca;
2016                 memset(&host, 0, sizeof(union amt_addr));
2017                 host.ip4 = iph->saddr;
2018                 gnode = amt_lookup_group(tunnel, &group, &host, false);
2019                 if (!gnode) {
2020                         gnode = amt_add_group(amt, tunnel, &group, &host,
2021                                               false);
2022                         if (IS_ERR(gnode))
2023                                 continue;
2024                 }
2025
2026                 amt_add_srcs(amt, tunnel, gnode, grec, false);
2027                 switch (grec->grec_type) {
2028                 case IGMPV3_MODE_IS_INCLUDE:
2029                         amt_mcast_is_in_handler(amt, tunnel, gnode, grec,
2030                                                 zero_grec, false);
2031                         break;
2032                 case IGMPV3_MODE_IS_EXCLUDE:
2033                         amt_mcast_is_ex_handler(amt, tunnel, gnode, grec,
2034                                                 zero_grec, false);
2035                         break;
2036                 case IGMPV3_CHANGE_TO_INCLUDE:
2037                         amt_mcast_to_in_handler(amt, tunnel, gnode, grec,
2038                                                 zero_grec, false);
2039                         break;
2040                 case IGMPV3_CHANGE_TO_EXCLUDE:
2041                         amt_mcast_to_ex_handler(amt, tunnel, gnode, grec,
2042                                                 zero_grec, false);
2043                         break;
2044                 case IGMPV3_ALLOW_NEW_SOURCES:
2045                         amt_mcast_allow_handler(amt, tunnel, gnode, grec,
2046                                                 zero_grec, false);
2047                         break;
2048                 case IGMPV3_BLOCK_OLD_SOURCES:
2049                         amt_mcast_block_handler(amt, tunnel, gnode, grec,
2050                                                 zero_grec, false);
2051                         break;
2052                 default:
2053                         break;
2054                 }
2055                 amt_cleanup_srcs(amt, tunnel, gnode);
2056         }
2057 }
2058
2059 /* caller held tunnel->lock */
2060 static void amt_igmp_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2061                                     struct amt_tunnel_list *tunnel)
2062 {
2063         struct igmphdr *ih = igmp_hdr(skb);
2064
2065         switch (ih->type) {
2066         case IGMPV3_HOST_MEMBERSHIP_REPORT:
2067                 amt_igmpv3_report_handler(amt, skb, tunnel);
2068                 break;
2069         case IGMPV2_HOST_MEMBERSHIP_REPORT:
2070                 amt_igmpv2_report_handler(amt, skb, tunnel);
2071                 break;
2072         case IGMP_HOST_LEAVE_MESSAGE:
2073                 amt_igmpv2_leave_handler(amt, skb, tunnel);
2074                 break;
2075         default:
2076                 break;
2077         }
2078 }
2079
2080 #if IS_ENABLED(CONFIG_IPV6)
2081 /* RFC 3810
2082  * 8.3.2. In the Presence of MLDv1 Multicast Address Listeners
2083  *
2084  * When Multicast Address Compatibility Mode is MLDv2, a router acts
2085  * using the MLDv2 protocol for that multicast address.  When Multicast
2086  * Address Compatibility Mode is MLDv1, a router internally translates
2087  * the following MLDv1 messages for that multicast address to their
2088  * MLDv2 equivalents:
2089  *
2090  * MLDv1 Message                 MLDv2 Equivalent
2091  * --------------                -----------------
2092  * Report                        IS_EX( {} )
2093  * Done                          TO_IN( {} )
2094  */
2095 static void amt_mldv1_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2096                                      struct amt_tunnel_list *tunnel)
2097 {
2098         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2099         struct ipv6hdr *ip6h = ipv6_hdr(skb);
2100         struct amt_group_node *gnode;
2101         union amt_addr group, host;
2102
2103         memcpy(&group.ip6, &mld->mld_mca, sizeof(struct in6_addr));
2104         memcpy(&host.ip6, &ip6h->saddr, sizeof(struct in6_addr));
2105
2106         gnode = amt_lookup_group(tunnel, &group, &host, true);
2107         if (!gnode) {
2108                 gnode = amt_add_group(amt, tunnel, &group, &host, true);
2109                 if (!IS_ERR(gnode)) {
2110                         gnode->filter_mode = MCAST_EXCLUDE;
2111                         if (!mod_delayed_work(amt_wq, &gnode->group_timer,
2112                                               msecs_to_jiffies(amt_gmi(amt))))
2113                                 dev_hold(amt->dev);
2114                 }
2115         }
2116 }
2117
2118 /* RFC 3810
2119  * 8.3.2. In the Presence of MLDv1 Multicast Address Listeners
2120  *
2121  * When Multicast Address Compatibility Mode is MLDv2, a router acts
2122  * using the MLDv2 protocol for that multicast address.  When Multicast
2123  * Address Compatibility Mode is MLDv1, a router internally translates
2124  * the following MLDv1 messages for that multicast address to their
2125  * MLDv2 equivalents:
2126  *
2127  * MLDv1 Message                 MLDv2 Equivalent
2128  * --------------                -----------------
2129  * Report                        IS_EX( {} )
2130  * Done                          TO_IN( {} )
2131  */
2132 static void amt_mldv1_leave_handler(struct amt_dev *amt, struct sk_buff *skb,
2133                                     struct amt_tunnel_list *tunnel)
2134 {
2135         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2136         struct iphdr *iph = ip_hdr(skb);
2137         struct amt_group_node *gnode;
2138         union amt_addr group, host;
2139
2140         memcpy(&group.ip6, &mld->mld_mca, sizeof(struct in6_addr));
2141         memset(&host, 0, sizeof(union amt_addr));
2142         host.ip4 = iph->saddr;
2143
2144         gnode = amt_lookup_group(tunnel, &group, &host, true);
2145         if (gnode) {
2146                 amt_del_group(amt, gnode);
2147                 return;
2148         }
2149 }
2150
2151 static void amt_mldv2_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2152                                      struct amt_tunnel_list *tunnel)
2153 {
2154         struct mld2_report *mld2r = (struct mld2_report *)icmp6_hdr(skb);
2155         int len = skb_transport_offset(skb) + sizeof(*mld2r);
2156         void *zero_grec = (void *)&mldv2_zero_grec;
2157         struct ipv6hdr *ip6h = ipv6_hdr(skb);
2158         struct amt_group_node *gnode;
2159         union amt_addr group, host;
2160         struct mld2_grec *grec;
2161         u16 nsrcs;
2162         int i;
2163
2164         for (i = 0; i < ntohs(mld2r->mld2r_ngrec); i++) {
2165                 len += sizeof(*grec);
2166                 if (!ipv6_mc_may_pull(skb, len))
2167                         break;
2168
2169                 grec = (void *)(skb->data + len - sizeof(*grec));
2170                 nsrcs = ntohs(grec->grec_nsrcs);
2171
2172                 len += nsrcs * sizeof(struct in6_addr);
2173                 if (!ipv6_mc_may_pull(skb, len))
2174                         break;
2175
2176                 memset(&group, 0, sizeof(union amt_addr));
2177                 group.ip6 = grec->grec_mca;
2178                 memset(&host, 0, sizeof(union amt_addr));
2179                 host.ip6 = ip6h->saddr;
2180                 gnode = amt_lookup_group(tunnel, &group, &host, true);
2181                 if (!gnode) {
2182                         gnode = amt_add_group(amt, tunnel, &group, &host,
2183                                               ETH_P_IPV6);
2184                         if (IS_ERR(gnode))
2185                                 continue;
2186                 }
2187
2188                 amt_add_srcs(amt, tunnel, gnode, grec, true);
2189                 switch (grec->grec_type) {
2190                 case MLD2_MODE_IS_INCLUDE:
2191                         amt_mcast_is_in_handler(amt, tunnel, gnode, grec,
2192                                                 zero_grec, true);
2193                         break;
2194                 case MLD2_MODE_IS_EXCLUDE:
2195                         amt_mcast_is_ex_handler(amt, tunnel, gnode, grec,
2196                                                 zero_grec, true);
2197                         break;
2198                 case MLD2_CHANGE_TO_INCLUDE:
2199                         amt_mcast_to_in_handler(amt, tunnel, gnode, grec,
2200                                                 zero_grec, true);
2201                         break;
2202                 case MLD2_CHANGE_TO_EXCLUDE:
2203                         amt_mcast_to_ex_handler(amt, tunnel, gnode, grec,
2204                                                 zero_grec, true);
2205                         break;
2206                 case MLD2_ALLOW_NEW_SOURCES:
2207                         amt_mcast_allow_handler(amt, tunnel, gnode, grec,
2208                                                 zero_grec, true);
2209                         break;
2210                 case MLD2_BLOCK_OLD_SOURCES:
2211                         amt_mcast_block_handler(amt, tunnel, gnode, grec,
2212                                                 zero_grec, true);
2213                         break;
2214                 default:
2215                         break;
2216                 }
2217                 amt_cleanup_srcs(amt, tunnel, gnode);
2218         }
2219 }
2220
2221 /* caller held tunnel->lock */
2222 static void amt_mld_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2223                                    struct amt_tunnel_list *tunnel)
2224 {
2225         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2226
2227         switch (mld->mld_type) {
2228         case ICMPV6_MGM_REPORT:
2229                 amt_mldv1_report_handler(amt, skb, tunnel);
2230                 break;
2231         case ICMPV6_MLD2_REPORT:
2232                 amt_mldv2_report_handler(amt, skb, tunnel);
2233                 break;
2234         case ICMPV6_MGM_REDUCTION:
2235                 amt_mldv1_leave_handler(amt, skb, tunnel);
2236                 break;
2237         default:
2238                 break;
2239         }
2240 }
2241 #endif
2242
2243 static bool amt_advertisement_handler(struct amt_dev *amt, struct sk_buff *skb)
2244 {
2245         struct amt_header_advertisement *amta;
2246         int hdr_size;
2247
2248         hdr_size = sizeof(*amta) + sizeof(struct udphdr);
2249         if (!pskb_may_pull(skb, hdr_size))
2250                 return true;
2251
2252         amta = (struct amt_header_advertisement *)(udp_hdr(skb) + 1);
2253         if (!amta->ip4)
2254                 return true;
2255
2256         if (amta->reserved || amta->version)
2257                 return true;
2258
2259         if (ipv4_is_loopback(amta->ip4) || ipv4_is_multicast(amta->ip4) ||
2260             ipv4_is_zeronet(amta->ip4))
2261                 return true;
2262
2263         if (amt->status != AMT_STATUS_SENT_DISCOVERY ||
2264             amt->nonce != amta->nonce)
2265                 return true;
2266
2267         amt->remote_ip = amta->ip4;
2268         netdev_dbg(amt->dev, "advertised remote ip = %pI4\n", &amt->remote_ip);
2269         mod_delayed_work(amt_wq, &amt->req_wq, 0);
2270
2271         amt_update_gw_status(amt, AMT_STATUS_RECEIVED_ADVERTISEMENT, true);
2272         return false;
2273 }
2274
2275 static bool amt_multicast_data_handler(struct amt_dev *amt, struct sk_buff *skb)
2276 {
2277         struct amt_header_mcast_data *amtmd;
2278         int hdr_size, len, err;
2279         struct ethhdr *eth;
2280         struct iphdr *iph;
2281
2282         hdr_size = sizeof(*amtmd) + sizeof(struct udphdr);
2283         if (!pskb_may_pull(skb, hdr_size))
2284                 return true;
2285
2286         amtmd = (struct amt_header_mcast_data *)(udp_hdr(skb) + 1);
2287         if (amtmd->reserved || amtmd->version)
2288                 return true;
2289
2290         if (iptunnel_pull_header(skb, hdr_size, htons(ETH_P_IP), false))
2291                 return true;
2292
2293         skb_reset_network_header(skb);
2294         skb_push(skb, sizeof(*eth));
2295         skb_reset_mac_header(skb);
2296         skb_pull(skb, sizeof(*eth));
2297         eth = eth_hdr(skb);
2298
2299         if (!pskb_may_pull(skb, sizeof(*iph)))
2300                 return true;
2301         iph = ip_hdr(skb);
2302
2303         if (iph->version == 4) {
2304                 if (!ipv4_is_multicast(iph->daddr))
2305                         return true;
2306                 skb->protocol = htons(ETH_P_IP);
2307                 eth->h_proto = htons(ETH_P_IP);
2308                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2309 #if IS_ENABLED(CONFIG_IPV6)
2310         } else if (iph->version == 6) {
2311                 struct ipv6hdr *ip6h;
2312
2313                 if (!pskb_may_pull(skb, sizeof(*ip6h)))
2314                         return true;
2315
2316                 ip6h = ipv6_hdr(skb);
2317                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
2318                         return true;
2319                 skb->protocol = htons(ETH_P_IPV6);
2320                 eth->h_proto = htons(ETH_P_IPV6);
2321                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2322 #endif
2323         } else {
2324                 return true;
2325         }
2326
2327         skb->pkt_type = PACKET_MULTICAST;
2328         skb->ip_summed = CHECKSUM_NONE;
2329         len = skb->len;
2330         err = gro_cells_receive(&amt->gro_cells, skb);
2331         if (likely(err == NET_RX_SUCCESS))
2332                 dev_sw_netstats_rx_add(amt->dev, len);
2333         else
2334                 amt->dev->stats.rx_dropped++;
2335
2336         return false;
2337 }
2338
2339 static bool amt_membership_query_handler(struct amt_dev *amt,
2340                                          struct sk_buff *skb)
2341 {
2342         struct amt_header_membership_query *amtmq;
2343         struct igmpv3_query *ihv3;
2344         struct ethhdr *eth, *oeth;
2345         struct iphdr *iph;
2346         int hdr_size, len;
2347
2348         hdr_size = sizeof(*amtmq) + sizeof(struct udphdr);
2349         if (!pskb_may_pull(skb, hdr_size))
2350                 return true;
2351
2352         amtmq = (struct amt_header_membership_query *)(udp_hdr(skb) + 1);
2353         if (amtmq->reserved || amtmq->version)
2354                 return true;
2355
2356         hdr_size -= sizeof(*eth);
2357         if (iptunnel_pull_header(skb, hdr_size, htons(ETH_P_TEB), false))
2358                 return true;
2359
2360         oeth = eth_hdr(skb);
2361         skb_reset_mac_header(skb);
2362         skb_pull(skb, sizeof(*eth));
2363         skb_reset_network_header(skb);
2364         eth = eth_hdr(skb);
2365         if (!pskb_may_pull(skb, sizeof(*iph)))
2366                 return true;
2367
2368         iph = ip_hdr(skb);
2369         if (iph->version == 4) {
2370                 if (!pskb_may_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS +
2371                                    sizeof(*ihv3)))
2372                         return true;
2373
2374                 if (!ipv4_is_multicast(iph->daddr))
2375                         return true;
2376
2377                 ihv3 = skb_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
2378                 skb_reset_transport_header(skb);
2379                 skb_push(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
2380                 WRITE_ONCE(amt->ready4, true);
2381                 amt->mac = amtmq->response_mac;
2382                 amt->req_cnt = 0;
2383                 amt->qi = ihv3->qqic;
2384                 skb->protocol = htons(ETH_P_IP);
2385                 eth->h_proto = htons(ETH_P_IP);
2386                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2387 #if IS_ENABLED(CONFIG_IPV6)
2388         } else if (iph->version == 6) {
2389                 struct mld2_query *mld2q;
2390                 struct ipv6hdr *ip6h;
2391
2392                 if (!pskb_may_pull(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS +
2393                                    sizeof(*mld2q)))
2394                         return true;
2395
2396                 ip6h = ipv6_hdr(skb);
2397                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
2398                         return true;
2399
2400                 mld2q = skb_pull(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS);
2401                 skb_reset_transport_header(skb);
2402                 skb_push(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS);
2403                 WRITE_ONCE(amt->ready6, true);
2404                 amt->mac = amtmq->response_mac;
2405                 amt->req_cnt = 0;
2406                 amt->qi = mld2q->mld2q_qqic;
2407                 skb->protocol = htons(ETH_P_IPV6);
2408                 eth->h_proto = htons(ETH_P_IPV6);
2409                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2410 #endif
2411         } else {
2412                 return true;
2413         }
2414
2415         ether_addr_copy(eth->h_source, oeth->h_source);
2416         skb->pkt_type = PACKET_MULTICAST;
2417         skb->ip_summed = CHECKSUM_NONE;
2418         len = skb->len;
2419         local_bh_disable();
2420         if (__netif_rx(skb) == NET_RX_SUCCESS) {
2421                 amt_update_gw_status(amt, AMT_STATUS_RECEIVED_QUERY, true);
2422                 dev_sw_netstats_rx_add(amt->dev, len);
2423         } else {
2424                 amt->dev->stats.rx_dropped++;
2425         }
2426         local_bh_enable();
2427
2428         return false;
2429 }
2430
2431 static bool amt_update_handler(struct amt_dev *amt, struct sk_buff *skb)
2432 {
2433         struct amt_header_membership_update *amtmu;
2434         struct amt_tunnel_list *tunnel;
2435         struct ethhdr *eth;
2436         struct iphdr *iph;
2437         int len, hdr_size;
2438
2439         iph = ip_hdr(skb);
2440
2441         hdr_size = sizeof(*amtmu) + sizeof(struct udphdr);
2442         if (!pskb_may_pull(skb, hdr_size))
2443                 return true;
2444
2445         amtmu = (struct amt_header_membership_update *)(udp_hdr(skb) + 1);
2446         if (amtmu->reserved || amtmu->version)
2447                 return true;
2448
2449         if (iptunnel_pull_header(skb, hdr_size, skb->protocol, false))
2450                 return true;
2451
2452         skb_reset_network_header(skb);
2453
2454         list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list) {
2455                 if (tunnel->ip4 == iph->saddr) {
2456                         if ((amtmu->nonce == tunnel->nonce &&
2457                              amtmu->response_mac == tunnel->mac)) {
2458                                 mod_delayed_work(amt_wq, &tunnel->gc_wq,
2459                                                  msecs_to_jiffies(amt_gmi(amt))
2460                                                                   * 3);
2461                                 goto report;
2462                         } else {
2463                                 netdev_dbg(amt->dev, "Invalid MAC\n");
2464                                 return true;
2465                         }
2466                 }
2467         }
2468
2469         return true;
2470
2471 report:
2472         if (!pskb_may_pull(skb, sizeof(*iph)))
2473                 return true;
2474
2475         iph = ip_hdr(skb);
2476         if (iph->version == 4) {
2477                 if (ip_mc_check_igmp(skb)) {
2478                         netdev_dbg(amt->dev, "Invalid IGMP\n");
2479                         return true;
2480                 }
2481
2482                 spin_lock_bh(&tunnel->lock);
2483                 amt_igmp_report_handler(amt, skb, tunnel);
2484                 spin_unlock_bh(&tunnel->lock);
2485
2486                 skb_push(skb, sizeof(struct ethhdr));
2487                 skb_reset_mac_header(skb);
2488                 eth = eth_hdr(skb);
2489                 skb->protocol = htons(ETH_P_IP);
2490                 eth->h_proto = htons(ETH_P_IP);
2491                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2492 #if IS_ENABLED(CONFIG_IPV6)
2493         } else if (iph->version == 6) {
2494                 struct ipv6hdr *ip6h = ipv6_hdr(skb);
2495
2496                 if (ipv6_mc_check_mld(skb)) {
2497                         netdev_dbg(amt->dev, "Invalid MLD\n");
2498                         return true;
2499                 }
2500
2501                 spin_lock_bh(&tunnel->lock);
2502                 amt_mld_report_handler(amt, skb, tunnel);
2503                 spin_unlock_bh(&tunnel->lock);
2504
2505                 skb_push(skb, sizeof(struct ethhdr));
2506                 skb_reset_mac_header(skb);
2507                 eth = eth_hdr(skb);
2508                 skb->protocol = htons(ETH_P_IPV6);
2509                 eth->h_proto = htons(ETH_P_IPV6);
2510                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2511 #endif
2512         } else {
2513                 netdev_dbg(amt->dev, "Unsupported Protocol\n");
2514                 return true;
2515         }
2516
2517         skb_pull(skb, sizeof(struct ethhdr));
2518         skb->pkt_type = PACKET_MULTICAST;
2519         skb->ip_summed = CHECKSUM_NONE;
2520         len = skb->len;
2521         if (__netif_rx(skb) == NET_RX_SUCCESS) {
2522                 amt_update_relay_status(tunnel, AMT_STATUS_RECEIVED_UPDATE,
2523                                         true);
2524                 dev_sw_netstats_rx_add(amt->dev, len);
2525         } else {
2526                 amt->dev->stats.rx_dropped++;
2527         }
2528
2529         return false;
2530 }
2531
2532 static void amt_send_advertisement(struct amt_dev *amt, __be32 nonce,
2533                                    __be32 daddr, __be16 dport)
2534 {
2535         struct amt_header_advertisement *amta;
2536         int hlen, tlen, offset;
2537         struct socket *sock;
2538         struct udphdr *udph;
2539         struct sk_buff *skb;
2540         struct iphdr *iph;
2541         struct rtable *rt;
2542         struct flowi4 fl4;
2543         u32 len;
2544         int err;
2545
2546         rcu_read_lock();
2547         sock = rcu_dereference(amt->sock);
2548         if (!sock)
2549                 goto out;
2550
2551         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
2552                 goto out;
2553
2554         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
2555                                    daddr, amt->local_ip,
2556                                    dport, amt->relay_port,
2557                                    IPPROTO_UDP, 0,
2558                                    amt->stream_dev->ifindex);
2559         if (IS_ERR(rt)) {
2560                 amt->dev->stats.tx_errors++;
2561                 goto out;
2562         }
2563
2564         hlen = LL_RESERVED_SPACE(amt->dev);
2565         tlen = amt->dev->needed_tailroom;
2566         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amta);
2567         skb = netdev_alloc_skb_ip_align(amt->dev, len);
2568         if (!skb) {
2569                 ip_rt_put(rt);
2570                 amt->dev->stats.tx_errors++;
2571                 goto out;
2572         }
2573
2574         skb->priority = TC_PRIO_CONTROL;
2575         skb_dst_set(skb, &rt->dst);
2576
2577         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amta);
2578         skb_reset_network_header(skb);
2579         skb_put(skb, len);
2580         amta = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
2581         amta->version   = 0;
2582         amta->type      = AMT_MSG_ADVERTISEMENT;
2583         amta->reserved  = 0;
2584         amta->nonce     = nonce;
2585         amta->ip4       = amt->local_ip;
2586         skb_push(skb, sizeof(*udph));
2587         skb_reset_transport_header(skb);
2588         udph            = udp_hdr(skb);
2589         udph->source    = amt->relay_port;
2590         udph->dest      = dport;
2591         udph->len       = htons(sizeof(*amta) + sizeof(*udph));
2592         udph->check     = 0;
2593         offset = skb_transport_offset(skb);
2594         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
2595         udph->check = csum_tcpudp_magic(amt->local_ip, daddr,
2596                                         sizeof(*udph) + sizeof(*amta),
2597                                         IPPROTO_UDP, skb->csum);
2598
2599         skb_push(skb, sizeof(*iph));
2600         iph             = ip_hdr(skb);
2601         iph->version    = 4;
2602         iph->ihl        = (sizeof(struct iphdr)) >> 2;
2603         iph->tos        = AMT_TOS;
2604         iph->frag_off   = 0;
2605         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
2606         iph->daddr      = daddr;
2607         iph->saddr      = amt->local_ip;
2608         iph->protocol   = IPPROTO_UDP;
2609         iph->tot_len    = htons(len);
2610
2611         skb->ip_summed = CHECKSUM_NONE;
2612         ip_select_ident(amt->net, skb, NULL);
2613         ip_send_check(iph);
2614         err = ip_local_out(amt->net, sock->sk, skb);
2615         if (unlikely(net_xmit_eval(err)))
2616                 amt->dev->stats.tx_errors++;
2617
2618 out:
2619         rcu_read_unlock();
2620 }
2621
2622 static bool amt_discovery_handler(struct amt_dev *amt, struct sk_buff *skb)
2623 {
2624         struct amt_header_discovery *amtd;
2625         struct udphdr *udph;
2626         struct iphdr *iph;
2627
2628         if (!pskb_may_pull(skb, sizeof(*udph) + sizeof(*amtd)))
2629                 return true;
2630
2631         iph = ip_hdr(skb);
2632         udph = udp_hdr(skb);
2633         amtd = (struct amt_header_discovery *)(udp_hdr(skb) + 1);
2634
2635         if (amtd->reserved || amtd->version)
2636                 return true;
2637
2638         amt_send_advertisement(amt, amtd->nonce, iph->saddr, udph->source);
2639
2640         return false;
2641 }
2642
2643 static bool amt_request_handler(struct amt_dev *amt, struct sk_buff *skb)
2644 {
2645         struct amt_header_request *amtrh;
2646         struct amt_tunnel_list *tunnel;
2647         unsigned long long key;
2648         struct udphdr *udph;
2649         struct iphdr *iph;
2650         u64 mac;
2651         int i;
2652
2653         if (!pskb_may_pull(skb, sizeof(*udph) + sizeof(*amtrh)))
2654                 return true;
2655
2656         iph = ip_hdr(skb);
2657         udph = udp_hdr(skb);
2658         amtrh = (struct amt_header_request *)(udp_hdr(skb) + 1);
2659
2660         if (amtrh->reserved1 || amtrh->reserved2 || amtrh->version)
2661                 return true;
2662
2663         list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list)
2664                 if (tunnel->ip4 == iph->saddr)
2665                         goto send;
2666
2667         if (amt->nr_tunnels >= amt->max_tunnels) {
2668                 icmp_ndo_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_UNREACH, 0);
2669                 return true;
2670         }
2671
2672         tunnel = kzalloc(sizeof(*tunnel) +
2673                          (sizeof(struct hlist_head) * amt->hash_buckets),
2674                          GFP_ATOMIC);
2675         if (!tunnel)
2676                 return true;
2677
2678         tunnel->source_port = udph->source;
2679         tunnel->ip4 = iph->saddr;
2680
2681         memcpy(&key, &tunnel->key, sizeof(unsigned long long));
2682         tunnel->amt = amt;
2683         spin_lock_init(&tunnel->lock);
2684         for (i = 0; i < amt->hash_buckets; i++)
2685                 INIT_HLIST_HEAD(&tunnel->groups[i]);
2686
2687         INIT_DELAYED_WORK(&tunnel->gc_wq, amt_tunnel_expire);
2688
2689         spin_lock_bh(&amt->lock);
2690         list_add_tail_rcu(&tunnel->list, &amt->tunnel_list);
2691         tunnel->key = amt->key;
2692         amt_update_relay_status(tunnel, AMT_STATUS_RECEIVED_REQUEST, true);
2693         amt->nr_tunnels++;
2694         mod_delayed_work(amt_wq, &tunnel->gc_wq,
2695                          msecs_to_jiffies(amt_gmi(amt)));
2696         spin_unlock_bh(&amt->lock);
2697
2698 send:
2699         tunnel->nonce = amtrh->nonce;
2700         mac = siphash_3u32((__force u32)tunnel->ip4,
2701                            (__force u32)tunnel->source_port,
2702                            (__force u32)tunnel->nonce,
2703                            &tunnel->key);
2704         tunnel->mac = mac >> 16;
2705
2706         if (!netif_running(amt->dev) || !netif_running(amt->stream_dev))
2707                 return true;
2708
2709         if (!amtrh->p)
2710                 amt_send_igmp_gq(amt, tunnel);
2711         else
2712                 amt_send_mld_gq(amt, tunnel);
2713
2714         return false;
2715 }
2716
2717 static void amt_gw_rcv(struct amt_dev *amt, struct sk_buff *skb)
2718 {
2719         int type = amt_parse_type(skb);
2720         int err = 1;
2721
2722         if (type == -1)
2723                 goto drop;
2724
2725         if (amt->mode == AMT_MODE_GATEWAY) {
2726                 switch (type) {
2727                 case AMT_MSG_ADVERTISEMENT:
2728                         err = amt_advertisement_handler(amt, skb);
2729                         break;
2730                 case AMT_MSG_MEMBERSHIP_QUERY:
2731                         err = amt_membership_query_handler(amt, skb);
2732                         if (!err)
2733                                 return;
2734                         break;
2735                 default:
2736                         netdev_dbg(amt->dev, "Invalid type of Gateway\n");
2737                         break;
2738                 }
2739         }
2740 drop:
2741         if (err) {
2742                 amt->dev->stats.rx_dropped++;
2743                 kfree_skb(skb);
2744         } else {
2745                 consume_skb(skb);
2746         }
2747 }
2748
2749 static int amt_rcv(struct sock *sk, struct sk_buff *skb)
2750 {
2751         struct amt_dev *amt;
2752         struct iphdr *iph;
2753         int type;
2754         bool err;
2755
2756         rcu_read_lock_bh();
2757         amt = rcu_dereference_sk_user_data(sk);
2758         if (!amt) {
2759                 err = true;
2760                 kfree_skb(skb);
2761                 goto out;
2762         }
2763
2764         skb->dev = amt->dev;
2765         iph = ip_hdr(skb);
2766         type = amt_parse_type(skb);
2767         if (type == -1) {
2768                 err = true;
2769                 goto drop;
2770         }
2771
2772         if (amt->mode == AMT_MODE_GATEWAY) {
2773                 switch (type) {
2774                 case AMT_MSG_ADVERTISEMENT:
2775                         if (iph->saddr != amt->discovery_ip) {
2776                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2777                                 err = true;
2778                                 goto drop;
2779                         }
2780                         if (amt_queue_event(amt, AMT_EVENT_RECEIVE, skb)) {
2781                                 netdev_dbg(amt->dev, "AMT Event queue full\n");
2782                                 err = true;
2783                                 goto drop;
2784                         }
2785                         goto out;
2786                 case AMT_MSG_MULTICAST_DATA:
2787                         if (iph->saddr != amt->remote_ip) {
2788                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2789                                 err = true;
2790                                 goto drop;
2791                         }
2792                         err = amt_multicast_data_handler(amt, skb);
2793                         if (err)
2794                                 goto drop;
2795                         else
2796                                 goto out;
2797                 case AMT_MSG_MEMBERSHIP_QUERY:
2798                         if (iph->saddr != amt->remote_ip) {
2799                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2800                                 err = true;
2801                                 goto drop;
2802                         }
2803                         if (amt_queue_event(amt, AMT_EVENT_RECEIVE, skb)) {
2804                                 netdev_dbg(amt->dev, "AMT Event queue full\n");
2805                                 err = true;
2806                                 goto drop;
2807                         }
2808                         goto out;
2809                 default:
2810                         err = true;
2811                         netdev_dbg(amt->dev, "Invalid type of Gateway\n");
2812                         break;
2813                 }
2814         } else {
2815                 switch (type) {
2816                 case AMT_MSG_DISCOVERY:
2817                         err = amt_discovery_handler(amt, skb);
2818                         break;
2819                 case AMT_MSG_REQUEST:
2820                         err = amt_request_handler(amt, skb);
2821                         break;
2822                 case AMT_MSG_MEMBERSHIP_UPDATE:
2823                         err = amt_update_handler(amt, skb);
2824                         if (err)
2825                                 goto drop;
2826                         else
2827                                 goto out;
2828                 default:
2829                         err = true;
2830                         netdev_dbg(amt->dev, "Invalid type of relay\n");
2831                         break;
2832                 }
2833         }
2834 drop:
2835         if (err) {
2836                 amt->dev->stats.rx_dropped++;
2837                 kfree_skb(skb);
2838         } else {
2839                 consume_skb(skb);
2840         }
2841 out:
2842         rcu_read_unlock_bh();
2843         return 0;
2844 }
2845
2846 static void amt_event_work(struct work_struct *work)
2847 {
2848         struct amt_dev *amt = container_of(work, struct amt_dev, event_wq);
2849         struct sk_buff *skb;
2850         u8 event;
2851         int i;
2852
2853         for (i = 0; i < AMT_MAX_EVENTS; i++) {
2854                 spin_lock_bh(&amt->lock);
2855                 if (amt->nr_events == 0) {
2856                         spin_unlock_bh(&amt->lock);
2857                         return;
2858                 }
2859                 event = amt->events[amt->event_idx].event;
2860                 skb = amt->events[amt->event_idx].skb;
2861                 amt->events[amt->event_idx].event = AMT_EVENT_NONE;
2862                 amt->events[amt->event_idx].skb = NULL;
2863                 amt->nr_events--;
2864                 amt->event_idx++;
2865                 amt->event_idx %= AMT_MAX_EVENTS;
2866                 spin_unlock_bh(&amt->lock);
2867
2868                 switch (event) {
2869                 case AMT_EVENT_RECEIVE:
2870                         amt_gw_rcv(amt, skb);
2871                         break;
2872                 case AMT_EVENT_SEND_DISCOVERY:
2873                         amt_event_send_discovery(amt);
2874                         break;
2875                 case AMT_EVENT_SEND_REQUEST:
2876                         amt_event_send_request(amt);
2877                         break;
2878                 default:
2879                         if (skb)
2880                                 kfree_skb(skb);
2881                         break;
2882                 }
2883         }
2884 }
2885
2886 static int amt_err_lookup(struct sock *sk, struct sk_buff *skb)
2887 {
2888         struct amt_dev *amt;
2889         int type;
2890
2891         rcu_read_lock_bh();
2892         amt = rcu_dereference_sk_user_data(sk);
2893         if (!amt)
2894                 goto out;
2895
2896         if (amt->mode != AMT_MODE_GATEWAY)
2897                 goto drop;
2898
2899         type = amt_parse_type(skb);
2900         if (type == -1)
2901                 goto drop;
2902
2903         netdev_dbg(amt->dev, "Received IGMP Unreachable of %s\n",
2904                    type_str[type]);
2905         switch (type) {
2906         case AMT_MSG_DISCOVERY:
2907                 break;
2908         case AMT_MSG_REQUEST:
2909         case AMT_MSG_MEMBERSHIP_UPDATE:
2910                 if (READ_ONCE(amt->status) >= AMT_STATUS_RECEIVED_ADVERTISEMENT)
2911                         mod_delayed_work(amt_wq, &amt->req_wq, 0);
2912                 break;
2913         default:
2914                 goto drop;
2915         }
2916 out:
2917         rcu_read_unlock_bh();
2918         return 0;
2919 drop:
2920         rcu_read_unlock_bh();
2921         amt->dev->stats.rx_dropped++;
2922         return 0;
2923 }
2924
2925 static struct socket *amt_create_sock(struct net *net, __be16 port)
2926 {
2927         struct udp_port_cfg udp_conf;
2928         struct socket *sock;
2929         int err;
2930
2931         memset(&udp_conf, 0, sizeof(udp_conf));
2932         udp_conf.family = AF_INET;
2933         udp_conf.local_ip.s_addr = htonl(INADDR_ANY);
2934
2935         udp_conf.local_udp_port = port;
2936
2937         err = udp_sock_create(net, &udp_conf, &sock);
2938         if (err < 0)
2939                 return ERR_PTR(err);
2940
2941         return sock;
2942 }
2943
2944 static int amt_socket_create(struct amt_dev *amt)
2945 {
2946         struct udp_tunnel_sock_cfg tunnel_cfg;
2947         struct socket *sock;
2948
2949         sock = amt_create_sock(amt->net, amt->relay_port);
2950         if (IS_ERR(sock))
2951                 return PTR_ERR(sock);
2952
2953         /* Mark socket as an encapsulation socket */
2954         memset(&tunnel_cfg, 0, sizeof(tunnel_cfg));
2955         tunnel_cfg.sk_user_data = amt;
2956         tunnel_cfg.encap_type = 1;
2957         tunnel_cfg.encap_rcv = amt_rcv;
2958         tunnel_cfg.encap_err_lookup = amt_err_lookup;
2959         tunnel_cfg.encap_destroy = NULL;
2960         setup_udp_tunnel_sock(amt->net, sock, &tunnel_cfg);
2961
2962         rcu_assign_pointer(amt->sock, sock);
2963         return 0;
2964 }
2965
2966 static int amt_dev_open(struct net_device *dev)
2967 {
2968         struct amt_dev *amt = netdev_priv(dev);
2969         int err;
2970
2971         amt->ready4 = false;
2972         amt->ready6 = false;
2973         amt->event_idx = 0;
2974         amt->nr_events = 0;
2975
2976         err = amt_socket_create(amt);
2977         if (err)
2978                 return err;
2979
2980         amt->req_cnt = 0;
2981         amt->remote_ip = 0;
2982         amt->nonce = 0;
2983         get_random_bytes(&amt->key, sizeof(siphash_key_t));
2984
2985         amt->status = AMT_STATUS_INIT;
2986         if (amt->mode == AMT_MODE_GATEWAY) {
2987                 mod_delayed_work(amt_wq, &amt->discovery_wq, 0);
2988                 mod_delayed_work(amt_wq, &amt->req_wq, 0);
2989         } else if (amt->mode == AMT_MODE_RELAY) {
2990                 mod_delayed_work(amt_wq, &amt->secret_wq,
2991                                  msecs_to_jiffies(AMT_SECRET_TIMEOUT));
2992         }
2993         return err;
2994 }
2995
2996 static int amt_dev_stop(struct net_device *dev)
2997 {
2998         struct amt_dev *amt = netdev_priv(dev);
2999         struct amt_tunnel_list *tunnel, *tmp;
3000         struct socket *sock;
3001         struct sk_buff *skb;
3002         int i;
3003
3004         cancel_delayed_work_sync(&amt->req_wq);
3005         cancel_delayed_work_sync(&amt->discovery_wq);
3006         cancel_delayed_work_sync(&amt->secret_wq);
3007
3008         /* shutdown */
3009         sock = rtnl_dereference(amt->sock);
3010         RCU_INIT_POINTER(amt->sock, NULL);
3011         synchronize_net();
3012         if (sock)
3013                 udp_tunnel_sock_release(sock);
3014
3015         cancel_work_sync(&amt->event_wq);
3016         for (i = 0; i < AMT_MAX_EVENTS; i++) {
3017                 skb = amt->events[i].skb;
3018                 if (skb)
3019                         kfree_skb(skb);
3020                 amt->events[i].event = AMT_EVENT_NONE;
3021                 amt->events[i].skb = NULL;
3022         }
3023
3024         amt->ready4 = false;
3025         amt->ready6 = false;
3026         amt->req_cnt = 0;
3027         amt->remote_ip = 0;
3028
3029         list_for_each_entry_safe(tunnel, tmp, &amt->tunnel_list, list) {
3030                 list_del_rcu(&tunnel->list);
3031                 amt->nr_tunnels--;
3032                 cancel_delayed_work_sync(&tunnel->gc_wq);
3033                 amt_clear_groups(tunnel);
3034                 kfree_rcu(tunnel, rcu);
3035         }
3036
3037         return 0;
3038 }
3039
3040 static const struct device_type amt_type = {
3041         .name = "amt",
3042 };
3043
3044 static int amt_dev_init(struct net_device *dev)
3045 {
3046         struct amt_dev *amt = netdev_priv(dev);
3047         int err;
3048
3049         amt->dev = dev;
3050         dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats);
3051         if (!dev->tstats)
3052                 return -ENOMEM;
3053
3054         err = gro_cells_init(&amt->gro_cells, dev);
3055         if (err) {
3056                 free_percpu(dev->tstats);
3057                 return err;
3058         }
3059
3060         return 0;
3061 }
3062
3063 static void amt_dev_uninit(struct net_device *dev)
3064 {
3065         struct amt_dev *amt = netdev_priv(dev);
3066
3067         gro_cells_destroy(&amt->gro_cells);
3068         free_percpu(dev->tstats);
3069 }
3070
3071 static const struct net_device_ops amt_netdev_ops = {
3072         .ndo_init               = amt_dev_init,
3073         .ndo_uninit             = amt_dev_uninit,
3074         .ndo_open               = amt_dev_open,
3075         .ndo_stop               = amt_dev_stop,
3076         .ndo_start_xmit         = amt_dev_xmit,
3077         .ndo_get_stats64        = dev_get_tstats64,
3078 };
3079
3080 static void amt_link_setup(struct net_device *dev)
3081 {
3082         dev->netdev_ops         = &amt_netdev_ops;
3083         dev->needs_free_netdev  = true;
3084         SET_NETDEV_DEVTYPE(dev, &amt_type);
3085         dev->min_mtu            = ETH_MIN_MTU;
3086         dev->max_mtu            = ETH_MAX_MTU;
3087         dev->type               = ARPHRD_NONE;
3088         dev->flags              = IFF_POINTOPOINT | IFF_NOARP | IFF_MULTICAST;
3089         dev->hard_header_len    = 0;
3090         dev->addr_len           = 0;
3091         dev->priv_flags         |= IFF_NO_QUEUE;
3092         dev->features           |= NETIF_F_LLTX;
3093         dev->features           |= NETIF_F_GSO_SOFTWARE;
3094         dev->features           |= NETIF_F_NETNS_LOCAL;
3095         dev->hw_features        |= NETIF_F_SG | NETIF_F_HW_CSUM;
3096         dev->hw_features        |= NETIF_F_FRAGLIST | NETIF_F_RXCSUM;
3097         dev->hw_features        |= NETIF_F_GSO_SOFTWARE;
3098         eth_hw_addr_random(dev);
3099         eth_zero_addr(dev->broadcast);
3100         ether_setup(dev);
3101 }
3102
3103 static const struct nla_policy amt_policy[IFLA_AMT_MAX + 1] = {
3104         [IFLA_AMT_MODE]         = { .type = NLA_U32 },
3105         [IFLA_AMT_RELAY_PORT]   = { .type = NLA_U16 },
3106         [IFLA_AMT_GATEWAY_PORT] = { .type = NLA_U16 },
3107         [IFLA_AMT_LINK]         = { .type = NLA_U32 },
3108         [IFLA_AMT_LOCAL_IP]     = { .len = sizeof_field(struct iphdr, daddr) },
3109         [IFLA_AMT_REMOTE_IP]    = { .len = sizeof_field(struct iphdr, daddr) },
3110         [IFLA_AMT_DISCOVERY_IP] = { .len = sizeof_field(struct iphdr, daddr) },
3111         [IFLA_AMT_MAX_TUNNELS]  = { .type = NLA_U32 },
3112 };
3113
3114 static int amt_validate(struct nlattr *tb[], struct nlattr *data[],
3115                         struct netlink_ext_ack *extack)
3116 {
3117         if (!data)
3118                 return -EINVAL;
3119
3120         if (!data[IFLA_AMT_LINK]) {
3121                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_LINK],
3122                                     "Link attribute is required");
3123                 return -EINVAL;
3124         }
3125
3126         if (!data[IFLA_AMT_MODE]) {
3127                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_MODE],
3128                                     "Mode attribute is required");
3129                 return -EINVAL;
3130         }
3131
3132         if (nla_get_u32(data[IFLA_AMT_MODE]) > AMT_MODE_MAX) {
3133                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_MODE],
3134                                     "Mode attribute is not valid");
3135                 return -EINVAL;
3136         }
3137
3138         if (!data[IFLA_AMT_LOCAL_IP]) {
3139                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_DISCOVERY_IP],
3140                                     "Local attribute is required");
3141                 return -EINVAL;
3142         }
3143
3144         if (!data[IFLA_AMT_DISCOVERY_IP] &&
3145             nla_get_u32(data[IFLA_AMT_MODE]) == AMT_MODE_GATEWAY) {
3146                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_LOCAL_IP],
3147                                     "Discovery attribute is required");
3148                 return -EINVAL;
3149         }
3150
3151         return 0;
3152 }
3153
3154 static int amt_newlink(struct net *net, struct net_device *dev,
3155                        struct nlattr *tb[], struct nlattr *data[],
3156                        struct netlink_ext_ack *extack)
3157 {
3158         struct amt_dev *amt = netdev_priv(dev);
3159         int err = -EINVAL;
3160
3161         amt->net = net;
3162         amt->mode = nla_get_u32(data[IFLA_AMT_MODE]);
3163
3164         if (data[IFLA_AMT_MAX_TUNNELS] &&
3165             nla_get_u32(data[IFLA_AMT_MAX_TUNNELS]))
3166                 amt->max_tunnels = nla_get_u32(data[IFLA_AMT_MAX_TUNNELS]);
3167         else
3168                 amt->max_tunnels = AMT_MAX_TUNNELS;
3169
3170         spin_lock_init(&amt->lock);
3171         amt->max_groups = AMT_MAX_GROUP;
3172         amt->max_sources = AMT_MAX_SOURCE;
3173         amt->hash_buckets = AMT_HSIZE;
3174         amt->nr_tunnels = 0;
3175         get_random_bytes(&amt->hash_seed, sizeof(amt->hash_seed));
3176         amt->stream_dev = dev_get_by_index(net,
3177                                            nla_get_u32(data[IFLA_AMT_LINK]));
3178         if (!amt->stream_dev) {
3179                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LINK],
3180                                     "Can't find stream device");
3181                 return -ENODEV;
3182         }
3183
3184         if (amt->stream_dev->type != ARPHRD_ETHER) {
3185                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LINK],
3186                                     "Invalid stream device type");
3187                 goto err;
3188         }
3189
3190         amt->local_ip = nla_get_in_addr(data[IFLA_AMT_LOCAL_IP]);
3191         if (ipv4_is_loopback(amt->local_ip) ||
3192             ipv4_is_zeronet(amt->local_ip) ||
3193             ipv4_is_multicast(amt->local_ip)) {
3194                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LOCAL_IP],
3195                                     "Invalid Local address");
3196                 goto err;
3197         }
3198
3199         if (data[IFLA_AMT_RELAY_PORT])
3200                 amt->relay_port = nla_get_be16(data[IFLA_AMT_RELAY_PORT]);
3201         else
3202                 amt->relay_port = htons(IANA_AMT_UDP_PORT);
3203
3204         if (data[IFLA_AMT_GATEWAY_PORT])
3205                 amt->gw_port = nla_get_be16(data[IFLA_AMT_GATEWAY_PORT]);
3206         else
3207                 amt->gw_port = htons(IANA_AMT_UDP_PORT);
3208
3209         if (!amt->relay_port) {
3210                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3211                                     "relay port must not be 0");
3212                 goto err;
3213         }
3214         if (amt->mode == AMT_MODE_RELAY) {
3215                 amt->qrv = READ_ONCE(amt->net->ipv4.sysctl_igmp_qrv);
3216                 amt->qri = 10;
3217                 dev->needed_headroom = amt->stream_dev->needed_headroom +
3218                                        AMT_RELAY_HLEN;
3219                 dev->mtu = amt->stream_dev->mtu - AMT_RELAY_HLEN;
3220                 dev->max_mtu = dev->mtu;
3221                 dev->min_mtu = ETH_MIN_MTU + AMT_RELAY_HLEN;
3222         } else {
3223                 if (!data[IFLA_AMT_DISCOVERY_IP]) {
3224                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3225                                             "discovery must be set in gateway mode");
3226                         goto err;
3227                 }
3228                 if (!amt->gw_port) {
3229                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3230                                             "gateway port must not be 0");
3231                         goto err;
3232                 }
3233                 amt->remote_ip = 0;
3234                 amt->discovery_ip = nla_get_in_addr(data[IFLA_AMT_DISCOVERY_IP]);
3235                 if (ipv4_is_loopback(amt->discovery_ip) ||
3236                     ipv4_is_zeronet(amt->discovery_ip) ||
3237                     ipv4_is_multicast(amt->discovery_ip)) {
3238                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3239                                             "discovery must be unicast");
3240                         goto err;
3241                 }
3242
3243                 dev->needed_headroom = amt->stream_dev->needed_headroom +
3244                                        AMT_GW_HLEN;
3245                 dev->mtu = amt->stream_dev->mtu - AMT_GW_HLEN;
3246                 dev->max_mtu = dev->mtu;
3247                 dev->min_mtu = ETH_MIN_MTU + AMT_GW_HLEN;
3248         }
3249         amt->qi = AMT_INIT_QUERY_INTERVAL;
3250
3251         err = register_netdevice(dev);
3252         if (err < 0) {
3253                 netdev_dbg(dev, "failed to register new netdev %d\n", err);
3254                 goto err;
3255         }
3256
3257         err = netdev_upper_dev_link(amt->stream_dev, dev, extack);
3258         if (err < 0) {
3259                 unregister_netdevice(dev);
3260                 goto err;
3261         }
3262
3263         INIT_DELAYED_WORK(&amt->discovery_wq, amt_discovery_work);
3264         INIT_DELAYED_WORK(&amt->req_wq, amt_req_work);
3265         INIT_DELAYED_WORK(&amt->secret_wq, amt_secret_work);
3266         INIT_WORK(&amt->event_wq, amt_event_work);
3267         INIT_LIST_HEAD(&amt->tunnel_list);
3268         return 0;
3269 err:
3270         dev_put(amt->stream_dev);
3271         return err;
3272 }
3273
3274 static void amt_dellink(struct net_device *dev, struct list_head *head)
3275 {
3276         struct amt_dev *amt = netdev_priv(dev);
3277
3278         unregister_netdevice_queue(dev, head);
3279         netdev_upper_dev_unlink(amt->stream_dev, dev);
3280         dev_put(amt->stream_dev);
3281 }
3282
3283 static size_t amt_get_size(const struct net_device *dev)
3284 {
3285         return nla_total_size(sizeof(__u32)) + /* IFLA_AMT_MODE */
3286                nla_total_size(sizeof(__u16)) + /* IFLA_AMT_RELAY_PORT */
3287                nla_total_size(sizeof(__u16)) + /* IFLA_AMT_GATEWAY_PORT */
3288                nla_total_size(sizeof(__u32)) + /* IFLA_AMT_LINK */
3289                nla_total_size(sizeof(__u32)) + /* IFLA_MAX_TUNNELS */
3290                nla_total_size(sizeof(struct iphdr)) + /* IFLA_AMT_DISCOVERY_IP */
3291                nla_total_size(sizeof(struct iphdr)) + /* IFLA_AMT_REMOTE_IP */
3292                nla_total_size(sizeof(struct iphdr)); /* IFLA_AMT_LOCAL_IP */
3293 }
3294
3295 static int amt_fill_info(struct sk_buff *skb, const struct net_device *dev)
3296 {
3297         struct amt_dev *amt = netdev_priv(dev);
3298
3299         if (nla_put_u32(skb, IFLA_AMT_MODE, amt->mode))
3300                 goto nla_put_failure;
3301         if (nla_put_be16(skb, IFLA_AMT_RELAY_PORT, amt->relay_port))
3302                 goto nla_put_failure;
3303         if (nla_put_be16(skb, IFLA_AMT_GATEWAY_PORT, amt->gw_port))
3304                 goto nla_put_failure;
3305         if (nla_put_u32(skb, IFLA_AMT_LINK, amt->stream_dev->ifindex))
3306                 goto nla_put_failure;
3307         if (nla_put_in_addr(skb, IFLA_AMT_LOCAL_IP, amt->local_ip))
3308                 goto nla_put_failure;
3309         if (nla_put_in_addr(skb, IFLA_AMT_DISCOVERY_IP, amt->discovery_ip))
3310                 goto nla_put_failure;
3311         if (amt->remote_ip)
3312                 if (nla_put_in_addr(skb, IFLA_AMT_REMOTE_IP, amt->remote_ip))
3313                         goto nla_put_failure;
3314         if (nla_put_u32(skb, IFLA_AMT_MAX_TUNNELS, amt->max_tunnels))
3315                 goto nla_put_failure;
3316
3317         return 0;
3318
3319 nla_put_failure:
3320         return -EMSGSIZE;
3321 }
3322
3323 static struct rtnl_link_ops amt_link_ops __read_mostly = {
3324         .kind           = "amt",
3325         .maxtype        = IFLA_AMT_MAX,
3326         .policy         = amt_policy,
3327         .priv_size      = sizeof(struct amt_dev),
3328         .setup          = amt_link_setup,
3329         .validate       = amt_validate,
3330         .newlink        = amt_newlink,
3331         .dellink        = amt_dellink,
3332         .get_size       = amt_get_size,
3333         .fill_info      = amt_fill_info,
3334 };
3335
3336 static struct net_device *amt_lookup_upper_dev(struct net_device *dev)
3337 {
3338         struct net_device *upper_dev;
3339         struct amt_dev *amt;
3340
3341         for_each_netdev(dev_net(dev), upper_dev) {
3342                 if (netif_is_amt(upper_dev)) {
3343                         amt = netdev_priv(upper_dev);
3344                         if (amt->stream_dev == dev)
3345                                 return upper_dev;
3346                 }
3347         }
3348
3349         return NULL;
3350 }
3351
3352 static int amt_device_event(struct notifier_block *unused,
3353                             unsigned long event, void *ptr)
3354 {
3355         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
3356         struct net_device *upper_dev;
3357         struct amt_dev *amt;
3358         LIST_HEAD(list);
3359         int new_mtu;
3360
3361         upper_dev = amt_lookup_upper_dev(dev);
3362         if (!upper_dev)
3363                 return NOTIFY_DONE;
3364         amt = netdev_priv(upper_dev);
3365
3366         switch (event) {
3367         case NETDEV_UNREGISTER:
3368                 amt_dellink(amt->dev, &list);
3369                 unregister_netdevice_many(&list);
3370                 break;
3371         case NETDEV_CHANGEMTU:
3372                 if (amt->mode == AMT_MODE_RELAY)
3373                         new_mtu = dev->mtu - AMT_RELAY_HLEN;
3374                 else
3375                         new_mtu = dev->mtu - AMT_GW_HLEN;
3376
3377                 dev_set_mtu(amt->dev, new_mtu);
3378                 break;
3379         }
3380
3381         return NOTIFY_DONE;
3382 }
3383
3384 static struct notifier_block amt_notifier_block __read_mostly = {
3385         .notifier_call = amt_device_event,
3386 };
3387
3388 static int __init amt_init(void)
3389 {
3390         int err;
3391
3392         err = register_netdevice_notifier(&amt_notifier_block);
3393         if (err < 0)
3394                 goto err;
3395
3396         err = rtnl_link_register(&amt_link_ops);
3397         if (err < 0)
3398                 goto unregister_notifier;
3399
3400         amt_wq = alloc_workqueue("amt", WQ_UNBOUND, 0);
3401         if (!amt_wq) {
3402                 err = -ENOMEM;
3403                 goto rtnl_unregister;
3404         }
3405
3406         spin_lock_init(&source_gc_lock);
3407         spin_lock_bh(&source_gc_lock);
3408         INIT_DELAYED_WORK(&source_gc_wq, amt_source_gc_work);
3409         mod_delayed_work(amt_wq, &source_gc_wq,
3410                          msecs_to_jiffies(AMT_GC_INTERVAL));
3411         spin_unlock_bh(&source_gc_lock);
3412
3413         return 0;
3414
3415 rtnl_unregister:
3416         rtnl_link_unregister(&amt_link_ops);
3417 unregister_notifier:
3418         unregister_netdevice_notifier(&amt_notifier_block);
3419 err:
3420         pr_err("error loading AMT module loaded\n");
3421         return err;
3422 }
3423 late_initcall(amt_init);
3424
3425 static void __exit amt_fini(void)
3426 {
3427         rtnl_link_unregister(&amt_link_ops);
3428         unregister_netdevice_notifier(&amt_notifier_block);
3429         cancel_delayed_work_sync(&source_gc_wq);
3430         __amt_source_gc_work();
3431         destroy_workqueue(amt_wq);
3432 }
3433 module_exit(amt_fini);
3434
3435 MODULE_LICENSE("GPL");
3436 MODULE_AUTHOR("Taehee Yoo <ap420073@gmail.com>");
3437 MODULE_ALIAS_RTNL_LINK("amt");