amt: fix possible null-ptr-deref in amt_rcv()
[linux-2.6-microblaze.git] / drivers / net / amt.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* Copyright (c) 2021 Taehee Yoo <ap420073@gmail.com> */
3
4 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
5
6 #include <linux/module.h>
7 #include <linux/skbuff.h>
8 #include <linux/udp.h>
9 #include <linux/jhash.h>
10 #include <linux/if_tunnel.h>
11 #include <linux/net.h>
12 #include <linux/igmp.h>
13 #include <linux/workqueue.h>
14 #include <net/sch_generic.h>
15 #include <net/net_namespace.h>
16 #include <net/ip.h>
17 #include <net/udp.h>
18 #include <net/udp_tunnel.h>
19 #include <net/icmp.h>
20 #include <net/mld.h>
21 #include <net/amt.h>
22 #include <uapi/linux/amt.h>
23 #include <linux/security.h>
24 #include <net/gro_cells.h>
25 #include <net/ipv6.h>
26 #include <net/if_inet6.h>
27 #include <net/ndisc.h>
28 #include <net/addrconf.h>
29 #include <net/ip6_route.h>
30 #include <net/inet_common.h>
31 #include <net/ip6_checksum.h>
32
33 static struct workqueue_struct *amt_wq;
34
35 static HLIST_HEAD(source_gc_list);
36 /* Lock for source_gc_list */
37 static spinlock_t source_gc_lock;
38 static struct delayed_work source_gc_wq;
39 static char *status_str[] = {
40         "AMT_STATUS_INIT",
41         "AMT_STATUS_SENT_DISCOVERY",
42         "AMT_STATUS_RECEIVED_DISCOVERY",
43         "AMT_STATUS_SENT_ADVERTISEMENT",
44         "AMT_STATUS_RECEIVED_ADVERTISEMENT",
45         "AMT_STATUS_SENT_REQUEST",
46         "AMT_STATUS_RECEIVED_REQUEST",
47         "AMT_STATUS_SENT_QUERY",
48         "AMT_STATUS_RECEIVED_QUERY",
49         "AMT_STATUS_SENT_UPDATE",
50         "AMT_STATUS_RECEIVED_UPDATE",
51 };
52
53 static char *type_str[] = {
54         "AMT_MSG_DISCOVERY",
55         "AMT_MSG_ADVERTISEMENT",
56         "AMT_MSG_REQUEST",
57         "AMT_MSG_MEMBERSHIP_QUERY",
58         "AMT_MSG_MEMBERSHIP_UPDATE",
59         "AMT_MSG_MULTICAST_DATA",
60         "AMT_MSG_TEARDOWN",
61 };
62
63 static char *action_str[] = {
64         "AMT_ACT_GMI",
65         "AMT_ACT_GMI_ZERO",
66         "AMT_ACT_GT",
67         "AMT_ACT_STATUS_FWD_NEW",
68         "AMT_ACT_STATUS_D_FWD_NEW",
69         "AMT_ACT_STATUS_NONE_NEW",
70 };
71
72 static struct igmpv3_grec igmpv3_zero_grec;
73
74 #if IS_ENABLED(CONFIG_IPV6)
75 #define MLD2_ALL_NODE_INIT { { { 0xff, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01 } } }
76 static struct in6_addr mld2_all_node = MLD2_ALL_NODE_INIT;
77 static struct mld2_grec mldv2_zero_grec;
78 #endif
79
80 static struct amt_skb_cb *amt_skb_cb(struct sk_buff *skb)
81 {
82         BUILD_BUG_ON(sizeof(struct amt_skb_cb) + sizeof(struct qdisc_skb_cb) >
83                      sizeof_field(struct sk_buff, cb));
84
85         return (struct amt_skb_cb *)((void *)skb->cb +
86                 sizeof(struct qdisc_skb_cb));
87 }
88
89 static void __amt_source_gc_work(void)
90 {
91         struct amt_source_node *snode;
92         struct hlist_head gc_list;
93         struct hlist_node *t;
94
95         spin_lock_bh(&source_gc_lock);
96         hlist_move_list(&source_gc_list, &gc_list);
97         spin_unlock_bh(&source_gc_lock);
98
99         hlist_for_each_entry_safe(snode, t, &gc_list, node) {
100                 hlist_del_rcu(&snode->node);
101                 kfree_rcu(snode, rcu);
102         }
103 }
104
105 static void amt_source_gc_work(struct work_struct *work)
106 {
107         __amt_source_gc_work();
108
109         spin_lock_bh(&source_gc_lock);
110         mod_delayed_work(amt_wq, &source_gc_wq,
111                          msecs_to_jiffies(AMT_GC_INTERVAL));
112         spin_unlock_bh(&source_gc_lock);
113 }
114
115 static bool amt_addr_equal(union amt_addr *a, union amt_addr *b)
116 {
117         return !memcmp(a, b, sizeof(union amt_addr));
118 }
119
120 static u32 amt_source_hash(struct amt_tunnel_list *tunnel, union amt_addr *src)
121 {
122         u32 hash = jhash(src, sizeof(*src), tunnel->amt->hash_seed);
123
124         return reciprocal_scale(hash, tunnel->amt->hash_buckets);
125 }
126
127 static bool amt_status_filter(struct amt_source_node *snode,
128                               enum amt_filter filter)
129 {
130         bool rc = false;
131
132         switch (filter) {
133         case AMT_FILTER_FWD:
134                 if (snode->status == AMT_SOURCE_STATUS_FWD &&
135                     snode->flags == AMT_SOURCE_OLD)
136                         rc = true;
137                 break;
138         case AMT_FILTER_D_FWD:
139                 if (snode->status == AMT_SOURCE_STATUS_D_FWD &&
140                     snode->flags == AMT_SOURCE_OLD)
141                         rc = true;
142                 break;
143         case AMT_FILTER_FWD_NEW:
144                 if (snode->status == AMT_SOURCE_STATUS_FWD &&
145                     snode->flags == AMT_SOURCE_NEW)
146                         rc = true;
147                 break;
148         case AMT_FILTER_D_FWD_NEW:
149                 if (snode->status == AMT_SOURCE_STATUS_D_FWD &&
150                     snode->flags == AMT_SOURCE_NEW)
151                         rc = true;
152                 break;
153         case AMT_FILTER_ALL:
154                 rc = true;
155                 break;
156         case AMT_FILTER_NONE_NEW:
157                 if (snode->status == AMT_SOURCE_STATUS_NONE &&
158                     snode->flags == AMT_SOURCE_NEW)
159                         rc = true;
160                 break;
161         case AMT_FILTER_BOTH:
162                 if ((snode->status == AMT_SOURCE_STATUS_D_FWD ||
163                      snode->status == AMT_SOURCE_STATUS_FWD) &&
164                     snode->flags == AMT_SOURCE_OLD)
165                         rc = true;
166                 break;
167         case AMT_FILTER_BOTH_NEW:
168                 if ((snode->status == AMT_SOURCE_STATUS_D_FWD ||
169                      snode->status == AMT_SOURCE_STATUS_FWD) &&
170                     snode->flags == AMT_SOURCE_NEW)
171                         rc = true;
172                 break;
173         default:
174                 WARN_ON_ONCE(1);
175                 break;
176         }
177
178         return rc;
179 }
180
181 static struct amt_source_node *amt_lookup_src(struct amt_tunnel_list *tunnel,
182                                               struct amt_group_node *gnode,
183                                               enum amt_filter filter,
184                                               union amt_addr *src)
185 {
186         u32 hash = amt_source_hash(tunnel, src);
187         struct amt_source_node *snode;
188
189         hlist_for_each_entry_rcu(snode, &gnode->sources[hash], node)
190                 if (amt_status_filter(snode, filter) &&
191                     amt_addr_equal(&snode->source_addr, src))
192                         return snode;
193
194         return NULL;
195 }
196
197 static u32 amt_group_hash(struct amt_tunnel_list *tunnel, union amt_addr *group)
198 {
199         u32 hash = jhash(group, sizeof(*group), tunnel->amt->hash_seed);
200
201         return reciprocal_scale(hash, tunnel->amt->hash_buckets);
202 }
203
204 static struct amt_group_node *amt_lookup_group(struct amt_tunnel_list *tunnel,
205                                                union amt_addr *group,
206                                                union amt_addr *host,
207                                                bool v6)
208 {
209         u32 hash = amt_group_hash(tunnel, group);
210         struct amt_group_node *gnode;
211
212         hlist_for_each_entry_rcu(gnode, &tunnel->groups[hash], node) {
213                 if (amt_addr_equal(&gnode->group_addr, group) &&
214                     amt_addr_equal(&gnode->host_addr, host) &&
215                     gnode->v6 == v6)
216                         return gnode;
217         }
218
219         return NULL;
220 }
221
222 static void amt_destroy_source(struct amt_source_node *snode)
223 {
224         struct amt_group_node *gnode = snode->gnode;
225         struct amt_tunnel_list *tunnel;
226
227         tunnel = gnode->tunnel_list;
228
229         if (!gnode->v6) {
230                 netdev_dbg(snode->gnode->amt->dev,
231                            "Delete source %pI4 from %pI4\n",
232                            &snode->source_addr.ip4,
233                            &gnode->group_addr.ip4);
234 #if IS_ENABLED(CONFIG_IPV6)
235         } else {
236                 netdev_dbg(snode->gnode->amt->dev,
237                            "Delete source %pI6 from %pI6\n",
238                            &snode->source_addr.ip6,
239                            &gnode->group_addr.ip6);
240 #endif
241         }
242
243         cancel_delayed_work(&snode->source_timer);
244         hlist_del_init_rcu(&snode->node);
245         tunnel->nr_sources--;
246         gnode->nr_sources--;
247         spin_lock_bh(&source_gc_lock);
248         hlist_add_head_rcu(&snode->node, &source_gc_list);
249         spin_unlock_bh(&source_gc_lock);
250 }
251
252 static void amt_del_group(struct amt_dev *amt, struct amt_group_node *gnode)
253 {
254         struct amt_source_node *snode;
255         struct hlist_node *t;
256         int i;
257
258         if (cancel_delayed_work(&gnode->group_timer))
259                 dev_put(amt->dev);
260         hlist_del_rcu(&gnode->node);
261         gnode->tunnel_list->nr_groups--;
262
263         if (!gnode->v6)
264                 netdev_dbg(amt->dev, "Leave group %pI4\n",
265                            &gnode->group_addr.ip4);
266 #if IS_ENABLED(CONFIG_IPV6)
267         else
268                 netdev_dbg(amt->dev, "Leave group %pI6\n",
269                            &gnode->group_addr.ip6);
270 #endif
271         for (i = 0; i < amt->hash_buckets; i++)
272                 hlist_for_each_entry_safe(snode, t, &gnode->sources[i], node)
273                         amt_destroy_source(snode);
274
275         /* tunnel->lock was acquired outside of amt_del_group()
276          * But rcu_read_lock() was acquired too so It's safe.
277          */
278         kfree_rcu(gnode, rcu);
279 }
280
281 /* If a source timer expires with a router filter-mode for the group of
282  * INCLUDE, the router concludes that traffic from this particular
283  * source is no longer desired on the attached network, and deletes the
284  * associated source record.
285  */
286 static void amt_source_work(struct work_struct *work)
287 {
288         struct amt_source_node *snode = container_of(to_delayed_work(work),
289                                                      struct amt_source_node,
290                                                      source_timer);
291         struct amt_group_node *gnode = snode->gnode;
292         struct amt_dev *amt = gnode->amt;
293         struct amt_tunnel_list *tunnel;
294
295         tunnel = gnode->tunnel_list;
296         spin_lock_bh(&tunnel->lock);
297         rcu_read_lock();
298         if (gnode->filter_mode == MCAST_INCLUDE) {
299                 amt_destroy_source(snode);
300                 if (!gnode->nr_sources)
301                         amt_del_group(amt, gnode);
302         } else {
303                 /* When a router filter-mode for a group is EXCLUDE,
304                  * source records are only deleted when the group timer expires
305                  */
306                 snode->status = AMT_SOURCE_STATUS_D_FWD;
307         }
308         rcu_read_unlock();
309         spin_unlock_bh(&tunnel->lock);
310 }
311
312 static void amt_act_src(struct amt_tunnel_list *tunnel,
313                         struct amt_group_node *gnode,
314                         struct amt_source_node *snode,
315                         enum amt_act act)
316 {
317         struct amt_dev *amt = tunnel->amt;
318
319         switch (act) {
320         case AMT_ACT_GMI:
321                 mod_delayed_work(amt_wq, &snode->source_timer,
322                                  msecs_to_jiffies(amt_gmi(amt)));
323                 break;
324         case AMT_ACT_GMI_ZERO:
325                 cancel_delayed_work(&snode->source_timer);
326                 break;
327         case AMT_ACT_GT:
328                 mod_delayed_work(amt_wq, &snode->source_timer,
329                                  gnode->group_timer.timer.expires);
330                 break;
331         case AMT_ACT_STATUS_FWD_NEW:
332                 snode->status = AMT_SOURCE_STATUS_FWD;
333                 snode->flags = AMT_SOURCE_NEW;
334                 break;
335         case AMT_ACT_STATUS_D_FWD_NEW:
336                 snode->status = AMT_SOURCE_STATUS_D_FWD;
337                 snode->flags = AMT_SOURCE_NEW;
338                 break;
339         case AMT_ACT_STATUS_NONE_NEW:
340                 cancel_delayed_work(&snode->source_timer);
341                 snode->status = AMT_SOURCE_STATUS_NONE;
342                 snode->flags = AMT_SOURCE_NEW;
343                 break;
344         default:
345                 WARN_ON_ONCE(1);
346                 return;
347         }
348
349         if (!gnode->v6)
350                 netdev_dbg(amt->dev, "Source %pI4 from %pI4 Acted %s\n",
351                            &snode->source_addr.ip4,
352                            &gnode->group_addr.ip4,
353                            action_str[act]);
354 #if IS_ENABLED(CONFIG_IPV6)
355         else
356                 netdev_dbg(amt->dev, "Source %pI6 from %pI6 Acted %s\n",
357                            &snode->source_addr.ip6,
358                            &gnode->group_addr.ip6,
359                            action_str[act]);
360 #endif
361 }
362
363 static struct amt_source_node *amt_alloc_snode(struct amt_group_node *gnode,
364                                                union amt_addr *src)
365 {
366         struct amt_source_node *snode;
367
368         snode = kzalloc(sizeof(*snode), GFP_ATOMIC);
369         if (!snode)
370                 return NULL;
371
372         memcpy(&snode->source_addr, src, sizeof(union amt_addr));
373         snode->gnode = gnode;
374         snode->status = AMT_SOURCE_STATUS_NONE;
375         snode->flags = AMT_SOURCE_NEW;
376         INIT_HLIST_NODE(&snode->node);
377         INIT_DELAYED_WORK(&snode->source_timer, amt_source_work);
378
379         return snode;
380 }
381
382 /* RFC 3810 - 7.2.2.  Definition of Filter Timers
383  *
384  *  Router Mode          Filter Timer         Actions/Comments
385  *  -----------       -----------------       ----------------
386  *
387  *    INCLUDE             Not Used            All listeners in
388  *                                            INCLUDE mode.
389  *
390  *    EXCLUDE             Timer > 0           At least one listener
391  *                                            in EXCLUDE mode.
392  *
393  *    EXCLUDE             Timer == 0          No more listeners in
394  *                                            EXCLUDE mode for the
395  *                                            multicast address.
396  *                                            If the Requested List
397  *                                            is empty, delete
398  *                                            Multicast Address
399  *                                            Record.  If not, switch
400  *                                            to INCLUDE filter mode;
401  *                                            the sources in the
402  *                                            Requested List are
403  *                                            moved to the Include
404  *                                            List, and the Exclude
405  *                                            List is deleted.
406  */
407 static void amt_group_work(struct work_struct *work)
408 {
409         struct amt_group_node *gnode = container_of(to_delayed_work(work),
410                                                     struct amt_group_node,
411                                                     group_timer);
412         struct amt_tunnel_list *tunnel = gnode->tunnel_list;
413         struct amt_dev *amt = gnode->amt;
414         struct amt_source_node *snode;
415         bool delete_group = true;
416         struct hlist_node *t;
417         int i, buckets;
418
419         buckets = amt->hash_buckets;
420
421         spin_lock_bh(&tunnel->lock);
422         if (gnode->filter_mode == MCAST_INCLUDE) {
423                 /* Not Used */
424                 spin_unlock_bh(&tunnel->lock);
425                 goto out;
426         }
427
428         rcu_read_lock();
429         for (i = 0; i < buckets; i++) {
430                 hlist_for_each_entry_safe(snode, t,
431                                           &gnode->sources[i], node) {
432                         if (!delayed_work_pending(&snode->source_timer) ||
433                             snode->status == AMT_SOURCE_STATUS_D_FWD) {
434                                 amt_destroy_source(snode);
435                         } else {
436                                 delete_group = false;
437                                 snode->status = AMT_SOURCE_STATUS_FWD;
438                         }
439                 }
440         }
441         if (delete_group)
442                 amt_del_group(amt, gnode);
443         else
444                 gnode->filter_mode = MCAST_INCLUDE;
445         rcu_read_unlock();
446         spin_unlock_bh(&tunnel->lock);
447 out:
448         dev_put(amt->dev);
449 }
450
451 /* Non-existant group is created as INCLUDE {empty}:
452  *
453  * RFC 3376 - 5.1. Action on Change of Interface State
454  *
455  * If no interface state existed for that multicast address before
456  * the change (i.e., the change consisted of creating a new
457  * per-interface record), or if no state exists after the change
458  * (i.e., the change consisted of deleting a per-interface record),
459  * then the "non-existent" state is considered to have a filter mode
460  * of INCLUDE and an empty source list.
461  */
462 static struct amt_group_node *amt_add_group(struct amt_dev *amt,
463                                             struct amt_tunnel_list *tunnel,
464                                             union amt_addr *group,
465                                             union amt_addr *host,
466                                             bool v6)
467 {
468         struct amt_group_node *gnode;
469         u32 hash;
470         int i;
471
472         if (tunnel->nr_groups >= amt->max_groups)
473                 return ERR_PTR(-ENOSPC);
474
475         gnode = kzalloc(sizeof(*gnode) +
476                         (sizeof(struct hlist_head) * amt->hash_buckets),
477                         GFP_ATOMIC);
478         if (unlikely(!gnode))
479                 return ERR_PTR(-ENOMEM);
480
481         gnode->amt = amt;
482         gnode->group_addr = *group;
483         gnode->host_addr = *host;
484         gnode->v6 = v6;
485         gnode->tunnel_list = tunnel;
486         gnode->filter_mode = MCAST_INCLUDE;
487         INIT_HLIST_NODE(&gnode->node);
488         INIT_DELAYED_WORK(&gnode->group_timer, amt_group_work);
489         for (i = 0; i < amt->hash_buckets; i++)
490                 INIT_HLIST_HEAD(&gnode->sources[i]);
491
492         hash = amt_group_hash(tunnel, group);
493         hlist_add_head_rcu(&gnode->node, &tunnel->groups[hash]);
494         tunnel->nr_groups++;
495
496         if (!gnode->v6)
497                 netdev_dbg(amt->dev, "Join group %pI4\n",
498                            &gnode->group_addr.ip4);
499 #if IS_ENABLED(CONFIG_IPV6)
500         else
501                 netdev_dbg(amt->dev, "Join group %pI6\n",
502                            &gnode->group_addr.ip6);
503 #endif
504
505         return gnode;
506 }
507
508 static struct sk_buff *amt_build_igmp_gq(struct amt_dev *amt)
509 {
510         u8 ra[AMT_IPHDR_OPTS] = { IPOPT_RA, 4, 0, 0 };
511         int hlen = LL_RESERVED_SPACE(amt->dev);
512         int tlen = amt->dev->needed_tailroom;
513         struct igmpv3_query *ihv3;
514         void *csum_start = NULL;
515         __sum16 *csum = NULL;
516         struct sk_buff *skb;
517         struct ethhdr *eth;
518         struct iphdr *iph;
519         unsigned int len;
520         int offset;
521
522         len = hlen + tlen + sizeof(*iph) + AMT_IPHDR_OPTS + sizeof(*ihv3);
523         skb = netdev_alloc_skb_ip_align(amt->dev, len);
524         if (!skb)
525                 return NULL;
526
527         skb_reserve(skb, hlen);
528         skb_push(skb, sizeof(*eth));
529         skb->protocol = htons(ETH_P_IP);
530         skb_reset_mac_header(skb);
531         skb->priority = TC_PRIO_CONTROL;
532         skb_put(skb, sizeof(*iph));
533         skb_put_data(skb, ra, sizeof(ra));
534         skb_put(skb, sizeof(*ihv3));
535         skb_pull(skb, sizeof(*eth));
536         skb_reset_network_header(skb);
537
538         iph             = ip_hdr(skb);
539         iph->version    = 4;
540         iph->ihl        = (sizeof(struct iphdr) + AMT_IPHDR_OPTS) >> 2;
541         iph->tos        = AMT_TOS;
542         iph->tot_len    = htons(sizeof(*iph) + AMT_IPHDR_OPTS + sizeof(*ihv3));
543         iph->frag_off   = htons(IP_DF);
544         iph->ttl        = 1;
545         iph->id         = 0;
546         iph->protocol   = IPPROTO_IGMP;
547         iph->daddr      = htonl(INADDR_ALLHOSTS_GROUP);
548         iph->saddr      = htonl(INADDR_ANY);
549         ip_send_check(iph);
550
551         eth = eth_hdr(skb);
552         ether_addr_copy(eth->h_source, amt->dev->dev_addr);
553         ip_eth_mc_map(htonl(INADDR_ALLHOSTS_GROUP), eth->h_dest);
554         eth->h_proto = htons(ETH_P_IP);
555
556         ihv3            = skb_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
557         skb_reset_transport_header(skb);
558         ihv3->type      = IGMP_HOST_MEMBERSHIP_QUERY;
559         ihv3->code      = 1;
560         ihv3->group     = 0;
561         ihv3->qqic      = amt->qi;
562         ihv3->nsrcs     = 0;
563         ihv3->resv      = 0;
564         ihv3->suppress  = false;
565         ihv3->qrv       = amt->net->ipv4.sysctl_igmp_qrv;
566         ihv3->csum      = 0;
567         csum            = &ihv3->csum;
568         csum_start      = (void *)ihv3;
569         *csum           = ip_compute_csum(csum_start, sizeof(*ihv3));
570         offset          = skb_transport_offset(skb);
571         skb->csum       = skb_checksum(skb, offset, skb->len - offset, 0);
572         skb->ip_summed  = CHECKSUM_NONE;
573
574         skb_push(skb, sizeof(*eth) + sizeof(*iph) + AMT_IPHDR_OPTS);
575
576         return skb;
577 }
578
579 static void __amt_update_gw_status(struct amt_dev *amt, enum amt_status status,
580                                    bool validate)
581 {
582         if (validate && amt->status >= status)
583                 return;
584         netdev_dbg(amt->dev, "Update GW status %s -> %s",
585                    status_str[amt->status], status_str[status]);
586         amt->status = status;
587 }
588
589 static void __amt_update_relay_status(struct amt_tunnel_list *tunnel,
590                                       enum amt_status status,
591                                       bool validate)
592 {
593         if (validate && tunnel->status >= status)
594                 return;
595         netdev_dbg(tunnel->amt->dev,
596                    "Update Tunnel(IP = %pI4, PORT = %u) status %s -> %s",
597                    &tunnel->ip4, ntohs(tunnel->source_port),
598                    status_str[tunnel->status], status_str[status]);
599         tunnel->status = status;
600 }
601
602 static void amt_update_gw_status(struct amt_dev *amt, enum amt_status status,
603                                  bool validate)
604 {
605         spin_lock_bh(&amt->lock);
606         __amt_update_gw_status(amt, status, validate);
607         spin_unlock_bh(&amt->lock);
608 }
609
610 static void amt_update_relay_status(struct amt_tunnel_list *tunnel,
611                                     enum amt_status status, bool validate)
612 {
613         spin_lock_bh(&tunnel->lock);
614         __amt_update_relay_status(tunnel, status, validate);
615         spin_unlock_bh(&tunnel->lock);
616 }
617
618 static void amt_send_discovery(struct amt_dev *amt)
619 {
620         struct amt_header_discovery *amtd;
621         int hlen, tlen, offset;
622         struct socket *sock;
623         struct udphdr *udph;
624         struct sk_buff *skb;
625         struct iphdr *iph;
626         struct rtable *rt;
627         struct flowi4 fl4;
628         u32 len;
629         int err;
630
631         rcu_read_lock();
632         sock = rcu_dereference(amt->sock);
633         if (!sock)
634                 goto out;
635
636         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
637                 goto out;
638
639         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
640                                    amt->discovery_ip, amt->local_ip,
641                                    amt->gw_port, amt->relay_port,
642                                    IPPROTO_UDP, 0,
643                                    amt->stream_dev->ifindex);
644         if (IS_ERR(rt)) {
645                 amt->dev->stats.tx_errors++;
646                 goto out;
647         }
648
649         hlen = LL_RESERVED_SPACE(amt->dev);
650         tlen = amt->dev->needed_tailroom;
651         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amtd);
652         skb = netdev_alloc_skb_ip_align(amt->dev, len);
653         if (!skb) {
654                 ip_rt_put(rt);
655                 amt->dev->stats.tx_errors++;
656                 goto out;
657         }
658
659         skb->priority = TC_PRIO_CONTROL;
660         skb_dst_set(skb, &rt->dst);
661
662         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amtd);
663         skb_reset_network_header(skb);
664         skb_put(skb, len);
665         amtd = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
666         amtd->version   = 0;
667         amtd->type      = AMT_MSG_DISCOVERY;
668         amtd->reserved  = 0;
669         amtd->nonce     = amt->nonce;
670         skb_push(skb, sizeof(*udph));
671         skb_reset_transport_header(skb);
672         udph            = udp_hdr(skb);
673         udph->source    = amt->gw_port;
674         udph->dest      = amt->relay_port;
675         udph->len       = htons(sizeof(*udph) + sizeof(*amtd));
676         udph->check     = 0;
677         offset = skb_transport_offset(skb);
678         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
679         udph->check = csum_tcpudp_magic(amt->local_ip, amt->discovery_ip,
680                                         sizeof(*udph) + sizeof(*amtd),
681                                         IPPROTO_UDP, skb->csum);
682
683         skb_push(skb, sizeof(*iph));
684         iph             = ip_hdr(skb);
685         iph->version    = 4;
686         iph->ihl        = (sizeof(struct iphdr)) >> 2;
687         iph->tos        = AMT_TOS;
688         iph->frag_off   = 0;
689         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
690         iph->daddr      = amt->discovery_ip;
691         iph->saddr      = amt->local_ip;
692         iph->protocol   = IPPROTO_UDP;
693         iph->tot_len    = htons(len);
694
695         skb->ip_summed = CHECKSUM_NONE;
696         ip_select_ident(amt->net, skb, NULL);
697         ip_send_check(iph);
698         err = ip_local_out(amt->net, sock->sk, skb);
699         if (unlikely(net_xmit_eval(err)))
700                 amt->dev->stats.tx_errors++;
701
702         spin_lock_bh(&amt->lock);
703         __amt_update_gw_status(amt, AMT_STATUS_SENT_DISCOVERY, true);
704         spin_unlock_bh(&amt->lock);
705 out:
706         rcu_read_unlock();
707 }
708
709 static void amt_send_request(struct amt_dev *amt, bool v6)
710 {
711         struct amt_header_request *amtrh;
712         int hlen, tlen, offset;
713         struct socket *sock;
714         struct udphdr *udph;
715         struct sk_buff *skb;
716         struct iphdr *iph;
717         struct rtable *rt;
718         struct flowi4 fl4;
719         u32 len;
720         int err;
721
722         rcu_read_lock();
723         sock = rcu_dereference(amt->sock);
724         if (!sock)
725                 goto out;
726
727         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
728                 goto out;
729
730         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
731                                    amt->remote_ip, amt->local_ip,
732                                    amt->gw_port, amt->relay_port,
733                                    IPPROTO_UDP, 0,
734                                    amt->stream_dev->ifindex);
735         if (IS_ERR(rt)) {
736                 amt->dev->stats.tx_errors++;
737                 goto out;
738         }
739
740         hlen = LL_RESERVED_SPACE(amt->dev);
741         tlen = amt->dev->needed_tailroom;
742         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amtrh);
743         skb = netdev_alloc_skb_ip_align(amt->dev, len);
744         if (!skb) {
745                 ip_rt_put(rt);
746                 amt->dev->stats.tx_errors++;
747                 goto out;
748         }
749
750         skb->priority = TC_PRIO_CONTROL;
751         skb_dst_set(skb, &rt->dst);
752
753         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amtrh);
754         skb_reset_network_header(skb);
755         skb_put(skb, len);
756         amtrh = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
757         amtrh->version   = 0;
758         amtrh->type      = AMT_MSG_REQUEST;
759         amtrh->reserved1 = 0;
760         amtrh->p         = v6;
761         amtrh->reserved2 = 0;
762         amtrh->nonce     = amt->nonce;
763         skb_push(skb, sizeof(*udph));
764         skb_reset_transport_header(skb);
765         udph            = udp_hdr(skb);
766         udph->source    = amt->gw_port;
767         udph->dest      = amt->relay_port;
768         udph->len       = htons(sizeof(*amtrh) + sizeof(*udph));
769         udph->check     = 0;
770         offset = skb_transport_offset(skb);
771         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
772         udph->check = csum_tcpudp_magic(amt->local_ip, amt->remote_ip,
773                                         sizeof(*udph) + sizeof(*amtrh),
774                                         IPPROTO_UDP, skb->csum);
775
776         skb_push(skb, sizeof(*iph));
777         iph             = ip_hdr(skb);
778         iph->version    = 4;
779         iph->ihl        = (sizeof(struct iphdr)) >> 2;
780         iph->tos        = AMT_TOS;
781         iph->frag_off   = 0;
782         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
783         iph->daddr      = amt->remote_ip;
784         iph->saddr      = amt->local_ip;
785         iph->protocol   = IPPROTO_UDP;
786         iph->tot_len    = htons(len);
787
788         skb->ip_summed = CHECKSUM_NONE;
789         ip_select_ident(amt->net, skb, NULL);
790         ip_send_check(iph);
791         err = ip_local_out(amt->net, sock->sk, skb);
792         if (unlikely(net_xmit_eval(err)))
793                 amt->dev->stats.tx_errors++;
794
795 out:
796         rcu_read_unlock();
797 }
798
799 static void amt_send_igmp_gq(struct amt_dev *amt,
800                              struct amt_tunnel_list *tunnel)
801 {
802         struct sk_buff *skb;
803
804         skb = amt_build_igmp_gq(amt);
805         if (!skb)
806                 return;
807
808         amt_skb_cb(skb)->tunnel = tunnel;
809         dev_queue_xmit(skb);
810 }
811
812 #if IS_ENABLED(CONFIG_IPV6)
813 static struct sk_buff *amt_build_mld_gq(struct amt_dev *amt)
814 {
815         u8 ra[AMT_IP6HDR_OPTS] = { IPPROTO_ICMPV6, 0, IPV6_TLV_ROUTERALERT,
816                                    2, 0, 0, IPV6_TLV_PAD1, IPV6_TLV_PAD1 };
817         int hlen = LL_RESERVED_SPACE(amt->dev);
818         int tlen = amt->dev->needed_tailroom;
819         struct mld2_query *mld2q;
820         void *csum_start = NULL;
821         struct ipv6hdr *ip6h;
822         struct sk_buff *skb;
823         struct ethhdr *eth;
824         u32 len;
825
826         len = hlen + tlen + sizeof(*ip6h) + sizeof(ra) + sizeof(*mld2q);
827         skb = netdev_alloc_skb_ip_align(amt->dev, len);
828         if (!skb)
829                 return NULL;
830
831         skb_reserve(skb, hlen);
832         skb_push(skb, sizeof(*eth));
833         skb_reset_mac_header(skb);
834         eth = eth_hdr(skb);
835         skb->priority = TC_PRIO_CONTROL;
836         skb->protocol = htons(ETH_P_IPV6);
837         skb_put_zero(skb, sizeof(*ip6h));
838         skb_put_data(skb, ra, sizeof(ra));
839         skb_put_zero(skb, sizeof(*mld2q));
840         skb_pull(skb, sizeof(*eth));
841         skb_reset_network_header(skb);
842         ip6h                    = ipv6_hdr(skb);
843         ip6h->payload_len       = htons(sizeof(ra) + sizeof(*mld2q));
844         ip6h->nexthdr           = NEXTHDR_HOP;
845         ip6h->hop_limit         = 1;
846         ip6h->daddr             = mld2_all_node;
847         ip6_flow_hdr(ip6h, 0, 0);
848
849         if (ipv6_dev_get_saddr(amt->net, amt->dev, &ip6h->daddr, 0,
850                                &ip6h->saddr)) {
851                 amt->dev->stats.tx_errors++;
852                 kfree_skb(skb);
853                 return NULL;
854         }
855
856         eth->h_proto = htons(ETH_P_IPV6);
857         ether_addr_copy(eth->h_source, amt->dev->dev_addr);
858         ipv6_eth_mc_map(&mld2_all_node, eth->h_dest);
859
860         skb_pull(skb, sizeof(*ip6h) + sizeof(ra));
861         skb_reset_transport_header(skb);
862         mld2q                   = (struct mld2_query *)icmp6_hdr(skb);
863         mld2q->mld2q_mrc        = htons(1);
864         mld2q->mld2q_type       = ICMPV6_MGM_QUERY;
865         mld2q->mld2q_code       = 0;
866         mld2q->mld2q_cksum      = 0;
867         mld2q->mld2q_resv1      = 0;
868         mld2q->mld2q_resv2      = 0;
869         mld2q->mld2q_suppress   = 0;
870         mld2q->mld2q_qrv        = amt->qrv;
871         mld2q->mld2q_nsrcs      = 0;
872         mld2q->mld2q_qqic       = amt->qi;
873         csum_start              = (void *)mld2q;
874         mld2q->mld2q_cksum = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
875                                              sizeof(*mld2q),
876                                              IPPROTO_ICMPV6,
877                                              csum_partial(csum_start,
878                                                           sizeof(*mld2q), 0));
879
880         skb->ip_summed = CHECKSUM_NONE;
881         skb_push(skb, sizeof(*eth) + sizeof(*ip6h) + sizeof(ra));
882         return skb;
883 }
884
885 static void amt_send_mld_gq(struct amt_dev *amt, struct amt_tunnel_list *tunnel)
886 {
887         struct sk_buff *skb;
888
889         skb = amt_build_mld_gq(amt);
890         if (!skb)
891                 return;
892
893         amt_skb_cb(skb)->tunnel = tunnel;
894         dev_queue_xmit(skb);
895 }
896 #else
897 static void amt_send_mld_gq(struct amt_dev *amt, struct amt_tunnel_list *tunnel)
898 {
899 }
900 #endif
901
902 static void amt_secret_work(struct work_struct *work)
903 {
904         struct amt_dev *amt = container_of(to_delayed_work(work),
905                                            struct amt_dev,
906                                            secret_wq);
907
908         spin_lock_bh(&amt->lock);
909         get_random_bytes(&amt->key, sizeof(siphash_key_t));
910         spin_unlock_bh(&amt->lock);
911         mod_delayed_work(amt_wq, &amt->secret_wq,
912                          msecs_to_jiffies(AMT_SECRET_TIMEOUT));
913 }
914
915 static void amt_discovery_work(struct work_struct *work)
916 {
917         struct amt_dev *amt = container_of(to_delayed_work(work),
918                                            struct amt_dev,
919                                            discovery_wq);
920
921         spin_lock_bh(&amt->lock);
922         if (amt->status > AMT_STATUS_SENT_DISCOVERY)
923                 goto out;
924         get_random_bytes(&amt->nonce, sizeof(__be32));
925         spin_unlock_bh(&amt->lock);
926
927         amt_send_discovery(amt);
928         spin_lock_bh(&amt->lock);
929 out:
930         mod_delayed_work(amt_wq, &amt->discovery_wq,
931                          msecs_to_jiffies(AMT_DISCOVERY_TIMEOUT));
932         spin_unlock_bh(&amt->lock);
933 }
934
935 static void amt_req_work(struct work_struct *work)
936 {
937         struct amt_dev *amt = container_of(to_delayed_work(work),
938                                            struct amt_dev,
939                                            req_wq);
940         u32 exp;
941
942         spin_lock_bh(&amt->lock);
943         if (amt->status < AMT_STATUS_RECEIVED_ADVERTISEMENT)
944                 goto out;
945
946         if (amt->req_cnt > AMT_MAX_REQ_COUNT) {
947                 netdev_dbg(amt->dev, "Gateway is not ready");
948                 amt->qi = AMT_INIT_REQ_TIMEOUT;
949                 amt->ready4 = false;
950                 amt->ready6 = false;
951                 amt->remote_ip = 0;
952                 __amt_update_gw_status(amt, AMT_STATUS_INIT, false);
953                 amt->req_cnt = 0;
954                 goto out;
955         }
956         spin_unlock_bh(&amt->lock);
957
958         amt_send_request(amt, false);
959         amt_send_request(amt, true);
960         spin_lock_bh(&amt->lock);
961         __amt_update_gw_status(amt, AMT_STATUS_SENT_REQUEST, true);
962         amt->req_cnt++;
963 out:
964         exp = min_t(u32, (1 * (1 << amt->req_cnt)), AMT_MAX_REQ_TIMEOUT);
965         mod_delayed_work(amt_wq, &amt->req_wq, msecs_to_jiffies(exp * 1000));
966         spin_unlock_bh(&amt->lock);
967 }
968
969 static bool amt_send_membership_update(struct amt_dev *amt,
970                                        struct sk_buff *skb,
971                                        bool v6)
972 {
973         struct amt_header_membership_update *amtmu;
974         struct socket *sock;
975         struct iphdr *iph;
976         struct flowi4 fl4;
977         struct rtable *rt;
978         int err;
979
980         sock = rcu_dereference_bh(amt->sock);
981         if (!sock)
982                 return true;
983
984         err = skb_cow_head(skb, LL_RESERVED_SPACE(amt->dev) + sizeof(*amtmu) +
985                            sizeof(*iph) + sizeof(struct udphdr));
986         if (err)
987                 return true;
988
989         skb_reset_inner_headers(skb);
990         memset(&fl4, 0, sizeof(struct flowi4));
991         fl4.flowi4_oif         = amt->stream_dev->ifindex;
992         fl4.daddr              = amt->remote_ip;
993         fl4.saddr              = amt->local_ip;
994         fl4.flowi4_tos         = AMT_TOS;
995         fl4.flowi4_proto       = IPPROTO_UDP;
996         rt = ip_route_output_key(amt->net, &fl4);
997         if (IS_ERR(rt)) {
998                 netdev_dbg(amt->dev, "no route to %pI4\n", &amt->remote_ip);
999                 return true;
1000         }
1001
1002         amtmu                   = skb_push(skb, sizeof(*amtmu));
1003         amtmu->version          = 0;
1004         amtmu->type             = AMT_MSG_MEMBERSHIP_UPDATE;
1005         amtmu->reserved         = 0;
1006         amtmu->nonce            = amt->nonce;
1007         amtmu->response_mac     = amt->mac;
1008
1009         if (!v6)
1010                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1011         else
1012                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1013         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1014                             fl4.saddr,
1015                             fl4.daddr,
1016                             AMT_TOS,
1017                             ip4_dst_hoplimit(&rt->dst),
1018                             0,
1019                             amt->gw_port,
1020                             amt->relay_port,
1021                             false,
1022                             false);
1023         amt_update_gw_status(amt, AMT_STATUS_SENT_UPDATE, true);
1024         return false;
1025 }
1026
1027 static void amt_send_multicast_data(struct amt_dev *amt,
1028                                     const struct sk_buff *oskb,
1029                                     struct amt_tunnel_list *tunnel,
1030                                     bool v6)
1031 {
1032         struct amt_header_mcast_data *amtmd;
1033         struct socket *sock;
1034         struct sk_buff *skb;
1035         struct iphdr *iph;
1036         struct flowi4 fl4;
1037         struct rtable *rt;
1038
1039         sock = rcu_dereference_bh(amt->sock);
1040         if (!sock)
1041                 return;
1042
1043         skb = skb_copy_expand(oskb, sizeof(*amtmd) + sizeof(*iph) +
1044                               sizeof(struct udphdr), 0, GFP_ATOMIC);
1045         if (!skb)
1046                 return;
1047
1048         skb_reset_inner_headers(skb);
1049         memset(&fl4, 0, sizeof(struct flowi4));
1050         fl4.flowi4_oif         = amt->stream_dev->ifindex;
1051         fl4.daddr              = tunnel->ip4;
1052         fl4.saddr              = amt->local_ip;
1053         fl4.flowi4_proto       = IPPROTO_UDP;
1054         rt = ip_route_output_key(amt->net, &fl4);
1055         if (IS_ERR(rt)) {
1056                 netdev_dbg(amt->dev, "no route to %pI4\n", &tunnel->ip4);
1057                 kfree_skb(skb);
1058                 return;
1059         }
1060
1061         amtmd = skb_push(skb, sizeof(*amtmd));
1062         amtmd->version = 0;
1063         amtmd->reserved = 0;
1064         amtmd->type = AMT_MSG_MULTICAST_DATA;
1065
1066         if (!v6)
1067                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1068         else
1069                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1070         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1071                             fl4.saddr,
1072                             fl4.daddr,
1073                             AMT_TOS,
1074                             ip4_dst_hoplimit(&rt->dst),
1075                             0,
1076                             amt->relay_port,
1077                             tunnel->source_port,
1078                             false,
1079                             false);
1080 }
1081
1082 static bool amt_send_membership_query(struct amt_dev *amt,
1083                                       struct sk_buff *skb,
1084                                       struct amt_tunnel_list *tunnel,
1085                                       bool v6)
1086 {
1087         struct amt_header_membership_query *amtmq;
1088         struct socket *sock;
1089         struct rtable *rt;
1090         struct flowi4 fl4;
1091         int err;
1092
1093         sock = rcu_dereference_bh(amt->sock);
1094         if (!sock)
1095                 return true;
1096
1097         err = skb_cow_head(skb, LL_RESERVED_SPACE(amt->dev) + sizeof(*amtmq) +
1098                            sizeof(struct iphdr) + sizeof(struct udphdr));
1099         if (err)
1100                 return true;
1101
1102         skb_reset_inner_headers(skb);
1103         memset(&fl4, 0, sizeof(struct flowi4));
1104         fl4.flowi4_oif         = amt->stream_dev->ifindex;
1105         fl4.daddr              = tunnel->ip4;
1106         fl4.saddr              = amt->local_ip;
1107         fl4.flowi4_tos         = AMT_TOS;
1108         fl4.flowi4_proto       = IPPROTO_UDP;
1109         rt = ip_route_output_key(amt->net, &fl4);
1110         if (IS_ERR(rt)) {
1111                 netdev_dbg(amt->dev, "no route to %pI4\n", &tunnel->ip4);
1112                 return true;
1113         }
1114
1115         amtmq           = skb_push(skb, sizeof(*amtmq));
1116         amtmq->version  = 0;
1117         amtmq->type     = AMT_MSG_MEMBERSHIP_QUERY;
1118         amtmq->reserved = 0;
1119         amtmq->l        = 0;
1120         amtmq->g        = 0;
1121         amtmq->nonce    = tunnel->nonce;
1122         amtmq->response_mac = tunnel->mac;
1123
1124         if (!v6)
1125                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1126         else
1127                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1128         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1129                             fl4.saddr,
1130                             fl4.daddr,
1131                             AMT_TOS,
1132                             ip4_dst_hoplimit(&rt->dst),
1133                             0,
1134                             amt->relay_port,
1135                             tunnel->source_port,
1136                             false,
1137                             false);
1138         amt_update_relay_status(tunnel, AMT_STATUS_SENT_QUERY, true);
1139         return false;
1140 }
1141
1142 static netdev_tx_t amt_dev_xmit(struct sk_buff *skb, struct net_device *dev)
1143 {
1144         struct amt_dev *amt = netdev_priv(dev);
1145         struct amt_tunnel_list *tunnel;
1146         struct amt_group_node *gnode;
1147         union amt_addr group = {0,};
1148 #if IS_ENABLED(CONFIG_IPV6)
1149         struct ipv6hdr *ip6h;
1150         struct mld_msg *mld;
1151 #endif
1152         bool report = false;
1153         struct igmphdr *ih;
1154         bool query = false;
1155         struct iphdr *iph;
1156         bool data = false;
1157         bool v6 = false;
1158         u32 hash;
1159
1160         iph = ip_hdr(skb);
1161         if (iph->version == 4) {
1162                 if (!ipv4_is_multicast(iph->daddr))
1163                         goto free;
1164
1165                 if (!ip_mc_check_igmp(skb)) {
1166                         ih = igmp_hdr(skb);
1167                         switch (ih->type) {
1168                         case IGMPV3_HOST_MEMBERSHIP_REPORT:
1169                         case IGMP_HOST_MEMBERSHIP_REPORT:
1170                                 report = true;
1171                                 break;
1172                         case IGMP_HOST_MEMBERSHIP_QUERY:
1173                                 query = true;
1174                                 break;
1175                         default:
1176                                 goto free;
1177                         }
1178                 } else {
1179                         data = true;
1180                 }
1181                 v6 = false;
1182                 group.ip4 = iph->daddr;
1183 #if IS_ENABLED(CONFIG_IPV6)
1184         } else if (iph->version == 6) {
1185                 ip6h = ipv6_hdr(skb);
1186                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
1187                         goto free;
1188
1189                 if (!ipv6_mc_check_mld(skb)) {
1190                         mld = (struct mld_msg *)skb_transport_header(skb);
1191                         switch (mld->mld_type) {
1192                         case ICMPV6_MGM_REPORT:
1193                         case ICMPV6_MLD2_REPORT:
1194                                 report = true;
1195                                 break;
1196                         case ICMPV6_MGM_QUERY:
1197                                 query = true;
1198                                 break;
1199                         default:
1200                                 goto free;
1201                         }
1202                 } else {
1203                         data = true;
1204                 }
1205                 v6 = true;
1206                 group.ip6 = ip6h->daddr;
1207 #endif
1208         } else {
1209                 dev->stats.tx_errors++;
1210                 goto free;
1211         }
1212
1213         if (!pskb_may_pull(skb, sizeof(struct ethhdr)))
1214                 goto free;
1215
1216         skb_pull(skb, sizeof(struct ethhdr));
1217
1218         if (amt->mode == AMT_MODE_GATEWAY) {
1219                 /* Gateway only passes IGMP/MLD packets */
1220                 if (!report)
1221                         goto free;
1222                 if ((!v6 && !amt->ready4) || (v6 && !amt->ready6))
1223                         goto free;
1224                 if (amt_send_membership_update(amt, skb,  v6))
1225                         goto free;
1226                 goto unlock;
1227         } else if (amt->mode == AMT_MODE_RELAY) {
1228                 if (query) {
1229                         tunnel = amt_skb_cb(skb)->tunnel;
1230                         if (!tunnel) {
1231                                 WARN_ON(1);
1232                                 goto free;
1233                         }
1234
1235                         /* Do not forward unexpected query */
1236                         if (amt_send_membership_query(amt, skb, tunnel, v6))
1237                                 goto free;
1238                         goto unlock;
1239                 }
1240
1241                 if (!data)
1242                         goto free;
1243                 list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list) {
1244                         hash = amt_group_hash(tunnel, &group);
1245                         hlist_for_each_entry_rcu(gnode, &tunnel->groups[hash],
1246                                                  node) {
1247                                 if (!v6) {
1248                                         if (gnode->group_addr.ip4 == iph->daddr)
1249                                                 goto found;
1250 #if IS_ENABLED(CONFIG_IPV6)
1251                                 } else {
1252                                         if (ipv6_addr_equal(&gnode->group_addr.ip6,
1253                                                             &ip6h->daddr))
1254                                                 goto found;
1255 #endif
1256                                 }
1257                         }
1258                         continue;
1259 found:
1260                         amt_send_multicast_data(amt, skb, tunnel, v6);
1261                 }
1262         }
1263
1264         dev_kfree_skb(skb);
1265         return NETDEV_TX_OK;
1266 free:
1267         dev_kfree_skb(skb);
1268 unlock:
1269         dev->stats.tx_dropped++;
1270         return NETDEV_TX_OK;
1271 }
1272
1273 static int amt_parse_type(struct sk_buff *skb)
1274 {
1275         struct amt_header *amth;
1276
1277         if (!pskb_may_pull(skb, sizeof(struct udphdr) +
1278                            sizeof(struct amt_header)))
1279                 return -1;
1280
1281         amth = (struct amt_header *)(udp_hdr(skb) + 1);
1282
1283         if (amth->version != 0)
1284                 return -1;
1285
1286         if (amth->type >= __AMT_MSG_MAX || !amth->type)
1287                 return -1;
1288         return amth->type;
1289 }
1290
1291 static void amt_clear_groups(struct amt_tunnel_list *tunnel)
1292 {
1293         struct amt_dev *amt = tunnel->amt;
1294         struct amt_group_node *gnode;
1295         struct hlist_node *t;
1296         int i;
1297
1298         spin_lock_bh(&tunnel->lock);
1299         rcu_read_lock();
1300         for (i = 0; i < amt->hash_buckets; i++)
1301                 hlist_for_each_entry_safe(gnode, t, &tunnel->groups[i], node)
1302                         amt_del_group(amt, gnode);
1303         rcu_read_unlock();
1304         spin_unlock_bh(&tunnel->lock);
1305 }
1306
1307 static void amt_tunnel_expire(struct work_struct *work)
1308 {
1309         struct amt_tunnel_list *tunnel = container_of(to_delayed_work(work),
1310                                                       struct amt_tunnel_list,
1311                                                       gc_wq);
1312         struct amt_dev *amt = tunnel->amt;
1313
1314         spin_lock_bh(&amt->lock);
1315         rcu_read_lock();
1316         list_del_rcu(&tunnel->list);
1317         amt->nr_tunnels--;
1318         amt_clear_groups(tunnel);
1319         rcu_read_unlock();
1320         spin_unlock_bh(&amt->lock);
1321         kfree_rcu(tunnel, rcu);
1322 }
1323
1324 static void amt_cleanup_srcs(struct amt_dev *amt,
1325                              struct amt_tunnel_list *tunnel,
1326                              struct amt_group_node *gnode)
1327 {
1328         struct amt_source_node *snode;
1329         struct hlist_node *t;
1330         int i;
1331
1332         /* Delete old sources */
1333         for (i = 0; i < amt->hash_buckets; i++) {
1334                 hlist_for_each_entry_safe(snode, t, &gnode->sources[i], node) {
1335                         if (snode->flags == AMT_SOURCE_OLD)
1336                                 amt_destroy_source(snode);
1337                 }
1338         }
1339
1340         /* switch from new to old */
1341         for (i = 0; i < amt->hash_buckets; i++)  {
1342                 hlist_for_each_entry_rcu(snode, &gnode->sources[i], node) {
1343                         snode->flags = AMT_SOURCE_OLD;
1344                         if (!gnode->v6)
1345                                 netdev_dbg(snode->gnode->amt->dev,
1346                                            "Add source as OLD %pI4 from %pI4\n",
1347                                            &snode->source_addr.ip4,
1348                                            &gnode->group_addr.ip4);
1349 #if IS_ENABLED(CONFIG_IPV6)
1350                         else
1351                                 netdev_dbg(snode->gnode->amt->dev,
1352                                            "Add source as OLD %pI6 from %pI6\n",
1353                                            &snode->source_addr.ip6,
1354                                            &gnode->group_addr.ip6);
1355 #endif
1356                 }
1357         }
1358 }
1359
1360 static void amt_add_srcs(struct amt_dev *amt, struct amt_tunnel_list *tunnel,
1361                          struct amt_group_node *gnode, void *grec,
1362                          bool v6)
1363 {
1364         struct igmpv3_grec *igmp_grec;
1365         struct amt_source_node *snode;
1366 #if IS_ENABLED(CONFIG_IPV6)
1367         struct mld2_grec *mld_grec;
1368 #endif
1369         union amt_addr src = {0,};
1370         u16 nsrcs;
1371         u32 hash;
1372         int i;
1373
1374         if (!v6) {
1375                 igmp_grec = (struct igmpv3_grec *)grec;
1376                 nsrcs = ntohs(igmp_grec->grec_nsrcs);
1377         } else {
1378 #if IS_ENABLED(CONFIG_IPV6)
1379                 mld_grec = (struct mld2_grec *)grec;
1380                 nsrcs = ntohs(mld_grec->grec_nsrcs);
1381 #else
1382         return;
1383 #endif
1384         }
1385         for (i = 0; i < nsrcs; i++) {
1386                 if (tunnel->nr_sources >= amt->max_sources)
1387                         return;
1388                 if (!v6)
1389                         src.ip4 = igmp_grec->grec_src[i];
1390 #if IS_ENABLED(CONFIG_IPV6)
1391                 else
1392                         memcpy(&src.ip6, &mld_grec->grec_src[i],
1393                                sizeof(struct in6_addr));
1394 #endif
1395                 if (amt_lookup_src(tunnel, gnode, AMT_FILTER_ALL, &src))
1396                         continue;
1397
1398                 snode = amt_alloc_snode(gnode, &src);
1399                 if (snode) {
1400                         hash = amt_source_hash(tunnel, &snode->source_addr);
1401                         hlist_add_head_rcu(&snode->node, &gnode->sources[hash]);
1402                         tunnel->nr_sources++;
1403                         gnode->nr_sources++;
1404
1405                         if (!gnode->v6)
1406                                 netdev_dbg(snode->gnode->amt->dev,
1407                                            "Add source as NEW %pI4 from %pI4\n",
1408                                            &snode->source_addr.ip4,
1409                                            &gnode->group_addr.ip4);
1410 #if IS_ENABLED(CONFIG_IPV6)
1411                         else
1412                                 netdev_dbg(snode->gnode->amt->dev,
1413                                            "Add source as NEW %pI6 from %pI6\n",
1414                                            &snode->source_addr.ip6,
1415                                            &gnode->group_addr.ip6);
1416 #endif
1417                 }
1418         }
1419 }
1420
1421 /* Router State   Report Rec'd New Router State
1422  * ------------   ------------ ----------------
1423  * EXCLUDE (X,Y)  IS_IN (A)    EXCLUDE (X+A,Y-A)
1424  *
1425  * -----------+-----------+-----------+
1426  *            |    OLD    |    NEW    |
1427  * -----------+-----------+-----------+
1428  *    FWD     |     X     |    X+A    |
1429  * -----------+-----------+-----------+
1430  *    D_FWD   |     Y     |    Y-A    |
1431  * -----------+-----------+-----------+
1432  *    NONE    |           |     A     |
1433  * -----------+-----------+-----------+
1434  *
1435  * a) Received sources are NONE/NEW
1436  * b) All NONE will be deleted by amt_cleanup_srcs().
1437  * c) All OLD will be deleted by amt_cleanup_srcs().
1438  * d) After delete, NEW source will be switched to OLD.
1439  */
1440 static void amt_lookup_act_srcs(struct amt_tunnel_list *tunnel,
1441                                 struct amt_group_node *gnode,
1442                                 void *grec,
1443                                 enum amt_ops ops,
1444                                 enum amt_filter filter,
1445                                 enum amt_act act,
1446                                 bool v6)
1447 {
1448         struct amt_dev *amt = tunnel->amt;
1449         struct amt_source_node *snode;
1450         struct igmpv3_grec *igmp_grec;
1451 #if IS_ENABLED(CONFIG_IPV6)
1452         struct mld2_grec *mld_grec;
1453 #endif
1454         union amt_addr src = {0,};
1455         struct hlist_node *t;
1456         u16 nsrcs;
1457         int i, j;
1458
1459         if (!v6) {
1460                 igmp_grec = (struct igmpv3_grec *)grec;
1461                 nsrcs = ntohs(igmp_grec->grec_nsrcs);
1462         } else {
1463 #if IS_ENABLED(CONFIG_IPV6)
1464                 mld_grec = (struct mld2_grec *)grec;
1465                 nsrcs = ntohs(mld_grec->grec_nsrcs);
1466 #else
1467         return;
1468 #endif
1469         }
1470
1471         memset(&src, 0, sizeof(union amt_addr));
1472         switch (ops) {
1473         case AMT_OPS_INT:
1474                 /* A*B */
1475                 for (i = 0; i < nsrcs; i++) {
1476                         if (!v6)
1477                                 src.ip4 = igmp_grec->grec_src[i];
1478 #if IS_ENABLED(CONFIG_IPV6)
1479                         else
1480                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1481                                        sizeof(struct in6_addr));
1482 #endif
1483                         snode = amt_lookup_src(tunnel, gnode, filter, &src);
1484                         if (!snode)
1485                                 continue;
1486                         amt_act_src(tunnel, gnode, snode, act);
1487                 }
1488                 break;
1489         case AMT_OPS_UNI:
1490                 /* A+B */
1491                 for (i = 0; i < amt->hash_buckets; i++) {
1492                         hlist_for_each_entry_safe(snode, t, &gnode->sources[i],
1493                                                   node) {
1494                                 if (amt_status_filter(snode, filter))
1495                                         amt_act_src(tunnel, gnode, snode, act);
1496                         }
1497                 }
1498                 for (i = 0; i < nsrcs; i++) {
1499                         if (!v6)
1500                                 src.ip4 = igmp_grec->grec_src[i];
1501 #if IS_ENABLED(CONFIG_IPV6)
1502                         else
1503                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1504                                        sizeof(struct in6_addr));
1505 #endif
1506                         snode = amt_lookup_src(tunnel, gnode, filter, &src);
1507                         if (!snode)
1508                                 continue;
1509                         amt_act_src(tunnel, gnode, snode, act);
1510                 }
1511                 break;
1512         case AMT_OPS_SUB:
1513                 /* A-B */
1514                 for (i = 0; i < amt->hash_buckets; i++) {
1515                         hlist_for_each_entry_safe(snode, t, &gnode->sources[i],
1516                                                   node) {
1517                                 if (!amt_status_filter(snode, filter))
1518                                         continue;
1519                                 for (j = 0; j < nsrcs; j++) {
1520                                         if (!v6)
1521                                                 src.ip4 = igmp_grec->grec_src[j];
1522 #if IS_ENABLED(CONFIG_IPV6)
1523                                         else
1524                                                 memcpy(&src.ip6,
1525                                                        &mld_grec->grec_src[j],
1526                                                        sizeof(struct in6_addr));
1527 #endif
1528                                         if (amt_addr_equal(&snode->source_addr,
1529                                                            &src))
1530                                                 goto out_sub;
1531                                 }
1532                                 amt_act_src(tunnel, gnode, snode, act);
1533                                 continue;
1534 out_sub:;
1535                         }
1536                 }
1537                 break;
1538         case AMT_OPS_SUB_REV:
1539                 /* B-A */
1540                 for (i = 0; i < nsrcs; i++) {
1541                         if (!v6)
1542                                 src.ip4 = igmp_grec->grec_src[i];
1543 #if IS_ENABLED(CONFIG_IPV6)
1544                         else
1545                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1546                                        sizeof(struct in6_addr));
1547 #endif
1548                         snode = amt_lookup_src(tunnel, gnode, AMT_FILTER_ALL,
1549                                                &src);
1550                         if (!snode) {
1551                                 snode = amt_lookup_src(tunnel, gnode,
1552                                                        filter, &src);
1553                                 if (snode)
1554                                         amt_act_src(tunnel, gnode, snode, act);
1555                         }
1556                 }
1557                 break;
1558         default:
1559                 netdev_dbg(amt->dev, "Invalid type\n");
1560                 return;
1561         }
1562 }
1563
1564 static void amt_mcast_is_in_handler(struct amt_dev *amt,
1565                                     struct amt_tunnel_list *tunnel,
1566                                     struct amt_group_node *gnode,
1567                                     void *grec, void *zero_grec, bool v6)
1568 {
1569         if (gnode->filter_mode == MCAST_INCLUDE) {
1570 /* Router State   Report Rec'd New Router State        Actions
1571  * ------------   ------------ ----------------        -------
1572  * INCLUDE (A)    IS_IN (B)    INCLUDE (A+B)           (B)=GMI
1573  */
1574                 /* Update IS_IN (B) as FWD/NEW */
1575                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1576                                     AMT_FILTER_NONE_NEW,
1577                                     AMT_ACT_STATUS_FWD_NEW,
1578                                     v6);
1579                 /* Update INCLUDE (A) as NEW */
1580                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1581                                     AMT_FILTER_FWD,
1582                                     AMT_ACT_STATUS_FWD_NEW,
1583                                     v6);
1584                 /* (B)=GMI */
1585                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1586                                     AMT_FILTER_FWD_NEW,
1587                                     AMT_ACT_GMI,
1588                                     v6);
1589         } else {
1590 /* State        Actions
1591  * ------------   ------------ ----------------        -------
1592  * EXCLUDE (X,Y)  IS_IN (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1593  */
1594                 /* Update (A) in (X, Y) as NONE/NEW */
1595                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1596                                     AMT_FILTER_BOTH,
1597                                     AMT_ACT_STATUS_NONE_NEW,
1598                                     v6);
1599                 /* Update FWD/OLD as FWD/NEW */
1600                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1601                                     AMT_FILTER_FWD,
1602                                     AMT_ACT_STATUS_FWD_NEW,
1603                                     v6);
1604                 /* Update IS_IN (A) as FWD/NEW */
1605                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1606                                     AMT_FILTER_NONE_NEW,
1607                                     AMT_ACT_STATUS_FWD_NEW,
1608                                     v6);
1609                 /* Update EXCLUDE (, Y-A) as D_FWD_NEW */
1610                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB,
1611                                     AMT_FILTER_D_FWD,
1612                                     AMT_ACT_STATUS_D_FWD_NEW,
1613                                     v6);
1614         }
1615 }
1616
1617 static void amt_mcast_is_ex_handler(struct amt_dev *amt,
1618                                     struct amt_tunnel_list *tunnel,
1619                                     struct amt_group_node *gnode,
1620                                     void *grec, void *zero_grec, bool v6)
1621 {
1622         if (gnode->filter_mode == MCAST_INCLUDE) {
1623 /* Router State   Report Rec'd  New Router State         Actions
1624  * ------------   ------------  ----------------         -------
1625  * INCLUDE (A)    IS_EX (B)     EXCLUDE (A*B,B-A)        (B-A)=0
1626  *                                                       Delete (A-B)
1627  *                                                       Group Timer=GMI
1628  */
1629                 /* EXCLUDE(A*B, ) */
1630                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1631                                     AMT_FILTER_FWD,
1632                                     AMT_ACT_STATUS_FWD_NEW,
1633                                     v6);
1634                 /* EXCLUDE(, B-A) */
1635                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1636                                     AMT_FILTER_FWD,
1637                                     AMT_ACT_STATUS_D_FWD_NEW,
1638                                     v6);
1639                 /* (B-A)=0 */
1640                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1641                                     AMT_FILTER_D_FWD_NEW,
1642                                     AMT_ACT_GMI_ZERO,
1643                                     v6);
1644                 /* Group Timer=GMI */
1645                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1646                                       msecs_to_jiffies(amt_gmi(amt))))
1647                         dev_hold(amt->dev);
1648                 gnode->filter_mode = MCAST_EXCLUDE;
1649                 /* Delete (A-B) will be worked by amt_cleanup_srcs(). */
1650         } else {
1651 /* Router State   Report Rec'd  New Router State        Actions
1652  * ------------   ------------  ----------------        -------
1653  * EXCLUDE (X,Y)  IS_EX (A)     EXCLUDE (A-Y,Y*A)       (A-X-Y)=GMI
1654  *                                                      Delete (X-A)
1655  *                                                      Delete (Y-A)
1656  *                                                      Group Timer=GMI
1657  */
1658                 /* EXCLUDE (A-Y, ) */
1659                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1660                                     AMT_FILTER_D_FWD,
1661                                     AMT_ACT_STATUS_FWD_NEW,
1662                                     v6);
1663                 /* EXCLUDE (, Y*A ) */
1664                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1665                                     AMT_FILTER_D_FWD,
1666                                     AMT_ACT_STATUS_D_FWD_NEW,
1667                                     v6);
1668                 /* (A-X-Y)=GMI */
1669                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1670                                     AMT_FILTER_BOTH_NEW,
1671                                     AMT_ACT_GMI,
1672                                     v6);
1673                 /* Group Timer=GMI */
1674                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1675                                       msecs_to_jiffies(amt_gmi(amt))))
1676                         dev_hold(amt->dev);
1677                 /* Delete (X-A), (Y-A) will be worked by amt_cleanup_srcs(). */
1678         }
1679 }
1680
1681 static void amt_mcast_to_in_handler(struct amt_dev *amt,
1682                                     struct amt_tunnel_list *tunnel,
1683                                     struct amt_group_node *gnode,
1684                                     void *grec, void *zero_grec, bool v6)
1685 {
1686         if (gnode->filter_mode == MCAST_INCLUDE) {
1687 /* Router State   Report Rec'd New Router State        Actions
1688  * ------------   ------------ ----------------        -------
1689  * INCLUDE (A)    TO_IN (B)    INCLUDE (A+B)           (B)=GMI
1690  *                                                     Send Q(G,A-B)
1691  */
1692                 /* Update TO_IN (B) sources as FWD/NEW */
1693                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1694                                     AMT_FILTER_NONE_NEW,
1695                                     AMT_ACT_STATUS_FWD_NEW,
1696                                     v6);
1697                 /* Update INCLUDE (A) sources as NEW */
1698                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1699                                     AMT_FILTER_FWD,
1700                                     AMT_ACT_STATUS_FWD_NEW,
1701                                     v6);
1702                 /* (B)=GMI */
1703                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1704                                     AMT_FILTER_FWD_NEW,
1705                                     AMT_ACT_GMI,
1706                                     v6);
1707         } else {
1708 /* Router State   Report Rec'd New Router State        Actions
1709  * ------------   ------------ ----------------        -------
1710  * EXCLUDE (X,Y)  TO_IN (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1711  *                                                     Send Q(G,X-A)
1712  *                                                     Send Q(G)
1713  */
1714                 /* Update TO_IN (A) sources as FWD/NEW */
1715                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1716                                     AMT_FILTER_NONE_NEW,
1717                                     AMT_ACT_STATUS_FWD_NEW,
1718                                     v6);
1719                 /* Update EXCLUDE(X,) sources as FWD/NEW */
1720                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1721                                     AMT_FILTER_FWD,
1722                                     AMT_ACT_STATUS_FWD_NEW,
1723                                     v6);
1724                 /* EXCLUDE (, Y-A)
1725                  * (A) are already switched to FWD_NEW.
1726                  * So, D_FWD/OLD -> D_FWD/NEW is okay.
1727                  */
1728                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1729                                     AMT_FILTER_D_FWD,
1730                                     AMT_ACT_STATUS_D_FWD_NEW,
1731                                     v6);
1732                 /* (A)=GMI
1733                  * Only FWD_NEW will have (A) sources.
1734                  */
1735                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1736                                     AMT_FILTER_FWD_NEW,
1737                                     AMT_ACT_GMI,
1738                                     v6);
1739         }
1740 }
1741
1742 static void amt_mcast_to_ex_handler(struct amt_dev *amt,
1743                                     struct amt_tunnel_list *tunnel,
1744                                     struct amt_group_node *gnode,
1745                                     void *grec, void *zero_grec, bool v6)
1746 {
1747         if (gnode->filter_mode == MCAST_INCLUDE) {
1748 /* Router State   Report Rec'd New Router State        Actions
1749  * ------------   ------------ ----------------        -------
1750  * INCLUDE (A)    TO_EX (B)    EXCLUDE (A*B,B-A)       (B-A)=0
1751  *                                                     Delete (A-B)
1752  *                                                     Send Q(G,A*B)
1753  *                                                     Group Timer=GMI
1754  */
1755                 /* EXCLUDE (A*B, ) */
1756                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1757                                     AMT_FILTER_FWD,
1758                                     AMT_ACT_STATUS_FWD_NEW,
1759                                     v6);
1760                 /* EXCLUDE (, B-A) */
1761                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1762                                     AMT_FILTER_FWD,
1763                                     AMT_ACT_STATUS_D_FWD_NEW,
1764                                     v6);
1765                 /* (B-A)=0 */
1766                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1767                                     AMT_FILTER_D_FWD_NEW,
1768                                     AMT_ACT_GMI_ZERO,
1769                                     v6);
1770                 /* Group Timer=GMI */
1771                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1772                                       msecs_to_jiffies(amt_gmi(amt))))
1773                         dev_hold(amt->dev);
1774                 gnode->filter_mode = MCAST_EXCLUDE;
1775                 /* Delete (A-B) will be worked by amt_cleanup_srcs(). */
1776         } else {
1777 /* Router State   Report Rec'd New Router State        Actions
1778  * ------------   ------------ ----------------        -------
1779  * EXCLUDE (X,Y)  TO_EX (A)    EXCLUDE (A-Y,Y*A)       (A-X-Y)=Group Timer
1780  *                                                     Delete (X-A)
1781  *                                                     Delete (Y-A)
1782  *                                                     Send Q(G,A-Y)
1783  *                                                     Group Timer=GMI
1784  */
1785                 /* Update (A-X-Y) as NONE/OLD */
1786                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1787                                     AMT_FILTER_BOTH,
1788                                     AMT_ACT_GT,
1789                                     v6);
1790                 /* EXCLUDE (A-Y, ) */
1791                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1792                                     AMT_FILTER_D_FWD,
1793                                     AMT_ACT_STATUS_FWD_NEW,
1794                                     v6);
1795                 /* EXCLUDE (, Y*A) */
1796                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1797                                     AMT_FILTER_D_FWD,
1798                                     AMT_ACT_STATUS_D_FWD_NEW,
1799                                     v6);
1800                 /* Group Timer=GMI */
1801                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1802                                       msecs_to_jiffies(amt_gmi(amt))))
1803                         dev_hold(amt->dev);
1804                 /* Delete (X-A), (Y-A) will be worked by amt_cleanup_srcs(). */
1805         }
1806 }
1807
1808 static void amt_mcast_allow_handler(struct amt_dev *amt,
1809                                     struct amt_tunnel_list *tunnel,
1810                                     struct amt_group_node *gnode,
1811                                     void *grec, void *zero_grec, bool v6)
1812 {
1813         if (gnode->filter_mode == MCAST_INCLUDE) {
1814 /* Router State   Report Rec'd New Router State        Actions
1815  * ------------   ------------ ----------------        -------
1816  * INCLUDE (A)    ALLOW (B)    INCLUDE (A+B)           (B)=GMI
1817  */
1818                 /* INCLUDE (A+B) */
1819                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1820                                     AMT_FILTER_FWD,
1821                                     AMT_ACT_STATUS_FWD_NEW,
1822                                     v6);
1823                 /* (B)=GMI */
1824                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1825                                     AMT_FILTER_FWD_NEW,
1826                                     AMT_ACT_GMI,
1827                                     v6);
1828         } else {
1829 /* Router State   Report Rec'd New Router State        Actions
1830  * ------------   ------------ ----------------        -------
1831  * EXCLUDE (X,Y)  ALLOW (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1832  */
1833                 /* EXCLUDE (X+A, ) */
1834                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1835                                     AMT_FILTER_FWD,
1836                                     AMT_ACT_STATUS_FWD_NEW,
1837                                     v6);
1838                 /* EXCLUDE (, Y-A) */
1839                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB,
1840                                     AMT_FILTER_D_FWD,
1841                                     AMT_ACT_STATUS_D_FWD_NEW,
1842                                     v6);
1843                 /* (A)=GMI
1844                  * All (A) source are now FWD/NEW status.
1845                  */
1846                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1847                                     AMT_FILTER_FWD_NEW,
1848                                     AMT_ACT_GMI,
1849                                     v6);
1850         }
1851 }
1852
1853 static void amt_mcast_block_handler(struct amt_dev *amt,
1854                                     struct amt_tunnel_list *tunnel,
1855                                     struct amt_group_node *gnode,
1856                                     void *grec, void *zero_grec, bool v6)
1857 {
1858         if (gnode->filter_mode == MCAST_INCLUDE) {
1859 /* Router State   Report Rec'd New Router State        Actions
1860  * ------------   ------------ ----------------        -------
1861  * INCLUDE (A)    BLOCK (B)    INCLUDE (A)             Send Q(G,A*B)
1862  */
1863                 /* INCLUDE (A) */
1864                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1865                                     AMT_FILTER_FWD,
1866                                     AMT_ACT_STATUS_FWD_NEW,
1867                                     v6);
1868         } else {
1869 /* Router State   Report Rec'd New Router State        Actions
1870  * ------------   ------------ ----------------        -------
1871  * EXCLUDE (X,Y)  BLOCK (A)    EXCLUDE (X+(A-Y),Y)     (A-X-Y)=Group Timer
1872  *                                                     Send Q(G,A-Y)
1873  */
1874                 /* (A-X-Y)=Group Timer */
1875                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1876                                     AMT_FILTER_BOTH,
1877                                     AMT_ACT_GT,
1878                                     v6);
1879                 /* EXCLUDE (X, ) */
1880                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1881                                     AMT_FILTER_FWD,
1882                                     AMT_ACT_STATUS_FWD_NEW,
1883                                     v6);
1884                 /* EXCLUDE (X+(A-Y) */
1885                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1886                                     AMT_FILTER_D_FWD,
1887                                     AMT_ACT_STATUS_FWD_NEW,
1888                                     v6);
1889                 /* EXCLUDE (, Y) */
1890                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1891                                     AMT_FILTER_D_FWD,
1892                                     AMT_ACT_STATUS_D_FWD_NEW,
1893                                     v6);
1894         }
1895 }
1896
1897 /* RFC 3376
1898  * 7.3.2. In the Presence of Older Version Group Members
1899  *
1900  * When Group Compatibility Mode is IGMPv2, a router internally
1901  * translates the following IGMPv2 messages for that group to their
1902  * IGMPv3 equivalents:
1903  *
1904  * IGMPv2 Message                IGMPv3 Equivalent
1905  * --------------                -----------------
1906  * Report                        IS_EX( {} )
1907  * Leave                         TO_IN( {} )
1908  */
1909 static void amt_igmpv2_report_handler(struct amt_dev *amt, struct sk_buff *skb,
1910                                       struct amt_tunnel_list *tunnel)
1911 {
1912         struct igmphdr *ih = igmp_hdr(skb);
1913         struct iphdr *iph = ip_hdr(skb);
1914         struct amt_group_node *gnode;
1915         union amt_addr group, host;
1916
1917         memset(&group, 0, sizeof(union amt_addr));
1918         group.ip4 = ih->group;
1919         memset(&host, 0, sizeof(union amt_addr));
1920         host.ip4 = iph->saddr;
1921
1922         gnode = amt_lookup_group(tunnel, &group, &host, false);
1923         if (!gnode) {
1924                 gnode = amt_add_group(amt, tunnel, &group, &host, false);
1925                 if (!IS_ERR(gnode)) {
1926                         gnode->filter_mode = MCAST_EXCLUDE;
1927                         if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1928                                               msecs_to_jiffies(amt_gmi(amt))))
1929                                 dev_hold(amt->dev);
1930                 }
1931         }
1932 }
1933
1934 /* RFC 3376
1935  * 7.3.2. In the Presence of Older Version Group Members
1936  *
1937  * When Group Compatibility Mode is IGMPv2, a router internally
1938  * translates the following IGMPv2 messages for that group to their
1939  * IGMPv3 equivalents:
1940  *
1941  * IGMPv2 Message                IGMPv3 Equivalent
1942  * --------------                -----------------
1943  * Report                        IS_EX( {} )
1944  * Leave                         TO_IN( {} )
1945  */
1946 static void amt_igmpv2_leave_handler(struct amt_dev *amt, struct sk_buff *skb,
1947                                      struct amt_tunnel_list *tunnel)
1948 {
1949         struct igmphdr *ih = igmp_hdr(skb);
1950         struct iphdr *iph = ip_hdr(skb);
1951         struct amt_group_node *gnode;
1952         union amt_addr group, host;
1953
1954         memset(&group, 0, sizeof(union amt_addr));
1955         group.ip4 = ih->group;
1956         memset(&host, 0, sizeof(union amt_addr));
1957         host.ip4 = iph->saddr;
1958
1959         gnode = amt_lookup_group(tunnel, &group, &host, false);
1960         if (gnode)
1961                 amt_del_group(amt, gnode);
1962 }
1963
1964 static void amt_igmpv3_report_handler(struct amt_dev *amt, struct sk_buff *skb,
1965                                       struct amt_tunnel_list *tunnel)
1966 {
1967         struct igmpv3_report *ihrv3 = igmpv3_report_hdr(skb);
1968         int len = skb_transport_offset(skb) + sizeof(*ihrv3);
1969         void *zero_grec = (void *)&igmpv3_zero_grec;
1970         struct iphdr *iph = ip_hdr(skb);
1971         struct amt_group_node *gnode;
1972         union amt_addr group, host;
1973         struct igmpv3_grec *grec;
1974         u16 nsrcs;
1975         int i;
1976
1977         for (i = 0; i < ntohs(ihrv3->ngrec); i++) {
1978                 len += sizeof(*grec);
1979                 if (!ip_mc_may_pull(skb, len))
1980                         break;
1981
1982                 grec = (void *)(skb->data + len - sizeof(*grec));
1983                 nsrcs = ntohs(grec->grec_nsrcs);
1984
1985                 len += nsrcs * sizeof(__be32);
1986                 if (!ip_mc_may_pull(skb, len))
1987                         break;
1988
1989                 memset(&group, 0, sizeof(union amt_addr));
1990                 group.ip4 = grec->grec_mca;
1991                 memset(&host, 0, sizeof(union amt_addr));
1992                 host.ip4 = iph->saddr;
1993                 gnode = amt_lookup_group(tunnel, &group, &host, false);
1994                 if (!gnode) {
1995                         gnode = amt_add_group(amt, tunnel, &group, &host,
1996                                               false);
1997                         if (IS_ERR(gnode))
1998                                 continue;
1999                 }
2000
2001                 amt_add_srcs(amt, tunnel, gnode, grec, false);
2002                 switch (grec->grec_type) {
2003                 case IGMPV3_MODE_IS_INCLUDE:
2004                         amt_mcast_is_in_handler(amt, tunnel, gnode, grec,
2005                                                 zero_grec, false);
2006                         break;
2007                 case IGMPV3_MODE_IS_EXCLUDE:
2008                         amt_mcast_is_ex_handler(amt, tunnel, gnode, grec,
2009                                                 zero_grec, false);
2010                         break;
2011                 case IGMPV3_CHANGE_TO_INCLUDE:
2012                         amt_mcast_to_in_handler(amt, tunnel, gnode, grec,
2013                                                 zero_grec, false);
2014                         break;
2015                 case IGMPV3_CHANGE_TO_EXCLUDE:
2016                         amt_mcast_to_ex_handler(amt, tunnel, gnode, grec,
2017                                                 zero_grec, false);
2018                         break;
2019                 case IGMPV3_ALLOW_NEW_SOURCES:
2020                         amt_mcast_allow_handler(amt, tunnel, gnode, grec,
2021                                                 zero_grec, false);
2022                         break;
2023                 case IGMPV3_BLOCK_OLD_SOURCES:
2024                         amt_mcast_block_handler(amt, tunnel, gnode, grec,
2025                                                 zero_grec, false);
2026                         break;
2027                 default:
2028                         break;
2029                 }
2030                 amt_cleanup_srcs(amt, tunnel, gnode);
2031         }
2032 }
2033
2034 /* caller held tunnel->lock */
2035 static void amt_igmp_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2036                                     struct amt_tunnel_list *tunnel)
2037 {
2038         struct igmphdr *ih = igmp_hdr(skb);
2039
2040         switch (ih->type) {
2041         case IGMPV3_HOST_MEMBERSHIP_REPORT:
2042                 amt_igmpv3_report_handler(amt, skb, tunnel);
2043                 break;
2044         case IGMPV2_HOST_MEMBERSHIP_REPORT:
2045                 amt_igmpv2_report_handler(amt, skb, tunnel);
2046                 break;
2047         case IGMP_HOST_LEAVE_MESSAGE:
2048                 amt_igmpv2_leave_handler(amt, skb, tunnel);
2049                 break;
2050         default:
2051                 break;
2052         }
2053 }
2054
2055 #if IS_ENABLED(CONFIG_IPV6)
2056 /* RFC 3810
2057  * 8.3.2. In the Presence of MLDv1 Multicast Address Listeners
2058  *
2059  * When Multicast Address Compatibility Mode is MLDv2, a router acts
2060  * using the MLDv2 protocol for that multicast address.  When Multicast
2061  * Address Compatibility Mode is MLDv1, a router internally translates
2062  * the following MLDv1 messages for that multicast address to their
2063  * MLDv2 equivalents:
2064  *
2065  * MLDv1 Message                 MLDv2 Equivalent
2066  * --------------                -----------------
2067  * Report                        IS_EX( {} )
2068  * Done                          TO_IN( {} )
2069  */
2070 static void amt_mldv1_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2071                                      struct amt_tunnel_list *tunnel)
2072 {
2073         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2074         struct ipv6hdr *ip6h = ipv6_hdr(skb);
2075         struct amt_group_node *gnode;
2076         union amt_addr group, host;
2077
2078         memcpy(&group.ip6, &mld->mld_mca, sizeof(struct in6_addr));
2079         memcpy(&host.ip6, &ip6h->saddr, sizeof(struct in6_addr));
2080
2081         gnode = amt_lookup_group(tunnel, &group, &host, true);
2082         if (!gnode) {
2083                 gnode = amt_add_group(amt, tunnel, &group, &host, true);
2084                 if (!IS_ERR(gnode)) {
2085                         gnode->filter_mode = MCAST_EXCLUDE;
2086                         if (!mod_delayed_work(amt_wq, &gnode->group_timer,
2087                                               msecs_to_jiffies(amt_gmi(amt))))
2088                                 dev_hold(amt->dev);
2089                 }
2090         }
2091 }
2092
2093 /* RFC 3810
2094  * 8.3.2. In the Presence of MLDv1 Multicast Address Listeners
2095  *
2096  * When Multicast Address Compatibility Mode is MLDv2, a router acts
2097  * using the MLDv2 protocol for that multicast address.  When Multicast
2098  * Address Compatibility Mode is MLDv1, a router internally translates
2099  * the following MLDv1 messages for that multicast address to their
2100  * MLDv2 equivalents:
2101  *
2102  * MLDv1 Message                 MLDv2 Equivalent
2103  * --------------                -----------------
2104  * Report                        IS_EX( {} )
2105  * Done                          TO_IN( {} )
2106  */
2107 static void amt_mldv1_leave_handler(struct amt_dev *amt, struct sk_buff *skb,
2108                                     struct amt_tunnel_list *tunnel)
2109 {
2110         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2111         struct iphdr *iph = ip_hdr(skb);
2112         struct amt_group_node *gnode;
2113         union amt_addr group, host;
2114
2115         memcpy(&group.ip6, &mld->mld_mca, sizeof(struct in6_addr));
2116         memset(&host, 0, sizeof(union amt_addr));
2117         host.ip4 = iph->saddr;
2118
2119         gnode = amt_lookup_group(tunnel, &group, &host, true);
2120         if (gnode) {
2121                 amt_del_group(amt, gnode);
2122                 return;
2123         }
2124 }
2125
2126 static void amt_mldv2_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2127                                      struct amt_tunnel_list *tunnel)
2128 {
2129         struct mld2_report *mld2r = (struct mld2_report *)icmp6_hdr(skb);
2130         int len = skb_transport_offset(skb) + sizeof(*mld2r);
2131         void *zero_grec = (void *)&mldv2_zero_grec;
2132         struct ipv6hdr *ip6h = ipv6_hdr(skb);
2133         struct amt_group_node *gnode;
2134         union amt_addr group, host;
2135         struct mld2_grec *grec;
2136         u16 nsrcs;
2137         int i;
2138
2139         for (i = 0; i < ntohs(mld2r->mld2r_ngrec); i++) {
2140                 len += sizeof(*grec);
2141                 if (!ipv6_mc_may_pull(skb, len))
2142                         break;
2143
2144                 grec = (void *)(skb->data + len - sizeof(*grec));
2145                 nsrcs = ntohs(grec->grec_nsrcs);
2146
2147                 len += nsrcs * sizeof(struct in6_addr);
2148                 if (!ipv6_mc_may_pull(skb, len))
2149                         break;
2150
2151                 memset(&group, 0, sizeof(union amt_addr));
2152                 group.ip6 = grec->grec_mca;
2153                 memset(&host, 0, sizeof(union amt_addr));
2154                 host.ip6 = ip6h->saddr;
2155                 gnode = amt_lookup_group(tunnel, &group, &host, true);
2156                 if (!gnode) {
2157                         gnode = amt_add_group(amt, tunnel, &group, &host,
2158                                               ETH_P_IPV6);
2159                         if (IS_ERR(gnode))
2160                                 continue;
2161                 }
2162
2163                 amt_add_srcs(amt, tunnel, gnode, grec, true);
2164                 switch (grec->grec_type) {
2165                 case MLD2_MODE_IS_INCLUDE:
2166                         amt_mcast_is_in_handler(amt, tunnel, gnode, grec,
2167                                                 zero_grec, true);
2168                         break;
2169                 case MLD2_MODE_IS_EXCLUDE:
2170                         amt_mcast_is_ex_handler(amt, tunnel, gnode, grec,
2171                                                 zero_grec, true);
2172                         break;
2173                 case MLD2_CHANGE_TO_INCLUDE:
2174                         amt_mcast_to_in_handler(amt, tunnel, gnode, grec,
2175                                                 zero_grec, true);
2176                         break;
2177                 case MLD2_CHANGE_TO_EXCLUDE:
2178                         amt_mcast_to_ex_handler(amt, tunnel, gnode, grec,
2179                                                 zero_grec, true);
2180                         break;
2181                 case MLD2_ALLOW_NEW_SOURCES:
2182                         amt_mcast_allow_handler(amt, tunnel, gnode, grec,
2183                                                 zero_grec, true);
2184                         break;
2185                 case MLD2_BLOCK_OLD_SOURCES:
2186                         amt_mcast_block_handler(amt, tunnel, gnode, grec,
2187                                                 zero_grec, true);
2188                         break;
2189                 default:
2190                         break;
2191                 }
2192                 amt_cleanup_srcs(amt, tunnel, gnode);
2193         }
2194 }
2195
2196 /* caller held tunnel->lock */
2197 static void amt_mld_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2198                                    struct amt_tunnel_list *tunnel)
2199 {
2200         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2201
2202         switch (mld->mld_type) {
2203         case ICMPV6_MGM_REPORT:
2204                 amt_mldv1_report_handler(amt, skb, tunnel);
2205                 break;
2206         case ICMPV6_MLD2_REPORT:
2207                 amt_mldv2_report_handler(amt, skb, tunnel);
2208                 break;
2209         case ICMPV6_MGM_REDUCTION:
2210                 amt_mldv1_leave_handler(amt, skb, tunnel);
2211                 break;
2212         default:
2213                 break;
2214         }
2215 }
2216 #endif
2217
2218 static bool amt_advertisement_handler(struct amt_dev *amt, struct sk_buff *skb)
2219 {
2220         struct amt_header_advertisement *amta;
2221         int hdr_size;
2222
2223         hdr_size = sizeof(*amta) + sizeof(struct udphdr);
2224         if (!pskb_may_pull(skb, hdr_size))
2225                 return true;
2226
2227         amta = (struct amt_header_advertisement *)(udp_hdr(skb) + 1);
2228         if (!amta->ip4)
2229                 return true;
2230
2231         if (amta->reserved || amta->version)
2232                 return true;
2233
2234         if (ipv4_is_loopback(amta->ip4) || ipv4_is_multicast(amta->ip4) ||
2235             ipv4_is_zeronet(amta->ip4))
2236                 return true;
2237
2238         amt->remote_ip = amta->ip4;
2239         netdev_dbg(amt->dev, "advertised remote ip = %pI4\n", &amt->remote_ip);
2240         mod_delayed_work(amt_wq, &amt->req_wq, 0);
2241
2242         amt_update_gw_status(amt, AMT_STATUS_RECEIVED_ADVERTISEMENT, true);
2243         return false;
2244 }
2245
2246 static bool amt_multicast_data_handler(struct amt_dev *amt, struct sk_buff *skb)
2247 {
2248         struct amt_header_mcast_data *amtmd;
2249         int hdr_size, len, err;
2250         struct ethhdr *eth;
2251         struct iphdr *iph;
2252
2253         hdr_size = sizeof(*amtmd) + sizeof(struct udphdr);
2254         if (!pskb_may_pull(skb, hdr_size))
2255                 return true;
2256
2257         amtmd = (struct amt_header_mcast_data *)(udp_hdr(skb) + 1);
2258         if (amtmd->reserved || amtmd->version)
2259                 return true;
2260
2261         if (iptunnel_pull_header(skb, hdr_size, htons(ETH_P_IP), false))
2262                 return true;
2263
2264         skb_reset_network_header(skb);
2265         skb_push(skb, sizeof(*eth));
2266         skb_reset_mac_header(skb);
2267         skb_pull(skb, sizeof(*eth));
2268         eth = eth_hdr(skb);
2269
2270         if (!pskb_may_pull(skb, sizeof(*iph)))
2271                 return true;
2272         iph = ip_hdr(skb);
2273
2274         if (iph->version == 4) {
2275                 if (!ipv4_is_multicast(iph->daddr))
2276                         return true;
2277                 skb->protocol = htons(ETH_P_IP);
2278                 eth->h_proto = htons(ETH_P_IP);
2279                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2280 #if IS_ENABLED(CONFIG_IPV6)
2281         } else if (iph->version == 6) {
2282                 struct ipv6hdr *ip6h;
2283
2284                 if (!pskb_may_pull(skb, sizeof(*ip6h)))
2285                         return true;
2286
2287                 ip6h = ipv6_hdr(skb);
2288                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
2289                         return true;
2290                 skb->protocol = htons(ETH_P_IPV6);
2291                 eth->h_proto = htons(ETH_P_IPV6);
2292                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2293 #endif
2294         } else {
2295                 return true;
2296         }
2297
2298         skb->pkt_type = PACKET_MULTICAST;
2299         skb->ip_summed = CHECKSUM_NONE;
2300         len = skb->len;
2301         err = gro_cells_receive(&amt->gro_cells, skb);
2302         if (likely(err == NET_RX_SUCCESS))
2303                 dev_sw_netstats_rx_add(amt->dev, len);
2304         else
2305                 amt->dev->stats.rx_dropped++;
2306
2307         return false;
2308 }
2309
2310 static bool amt_membership_query_handler(struct amt_dev *amt,
2311                                          struct sk_buff *skb)
2312 {
2313         struct amt_header_membership_query *amtmq;
2314         struct igmpv3_query *ihv3;
2315         struct ethhdr *eth, *oeth;
2316         struct iphdr *iph;
2317         int hdr_size, len;
2318
2319         hdr_size = sizeof(*amtmq) + sizeof(struct udphdr);
2320         if (!pskb_may_pull(skb, hdr_size))
2321                 return true;
2322
2323         amtmq = (struct amt_header_membership_query *)(udp_hdr(skb) + 1);
2324         if (amtmq->reserved || amtmq->version)
2325                 return true;
2326
2327         hdr_size -= sizeof(*eth);
2328         if (iptunnel_pull_header(skb, hdr_size, htons(ETH_P_TEB), false))
2329                 return true;
2330
2331         oeth = eth_hdr(skb);
2332         skb_reset_mac_header(skb);
2333         skb_pull(skb, sizeof(*eth));
2334         skb_reset_network_header(skb);
2335         eth = eth_hdr(skb);
2336         if (!pskb_may_pull(skb, sizeof(*iph)))
2337                 return true;
2338
2339         iph = ip_hdr(skb);
2340         if (iph->version == 4) {
2341                 if (!pskb_may_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS +
2342                                    sizeof(*ihv3)))
2343                         return true;
2344
2345                 if (!ipv4_is_multicast(iph->daddr))
2346                         return true;
2347
2348                 ihv3 = skb_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
2349                 skb_reset_transport_header(skb);
2350                 skb_push(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
2351                 spin_lock_bh(&amt->lock);
2352                 amt->ready4 = true;
2353                 amt->mac = amtmq->response_mac;
2354                 amt->req_cnt = 0;
2355                 amt->qi = ihv3->qqic;
2356                 spin_unlock_bh(&amt->lock);
2357                 skb->protocol = htons(ETH_P_IP);
2358                 eth->h_proto = htons(ETH_P_IP);
2359                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2360 #if IS_ENABLED(CONFIG_IPV6)
2361         } else if (iph->version == 6) {
2362                 struct mld2_query *mld2q;
2363                 struct ipv6hdr *ip6h;
2364
2365                 if (!pskb_may_pull(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS +
2366                                    sizeof(*mld2q)))
2367                         return true;
2368
2369                 ip6h = ipv6_hdr(skb);
2370                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
2371                         return true;
2372
2373                 mld2q = skb_pull(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS);
2374                 skb_reset_transport_header(skb);
2375                 skb_push(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS);
2376                 spin_lock_bh(&amt->lock);
2377                 amt->ready6 = true;
2378                 amt->mac = amtmq->response_mac;
2379                 amt->req_cnt = 0;
2380                 amt->qi = mld2q->mld2q_qqic;
2381                 spin_unlock_bh(&amt->lock);
2382                 skb->protocol = htons(ETH_P_IPV6);
2383                 eth->h_proto = htons(ETH_P_IPV6);
2384                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2385 #endif
2386         } else {
2387                 return true;
2388         }
2389
2390         ether_addr_copy(eth->h_source, oeth->h_source);
2391         skb->pkt_type = PACKET_MULTICAST;
2392         skb->ip_summed = CHECKSUM_NONE;
2393         len = skb->len;
2394         if (__netif_rx(skb) == NET_RX_SUCCESS) {
2395                 amt_update_gw_status(amt, AMT_STATUS_RECEIVED_QUERY, true);
2396                 dev_sw_netstats_rx_add(amt->dev, len);
2397         } else {
2398                 amt->dev->stats.rx_dropped++;
2399         }
2400
2401         return false;
2402 }
2403
2404 static bool amt_update_handler(struct amt_dev *amt, struct sk_buff *skb)
2405 {
2406         struct amt_header_membership_update *amtmu;
2407         struct amt_tunnel_list *tunnel;
2408         struct ethhdr *eth;
2409         struct iphdr *iph;
2410         int len, hdr_size;
2411
2412         iph = ip_hdr(skb);
2413
2414         hdr_size = sizeof(*amtmu) + sizeof(struct udphdr);
2415         if (!pskb_may_pull(skb, hdr_size))
2416                 return true;
2417
2418         amtmu = (struct amt_header_membership_update *)(udp_hdr(skb) + 1);
2419         if (amtmu->reserved || amtmu->version)
2420                 return true;
2421
2422         if (iptunnel_pull_header(skb, hdr_size, skb->protocol, false))
2423                 return true;
2424
2425         skb_reset_network_header(skb);
2426
2427         list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list) {
2428                 if (tunnel->ip4 == iph->saddr) {
2429                         if ((amtmu->nonce == tunnel->nonce &&
2430                              amtmu->response_mac == tunnel->mac)) {
2431                                 mod_delayed_work(amt_wq, &tunnel->gc_wq,
2432                                                  msecs_to_jiffies(amt_gmi(amt))
2433                                                                   * 3);
2434                                 goto report;
2435                         } else {
2436                                 netdev_dbg(amt->dev, "Invalid MAC\n");
2437                                 return true;
2438                         }
2439                 }
2440         }
2441
2442         return true;
2443
2444 report:
2445         if (!pskb_may_pull(skb, sizeof(*iph)))
2446                 return true;
2447
2448         iph = ip_hdr(skb);
2449         if (iph->version == 4) {
2450                 if (ip_mc_check_igmp(skb)) {
2451                         netdev_dbg(amt->dev, "Invalid IGMP\n");
2452                         return true;
2453                 }
2454
2455                 spin_lock_bh(&tunnel->lock);
2456                 amt_igmp_report_handler(amt, skb, tunnel);
2457                 spin_unlock_bh(&tunnel->lock);
2458
2459                 skb_push(skb, sizeof(struct ethhdr));
2460                 skb_reset_mac_header(skb);
2461                 eth = eth_hdr(skb);
2462                 skb->protocol = htons(ETH_P_IP);
2463                 eth->h_proto = htons(ETH_P_IP);
2464                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2465 #if IS_ENABLED(CONFIG_IPV6)
2466         } else if (iph->version == 6) {
2467                 struct ipv6hdr *ip6h = ipv6_hdr(skb);
2468
2469                 if (ipv6_mc_check_mld(skb)) {
2470                         netdev_dbg(amt->dev, "Invalid MLD\n");
2471                         return true;
2472                 }
2473
2474                 spin_lock_bh(&tunnel->lock);
2475                 amt_mld_report_handler(amt, skb, tunnel);
2476                 spin_unlock_bh(&tunnel->lock);
2477
2478                 skb_push(skb, sizeof(struct ethhdr));
2479                 skb_reset_mac_header(skb);
2480                 eth = eth_hdr(skb);
2481                 skb->protocol = htons(ETH_P_IPV6);
2482                 eth->h_proto = htons(ETH_P_IPV6);
2483                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2484 #endif
2485         } else {
2486                 netdev_dbg(amt->dev, "Unsupported Protocol\n");
2487                 return true;
2488         }
2489
2490         skb_pull(skb, sizeof(struct ethhdr));
2491         skb->pkt_type = PACKET_MULTICAST;
2492         skb->ip_summed = CHECKSUM_NONE;
2493         len = skb->len;
2494         if (__netif_rx(skb) == NET_RX_SUCCESS) {
2495                 amt_update_relay_status(tunnel, AMT_STATUS_RECEIVED_UPDATE,
2496                                         true);
2497                 dev_sw_netstats_rx_add(amt->dev, len);
2498         } else {
2499                 amt->dev->stats.rx_dropped++;
2500         }
2501
2502         return false;
2503 }
2504
2505 static void amt_send_advertisement(struct amt_dev *amt, __be32 nonce,
2506                                    __be32 daddr, __be16 dport)
2507 {
2508         struct amt_header_advertisement *amta;
2509         int hlen, tlen, offset;
2510         struct socket *sock;
2511         struct udphdr *udph;
2512         struct sk_buff *skb;
2513         struct iphdr *iph;
2514         struct rtable *rt;
2515         struct flowi4 fl4;
2516         u32 len;
2517         int err;
2518
2519         rcu_read_lock();
2520         sock = rcu_dereference(amt->sock);
2521         if (!sock)
2522                 goto out;
2523
2524         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
2525                 goto out;
2526
2527         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
2528                                    daddr, amt->local_ip,
2529                                    dport, amt->relay_port,
2530                                    IPPROTO_UDP, 0,
2531                                    amt->stream_dev->ifindex);
2532         if (IS_ERR(rt)) {
2533                 amt->dev->stats.tx_errors++;
2534                 goto out;
2535         }
2536
2537         hlen = LL_RESERVED_SPACE(amt->dev);
2538         tlen = amt->dev->needed_tailroom;
2539         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amta);
2540         skb = netdev_alloc_skb_ip_align(amt->dev, len);
2541         if (!skb) {
2542                 ip_rt_put(rt);
2543                 amt->dev->stats.tx_errors++;
2544                 goto out;
2545         }
2546
2547         skb->priority = TC_PRIO_CONTROL;
2548         skb_dst_set(skb, &rt->dst);
2549
2550         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amta);
2551         skb_reset_network_header(skb);
2552         skb_put(skb, len);
2553         amta = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
2554         amta->version   = 0;
2555         amta->type      = AMT_MSG_ADVERTISEMENT;
2556         amta->reserved  = 0;
2557         amta->nonce     = nonce;
2558         amta->ip4       = amt->local_ip;
2559         skb_push(skb, sizeof(*udph));
2560         skb_reset_transport_header(skb);
2561         udph            = udp_hdr(skb);
2562         udph->source    = amt->relay_port;
2563         udph->dest      = dport;
2564         udph->len       = htons(sizeof(*amta) + sizeof(*udph));
2565         udph->check     = 0;
2566         offset = skb_transport_offset(skb);
2567         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
2568         udph->check = csum_tcpudp_magic(amt->local_ip, daddr,
2569                                         sizeof(*udph) + sizeof(*amta),
2570                                         IPPROTO_UDP, skb->csum);
2571
2572         skb_push(skb, sizeof(*iph));
2573         iph             = ip_hdr(skb);
2574         iph->version    = 4;
2575         iph->ihl        = (sizeof(struct iphdr)) >> 2;
2576         iph->tos        = AMT_TOS;
2577         iph->frag_off   = 0;
2578         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
2579         iph->daddr      = daddr;
2580         iph->saddr      = amt->local_ip;
2581         iph->protocol   = IPPROTO_UDP;
2582         iph->tot_len    = htons(len);
2583
2584         skb->ip_summed = CHECKSUM_NONE;
2585         ip_select_ident(amt->net, skb, NULL);
2586         ip_send_check(iph);
2587         err = ip_local_out(amt->net, sock->sk, skb);
2588         if (unlikely(net_xmit_eval(err)))
2589                 amt->dev->stats.tx_errors++;
2590
2591 out:
2592         rcu_read_unlock();
2593 }
2594
2595 static bool amt_discovery_handler(struct amt_dev *amt, struct sk_buff *skb)
2596 {
2597         struct amt_header_discovery *amtd;
2598         struct udphdr *udph;
2599         struct iphdr *iph;
2600
2601         if (!pskb_may_pull(skb, sizeof(*udph) + sizeof(*amtd)))
2602                 return true;
2603
2604         iph = ip_hdr(skb);
2605         udph = udp_hdr(skb);
2606         amtd = (struct amt_header_discovery *)(udp_hdr(skb) + 1);
2607
2608         if (amtd->reserved || amtd->version)
2609                 return true;
2610
2611         amt_send_advertisement(amt, amtd->nonce, iph->saddr, udph->source);
2612
2613         return false;
2614 }
2615
2616 static bool amt_request_handler(struct amt_dev *amt, struct sk_buff *skb)
2617 {
2618         struct amt_header_request *amtrh;
2619         struct amt_tunnel_list *tunnel;
2620         unsigned long long key;
2621         struct udphdr *udph;
2622         struct iphdr *iph;
2623         u64 mac;
2624         int i;
2625
2626         if (!pskb_may_pull(skb, sizeof(*udph) + sizeof(*amtrh)))
2627                 return true;
2628
2629         iph = ip_hdr(skb);
2630         udph = udp_hdr(skb);
2631         amtrh = (struct amt_header_request *)(udp_hdr(skb) + 1);
2632
2633         if (amtrh->reserved1 || amtrh->reserved2 || amtrh->version)
2634                 return true;
2635
2636         list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list)
2637                 if (tunnel->ip4 == iph->saddr)
2638                         goto send;
2639
2640         if (amt->nr_tunnels >= amt->max_tunnels) {
2641                 icmp_ndo_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_UNREACH, 0);
2642                 return true;
2643         }
2644
2645         tunnel = kzalloc(sizeof(*tunnel) +
2646                          (sizeof(struct hlist_head) * amt->hash_buckets),
2647                          GFP_ATOMIC);
2648         if (!tunnel)
2649                 return true;
2650
2651         tunnel->source_port = udph->source;
2652         tunnel->ip4 = iph->saddr;
2653
2654         memcpy(&key, &tunnel->key, sizeof(unsigned long long));
2655         tunnel->amt = amt;
2656         spin_lock_init(&tunnel->lock);
2657         for (i = 0; i < amt->hash_buckets; i++)
2658                 INIT_HLIST_HEAD(&tunnel->groups[i]);
2659
2660         INIT_DELAYED_WORK(&tunnel->gc_wq, amt_tunnel_expire);
2661
2662         spin_lock_bh(&amt->lock);
2663         list_add_tail_rcu(&tunnel->list, &amt->tunnel_list);
2664         tunnel->key = amt->key;
2665         amt_update_relay_status(tunnel, AMT_STATUS_RECEIVED_REQUEST, true);
2666         amt->nr_tunnels++;
2667         mod_delayed_work(amt_wq, &tunnel->gc_wq,
2668                          msecs_to_jiffies(amt_gmi(amt)));
2669         spin_unlock_bh(&amt->lock);
2670
2671 send:
2672         tunnel->nonce = amtrh->nonce;
2673         mac = siphash_3u32((__force u32)tunnel->ip4,
2674                            (__force u32)tunnel->source_port,
2675                            (__force u32)tunnel->nonce,
2676                            &tunnel->key);
2677         tunnel->mac = mac >> 16;
2678
2679         if (!netif_running(amt->dev) || !netif_running(amt->stream_dev))
2680                 return true;
2681
2682         if (!amtrh->p)
2683                 amt_send_igmp_gq(amt, tunnel);
2684         else
2685                 amt_send_mld_gq(amt, tunnel);
2686
2687         return false;
2688 }
2689
2690 static int amt_rcv(struct sock *sk, struct sk_buff *skb)
2691 {
2692         struct amt_dev *amt;
2693         struct iphdr *iph;
2694         int type;
2695         bool err;
2696
2697         rcu_read_lock_bh();
2698         amt = rcu_dereference_sk_user_data(sk);
2699         if (!amt) {
2700                 err = true;
2701                 kfree_skb(skb);
2702                 goto out;
2703         }
2704
2705         skb->dev = amt->dev;
2706         iph = ip_hdr(skb);
2707         type = amt_parse_type(skb);
2708         if (type == -1) {
2709                 err = true;
2710                 goto drop;
2711         }
2712
2713         if (amt->mode == AMT_MODE_GATEWAY) {
2714                 switch (type) {
2715                 case AMT_MSG_ADVERTISEMENT:
2716                         if (iph->saddr != amt->discovery_ip) {
2717                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2718                                 err = true;
2719                                 goto drop;
2720                         }
2721                         err = amt_advertisement_handler(amt, skb);
2722                         break;
2723                 case AMT_MSG_MULTICAST_DATA:
2724                         if (iph->saddr != amt->remote_ip) {
2725                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2726                                 err = true;
2727                                 goto drop;
2728                         }
2729                         err = amt_multicast_data_handler(amt, skb);
2730                         if (err)
2731                                 goto drop;
2732                         else
2733                                 goto out;
2734                 case AMT_MSG_MEMBERSHIP_QUERY:
2735                         if (iph->saddr != amt->remote_ip) {
2736                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2737                                 err = true;
2738                                 goto drop;
2739                         }
2740                         err = amt_membership_query_handler(amt, skb);
2741                         if (err)
2742                                 goto drop;
2743                         else
2744                                 goto out;
2745                 default:
2746                         err = true;
2747                         netdev_dbg(amt->dev, "Invalid type of Gateway\n");
2748                         break;
2749                 }
2750         } else {
2751                 switch (type) {
2752                 case AMT_MSG_DISCOVERY:
2753                         err = amt_discovery_handler(amt, skb);
2754                         break;
2755                 case AMT_MSG_REQUEST:
2756                         err = amt_request_handler(amt, skb);
2757                         break;
2758                 case AMT_MSG_MEMBERSHIP_UPDATE:
2759                         err = amt_update_handler(amt, skb);
2760                         if (err)
2761                                 goto drop;
2762                         else
2763                                 goto out;
2764                 default:
2765                         err = true;
2766                         netdev_dbg(amt->dev, "Invalid type of relay\n");
2767                         break;
2768                 }
2769         }
2770 drop:
2771         if (err) {
2772                 amt->dev->stats.rx_dropped++;
2773                 kfree_skb(skb);
2774         } else {
2775                 consume_skb(skb);
2776         }
2777 out:
2778         rcu_read_unlock_bh();
2779         return 0;
2780 }
2781
2782 static int amt_err_lookup(struct sock *sk, struct sk_buff *skb)
2783 {
2784         struct amt_dev *amt;
2785         int type;
2786
2787         rcu_read_lock_bh();
2788         amt = rcu_dereference_sk_user_data(sk);
2789         if (!amt)
2790                 goto out;
2791
2792         if (amt->mode != AMT_MODE_GATEWAY)
2793                 goto drop;
2794
2795         type = amt_parse_type(skb);
2796         if (type == -1)
2797                 goto drop;
2798
2799         netdev_dbg(amt->dev, "Received IGMP Unreachable of %s\n",
2800                    type_str[type]);
2801         switch (type) {
2802         case AMT_MSG_DISCOVERY:
2803                 break;
2804         case AMT_MSG_REQUEST:
2805         case AMT_MSG_MEMBERSHIP_UPDATE:
2806                 if (amt->status >= AMT_STATUS_RECEIVED_ADVERTISEMENT)
2807                         mod_delayed_work(amt_wq, &amt->req_wq, 0);
2808                 break;
2809         default:
2810                 goto drop;
2811         }
2812 out:
2813         rcu_read_unlock_bh();
2814         return 0;
2815 drop:
2816         rcu_read_unlock_bh();
2817         amt->dev->stats.rx_dropped++;
2818         return 0;
2819 }
2820
2821 static struct socket *amt_create_sock(struct net *net, __be16 port)
2822 {
2823         struct udp_port_cfg udp_conf;
2824         struct socket *sock;
2825         int err;
2826
2827         memset(&udp_conf, 0, sizeof(udp_conf));
2828         udp_conf.family = AF_INET;
2829         udp_conf.local_ip.s_addr = htonl(INADDR_ANY);
2830
2831         udp_conf.local_udp_port = port;
2832
2833         err = udp_sock_create(net, &udp_conf, &sock);
2834         if (err < 0)
2835                 return ERR_PTR(err);
2836
2837         return sock;
2838 }
2839
2840 static int amt_socket_create(struct amt_dev *amt)
2841 {
2842         struct udp_tunnel_sock_cfg tunnel_cfg;
2843         struct socket *sock;
2844
2845         sock = amt_create_sock(amt->net, amt->relay_port);
2846         if (IS_ERR(sock))
2847                 return PTR_ERR(sock);
2848
2849         /* Mark socket as an encapsulation socket */
2850         memset(&tunnel_cfg, 0, sizeof(tunnel_cfg));
2851         tunnel_cfg.sk_user_data = amt;
2852         tunnel_cfg.encap_type = 1;
2853         tunnel_cfg.encap_rcv = amt_rcv;
2854         tunnel_cfg.encap_err_lookup = amt_err_lookup;
2855         tunnel_cfg.encap_destroy = NULL;
2856         setup_udp_tunnel_sock(amt->net, sock, &tunnel_cfg);
2857
2858         rcu_assign_pointer(amt->sock, sock);
2859         return 0;
2860 }
2861
2862 static int amt_dev_open(struct net_device *dev)
2863 {
2864         struct amt_dev *amt = netdev_priv(dev);
2865         int err;
2866
2867         amt->ready4 = false;
2868         amt->ready6 = false;
2869
2870         err = amt_socket_create(amt);
2871         if (err)
2872                 return err;
2873
2874         amt->req_cnt = 0;
2875         amt->remote_ip = 0;
2876         get_random_bytes(&amt->key, sizeof(siphash_key_t));
2877
2878         amt->status = AMT_STATUS_INIT;
2879         if (amt->mode == AMT_MODE_GATEWAY) {
2880                 mod_delayed_work(amt_wq, &amt->discovery_wq, 0);
2881                 mod_delayed_work(amt_wq, &amt->req_wq, 0);
2882         } else if (amt->mode == AMT_MODE_RELAY) {
2883                 mod_delayed_work(amt_wq, &amt->secret_wq,
2884                                  msecs_to_jiffies(AMT_SECRET_TIMEOUT));
2885         }
2886         return err;
2887 }
2888
2889 static int amt_dev_stop(struct net_device *dev)
2890 {
2891         struct amt_dev *amt = netdev_priv(dev);
2892         struct amt_tunnel_list *tunnel, *tmp;
2893         struct socket *sock;
2894
2895         cancel_delayed_work_sync(&amt->req_wq);
2896         cancel_delayed_work_sync(&amt->discovery_wq);
2897         cancel_delayed_work_sync(&amt->secret_wq);
2898
2899         /* shutdown */
2900         sock = rtnl_dereference(amt->sock);
2901         RCU_INIT_POINTER(amt->sock, NULL);
2902         synchronize_net();
2903         if (sock)
2904                 udp_tunnel_sock_release(sock);
2905
2906         amt->ready4 = false;
2907         amt->ready6 = false;
2908         amt->req_cnt = 0;
2909         amt->remote_ip = 0;
2910
2911         list_for_each_entry_safe(tunnel, tmp, &amt->tunnel_list, list) {
2912                 list_del_rcu(&tunnel->list);
2913                 amt->nr_tunnels--;
2914                 cancel_delayed_work_sync(&tunnel->gc_wq);
2915                 amt_clear_groups(tunnel);
2916                 kfree_rcu(tunnel, rcu);
2917         }
2918
2919         return 0;
2920 }
2921
2922 static const struct device_type amt_type = {
2923         .name = "amt",
2924 };
2925
2926 static int amt_dev_init(struct net_device *dev)
2927 {
2928         struct amt_dev *amt = netdev_priv(dev);
2929         int err;
2930
2931         amt->dev = dev;
2932         dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats);
2933         if (!dev->tstats)
2934                 return -ENOMEM;
2935
2936         err = gro_cells_init(&amt->gro_cells, dev);
2937         if (err) {
2938                 free_percpu(dev->tstats);
2939                 return err;
2940         }
2941
2942         return 0;
2943 }
2944
2945 static void amt_dev_uninit(struct net_device *dev)
2946 {
2947         struct amt_dev *amt = netdev_priv(dev);
2948
2949         gro_cells_destroy(&amt->gro_cells);
2950         free_percpu(dev->tstats);
2951 }
2952
2953 static const struct net_device_ops amt_netdev_ops = {
2954         .ndo_init               = amt_dev_init,
2955         .ndo_uninit             = amt_dev_uninit,
2956         .ndo_open               = amt_dev_open,
2957         .ndo_stop               = amt_dev_stop,
2958         .ndo_start_xmit         = amt_dev_xmit,
2959         .ndo_get_stats64        = dev_get_tstats64,
2960 };
2961
2962 static void amt_link_setup(struct net_device *dev)
2963 {
2964         dev->netdev_ops         = &amt_netdev_ops;
2965         dev->needs_free_netdev  = true;
2966         SET_NETDEV_DEVTYPE(dev, &amt_type);
2967         dev->min_mtu            = ETH_MIN_MTU;
2968         dev->max_mtu            = ETH_MAX_MTU;
2969         dev->type               = ARPHRD_NONE;
2970         dev->flags              = IFF_POINTOPOINT | IFF_NOARP | IFF_MULTICAST;
2971         dev->hard_header_len    = 0;
2972         dev->addr_len           = 0;
2973         dev->priv_flags         |= IFF_NO_QUEUE;
2974         dev->features           |= NETIF_F_LLTX;
2975         dev->features           |= NETIF_F_GSO_SOFTWARE;
2976         dev->features           |= NETIF_F_NETNS_LOCAL;
2977         dev->hw_features        |= NETIF_F_SG | NETIF_F_HW_CSUM;
2978         dev->hw_features        |= NETIF_F_FRAGLIST | NETIF_F_RXCSUM;
2979         dev->hw_features        |= NETIF_F_GSO_SOFTWARE;
2980         eth_hw_addr_random(dev);
2981         eth_zero_addr(dev->broadcast);
2982         ether_setup(dev);
2983 }
2984
2985 static const struct nla_policy amt_policy[IFLA_AMT_MAX + 1] = {
2986         [IFLA_AMT_MODE]         = { .type = NLA_U32 },
2987         [IFLA_AMT_RELAY_PORT]   = { .type = NLA_U16 },
2988         [IFLA_AMT_GATEWAY_PORT] = { .type = NLA_U16 },
2989         [IFLA_AMT_LINK]         = { .type = NLA_U32 },
2990         [IFLA_AMT_LOCAL_IP]     = { .len = sizeof_field(struct iphdr, daddr) },
2991         [IFLA_AMT_REMOTE_IP]    = { .len = sizeof_field(struct iphdr, daddr) },
2992         [IFLA_AMT_DISCOVERY_IP] = { .len = sizeof_field(struct iphdr, daddr) },
2993         [IFLA_AMT_MAX_TUNNELS]  = { .type = NLA_U32 },
2994 };
2995
2996 static int amt_validate(struct nlattr *tb[], struct nlattr *data[],
2997                         struct netlink_ext_ack *extack)
2998 {
2999         if (!data)
3000                 return -EINVAL;
3001
3002         if (!data[IFLA_AMT_LINK]) {
3003                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_LINK],
3004                                     "Link attribute is required");
3005                 return -EINVAL;
3006         }
3007
3008         if (!data[IFLA_AMT_MODE]) {
3009                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_MODE],
3010                                     "Mode attribute is required");
3011                 return -EINVAL;
3012         }
3013
3014         if (nla_get_u32(data[IFLA_AMT_MODE]) > AMT_MODE_MAX) {
3015                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_MODE],
3016                                     "Mode attribute is not valid");
3017                 return -EINVAL;
3018         }
3019
3020         if (!data[IFLA_AMT_LOCAL_IP]) {
3021                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_DISCOVERY_IP],
3022                                     "Local attribute is required");
3023                 return -EINVAL;
3024         }
3025
3026         if (!data[IFLA_AMT_DISCOVERY_IP] &&
3027             nla_get_u32(data[IFLA_AMT_MODE]) == AMT_MODE_GATEWAY) {
3028                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_LOCAL_IP],
3029                                     "Discovery attribute is required");
3030                 return -EINVAL;
3031         }
3032
3033         return 0;
3034 }
3035
3036 static int amt_newlink(struct net *net, struct net_device *dev,
3037                        struct nlattr *tb[], struct nlattr *data[],
3038                        struct netlink_ext_ack *extack)
3039 {
3040         struct amt_dev *amt = netdev_priv(dev);
3041         int err = -EINVAL;
3042
3043         amt->net = net;
3044         amt->mode = nla_get_u32(data[IFLA_AMT_MODE]);
3045
3046         if (data[IFLA_AMT_MAX_TUNNELS] &&
3047             nla_get_u32(data[IFLA_AMT_MAX_TUNNELS]))
3048                 amt->max_tunnels = nla_get_u32(data[IFLA_AMT_MAX_TUNNELS]);
3049         else
3050                 amt->max_tunnels = AMT_MAX_TUNNELS;
3051
3052         spin_lock_init(&amt->lock);
3053         amt->max_groups = AMT_MAX_GROUP;
3054         amt->max_sources = AMT_MAX_SOURCE;
3055         amt->hash_buckets = AMT_HSIZE;
3056         amt->nr_tunnels = 0;
3057         get_random_bytes(&amt->hash_seed, sizeof(amt->hash_seed));
3058         amt->stream_dev = dev_get_by_index(net,
3059                                            nla_get_u32(data[IFLA_AMT_LINK]));
3060         if (!amt->stream_dev) {
3061                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LINK],
3062                                     "Can't find stream device");
3063                 return -ENODEV;
3064         }
3065
3066         if (amt->stream_dev->type != ARPHRD_ETHER) {
3067                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LINK],
3068                                     "Invalid stream device type");
3069                 goto err;
3070         }
3071
3072         amt->local_ip = nla_get_in_addr(data[IFLA_AMT_LOCAL_IP]);
3073         if (ipv4_is_loopback(amt->local_ip) ||
3074             ipv4_is_zeronet(amt->local_ip) ||
3075             ipv4_is_multicast(amt->local_ip)) {
3076                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LOCAL_IP],
3077                                     "Invalid Local address");
3078                 goto err;
3079         }
3080
3081         if (data[IFLA_AMT_RELAY_PORT])
3082                 amt->relay_port = nla_get_be16(data[IFLA_AMT_RELAY_PORT]);
3083         else
3084                 amt->relay_port = htons(IANA_AMT_UDP_PORT);
3085
3086         if (data[IFLA_AMT_GATEWAY_PORT])
3087                 amt->gw_port = nla_get_be16(data[IFLA_AMT_GATEWAY_PORT]);
3088         else
3089                 amt->gw_port = htons(IANA_AMT_UDP_PORT);
3090
3091         if (!amt->relay_port) {
3092                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3093                                     "relay port must not be 0");
3094                 goto err;
3095         }
3096         if (amt->mode == AMT_MODE_RELAY) {
3097                 amt->qrv = amt->net->ipv4.sysctl_igmp_qrv;
3098                 amt->qri = 10;
3099                 dev->needed_headroom = amt->stream_dev->needed_headroom +
3100                                        AMT_RELAY_HLEN;
3101                 dev->mtu = amt->stream_dev->mtu - AMT_RELAY_HLEN;
3102                 dev->max_mtu = dev->mtu;
3103                 dev->min_mtu = ETH_MIN_MTU + AMT_RELAY_HLEN;
3104         } else {
3105                 if (!data[IFLA_AMT_DISCOVERY_IP]) {
3106                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3107                                             "discovery must be set in gateway mode");
3108                         goto err;
3109                 }
3110                 if (!amt->gw_port) {
3111                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3112                                             "gateway port must not be 0");
3113                         goto err;
3114                 }
3115                 amt->remote_ip = 0;
3116                 amt->discovery_ip = nla_get_in_addr(data[IFLA_AMT_DISCOVERY_IP]);
3117                 if (ipv4_is_loopback(amt->discovery_ip) ||
3118                     ipv4_is_zeronet(amt->discovery_ip) ||
3119                     ipv4_is_multicast(amt->discovery_ip)) {
3120                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3121                                             "discovery must be unicast");
3122                         goto err;
3123                 }
3124
3125                 dev->needed_headroom = amt->stream_dev->needed_headroom +
3126                                        AMT_GW_HLEN;
3127                 dev->mtu = amt->stream_dev->mtu - AMT_GW_HLEN;
3128                 dev->max_mtu = dev->mtu;
3129                 dev->min_mtu = ETH_MIN_MTU + AMT_GW_HLEN;
3130         }
3131         amt->qi = AMT_INIT_QUERY_INTERVAL;
3132
3133         err = register_netdevice(dev);
3134         if (err < 0) {
3135                 netdev_dbg(dev, "failed to register new netdev %d\n", err);
3136                 goto err;
3137         }
3138
3139         err = netdev_upper_dev_link(amt->stream_dev, dev, extack);
3140         if (err < 0) {
3141                 unregister_netdevice(dev);
3142                 goto err;
3143         }
3144
3145         INIT_DELAYED_WORK(&amt->discovery_wq, amt_discovery_work);
3146         INIT_DELAYED_WORK(&amt->req_wq, amt_req_work);
3147         INIT_DELAYED_WORK(&amt->secret_wq, amt_secret_work);
3148         INIT_LIST_HEAD(&amt->tunnel_list);
3149
3150         return 0;
3151 err:
3152         dev_put(amt->stream_dev);
3153         return err;
3154 }
3155
3156 static void amt_dellink(struct net_device *dev, struct list_head *head)
3157 {
3158         struct amt_dev *amt = netdev_priv(dev);
3159
3160         unregister_netdevice_queue(dev, head);
3161         netdev_upper_dev_unlink(amt->stream_dev, dev);
3162         dev_put(amt->stream_dev);
3163 }
3164
3165 static size_t amt_get_size(const struct net_device *dev)
3166 {
3167         return nla_total_size(sizeof(__u32)) + /* IFLA_AMT_MODE */
3168                nla_total_size(sizeof(__u16)) + /* IFLA_AMT_RELAY_PORT */
3169                nla_total_size(sizeof(__u16)) + /* IFLA_AMT_GATEWAY_PORT */
3170                nla_total_size(sizeof(__u32)) + /* IFLA_AMT_LINK */
3171                nla_total_size(sizeof(__u32)) + /* IFLA_MAX_TUNNELS */
3172                nla_total_size(sizeof(struct iphdr)) + /* IFLA_AMT_DISCOVERY_IP */
3173                nla_total_size(sizeof(struct iphdr)) + /* IFLA_AMT_REMOTE_IP */
3174                nla_total_size(sizeof(struct iphdr)); /* IFLA_AMT_LOCAL_IP */
3175 }
3176
3177 static int amt_fill_info(struct sk_buff *skb, const struct net_device *dev)
3178 {
3179         struct amt_dev *amt = netdev_priv(dev);
3180
3181         if (nla_put_u32(skb, IFLA_AMT_MODE, amt->mode))
3182                 goto nla_put_failure;
3183         if (nla_put_be16(skb, IFLA_AMT_RELAY_PORT, amt->relay_port))
3184                 goto nla_put_failure;
3185         if (nla_put_be16(skb, IFLA_AMT_GATEWAY_PORT, amt->gw_port))
3186                 goto nla_put_failure;
3187         if (nla_put_u32(skb, IFLA_AMT_LINK, amt->stream_dev->ifindex))
3188                 goto nla_put_failure;
3189         if (nla_put_in_addr(skb, IFLA_AMT_LOCAL_IP, amt->local_ip))
3190                 goto nla_put_failure;
3191         if (nla_put_in_addr(skb, IFLA_AMT_DISCOVERY_IP, amt->discovery_ip))
3192                 goto nla_put_failure;
3193         if (amt->remote_ip)
3194                 if (nla_put_in_addr(skb, IFLA_AMT_REMOTE_IP, amt->remote_ip))
3195                         goto nla_put_failure;
3196         if (nla_put_u32(skb, IFLA_AMT_MAX_TUNNELS, amt->max_tunnels))
3197                 goto nla_put_failure;
3198
3199         return 0;
3200
3201 nla_put_failure:
3202         return -EMSGSIZE;
3203 }
3204
3205 static struct rtnl_link_ops amt_link_ops __read_mostly = {
3206         .kind           = "amt",
3207         .maxtype        = IFLA_AMT_MAX,
3208         .policy         = amt_policy,
3209         .priv_size      = sizeof(struct amt_dev),
3210         .setup          = amt_link_setup,
3211         .validate       = amt_validate,
3212         .newlink        = amt_newlink,
3213         .dellink        = amt_dellink,
3214         .get_size       = amt_get_size,
3215         .fill_info      = amt_fill_info,
3216 };
3217
3218 static struct net_device *amt_lookup_upper_dev(struct net_device *dev)
3219 {
3220         struct net_device *upper_dev;
3221         struct amt_dev *amt;
3222
3223         for_each_netdev(dev_net(dev), upper_dev) {
3224                 if (netif_is_amt(upper_dev)) {
3225                         amt = netdev_priv(upper_dev);
3226                         if (amt->stream_dev == dev)
3227                                 return upper_dev;
3228                 }
3229         }
3230
3231         return NULL;
3232 }
3233
3234 static int amt_device_event(struct notifier_block *unused,
3235                             unsigned long event, void *ptr)
3236 {
3237         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
3238         struct net_device *upper_dev;
3239         struct amt_dev *amt;
3240         LIST_HEAD(list);
3241         int new_mtu;
3242
3243         upper_dev = amt_lookup_upper_dev(dev);
3244         if (!upper_dev)
3245                 return NOTIFY_DONE;
3246         amt = netdev_priv(upper_dev);
3247
3248         switch (event) {
3249         case NETDEV_UNREGISTER:
3250                 amt_dellink(amt->dev, &list);
3251                 unregister_netdevice_many(&list);
3252                 break;
3253         case NETDEV_CHANGEMTU:
3254                 if (amt->mode == AMT_MODE_RELAY)
3255                         new_mtu = dev->mtu - AMT_RELAY_HLEN;
3256                 else
3257                         new_mtu = dev->mtu - AMT_GW_HLEN;
3258
3259                 dev_set_mtu(amt->dev, new_mtu);
3260                 break;
3261         }
3262
3263         return NOTIFY_DONE;
3264 }
3265
3266 static struct notifier_block amt_notifier_block __read_mostly = {
3267         .notifier_call = amt_device_event,
3268 };
3269
3270 static int __init amt_init(void)
3271 {
3272         int err;
3273
3274         err = register_netdevice_notifier(&amt_notifier_block);
3275         if (err < 0)
3276                 goto err;
3277
3278         err = rtnl_link_register(&amt_link_ops);
3279         if (err < 0)
3280                 goto unregister_notifier;
3281
3282         amt_wq = alloc_workqueue("amt", WQ_UNBOUND, 1);
3283         if (!amt_wq) {
3284                 err = -ENOMEM;
3285                 goto rtnl_unregister;
3286         }
3287
3288         spin_lock_init(&source_gc_lock);
3289         spin_lock_bh(&source_gc_lock);
3290         INIT_DELAYED_WORK(&source_gc_wq, amt_source_gc_work);
3291         mod_delayed_work(amt_wq, &source_gc_wq,
3292                          msecs_to_jiffies(AMT_GC_INTERVAL));
3293         spin_unlock_bh(&source_gc_lock);
3294
3295         return 0;
3296
3297 rtnl_unregister:
3298         rtnl_link_unregister(&amt_link_ops);
3299 unregister_notifier:
3300         unregister_netdevice_notifier(&amt_notifier_block);
3301 err:
3302         pr_err("error loading AMT module loaded\n");
3303         return err;
3304 }
3305 late_initcall(amt_init);
3306
3307 static void __exit amt_fini(void)
3308 {
3309         rtnl_link_unregister(&amt_link_ops);
3310         unregister_netdevice_notifier(&amt_notifier_block);
3311         cancel_delayed_work_sync(&source_gc_wq);
3312         __amt_source_gc_work();
3313         destroy_workqueue(amt_wq);
3314 }
3315 module_exit(amt_fini);
3316
3317 MODULE_LICENSE("GPL");
3318 MODULE_AUTHOR("Taehee Yoo <ap420073@gmail.com>");
3319 MODULE_ALIAS_RTNL_LINK("amt");