Merge branch 'for-5.13/warnings' into for-linus
[linux-2.6-microblaze.git] / net / ipv6 / seg6_local.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *  SR-IPv6 implementation
4  *
5  *  Authors:
6  *  David Lebrun <david.lebrun@uclouvain.be>
7  *  eBPF support: Mathieu Xhonneux <m.xhonneux@gmail.com>
8  */
9
10 #include <linux/types.h>
11 #include <linux/skbuff.h>
12 #include <linux/net.h>
13 #include <linux/module.h>
14 #include <net/ip.h>
15 #include <net/lwtunnel.h>
16 #include <net/netevent.h>
17 #include <net/netns/generic.h>
18 #include <net/ip6_fib.h>
19 #include <net/route.h>
20 #include <net/seg6.h>
21 #include <linux/seg6.h>
22 #include <linux/seg6_local.h>
23 #include <net/addrconf.h>
24 #include <net/ip6_route.h>
25 #include <net/dst_cache.h>
26 #include <net/ip_tunnels.h>
27 #ifdef CONFIG_IPV6_SEG6_HMAC
28 #include <net/seg6_hmac.h>
29 #endif
30 #include <net/seg6_local.h>
31 #include <linux/etherdevice.h>
32 #include <linux/bpf.h>
33
34 #define SEG6_F_ATTR(i)          BIT(i)
35
36 struct seg6_local_lwt;
37
38 /* callbacks used for customizing the creation and destruction of a behavior */
39 struct seg6_local_lwtunnel_ops {
40         int (*build_state)(struct seg6_local_lwt *slwt, const void *cfg,
41                            struct netlink_ext_ack *extack);
42         void (*destroy_state)(struct seg6_local_lwt *slwt);
43 };
44
45 struct seg6_action_desc {
46         int action;
47         unsigned long attrs;
48
49         /* The optattrs field is used for specifying all the optional
50          * attributes supported by a specific behavior.
51          * It means that if one of these attributes is not provided in the
52          * netlink message during the behavior creation, no errors will be
53          * returned to the userspace.
54          *
55          * Each attribute can be only of two types (mutually exclusive):
56          * 1) required or 2) optional.
57          * Every user MUST obey to this rule! If you set an attribute as
58          * required the same attribute CANNOT be set as optional and vice
59          * versa.
60          */
61         unsigned long optattrs;
62
63         int (*input)(struct sk_buff *skb, struct seg6_local_lwt *slwt);
64         int static_headroom;
65
66         struct seg6_local_lwtunnel_ops slwt_ops;
67 };
68
69 struct bpf_lwt_prog {
70         struct bpf_prog *prog;
71         char *name;
72 };
73
74 enum seg6_end_dt_mode {
75         DT_INVALID_MODE = -EINVAL,
76         DT_LEGACY_MODE  = 0,
77         DT_VRF_MODE     = 1,
78 };
79
80 struct seg6_end_dt_info {
81         enum seg6_end_dt_mode mode;
82
83         struct net *net;
84         /* VRF device associated to the routing table used by the SRv6
85          * End.DT4/DT6 behavior for routing IPv4/IPv6 packets.
86          */
87         int vrf_ifindex;
88         int vrf_table;
89
90         /* tunneled packet proto and family (IPv4 or IPv6) */
91         __be16 proto;
92         u16 family;
93         int hdrlen;
94 };
95
96 struct seg6_local_lwt {
97         int action;
98         struct ipv6_sr_hdr *srh;
99         int table;
100         struct in_addr nh4;
101         struct in6_addr nh6;
102         int iif;
103         int oif;
104         struct bpf_lwt_prog bpf;
105 #ifdef CONFIG_NET_L3_MASTER_DEV
106         struct seg6_end_dt_info dt_info;
107 #endif
108
109         int headroom;
110         struct seg6_action_desc *desc;
111         /* unlike the required attrs, we have to track the optional attributes
112          * that have been effectively parsed.
113          */
114         unsigned long parsed_optattrs;
115 };
116
117 static struct seg6_local_lwt *seg6_local_lwtunnel(struct lwtunnel_state *lwt)
118 {
119         return (struct seg6_local_lwt *)lwt->data;
120 }
121
122 static struct ipv6_sr_hdr *get_srh(struct sk_buff *skb)
123 {
124         struct ipv6_sr_hdr *srh;
125         int len, srhoff = 0;
126
127         if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, NULL) < 0)
128                 return NULL;
129
130         if (!pskb_may_pull(skb, srhoff + sizeof(*srh)))
131                 return NULL;
132
133         srh = (struct ipv6_sr_hdr *)(skb->data + srhoff);
134
135         len = (srh->hdrlen + 1) << 3;
136
137         if (!pskb_may_pull(skb, srhoff + len))
138                 return NULL;
139
140         /* note that pskb_may_pull may change pointers in header;
141          * for this reason it is necessary to reload them when needed.
142          */
143         srh = (struct ipv6_sr_hdr *)(skb->data + srhoff);
144
145         if (!seg6_validate_srh(srh, len, true))
146                 return NULL;
147
148         return srh;
149 }
150
151 static struct ipv6_sr_hdr *get_and_validate_srh(struct sk_buff *skb)
152 {
153         struct ipv6_sr_hdr *srh;
154
155         srh = get_srh(skb);
156         if (!srh)
157                 return NULL;
158
159         if (srh->segments_left == 0)
160                 return NULL;
161
162 #ifdef CONFIG_IPV6_SEG6_HMAC
163         if (!seg6_hmac_validate_skb(skb))
164                 return NULL;
165 #endif
166
167         return srh;
168 }
169
170 static bool decap_and_validate(struct sk_buff *skb, int proto)
171 {
172         struct ipv6_sr_hdr *srh;
173         unsigned int off = 0;
174
175         srh = get_srh(skb);
176         if (srh && srh->segments_left > 0)
177                 return false;
178
179 #ifdef CONFIG_IPV6_SEG6_HMAC
180         if (srh && !seg6_hmac_validate_skb(skb))
181                 return false;
182 #endif
183
184         if (ipv6_find_hdr(skb, &off, proto, NULL, NULL) < 0)
185                 return false;
186
187         if (!pskb_pull(skb, off))
188                 return false;
189
190         skb_postpull_rcsum(skb, skb_network_header(skb), off);
191
192         skb_reset_network_header(skb);
193         skb_reset_transport_header(skb);
194         if (iptunnel_pull_offloads(skb))
195                 return false;
196
197         return true;
198 }
199
200 static void advance_nextseg(struct ipv6_sr_hdr *srh, struct in6_addr *daddr)
201 {
202         struct in6_addr *addr;
203
204         srh->segments_left--;
205         addr = srh->segments + srh->segments_left;
206         *daddr = *addr;
207 }
208
209 static int
210 seg6_lookup_any_nexthop(struct sk_buff *skb, struct in6_addr *nhaddr,
211                         u32 tbl_id, bool local_delivery)
212 {
213         struct net *net = dev_net(skb->dev);
214         struct ipv6hdr *hdr = ipv6_hdr(skb);
215         int flags = RT6_LOOKUP_F_HAS_SADDR;
216         struct dst_entry *dst = NULL;
217         struct rt6_info *rt;
218         struct flowi6 fl6;
219         int dev_flags = 0;
220
221         fl6.flowi6_iif = skb->dev->ifindex;
222         fl6.daddr = nhaddr ? *nhaddr : hdr->daddr;
223         fl6.saddr = hdr->saddr;
224         fl6.flowlabel = ip6_flowinfo(hdr);
225         fl6.flowi6_mark = skb->mark;
226         fl6.flowi6_proto = hdr->nexthdr;
227
228         if (nhaddr)
229                 fl6.flowi6_flags = FLOWI_FLAG_KNOWN_NH;
230
231         if (!tbl_id) {
232                 dst = ip6_route_input_lookup(net, skb->dev, &fl6, skb, flags);
233         } else {
234                 struct fib6_table *table;
235
236                 table = fib6_get_table(net, tbl_id);
237                 if (!table)
238                         goto out;
239
240                 rt = ip6_pol_route(net, table, 0, &fl6, skb, flags);
241                 dst = &rt->dst;
242         }
243
244         /* we want to discard traffic destined for local packet processing,
245          * if @local_delivery is set to false.
246          */
247         if (!local_delivery)
248                 dev_flags |= IFF_LOOPBACK;
249
250         if (dst && (dst->dev->flags & dev_flags) && !dst->error) {
251                 dst_release(dst);
252                 dst = NULL;
253         }
254
255 out:
256         if (!dst) {
257                 rt = net->ipv6.ip6_blk_hole_entry;
258                 dst = &rt->dst;
259                 dst_hold(dst);
260         }
261
262         skb_dst_drop(skb);
263         skb_dst_set(skb, dst);
264         return dst->error;
265 }
266
267 int seg6_lookup_nexthop(struct sk_buff *skb,
268                         struct in6_addr *nhaddr, u32 tbl_id)
269 {
270         return seg6_lookup_any_nexthop(skb, nhaddr, tbl_id, false);
271 }
272
273 /* regular endpoint function */
274 static int input_action_end(struct sk_buff *skb, struct seg6_local_lwt *slwt)
275 {
276         struct ipv6_sr_hdr *srh;
277
278         srh = get_and_validate_srh(skb);
279         if (!srh)
280                 goto drop;
281
282         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
283
284         seg6_lookup_nexthop(skb, NULL, 0);
285
286         return dst_input(skb);
287
288 drop:
289         kfree_skb(skb);
290         return -EINVAL;
291 }
292
293 /* regular endpoint, and forward to specified nexthop */
294 static int input_action_end_x(struct sk_buff *skb, struct seg6_local_lwt *slwt)
295 {
296         struct ipv6_sr_hdr *srh;
297
298         srh = get_and_validate_srh(skb);
299         if (!srh)
300                 goto drop;
301
302         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
303
304         seg6_lookup_nexthop(skb, &slwt->nh6, 0);
305
306         return dst_input(skb);
307
308 drop:
309         kfree_skb(skb);
310         return -EINVAL;
311 }
312
313 static int input_action_end_t(struct sk_buff *skb, struct seg6_local_lwt *slwt)
314 {
315         struct ipv6_sr_hdr *srh;
316
317         srh = get_and_validate_srh(skb);
318         if (!srh)
319                 goto drop;
320
321         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
322
323         seg6_lookup_nexthop(skb, NULL, slwt->table);
324
325         return dst_input(skb);
326
327 drop:
328         kfree_skb(skb);
329         return -EINVAL;
330 }
331
332 /* decapsulate and forward inner L2 frame on specified interface */
333 static int input_action_end_dx2(struct sk_buff *skb,
334                                 struct seg6_local_lwt *slwt)
335 {
336         struct net *net = dev_net(skb->dev);
337         struct net_device *odev;
338         struct ethhdr *eth;
339
340         if (!decap_and_validate(skb, IPPROTO_ETHERNET))
341                 goto drop;
342
343         if (!pskb_may_pull(skb, ETH_HLEN))
344                 goto drop;
345
346         skb_reset_mac_header(skb);
347         eth = (struct ethhdr *)skb->data;
348
349         /* To determine the frame's protocol, we assume it is 802.3. This avoids
350          * a call to eth_type_trans(), which is not really relevant for our
351          * use case.
352          */
353         if (!eth_proto_is_802_3(eth->h_proto))
354                 goto drop;
355
356         odev = dev_get_by_index_rcu(net, slwt->oif);
357         if (!odev)
358                 goto drop;
359
360         /* As we accept Ethernet frames, make sure the egress device is of
361          * the correct type.
362          */
363         if (odev->type != ARPHRD_ETHER)
364                 goto drop;
365
366         if (!(odev->flags & IFF_UP) || !netif_carrier_ok(odev))
367                 goto drop;
368
369         skb_orphan(skb);
370
371         if (skb_warn_if_lro(skb))
372                 goto drop;
373
374         skb_forward_csum(skb);
375
376         if (skb->len - ETH_HLEN > odev->mtu)
377                 goto drop;
378
379         skb->dev = odev;
380         skb->protocol = eth->h_proto;
381
382         return dev_queue_xmit(skb);
383
384 drop:
385         kfree_skb(skb);
386         return -EINVAL;
387 }
388
389 /* decapsulate and forward to specified nexthop */
390 static int input_action_end_dx6(struct sk_buff *skb,
391                                 struct seg6_local_lwt *slwt)
392 {
393         struct in6_addr *nhaddr = NULL;
394
395         /* this function accepts IPv6 encapsulated packets, with either
396          * an SRH with SL=0, or no SRH.
397          */
398
399         if (!decap_and_validate(skb, IPPROTO_IPV6))
400                 goto drop;
401
402         if (!pskb_may_pull(skb, sizeof(struct ipv6hdr)))
403                 goto drop;
404
405         /* The inner packet is not associated to any local interface,
406          * so we do not call netif_rx().
407          *
408          * If slwt->nh6 is set to ::, then lookup the nexthop for the
409          * inner packet's DA. Otherwise, use the specified nexthop.
410          */
411
412         if (!ipv6_addr_any(&slwt->nh6))
413                 nhaddr = &slwt->nh6;
414
415         skb_set_transport_header(skb, sizeof(struct ipv6hdr));
416
417         seg6_lookup_nexthop(skb, nhaddr, 0);
418
419         return dst_input(skb);
420 drop:
421         kfree_skb(skb);
422         return -EINVAL;
423 }
424
425 static int input_action_end_dx4(struct sk_buff *skb,
426                                 struct seg6_local_lwt *slwt)
427 {
428         struct iphdr *iph;
429         __be32 nhaddr;
430         int err;
431
432         if (!decap_and_validate(skb, IPPROTO_IPIP))
433                 goto drop;
434
435         if (!pskb_may_pull(skb, sizeof(struct iphdr)))
436                 goto drop;
437
438         skb->protocol = htons(ETH_P_IP);
439
440         iph = ip_hdr(skb);
441
442         nhaddr = slwt->nh4.s_addr ?: iph->daddr;
443
444         skb_dst_drop(skb);
445
446         skb_set_transport_header(skb, sizeof(struct iphdr));
447
448         err = ip_route_input(skb, nhaddr, iph->saddr, 0, skb->dev);
449         if (err)
450                 goto drop;
451
452         return dst_input(skb);
453
454 drop:
455         kfree_skb(skb);
456         return -EINVAL;
457 }
458
459 #ifdef CONFIG_NET_L3_MASTER_DEV
460 static struct net *fib6_config_get_net(const struct fib6_config *fib6_cfg)
461 {
462         const struct nl_info *nli = &fib6_cfg->fc_nlinfo;
463
464         return nli->nl_net;
465 }
466
467 static int __seg6_end_dt_vrf_build(struct seg6_local_lwt *slwt, const void *cfg,
468                                    u16 family, struct netlink_ext_ack *extack)
469 {
470         struct seg6_end_dt_info *info = &slwt->dt_info;
471         int vrf_ifindex;
472         struct net *net;
473
474         net = fib6_config_get_net(cfg);
475
476         /* note that vrf_table was already set by parse_nla_vrftable() */
477         vrf_ifindex = l3mdev_ifindex_lookup_by_table_id(L3MDEV_TYPE_VRF, net,
478                                                         info->vrf_table);
479         if (vrf_ifindex < 0) {
480                 if (vrf_ifindex == -EPERM) {
481                         NL_SET_ERR_MSG(extack,
482                                        "Strict mode for VRF is disabled");
483                 } else if (vrf_ifindex == -ENODEV) {
484                         NL_SET_ERR_MSG(extack,
485                                        "Table has no associated VRF device");
486                 } else {
487                         pr_debug("seg6local: SRv6 End.DT* creation error=%d\n",
488                                  vrf_ifindex);
489                 }
490
491                 return vrf_ifindex;
492         }
493
494         info->net = net;
495         info->vrf_ifindex = vrf_ifindex;
496
497         switch (family) {
498         case AF_INET:
499                 info->proto = htons(ETH_P_IP);
500                 info->hdrlen = sizeof(struct iphdr);
501                 break;
502         case AF_INET6:
503                 info->proto = htons(ETH_P_IPV6);
504                 info->hdrlen = sizeof(struct ipv6hdr);
505                 break;
506         default:
507                 return -EINVAL;
508         }
509
510         info->family = family;
511         info->mode = DT_VRF_MODE;
512
513         return 0;
514 }
515
516 /* The SRv6 End.DT4/DT6 behavior extracts the inner (IPv4/IPv6) packet and
517  * routes the IPv4/IPv6 packet by looking at the configured routing table.
518  *
519  * In the SRv6 End.DT4/DT6 use case, we can receive traffic (IPv6+Segment
520  * Routing Header packets) from several interfaces and the outer IPv6
521  * destination address (DA) is used for retrieving the specific instance of the
522  * End.DT4/DT6 behavior that should process the packets.
523  *
524  * However, the inner IPv4/IPv6 packet is not really bound to any receiving
525  * interface and thus the End.DT4/DT6 sets the VRF (associated with the
526  * corresponding routing table) as the *receiving* interface.
527  * In other words, the End.DT4/DT6 processes a packet as if it has been received
528  * directly by the VRF (and not by one of its slave devices, if any).
529  * In this way, the VRF interface is used for routing the IPv4/IPv6 packet in
530  * according to the routing table configured by the End.DT4/DT6 instance.
531  *
532  * This design allows you to get some interesting features like:
533  *  1) the statistics on rx packets;
534  *  2) the possibility to install a packet sniffer on the receiving interface
535  *     (the VRF one) for looking at the incoming packets;
536  *  3) the possibility to leverage the netfilter prerouting hook for the inner
537  *     IPv4 packet.
538  *
539  * This function returns:
540  *  - the sk_buff* when the VRF rcv handler has processed the packet correctly;
541  *  - NULL when the skb is consumed by the VRF rcv handler;
542  *  - a pointer which encodes a negative error number in case of error.
543  *    Note that in this case, the function takes care of freeing the skb.
544  */
545 static struct sk_buff *end_dt_vrf_rcv(struct sk_buff *skb, u16 family,
546                                       struct net_device *dev)
547 {
548         /* based on l3mdev_ip_rcv; we are only interested in the master */
549         if (unlikely(!netif_is_l3_master(dev) && !netif_has_l3_rx_handler(dev)))
550                 goto drop;
551
552         if (unlikely(!dev->l3mdev_ops->l3mdev_l3_rcv))
553                 goto drop;
554
555         /* the decap packet IPv4/IPv6 does not come with any mac header info.
556          * We must unset the mac header to allow the VRF device to rebuild it,
557          * just in case there is a sniffer attached on the device.
558          */
559         skb_unset_mac_header(skb);
560
561         skb = dev->l3mdev_ops->l3mdev_l3_rcv(dev, skb, family);
562         if (!skb)
563                 /* the skb buffer was consumed by the handler */
564                 return NULL;
565
566         /* when a packet is received by a VRF or by one of its slaves, the
567          * master device reference is set into the skb.
568          */
569         if (unlikely(skb->dev != dev || skb->skb_iif != dev->ifindex))
570                 goto drop;
571
572         return skb;
573
574 drop:
575         kfree_skb(skb);
576         return ERR_PTR(-EINVAL);
577 }
578
579 static struct net_device *end_dt_get_vrf_rcu(struct sk_buff *skb,
580                                              struct seg6_end_dt_info *info)
581 {
582         int vrf_ifindex = info->vrf_ifindex;
583         struct net *net = info->net;
584
585         if (unlikely(vrf_ifindex < 0))
586                 goto error;
587
588         if (unlikely(!net_eq(dev_net(skb->dev), net)))
589                 goto error;
590
591         return dev_get_by_index_rcu(net, vrf_ifindex);
592
593 error:
594         return NULL;
595 }
596
597 static struct sk_buff *end_dt_vrf_core(struct sk_buff *skb,
598                                        struct seg6_local_lwt *slwt)
599 {
600         struct seg6_end_dt_info *info = &slwt->dt_info;
601         struct net_device *vrf;
602
603         vrf = end_dt_get_vrf_rcu(skb, info);
604         if (unlikely(!vrf))
605                 goto drop;
606
607         skb->protocol = info->proto;
608
609         skb_dst_drop(skb);
610
611         skb_set_transport_header(skb, info->hdrlen);
612
613         return end_dt_vrf_rcv(skb, info->family, vrf);
614
615 drop:
616         kfree_skb(skb);
617         return ERR_PTR(-EINVAL);
618 }
619
620 static int input_action_end_dt4(struct sk_buff *skb,
621                                 struct seg6_local_lwt *slwt)
622 {
623         struct iphdr *iph;
624         int err;
625
626         if (!decap_and_validate(skb, IPPROTO_IPIP))
627                 goto drop;
628
629         if (!pskb_may_pull(skb, sizeof(struct iphdr)))
630                 goto drop;
631
632         skb = end_dt_vrf_core(skb, slwt);
633         if (!skb)
634                 /* packet has been processed and consumed by the VRF */
635                 return 0;
636
637         if (IS_ERR(skb))
638                 return PTR_ERR(skb);
639
640         iph = ip_hdr(skb);
641
642         err = ip_route_input(skb, iph->daddr, iph->saddr, 0, skb->dev);
643         if (unlikely(err))
644                 goto drop;
645
646         return dst_input(skb);
647
648 drop:
649         kfree_skb(skb);
650         return -EINVAL;
651 }
652
653 static int seg6_end_dt4_build(struct seg6_local_lwt *slwt, const void *cfg,
654                               struct netlink_ext_ack *extack)
655 {
656         return __seg6_end_dt_vrf_build(slwt, cfg, AF_INET, extack);
657 }
658
659 static enum
660 seg6_end_dt_mode seg6_end_dt6_parse_mode(struct seg6_local_lwt *slwt)
661 {
662         unsigned long parsed_optattrs = slwt->parsed_optattrs;
663         bool legacy, vrfmode;
664
665         legacy  = !!(parsed_optattrs & SEG6_F_ATTR(SEG6_LOCAL_TABLE));
666         vrfmode = !!(parsed_optattrs & SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE));
667
668         if (!(legacy ^ vrfmode))
669                 /* both are absent or present: invalid DT6 mode */
670                 return DT_INVALID_MODE;
671
672         return legacy ? DT_LEGACY_MODE : DT_VRF_MODE;
673 }
674
675 static enum seg6_end_dt_mode seg6_end_dt6_get_mode(struct seg6_local_lwt *slwt)
676 {
677         struct seg6_end_dt_info *info = &slwt->dt_info;
678
679         return info->mode;
680 }
681
682 static int seg6_end_dt6_build(struct seg6_local_lwt *slwt, const void *cfg,
683                               struct netlink_ext_ack *extack)
684 {
685         enum seg6_end_dt_mode mode = seg6_end_dt6_parse_mode(slwt);
686         struct seg6_end_dt_info *info = &slwt->dt_info;
687
688         switch (mode) {
689         case DT_LEGACY_MODE:
690                 info->mode = DT_LEGACY_MODE;
691                 return 0;
692         case DT_VRF_MODE:
693                 return __seg6_end_dt_vrf_build(slwt, cfg, AF_INET6, extack);
694         default:
695                 NL_SET_ERR_MSG(extack, "table or vrftable must be specified");
696                 return -EINVAL;
697         }
698 }
699 #endif
700
701 static int input_action_end_dt6(struct sk_buff *skb,
702                                 struct seg6_local_lwt *slwt)
703 {
704         if (!decap_and_validate(skb, IPPROTO_IPV6))
705                 goto drop;
706
707         if (!pskb_may_pull(skb, sizeof(struct ipv6hdr)))
708                 goto drop;
709
710 #ifdef CONFIG_NET_L3_MASTER_DEV
711         if (seg6_end_dt6_get_mode(slwt) == DT_LEGACY_MODE)
712                 goto legacy_mode;
713
714         /* DT6_VRF_MODE */
715         skb = end_dt_vrf_core(skb, slwt);
716         if (!skb)
717                 /* packet has been processed and consumed by the VRF */
718                 return 0;
719
720         if (IS_ERR(skb))
721                 return PTR_ERR(skb);
722
723         /* note: this time we do not need to specify the table because the VRF
724          * takes care of selecting the correct table.
725          */
726         seg6_lookup_any_nexthop(skb, NULL, 0, true);
727
728         return dst_input(skb);
729
730 legacy_mode:
731 #endif
732         skb_set_transport_header(skb, sizeof(struct ipv6hdr));
733
734         seg6_lookup_any_nexthop(skb, NULL, slwt->table, true);
735
736         return dst_input(skb);
737
738 drop:
739         kfree_skb(skb);
740         return -EINVAL;
741 }
742
743 /* push an SRH on top of the current one */
744 static int input_action_end_b6(struct sk_buff *skb, struct seg6_local_lwt *slwt)
745 {
746         struct ipv6_sr_hdr *srh;
747         int err = -EINVAL;
748
749         srh = get_and_validate_srh(skb);
750         if (!srh)
751                 goto drop;
752
753         err = seg6_do_srh_inline(skb, slwt->srh);
754         if (err)
755                 goto drop;
756
757         ipv6_hdr(skb)->payload_len = htons(skb->len - sizeof(struct ipv6hdr));
758         skb_set_transport_header(skb, sizeof(struct ipv6hdr));
759
760         seg6_lookup_nexthop(skb, NULL, 0);
761
762         return dst_input(skb);
763
764 drop:
765         kfree_skb(skb);
766         return err;
767 }
768
769 /* encapsulate within an outer IPv6 header and a specified SRH */
770 static int input_action_end_b6_encap(struct sk_buff *skb,
771                                      struct seg6_local_lwt *slwt)
772 {
773         struct ipv6_sr_hdr *srh;
774         int err = -EINVAL;
775
776         srh = get_and_validate_srh(skb);
777         if (!srh)
778                 goto drop;
779
780         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
781
782         skb_reset_inner_headers(skb);
783         skb->encapsulation = 1;
784
785         err = seg6_do_srh_encap(skb, slwt->srh, IPPROTO_IPV6);
786         if (err)
787                 goto drop;
788
789         ipv6_hdr(skb)->payload_len = htons(skb->len - sizeof(struct ipv6hdr));
790         skb_set_transport_header(skb, sizeof(struct ipv6hdr));
791
792         seg6_lookup_nexthop(skb, NULL, 0);
793
794         return dst_input(skb);
795
796 drop:
797         kfree_skb(skb);
798         return err;
799 }
800
801 DEFINE_PER_CPU(struct seg6_bpf_srh_state, seg6_bpf_srh_states);
802
803 bool seg6_bpf_has_valid_srh(struct sk_buff *skb)
804 {
805         struct seg6_bpf_srh_state *srh_state =
806                 this_cpu_ptr(&seg6_bpf_srh_states);
807         struct ipv6_sr_hdr *srh = srh_state->srh;
808
809         if (unlikely(srh == NULL))
810                 return false;
811
812         if (unlikely(!srh_state->valid)) {
813                 if ((srh_state->hdrlen & 7) != 0)
814                         return false;
815
816                 srh->hdrlen = (u8)(srh_state->hdrlen >> 3);
817                 if (!seg6_validate_srh(srh, (srh->hdrlen + 1) << 3, true))
818                         return false;
819
820                 srh_state->valid = true;
821         }
822
823         return true;
824 }
825
826 static int input_action_end_bpf(struct sk_buff *skb,
827                                 struct seg6_local_lwt *slwt)
828 {
829         struct seg6_bpf_srh_state *srh_state =
830                 this_cpu_ptr(&seg6_bpf_srh_states);
831         struct ipv6_sr_hdr *srh;
832         int ret;
833
834         srh = get_and_validate_srh(skb);
835         if (!srh) {
836                 kfree_skb(skb);
837                 return -EINVAL;
838         }
839         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
840
841         /* preempt_disable is needed to protect the per-CPU buffer srh_state,
842          * which is also accessed by the bpf_lwt_seg6_* helpers
843          */
844         preempt_disable();
845         srh_state->srh = srh;
846         srh_state->hdrlen = srh->hdrlen << 3;
847         srh_state->valid = true;
848
849         rcu_read_lock();
850         bpf_compute_data_pointers(skb);
851         ret = bpf_prog_run_save_cb(slwt->bpf.prog, skb);
852         rcu_read_unlock();
853
854         switch (ret) {
855         case BPF_OK:
856         case BPF_REDIRECT:
857                 break;
858         case BPF_DROP:
859                 goto drop;
860         default:
861                 pr_warn_once("bpf-seg6local: Illegal return value %u\n", ret);
862                 goto drop;
863         }
864
865         if (srh_state->srh && !seg6_bpf_has_valid_srh(skb))
866                 goto drop;
867
868         preempt_enable();
869         if (ret != BPF_REDIRECT)
870                 seg6_lookup_nexthop(skb, NULL, 0);
871
872         return dst_input(skb);
873
874 drop:
875         preempt_enable();
876         kfree_skb(skb);
877         return -EINVAL;
878 }
879
880 static struct seg6_action_desc seg6_action_table[] = {
881         {
882                 .action         = SEG6_LOCAL_ACTION_END,
883                 .attrs          = 0,
884                 .input          = input_action_end,
885         },
886         {
887                 .action         = SEG6_LOCAL_ACTION_END_X,
888                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_NH6),
889                 .input          = input_action_end_x,
890         },
891         {
892                 .action         = SEG6_LOCAL_ACTION_END_T,
893                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_TABLE),
894                 .input          = input_action_end_t,
895         },
896         {
897                 .action         = SEG6_LOCAL_ACTION_END_DX2,
898                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_OIF),
899                 .input          = input_action_end_dx2,
900         },
901         {
902                 .action         = SEG6_LOCAL_ACTION_END_DX6,
903                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_NH6),
904                 .input          = input_action_end_dx6,
905         },
906         {
907                 .action         = SEG6_LOCAL_ACTION_END_DX4,
908                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_NH4),
909                 .input          = input_action_end_dx4,
910         },
911         {
912                 .action         = SEG6_LOCAL_ACTION_END_DT4,
913                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE),
914 #ifdef CONFIG_NET_L3_MASTER_DEV
915                 .input          = input_action_end_dt4,
916                 .slwt_ops       = {
917                                         .build_state = seg6_end_dt4_build,
918                                   },
919 #endif
920         },
921         {
922                 .action         = SEG6_LOCAL_ACTION_END_DT6,
923 #ifdef CONFIG_NET_L3_MASTER_DEV
924                 .attrs          = 0,
925                 .optattrs       = SEG6_F_ATTR(SEG6_LOCAL_TABLE) |
926                                   SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE),
927                 .slwt_ops       = {
928                                         .build_state = seg6_end_dt6_build,
929                                   },
930 #else
931                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_TABLE),
932 #endif
933                 .input          = input_action_end_dt6,
934         },
935         {
936                 .action         = SEG6_LOCAL_ACTION_END_B6,
937                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_SRH),
938                 .input          = input_action_end_b6,
939         },
940         {
941                 .action         = SEG6_LOCAL_ACTION_END_B6_ENCAP,
942                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_SRH),
943                 .input          = input_action_end_b6_encap,
944                 .static_headroom        = sizeof(struct ipv6hdr),
945         },
946         {
947                 .action         = SEG6_LOCAL_ACTION_END_BPF,
948                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_BPF),
949                 .input          = input_action_end_bpf,
950         },
951
952 };
953
954 static struct seg6_action_desc *__get_action_desc(int action)
955 {
956         struct seg6_action_desc *desc;
957         int i, count;
958
959         count = ARRAY_SIZE(seg6_action_table);
960         for (i = 0; i < count; i++) {
961                 desc = &seg6_action_table[i];
962                 if (desc->action == action)
963                         return desc;
964         }
965
966         return NULL;
967 }
968
969 static int seg6_local_input(struct sk_buff *skb)
970 {
971         struct dst_entry *orig_dst = skb_dst(skb);
972         struct seg6_action_desc *desc;
973         struct seg6_local_lwt *slwt;
974
975         if (skb->protocol != htons(ETH_P_IPV6)) {
976                 kfree_skb(skb);
977                 return -EINVAL;
978         }
979
980         slwt = seg6_local_lwtunnel(orig_dst->lwtstate);
981         desc = slwt->desc;
982
983         return desc->input(skb, slwt);
984 }
985
986 static const struct nla_policy seg6_local_policy[SEG6_LOCAL_MAX + 1] = {
987         [SEG6_LOCAL_ACTION]     = { .type = NLA_U32 },
988         [SEG6_LOCAL_SRH]        = { .type = NLA_BINARY },
989         [SEG6_LOCAL_TABLE]      = { .type = NLA_U32 },
990         [SEG6_LOCAL_VRFTABLE]   = { .type = NLA_U32 },
991         [SEG6_LOCAL_NH4]        = { .type = NLA_BINARY,
992                                     .len = sizeof(struct in_addr) },
993         [SEG6_LOCAL_NH6]        = { .type = NLA_BINARY,
994                                     .len = sizeof(struct in6_addr) },
995         [SEG6_LOCAL_IIF]        = { .type = NLA_U32 },
996         [SEG6_LOCAL_OIF]        = { .type = NLA_U32 },
997         [SEG6_LOCAL_BPF]        = { .type = NLA_NESTED },
998 };
999
1000 static int parse_nla_srh(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1001 {
1002         struct ipv6_sr_hdr *srh;
1003         int len;
1004
1005         srh = nla_data(attrs[SEG6_LOCAL_SRH]);
1006         len = nla_len(attrs[SEG6_LOCAL_SRH]);
1007
1008         /* SRH must contain at least one segment */
1009         if (len < sizeof(*srh) + sizeof(struct in6_addr))
1010                 return -EINVAL;
1011
1012         if (!seg6_validate_srh(srh, len, false))
1013                 return -EINVAL;
1014
1015         slwt->srh = kmemdup(srh, len, GFP_KERNEL);
1016         if (!slwt->srh)
1017                 return -ENOMEM;
1018
1019         slwt->headroom += len;
1020
1021         return 0;
1022 }
1023
1024 static int put_nla_srh(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1025 {
1026         struct ipv6_sr_hdr *srh;
1027         struct nlattr *nla;
1028         int len;
1029
1030         srh = slwt->srh;
1031         len = (srh->hdrlen + 1) << 3;
1032
1033         nla = nla_reserve(skb, SEG6_LOCAL_SRH, len);
1034         if (!nla)
1035                 return -EMSGSIZE;
1036
1037         memcpy(nla_data(nla), srh, len);
1038
1039         return 0;
1040 }
1041
1042 static int cmp_nla_srh(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1043 {
1044         int len = (a->srh->hdrlen + 1) << 3;
1045
1046         if (len != ((b->srh->hdrlen + 1) << 3))
1047                 return 1;
1048
1049         return memcmp(a->srh, b->srh, len);
1050 }
1051
1052 static void destroy_attr_srh(struct seg6_local_lwt *slwt)
1053 {
1054         kfree(slwt->srh);
1055 }
1056
1057 static int parse_nla_table(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1058 {
1059         slwt->table = nla_get_u32(attrs[SEG6_LOCAL_TABLE]);
1060
1061         return 0;
1062 }
1063
1064 static int put_nla_table(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1065 {
1066         if (nla_put_u32(skb, SEG6_LOCAL_TABLE, slwt->table))
1067                 return -EMSGSIZE;
1068
1069         return 0;
1070 }
1071
1072 static int cmp_nla_table(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1073 {
1074         if (a->table != b->table)
1075                 return 1;
1076
1077         return 0;
1078 }
1079
1080 static struct
1081 seg6_end_dt_info *seg6_possible_end_dt_info(struct seg6_local_lwt *slwt)
1082 {
1083 #ifdef CONFIG_NET_L3_MASTER_DEV
1084         return &slwt->dt_info;
1085 #else
1086         return ERR_PTR(-EOPNOTSUPP);
1087 #endif
1088 }
1089
1090 static int parse_nla_vrftable(struct nlattr **attrs,
1091                               struct seg6_local_lwt *slwt)
1092 {
1093         struct seg6_end_dt_info *info = seg6_possible_end_dt_info(slwt);
1094
1095         if (IS_ERR(info))
1096                 return PTR_ERR(info);
1097
1098         info->vrf_table = nla_get_u32(attrs[SEG6_LOCAL_VRFTABLE]);
1099
1100         return 0;
1101 }
1102
1103 static int put_nla_vrftable(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1104 {
1105         struct seg6_end_dt_info *info = seg6_possible_end_dt_info(slwt);
1106
1107         if (IS_ERR(info))
1108                 return PTR_ERR(info);
1109
1110         if (nla_put_u32(skb, SEG6_LOCAL_VRFTABLE, info->vrf_table))
1111                 return -EMSGSIZE;
1112
1113         return 0;
1114 }
1115
1116 static int cmp_nla_vrftable(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1117 {
1118         struct seg6_end_dt_info *info_a = seg6_possible_end_dt_info(a);
1119         struct seg6_end_dt_info *info_b = seg6_possible_end_dt_info(b);
1120
1121         if (info_a->vrf_table != info_b->vrf_table)
1122                 return 1;
1123
1124         return 0;
1125 }
1126
1127 static int parse_nla_nh4(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1128 {
1129         memcpy(&slwt->nh4, nla_data(attrs[SEG6_LOCAL_NH4]),
1130                sizeof(struct in_addr));
1131
1132         return 0;
1133 }
1134
1135 static int put_nla_nh4(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1136 {
1137         struct nlattr *nla;
1138
1139         nla = nla_reserve(skb, SEG6_LOCAL_NH4, sizeof(struct in_addr));
1140         if (!nla)
1141                 return -EMSGSIZE;
1142
1143         memcpy(nla_data(nla), &slwt->nh4, sizeof(struct in_addr));
1144
1145         return 0;
1146 }
1147
1148 static int cmp_nla_nh4(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1149 {
1150         return memcmp(&a->nh4, &b->nh4, sizeof(struct in_addr));
1151 }
1152
1153 static int parse_nla_nh6(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1154 {
1155         memcpy(&slwt->nh6, nla_data(attrs[SEG6_LOCAL_NH6]),
1156                sizeof(struct in6_addr));
1157
1158         return 0;
1159 }
1160
1161 static int put_nla_nh6(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1162 {
1163         struct nlattr *nla;
1164
1165         nla = nla_reserve(skb, SEG6_LOCAL_NH6, sizeof(struct in6_addr));
1166         if (!nla)
1167                 return -EMSGSIZE;
1168
1169         memcpy(nla_data(nla), &slwt->nh6, sizeof(struct in6_addr));
1170
1171         return 0;
1172 }
1173
1174 static int cmp_nla_nh6(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1175 {
1176         return memcmp(&a->nh6, &b->nh6, sizeof(struct in6_addr));
1177 }
1178
1179 static int parse_nla_iif(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1180 {
1181         slwt->iif = nla_get_u32(attrs[SEG6_LOCAL_IIF]);
1182
1183         return 0;
1184 }
1185
1186 static int put_nla_iif(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1187 {
1188         if (nla_put_u32(skb, SEG6_LOCAL_IIF, slwt->iif))
1189                 return -EMSGSIZE;
1190
1191         return 0;
1192 }
1193
1194 static int cmp_nla_iif(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1195 {
1196         if (a->iif != b->iif)
1197                 return 1;
1198
1199         return 0;
1200 }
1201
1202 static int parse_nla_oif(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1203 {
1204         slwt->oif = nla_get_u32(attrs[SEG6_LOCAL_OIF]);
1205
1206         return 0;
1207 }
1208
1209 static int put_nla_oif(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1210 {
1211         if (nla_put_u32(skb, SEG6_LOCAL_OIF, slwt->oif))
1212                 return -EMSGSIZE;
1213
1214         return 0;
1215 }
1216
1217 static int cmp_nla_oif(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1218 {
1219         if (a->oif != b->oif)
1220                 return 1;
1221
1222         return 0;
1223 }
1224
1225 #define MAX_PROG_NAME 256
1226 static const struct nla_policy bpf_prog_policy[SEG6_LOCAL_BPF_PROG_MAX + 1] = {
1227         [SEG6_LOCAL_BPF_PROG]      = { .type = NLA_U32, },
1228         [SEG6_LOCAL_BPF_PROG_NAME] = { .type = NLA_NUL_STRING,
1229                                        .len = MAX_PROG_NAME },
1230 };
1231
1232 static int parse_nla_bpf(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1233 {
1234         struct nlattr *tb[SEG6_LOCAL_BPF_PROG_MAX + 1];
1235         struct bpf_prog *p;
1236         int ret;
1237         u32 fd;
1238
1239         ret = nla_parse_nested_deprecated(tb, SEG6_LOCAL_BPF_PROG_MAX,
1240                                           attrs[SEG6_LOCAL_BPF],
1241                                           bpf_prog_policy, NULL);
1242         if (ret < 0)
1243                 return ret;
1244
1245         if (!tb[SEG6_LOCAL_BPF_PROG] || !tb[SEG6_LOCAL_BPF_PROG_NAME])
1246                 return -EINVAL;
1247
1248         slwt->bpf.name = nla_memdup(tb[SEG6_LOCAL_BPF_PROG_NAME], GFP_KERNEL);
1249         if (!slwt->bpf.name)
1250                 return -ENOMEM;
1251
1252         fd = nla_get_u32(tb[SEG6_LOCAL_BPF_PROG]);
1253         p = bpf_prog_get_type(fd, BPF_PROG_TYPE_LWT_SEG6LOCAL);
1254         if (IS_ERR(p)) {
1255                 kfree(slwt->bpf.name);
1256                 return PTR_ERR(p);
1257         }
1258
1259         slwt->bpf.prog = p;
1260         return 0;
1261 }
1262
1263 static int put_nla_bpf(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1264 {
1265         struct nlattr *nest;
1266
1267         if (!slwt->bpf.prog)
1268                 return 0;
1269
1270         nest = nla_nest_start_noflag(skb, SEG6_LOCAL_BPF);
1271         if (!nest)
1272                 return -EMSGSIZE;
1273
1274         if (nla_put_u32(skb, SEG6_LOCAL_BPF_PROG, slwt->bpf.prog->aux->id))
1275                 return -EMSGSIZE;
1276
1277         if (slwt->bpf.name &&
1278             nla_put_string(skb, SEG6_LOCAL_BPF_PROG_NAME, slwt->bpf.name))
1279                 return -EMSGSIZE;
1280
1281         return nla_nest_end(skb, nest);
1282 }
1283
1284 static int cmp_nla_bpf(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1285 {
1286         if (!a->bpf.name && !b->bpf.name)
1287                 return 0;
1288
1289         if (!a->bpf.name || !b->bpf.name)
1290                 return 1;
1291
1292         return strcmp(a->bpf.name, b->bpf.name);
1293 }
1294
1295 static void destroy_attr_bpf(struct seg6_local_lwt *slwt)
1296 {
1297         kfree(slwt->bpf.name);
1298         if (slwt->bpf.prog)
1299                 bpf_prog_put(slwt->bpf.prog);
1300 }
1301
1302 struct seg6_action_param {
1303         int (*parse)(struct nlattr **attrs, struct seg6_local_lwt *slwt);
1304         int (*put)(struct sk_buff *skb, struct seg6_local_lwt *slwt);
1305         int (*cmp)(struct seg6_local_lwt *a, struct seg6_local_lwt *b);
1306
1307         /* optional destroy() callback useful for releasing resources which
1308          * have been previously acquired in the corresponding parse()
1309          * function.
1310          */
1311         void (*destroy)(struct seg6_local_lwt *slwt);
1312 };
1313
1314 static struct seg6_action_param seg6_action_params[SEG6_LOCAL_MAX + 1] = {
1315         [SEG6_LOCAL_SRH]        = { .parse = parse_nla_srh,
1316                                     .put = put_nla_srh,
1317                                     .cmp = cmp_nla_srh,
1318                                     .destroy = destroy_attr_srh },
1319
1320         [SEG6_LOCAL_TABLE]      = { .parse = parse_nla_table,
1321                                     .put = put_nla_table,
1322                                     .cmp = cmp_nla_table },
1323
1324         [SEG6_LOCAL_NH4]        = { .parse = parse_nla_nh4,
1325                                     .put = put_nla_nh4,
1326                                     .cmp = cmp_nla_nh4 },
1327
1328         [SEG6_LOCAL_NH6]        = { .parse = parse_nla_nh6,
1329                                     .put = put_nla_nh6,
1330                                     .cmp = cmp_nla_nh6 },
1331
1332         [SEG6_LOCAL_IIF]        = { .parse = parse_nla_iif,
1333                                     .put = put_nla_iif,
1334                                     .cmp = cmp_nla_iif },
1335
1336         [SEG6_LOCAL_OIF]        = { .parse = parse_nla_oif,
1337                                     .put = put_nla_oif,
1338                                     .cmp = cmp_nla_oif },
1339
1340         [SEG6_LOCAL_BPF]        = { .parse = parse_nla_bpf,
1341                                     .put = put_nla_bpf,
1342                                     .cmp = cmp_nla_bpf,
1343                                     .destroy = destroy_attr_bpf },
1344
1345         [SEG6_LOCAL_VRFTABLE]   = { .parse = parse_nla_vrftable,
1346                                     .put = put_nla_vrftable,
1347                                     .cmp = cmp_nla_vrftable },
1348
1349 };
1350
1351 /* call the destroy() callback (if available) for each set attribute in
1352  * @parsed_attrs, starting from the first attribute up to the @max_parsed
1353  * (excluded) attribute.
1354  */
1355 static void __destroy_attrs(unsigned long parsed_attrs, int max_parsed,
1356                             struct seg6_local_lwt *slwt)
1357 {
1358         struct seg6_action_param *param;
1359         int i;
1360
1361         /* Every required seg6local attribute is identified by an ID which is
1362          * encoded as a flag (i.e: 1 << ID) in the 'attrs' bitmask;
1363          *
1364          * We scan the 'parsed_attrs' bitmask, starting from the first attribute
1365          * up to the @max_parsed (excluded) attribute.
1366          * For each set attribute, we retrieve the corresponding destroy()
1367          * callback. If the callback is not available, then we skip to the next
1368          * attribute; otherwise, we call the destroy() callback.
1369          */
1370         for (i = 0; i < max_parsed; ++i) {
1371                 if (!(parsed_attrs & SEG6_F_ATTR(i)))
1372                         continue;
1373
1374                 param = &seg6_action_params[i];
1375
1376                 if (param->destroy)
1377                         param->destroy(slwt);
1378         }
1379 }
1380
1381 /* release all the resources that may have been acquired during parsing
1382  * operations.
1383  */
1384 static void destroy_attrs(struct seg6_local_lwt *slwt)
1385 {
1386         unsigned long attrs = slwt->desc->attrs | slwt->parsed_optattrs;
1387
1388         __destroy_attrs(attrs, SEG6_LOCAL_MAX + 1, slwt);
1389 }
1390
1391 static int parse_nla_optional_attrs(struct nlattr **attrs,
1392                                     struct seg6_local_lwt *slwt)
1393 {
1394         struct seg6_action_desc *desc = slwt->desc;
1395         unsigned long parsed_optattrs = 0;
1396         struct seg6_action_param *param;
1397         int err, i;
1398
1399         for (i = 0; i < SEG6_LOCAL_MAX + 1; ++i) {
1400                 if (!(desc->optattrs & SEG6_F_ATTR(i)) || !attrs[i])
1401                         continue;
1402
1403                 /* once here, the i-th attribute is provided by the
1404                  * userspace AND it is identified optional as well.
1405                  */
1406                 param = &seg6_action_params[i];
1407
1408                 err = param->parse(attrs, slwt);
1409                 if (err < 0)
1410                         goto parse_optattrs_err;
1411
1412                 /* current attribute has been correctly parsed */
1413                 parsed_optattrs |= SEG6_F_ATTR(i);
1414         }
1415
1416         /* store in the tunnel state all the optional attributed successfully
1417          * parsed.
1418          */
1419         slwt->parsed_optattrs = parsed_optattrs;
1420
1421         return 0;
1422
1423 parse_optattrs_err:
1424         __destroy_attrs(parsed_optattrs, i, slwt);
1425
1426         return err;
1427 }
1428
1429 /* call the custom constructor of the behavior during its initialization phase
1430  * and after that all its attributes have been parsed successfully.
1431  */
1432 static int
1433 seg6_local_lwtunnel_build_state(struct seg6_local_lwt *slwt, const void *cfg,
1434                                 struct netlink_ext_ack *extack)
1435 {
1436         struct seg6_action_desc *desc = slwt->desc;
1437         struct seg6_local_lwtunnel_ops *ops;
1438
1439         ops = &desc->slwt_ops;
1440         if (!ops->build_state)
1441                 return 0;
1442
1443         return ops->build_state(slwt, cfg, extack);
1444 }
1445
1446 /* call the custom destructor of the behavior which is invoked before the
1447  * tunnel is going to be destroyed.
1448  */
1449 static void seg6_local_lwtunnel_destroy_state(struct seg6_local_lwt *slwt)
1450 {
1451         struct seg6_action_desc *desc = slwt->desc;
1452         struct seg6_local_lwtunnel_ops *ops;
1453
1454         ops = &desc->slwt_ops;
1455         if (!ops->destroy_state)
1456                 return;
1457
1458         ops->destroy_state(slwt);
1459 }
1460
1461 static int parse_nla_action(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1462 {
1463         struct seg6_action_param *param;
1464         struct seg6_action_desc *desc;
1465         unsigned long invalid_attrs;
1466         int i, err;
1467
1468         desc = __get_action_desc(slwt->action);
1469         if (!desc)
1470                 return -EINVAL;
1471
1472         if (!desc->input)
1473                 return -EOPNOTSUPP;
1474
1475         slwt->desc = desc;
1476         slwt->headroom += desc->static_headroom;
1477
1478         /* Forcing the desc->optattrs *set* and the desc->attrs *set* to be
1479          * disjoined, this allow us to release acquired resources by optional
1480          * attributes and by required attributes independently from each other
1481          * without any interfarence.
1482          * In other terms, we are sure that we do not release some the acquired
1483          * resources twice.
1484          *
1485          * Note that if an attribute is configured both as required and as
1486          * optional, it means that the user has messed something up in the
1487          * seg6_action_table. Therefore, this check is required for SRv6
1488          * behaviors to work properly.
1489          */
1490         invalid_attrs = desc->attrs & desc->optattrs;
1491         if (invalid_attrs) {
1492                 WARN_ONCE(1,
1493                           "An attribute cannot be both required AND optional");
1494                 return -EINVAL;
1495         }
1496
1497         /* parse the required attributes */
1498         for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
1499                 if (desc->attrs & SEG6_F_ATTR(i)) {
1500                         if (!attrs[i])
1501                                 return -EINVAL;
1502
1503                         param = &seg6_action_params[i];
1504
1505                         err = param->parse(attrs, slwt);
1506                         if (err < 0)
1507                                 goto parse_attrs_err;
1508                 }
1509         }
1510
1511         /* parse the optional attributes, if any */
1512         err = parse_nla_optional_attrs(attrs, slwt);
1513         if (err < 0)
1514                 goto parse_attrs_err;
1515
1516         return 0;
1517
1518 parse_attrs_err:
1519         /* release any resource that may have been acquired during the i-1
1520          * parse() operations.
1521          */
1522         __destroy_attrs(desc->attrs, i, slwt);
1523
1524         return err;
1525 }
1526
1527 static int seg6_local_build_state(struct net *net, struct nlattr *nla,
1528                                   unsigned int family, const void *cfg,
1529                                   struct lwtunnel_state **ts,
1530                                   struct netlink_ext_ack *extack)
1531 {
1532         struct nlattr *tb[SEG6_LOCAL_MAX + 1];
1533         struct lwtunnel_state *newts;
1534         struct seg6_local_lwt *slwt;
1535         int err;
1536
1537         if (family != AF_INET6)
1538                 return -EINVAL;
1539
1540         err = nla_parse_nested_deprecated(tb, SEG6_LOCAL_MAX, nla,
1541                                           seg6_local_policy, extack);
1542
1543         if (err < 0)
1544                 return err;
1545
1546         if (!tb[SEG6_LOCAL_ACTION])
1547                 return -EINVAL;
1548
1549         newts = lwtunnel_state_alloc(sizeof(*slwt));
1550         if (!newts)
1551                 return -ENOMEM;
1552
1553         slwt = seg6_local_lwtunnel(newts);
1554         slwt->action = nla_get_u32(tb[SEG6_LOCAL_ACTION]);
1555
1556         err = parse_nla_action(tb, slwt);
1557         if (err < 0)
1558                 goto out_free;
1559
1560         err = seg6_local_lwtunnel_build_state(slwt, cfg, extack);
1561         if (err < 0)
1562                 goto out_destroy_attrs;
1563
1564         newts->type = LWTUNNEL_ENCAP_SEG6_LOCAL;
1565         newts->flags = LWTUNNEL_STATE_INPUT_REDIRECT;
1566         newts->headroom = slwt->headroom;
1567
1568         *ts = newts;
1569
1570         return 0;
1571
1572 out_destroy_attrs:
1573         destroy_attrs(slwt);
1574 out_free:
1575         kfree(newts);
1576         return err;
1577 }
1578
1579 static void seg6_local_destroy_state(struct lwtunnel_state *lwt)
1580 {
1581         struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
1582
1583         seg6_local_lwtunnel_destroy_state(slwt);
1584
1585         destroy_attrs(slwt);
1586
1587         return;
1588 }
1589
1590 static int seg6_local_fill_encap(struct sk_buff *skb,
1591                                  struct lwtunnel_state *lwt)
1592 {
1593         struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
1594         struct seg6_action_param *param;
1595         unsigned long attrs;
1596         int i, err;
1597
1598         if (nla_put_u32(skb, SEG6_LOCAL_ACTION, slwt->action))
1599                 return -EMSGSIZE;
1600
1601         attrs = slwt->desc->attrs | slwt->parsed_optattrs;
1602
1603         for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
1604                 if (attrs & SEG6_F_ATTR(i)) {
1605                         param = &seg6_action_params[i];
1606                         err = param->put(skb, slwt);
1607                         if (err < 0)
1608                                 return err;
1609                 }
1610         }
1611
1612         return 0;
1613 }
1614
1615 static int seg6_local_get_encap_size(struct lwtunnel_state *lwt)
1616 {
1617         struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
1618         unsigned long attrs;
1619         int nlsize;
1620
1621         nlsize = nla_total_size(4); /* action */
1622
1623         attrs = slwt->desc->attrs | slwt->parsed_optattrs;
1624
1625         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_SRH))
1626                 nlsize += nla_total_size((slwt->srh->hdrlen + 1) << 3);
1627
1628         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_TABLE))
1629                 nlsize += nla_total_size(4);
1630
1631         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_NH4))
1632                 nlsize += nla_total_size(4);
1633
1634         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_NH6))
1635                 nlsize += nla_total_size(16);
1636
1637         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_IIF))
1638                 nlsize += nla_total_size(4);
1639
1640         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_OIF))
1641                 nlsize += nla_total_size(4);
1642
1643         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_BPF))
1644                 nlsize += nla_total_size(sizeof(struct nlattr)) +
1645                        nla_total_size(MAX_PROG_NAME) +
1646                        nla_total_size(4);
1647
1648         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE))
1649                 nlsize += nla_total_size(4);
1650
1651         return nlsize;
1652 }
1653
1654 static int seg6_local_cmp_encap(struct lwtunnel_state *a,
1655                                 struct lwtunnel_state *b)
1656 {
1657         struct seg6_local_lwt *slwt_a, *slwt_b;
1658         struct seg6_action_param *param;
1659         unsigned long attrs_a, attrs_b;
1660         int i;
1661
1662         slwt_a = seg6_local_lwtunnel(a);
1663         slwt_b = seg6_local_lwtunnel(b);
1664
1665         if (slwt_a->action != slwt_b->action)
1666                 return 1;
1667
1668         attrs_a = slwt_a->desc->attrs | slwt_a->parsed_optattrs;
1669         attrs_b = slwt_b->desc->attrs | slwt_b->parsed_optattrs;
1670
1671         if (attrs_a != attrs_b)
1672                 return 1;
1673
1674         for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
1675                 if (attrs_a & SEG6_F_ATTR(i)) {
1676                         param = &seg6_action_params[i];
1677                         if (param->cmp(slwt_a, slwt_b))
1678                                 return 1;
1679                 }
1680         }
1681
1682         return 0;
1683 }
1684
1685 static const struct lwtunnel_encap_ops seg6_local_ops = {
1686         .build_state    = seg6_local_build_state,
1687         .destroy_state  = seg6_local_destroy_state,
1688         .input          = seg6_local_input,
1689         .fill_encap     = seg6_local_fill_encap,
1690         .get_encap_size = seg6_local_get_encap_size,
1691         .cmp_encap      = seg6_local_cmp_encap,
1692         .owner          = THIS_MODULE,
1693 };
1694
1695 int __init seg6_local_init(void)
1696 {
1697         /* If the max total number of defined attributes is reached, then your
1698          * kernel build stops here.
1699          *
1700          * This check is required to avoid arithmetic overflows when processing
1701          * behavior attributes and the maximum number of defined attributes
1702          * exceeds the allowed value.
1703          */
1704         BUILD_BUG_ON(SEG6_LOCAL_MAX + 1 > BITS_PER_TYPE(unsigned long));
1705
1706         return lwtunnel_encap_add_ops(&seg6_local_ops,
1707                                       LWTUNNEL_ENCAP_SEG6_LOCAL);
1708 }
1709
1710 void seg6_local_exit(void)
1711 {
1712         lwtunnel_encap_del_ops(&seg6_local_ops, LWTUNNEL_ENCAP_SEG6_LOCAL);
1713 }