atlantic: Fix driver resume flow.
[linux-2.6-microblaze.git] / net / ipv6 / seg6_local.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *  SR-IPv6 implementation
4  *
5  *  Authors:
6  *  David Lebrun <david.lebrun@uclouvain.be>
7  *  eBPF support: Mathieu Xhonneux <m.xhonneux@gmail.com>
8  */
9
10 #include <linux/types.h>
11 #include <linux/skbuff.h>
12 #include <linux/net.h>
13 #include <linux/module.h>
14 #include <net/ip.h>
15 #include <net/lwtunnel.h>
16 #include <net/netevent.h>
17 #include <net/netns/generic.h>
18 #include <net/ip6_fib.h>
19 #include <net/route.h>
20 #include <net/seg6.h>
21 #include <linux/seg6.h>
22 #include <linux/seg6_local.h>
23 #include <net/addrconf.h>
24 #include <net/ip6_route.h>
25 #include <net/dst_cache.h>
26 #include <net/ip_tunnels.h>
27 #ifdef CONFIG_IPV6_SEG6_HMAC
28 #include <net/seg6_hmac.h>
29 #endif
30 #include <net/seg6_local.h>
31 #include <linux/etherdevice.h>
32 #include <linux/bpf.h>
33
34 #define SEG6_F_ATTR(i)          BIT(i)
35
36 struct seg6_local_lwt;
37
38 /* callbacks used for customizing the creation and destruction of a behavior */
39 struct seg6_local_lwtunnel_ops {
40         int (*build_state)(struct seg6_local_lwt *slwt, const void *cfg,
41                            struct netlink_ext_ack *extack);
42         void (*destroy_state)(struct seg6_local_lwt *slwt);
43 };
44
45 struct seg6_action_desc {
46         int action;
47         unsigned long attrs;
48
49         /* The optattrs field is used for specifying all the optional
50          * attributes supported by a specific behavior.
51          * It means that if one of these attributes is not provided in the
52          * netlink message during the behavior creation, no errors will be
53          * returned to the userspace.
54          *
55          * Each attribute can be only of two types (mutually exclusive):
56          * 1) required or 2) optional.
57          * Every user MUST obey to this rule! If you set an attribute as
58          * required the same attribute CANNOT be set as optional and vice
59          * versa.
60          */
61         unsigned long optattrs;
62
63         int (*input)(struct sk_buff *skb, struct seg6_local_lwt *slwt);
64         int static_headroom;
65
66         struct seg6_local_lwtunnel_ops slwt_ops;
67 };
68
69 struct bpf_lwt_prog {
70         struct bpf_prog *prog;
71         char *name;
72 };
73
74 enum seg6_end_dt_mode {
75         DT_INVALID_MODE = -EINVAL,
76         DT_LEGACY_MODE  = 0,
77         DT_VRF_MODE     = 1,
78 };
79
80 struct seg6_end_dt_info {
81         enum seg6_end_dt_mode mode;
82
83         struct net *net;
84         /* VRF device associated to the routing table used by the SRv6
85          * End.DT4/DT6 behavior for routing IPv4/IPv6 packets.
86          */
87         int vrf_ifindex;
88         int vrf_table;
89
90         /* tunneled packet family (IPv4 or IPv6).
91          * Protocol and header length are inferred from family.
92          */
93         u16 family;
94 };
95
96 struct pcpu_seg6_local_counters {
97         u64_stats_t packets;
98         u64_stats_t bytes;
99         u64_stats_t errors;
100
101         struct u64_stats_sync syncp;
102 };
103
104 /* This struct groups all the SRv6 Behavior counters supported so far.
105  *
106  * put_nla_counters() makes use of this data structure to collect all counter
107  * values after the per-CPU counter evaluation has been performed.
108  * Finally, each counter value (in seg6_local_counters) is stored in the
109  * corresponding netlink attribute and sent to user space.
110  *
111  * NB: we don't want to expose this structure to user space!
112  */
113 struct seg6_local_counters {
114         __u64 packets;
115         __u64 bytes;
116         __u64 errors;
117 };
118
119 #define seg6_local_alloc_pcpu_counters(__gfp)                           \
120         __netdev_alloc_pcpu_stats(struct pcpu_seg6_local_counters,      \
121                                   ((__gfp) | __GFP_ZERO))
122
123 #define SEG6_F_LOCAL_COUNTERS   SEG6_F_ATTR(SEG6_LOCAL_COUNTERS)
124
125 struct seg6_local_lwt {
126         int action;
127         struct ipv6_sr_hdr *srh;
128         int table;
129         struct in_addr nh4;
130         struct in6_addr nh6;
131         int iif;
132         int oif;
133         struct bpf_lwt_prog bpf;
134 #ifdef CONFIG_NET_L3_MASTER_DEV
135         struct seg6_end_dt_info dt_info;
136 #endif
137         struct pcpu_seg6_local_counters __percpu *pcpu_counters;
138
139         int headroom;
140         struct seg6_action_desc *desc;
141         /* unlike the required attrs, we have to track the optional attributes
142          * that have been effectively parsed.
143          */
144         unsigned long parsed_optattrs;
145 };
146
147 static struct seg6_local_lwt *seg6_local_lwtunnel(struct lwtunnel_state *lwt)
148 {
149         return (struct seg6_local_lwt *)lwt->data;
150 }
151
152 static struct ipv6_sr_hdr *get_srh(struct sk_buff *skb, int flags)
153 {
154         struct ipv6_sr_hdr *srh;
155         int len, srhoff = 0;
156
157         if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, &flags) < 0)
158                 return NULL;
159
160         if (!pskb_may_pull(skb, srhoff + sizeof(*srh)))
161                 return NULL;
162
163         srh = (struct ipv6_sr_hdr *)(skb->data + srhoff);
164
165         len = (srh->hdrlen + 1) << 3;
166
167         if (!pskb_may_pull(skb, srhoff + len))
168                 return NULL;
169
170         /* note that pskb_may_pull may change pointers in header;
171          * for this reason it is necessary to reload them when needed.
172          */
173         srh = (struct ipv6_sr_hdr *)(skb->data + srhoff);
174
175         if (!seg6_validate_srh(srh, len, true))
176                 return NULL;
177
178         return srh;
179 }
180
181 static struct ipv6_sr_hdr *get_and_validate_srh(struct sk_buff *skb)
182 {
183         struct ipv6_sr_hdr *srh;
184
185         srh = get_srh(skb, IP6_FH_F_SKIP_RH);
186         if (!srh)
187                 return NULL;
188
189 #ifdef CONFIG_IPV6_SEG6_HMAC
190         if (!seg6_hmac_validate_skb(skb))
191                 return NULL;
192 #endif
193
194         return srh;
195 }
196
197 static bool decap_and_validate(struct sk_buff *skb, int proto)
198 {
199         struct ipv6_sr_hdr *srh;
200         unsigned int off = 0;
201
202         srh = get_srh(skb, 0);
203         if (srh && srh->segments_left > 0)
204                 return false;
205
206 #ifdef CONFIG_IPV6_SEG6_HMAC
207         if (srh && !seg6_hmac_validate_skb(skb))
208                 return false;
209 #endif
210
211         if (ipv6_find_hdr(skb, &off, proto, NULL, NULL) < 0)
212                 return false;
213
214         if (!pskb_pull(skb, off))
215                 return false;
216
217         skb_postpull_rcsum(skb, skb_network_header(skb), off);
218
219         skb_reset_network_header(skb);
220         skb_reset_transport_header(skb);
221         if (iptunnel_pull_offloads(skb))
222                 return false;
223
224         return true;
225 }
226
227 static void advance_nextseg(struct ipv6_sr_hdr *srh, struct in6_addr *daddr)
228 {
229         struct in6_addr *addr;
230
231         srh->segments_left--;
232         addr = srh->segments + srh->segments_left;
233         *daddr = *addr;
234 }
235
236 static int
237 seg6_lookup_any_nexthop(struct sk_buff *skb, struct in6_addr *nhaddr,
238                         u32 tbl_id, bool local_delivery)
239 {
240         struct net *net = dev_net(skb->dev);
241         struct ipv6hdr *hdr = ipv6_hdr(skb);
242         int flags = RT6_LOOKUP_F_HAS_SADDR;
243         struct dst_entry *dst = NULL;
244         struct rt6_info *rt;
245         struct flowi6 fl6;
246         int dev_flags = 0;
247
248         fl6.flowi6_iif = skb->dev->ifindex;
249         fl6.daddr = nhaddr ? *nhaddr : hdr->daddr;
250         fl6.saddr = hdr->saddr;
251         fl6.flowlabel = ip6_flowinfo(hdr);
252         fl6.flowi6_mark = skb->mark;
253         fl6.flowi6_proto = hdr->nexthdr;
254
255         if (nhaddr)
256                 fl6.flowi6_flags = FLOWI_FLAG_KNOWN_NH;
257
258         if (!tbl_id) {
259                 dst = ip6_route_input_lookup(net, skb->dev, &fl6, skb, flags);
260         } else {
261                 struct fib6_table *table;
262
263                 table = fib6_get_table(net, tbl_id);
264                 if (!table)
265                         goto out;
266
267                 rt = ip6_pol_route(net, table, 0, &fl6, skb, flags);
268                 dst = &rt->dst;
269         }
270
271         /* we want to discard traffic destined for local packet processing,
272          * if @local_delivery is set to false.
273          */
274         if (!local_delivery)
275                 dev_flags |= IFF_LOOPBACK;
276
277         if (dst && (dst->dev->flags & dev_flags) && !dst->error) {
278                 dst_release(dst);
279                 dst = NULL;
280         }
281
282 out:
283         if (!dst) {
284                 rt = net->ipv6.ip6_blk_hole_entry;
285                 dst = &rt->dst;
286                 dst_hold(dst);
287         }
288
289         skb_dst_drop(skb);
290         skb_dst_set(skb, dst);
291         return dst->error;
292 }
293
294 int seg6_lookup_nexthop(struct sk_buff *skb,
295                         struct in6_addr *nhaddr, u32 tbl_id)
296 {
297         return seg6_lookup_any_nexthop(skb, nhaddr, tbl_id, false);
298 }
299
300 /* regular endpoint function */
301 static int input_action_end(struct sk_buff *skb, struct seg6_local_lwt *slwt)
302 {
303         struct ipv6_sr_hdr *srh;
304
305         srh = get_and_validate_srh(skb);
306         if (!srh)
307                 goto drop;
308
309         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
310
311         seg6_lookup_nexthop(skb, NULL, 0);
312
313         return dst_input(skb);
314
315 drop:
316         kfree_skb(skb);
317         return -EINVAL;
318 }
319
320 /* regular endpoint, and forward to specified nexthop */
321 static int input_action_end_x(struct sk_buff *skb, struct seg6_local_lwt *slwt)
322 {
323         struct ipv6_sr_hdr *srh;
324
325         srh = get_and_validate_srh(skb);
326         if (!srh)
327                 goto drop;
328
329         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
330
331         seg6_lookup_nexthop(skb, &slwt->nh6, 0);
332
333         return dst_input(skb);
334
335 drop:
336         kfree_skb(skb);
337         return -EINVAL;
338 }
339
340 static int input_action_end_t(struct sk_buff *skb, struct seg6_local_lwt *slwt)
341 {
342         struct ipv6_sr_hdr *srh;
343
344         srh = get_and_validate_srh(skb);
345         if (!srh)
346                 goto drop;
347
348         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
349
350         seg6_lookup_nexthop(skb, NULL, slwt->table);
351
352         return dst_input(skb);
353
354 drop:
355         kfree_skb(skb);
356         return -EINVAL;
357 }
358
359 /* decapsulate and forward inner L2 frame on specified interface */
360 static int input_action_end_dx2(struct sk_buff *skb,
361                                 struct seg6_local_lwt *slwt)
362 {
363         struct net *net = dev_net(skb->dev);
364         struct net_device *odev;
365         struct ethhdr *eth;
366
367         if (!decap_and_validate(skb, IPPROTO_ETHERNET))
368                 goto drop;
369
370         if (!pskb_may_pull(skb, ETH_HLEN))
371                 goto drop;
372
373         skb_reset_mac_header(skb);
374         eth = (struct ethhdr *)skb->data;
375
376         /* To determine the frame's protocol, we assume it is 802.3. This avoids
377          * a call to eth_type_trans(), which is not really relevant for our
378          * use case.
379          */
380         if (!eth_proto_is_802_3(eth->h_proto))
381                 goto drop;
382
383         odev = dev_get_by_index_rcu(net, slwt->oif);
384         if (!odev)
385                 goto drop;
386
387         /* As we accept Ethernet frames, make sure the egress device is of
388          * the correct type.
389          */
390         if (odev->type != ARPHRD_ETHER)
391                 goto drop;
392
393         if (!(odev->flags & IFF_UP) || !netif_carrier_ok(odev))
394                 goto drop;
395
396         skb_orphan(skb);
397
398         if (skb_warn_if_lro(skb))
399                 goto drop;
400
401         skb_forward_csum(skb);
402
403         if (skb->len - ETH_HLEN > odev->mtu)
404                 goto drop;
405
406         skb->dev = odev;
407         skb->protocol = eth->h_proto;
408
409         return dev_queue_xmit(skb);
410
411 drop:
412         kfree_skb(skb);
413         return -EINVAL;
414 }
415
416 /* decapsulate and forward to specified nexthop */
417 static int input_action_end_dx6(struct sk_buff *skb,
418                                 struct seg6_local_lwt *slwt)
419 {
420         struct in6_addr *nhaddr = NULL;
421
422         /* this function accepts IPv6 encapsulated packets, with either
423          * an SRH with SL=0, or no SRH.
424          */
425
426         if (!decap_and_validate(skb, IPPROTO_IPV6))
427                 goto drop;
428
429         if (!pskb_may_pull(skb, sizeof(struct ipv6hdr)))
430                 goto drop;
431
432         /* The inner packet is not associated to any local interface,
433          * so we do not call netif_rx().
434          *
435          * If slwt->nh6 is set to ::, then lookup the nexthop for the
436          * inner packet's DA. Otherwise, use the specified nexthop.
437          */
438
439         if (!ipv6_addr_any(&slwt->nh6))
440                 nhaddr = &slwt->nh6;
441
442         skb_set_transport_header(skb, sizeof(struct ipv6hdr));
443
444         seg6_lookup_nexthop(skb, nhaddr, 0);
445
446         return dst_input(skb);
447 drop:
448         kfree_skb(skb);
449         return -EINVAL;
450 }
451
452 static int input_action_end_dx4(struct sk_buff *skb,
453                                 struct seg6_local_lwt *slwt)
454 {
455         struct iphdr *iph;
456         __be32 nhaddr;
457         int err;
458
459         if (!decap_and_validate(skb, IPPROTO_IPIP))
460                 goto drop;
461
462         if (!pskb_may_pull(skb, sizeof(struct iphdr)))
463                 goto drop;
464
465         skb->protocol = htons(ETH_P_IP);
466
467         iph = ip_hdr(skb);
468
469         nhaddr = slwt->nh4.s_addr ?: iph->daddr;
470
471         skb_dst_drop(skb);
472
473         skb_set_transport_header(skb, sizeof(struct iphdr));
474
475         err = ip_route_input(skb, nhaddr, iph->saddr, 0, skb->dev);
476         if (err)
477                 goto drop;
478
479         return dst_input(skb);
480
481 drop:
482         kfree_skb(skb);
483         return -EINVAL;
484 }
485
486 #ifdef CONFIG_NET_L3_MASTER_DEV
487 static struct net *fib6_config_get_net(const struct fib6_config *fib6_cfg)
488 {
489         const struct nl_info *nli = &fib6_cfg->fc_nlinfo;
490
491         return nli->nl_net;
492 }
493
494 static int __seg6_end_dt_vrf_build(struct seg6_local_lwt *slwt, const void *cfg,
495                                    u16 family, struct netlink_ext_ack *extack)
496 {
497         struct seg6_end_dt_info *info = &slwt->dt_info;
498         int vrf_ifindex;
499         struct net *net;
500
501         net = fib6_config_get_net(cfg);
502
503         /* note that vrf_table was already set by parse_nla_vrftable() */
504         vrf_ifindex = l3mdev_ifindex_lookup_by_table_id(L3MDEV_TYPE_VRF, net,
505                                                         info->vrf_table);
506         if (vrf_ifindex < 0) {
507                 if (vrf_ifindex == -EPERM) {
508                         NL_SET_ERR_MSG(extack,
509                                        "Strict mode for VRF is disabled");
510                 } else if (vrf_ifindex == -ENODEV) {
511                         NL_SET_ERR_MSG(extack,
512                                        "Table has no associated VRF device");
513                 } else {
514                         pr_debug("seg6local: SRv6 End.DT* creation error=%d\n",
515                                  vrf_ifindex);
516                 }
517
518                 return vrf_ifindex;
519         }
520
521         info->net = net;
522         info->vrf_ifindex = vrf_ifindex;
523
524         info->family = family;
525         info->mode = DT_VRF_MODE;
526
527         return 0;
528 }
529
530 /* The SRv6 End.DT4/DT6 behavior extracts the inner (IPv4/IPv6) packet and
531  * routes the IPv4/IPv6 packet by looking at the configured routing table.
532  *
533  * In the SRv6 End.DT4/DT6 use case, we can receive traffic (IPv6+Segment
534  * Routing Header packets) from several interfaces and the outer IPv6
535  * destination address (DA) is used for retrieving the specific instance of the
536  * End.DT4/DT6 behavior that should process the packets.
537  *
538  * However, the inner IPv4/IPv6 packet is not really bound to any receiving
539  * interface and thus the End.DT4/DT6 sets the VRF (associated with the
540  * corresponding routing table) as the *receiving* interface.
541  * In other words, the End.DT4/DT6 processes a packet as if it has been received
542  * directly by the VRF (and not by one of its slave devices, if any).
543  * In this way, the VRF interface is used for routing the IPv4/IPv6 packet in
544  * according to the routing table configured by the End.DT4/DT6 instance.
545  *
546  * This design allows you to get some interesting features like:
547  *  1) the statistics on rx packets;
548  *  2) the possibility to install a packet sniffer on the receiving interface
549  *     (the VRF one) for looking at the incoming packets;
550  *  3) the possibility to leverage the netfilter prerouting hook for the inner
551  *     IPv4 packet.
552  *
553  * This function returns:
554  *  - the sk_buff* when the VRF rcv handler has processed the packet correctly;
555  *  - NULL when the skb is consumed by the VRF rcv handler;
556  *  - a pointer which encodes a negative error number in case of error.
557  *    Note that in this case, the function takes care of freeing the skb.
558  */
559 static struct sk_buff *end_dt_vrf_rcv(struct sk_buff *skb, u16 family,
560                                       struct net_device *dev)
561 {
562         /* based on l3mdev_ip_rcv; we are only interested in the master */
563         if (unlikely(!netif_is_l3_master(dev) && !netif_has_l3_rx_handler(dev)))
564                 goto drop;
565
566         if (unlikely(!dev->l3mdev_ops->l3mdev_l3_rcv))
567                 goto drop;
568
569         /* the decap packet IPv4/IPv6 does not come with any mac header info.
570          * We must unset the mac header to allow the VRF device to rebuild it,
571          * just in case there is a sniffer attached on the device.
572          */
573         skb_unset_mac_header(skb);
574
575         skb = dev->l3mdev_ops->l3mdev_l3_rcv(dev, skb, family);
576         if (!skb)
577                 /* the skb buffer was consumed by the handler */
578                 return NULL;
579
580         /* when a packet is received by a VRF or by one of its slaves, the
581          * master device reference is set into the skb.
582          */
583         if (unlikely(skb->dev != dev || skb->skb_iif != dev->ifindex))
584                 goto drop;
585
586         return skb;
587
588 drop:
589         kfree_skb(skb);
590         return ERR_PTR(-EINVAL);
591 }
592
593 static struct net_device *end_dt_get_vrf_rcu(struct sk_buff *skb,
594                                              struct seg6_end_dt_info *info)
595 {
596         int vrf_ifindex = info->vrf_ifindex;
597         struct net *net = info->net;
598
599         if (unlikely(vrf_ifindex < 0))
600                 goto error;
601
602         if (unlikely(!net_eq(dev_net(skb->dev), net)))
603                 goto error;
604
605         return dev_get_by_index_rcu(net, vrf_ifindex);
606
607 error:
608         return NULL;
609 }
610
611 static struct sk_buff *end_dt_vrf_core(struct sk_buff *skb,
612                                        struct seg6_local_lwt *slwt, u16 family)
613 {
614         struct seg6_end_dt_info *info = &slwt->dt_info;
615         struct net_device *vrf;
616         __be16 protocol;
617         int hdrlen;
618
619         vrf = end_dt_get_vrf_rcu(skb, info);
620         if (unlikely(!vrf))
621                 goto drop;
622
623         switch (family) {
624         case AF_INET:
625                 protocol = htons(ETH_P_IP);
626                 hdrlen = sizeof(struct iphdr);
627                 break;
628         case AF_INET6:
629                 protocol = htons(ETH_P_IPV6);
630                 hdrlen = sizeof(struct ipv6hdr);
631                 break;
632         case AF_UNSPEC:
633                 fallthrough;
634         default:
635                 goto drop;
636         }
637
638         if (unlikely(info->family != AF_UNSPEC && info->family != family)) {
639                 pr_warn_once("seg6local: SRv6 End.DT* family mismatch");
640                 goto drop;
641         }
642
643         skb->protocol = protocol;
644
645         skb_dst_drop(skb);
646
647         skb_set_transport_header(skb, hdrlen);
648
649         return end_dt_vrf_rcv(skb, family, vrf);
650
651 drop:
652         kfree_skb(skb);
653         return ERR_PTR(-EINVAL);
654 }
655
656 static int input_action_end_dt4(struct sk_buff *skb,
657                                 struct seg6_local_lwt *slwt)
658 {
659         struct iphdr *iph;
660         int err;
661
662         if (!decap_and_validate(skb, IPPROTO_IPIP))
663                 goto drop;
664
665         if (!pskb_may_pull(skb, sizeof(struct iphdr)))
666                 goto drop;
667
668         skb = end_dt_vrf_core(skb, slwt, AF_INET);
669         if (!skb)
670                 /* packet has been processed and consumed by the VRF */
671                 return 0;
672
673         if (IS_ERR(skb))
674                 return PTR_ERR(skb);
675
676         iph = ip_hdr(skb);
677
678         err = ip_route_input(skb, iph->daddr, iph->saddr, 0, skb->dev);
679         if (unlikely(err))
680                 goto drop;
681
682         return dst_input(skb);
683
684 drop:
685         kfree_skb(skb);
686         return -EINVAL;
687 }
688
689 static int seg6_end_dt4_build(struct seg6_local_lwt *slwt, const void *cfg,
690                               struct netlink_ext_ack *extack)
691 {
692         return __seg6_end_dt_vrf_build(slwt, cfg, AF_INET, extack);
693 }
694
695 static enum
696 seg6_end_dt_mode seg6_end_dt6_parse_mode(struct seg6_local_lwt *slwt)
697 {
698         unsigned long parsed_optattrs = slwt->parsed_optattrs;
699         bool legacy, vrfmode;
700
701         legacy  = !!(parsed_optattrs & SEG6_F_ATTR(SEG6_LOCAL_TABLE));
702         vrfmode = !!(parsed_optattrs & SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE));
703
704         if (!(legacy ^ vrfmode))
705                 /* both are absent or present: invalid DT6 mode */
706                 return DT_INVALID_MODE;
707
708         return legacy ? DT_LEGACY_MODE : DT_VRF_MODE;
709 }
710
711 static enum seg6_end_dt_mode seg6_end_dt6_get_mode(struct seg6_local_lwt *slwt)
712 {
713         struct seg6_end_dt_info *info = &slwt->dt_info;
714
715         return info->mode;
716 }
717
718 static int seg6_end_dt6_build(struct seg6_local_lwt *slwt, const void *cfg,
719                               struct netlink_ext_ack *extack)
720 {
721         enum seg6_end_dt_mode mode = seg6_end_dt6_parse_mode(slwt);
722         struct seg6_end_dt_info *info = &slwt->dt_info;
723
724         switch (mode) {
725         case DT_LEGACY_MODE:
726                 info->mode = DT_LEGACY_MODE;
727                 return 0;
728         case DT_VRF_MODE:
729                 return __seg6_end_dt_vrf_build(slwt, cfg, AF_INET6, extack);
730         default:
731                 NL_SET_ERR_MSG(extack, "table or vrftable must be specified");
732                 return -EINVAL;
733         }
734 }
735 #endif
736
737 static int input_action_end_dt6(struct sk_buff *skb,
738                                 struct seg6_local_lwt *slwt)
739 {
740         if (!decap_and_validate(skb, IPPROTO_IPV6))
741                 goto drop;
742
743         if (!pskb_may_pull(skb, sizeof(struct ipv6hdr)))
744                 goto drop;
745
746 #ifdef CONFIG_NET_L3_MASTER_DEV
747         if (seg6_end_dt6_get_mode(slwt) == DT_LEGACY_MODE)
748                 goto legacy_mode;
749
750         /* DT6_VRF_MODE */
751         skb = end_dt_vrf_core(skb, slwt, AF_INET6);
752         if (!skb)
753                 /* packet has been processed and consumed by the VRF */
754                 return 0;
755
756         if (IS_ERR(skb))
757                 return PTR_ERR(skb);
758
759         /* note: this time we do not need to specify the table because the VRF
760          * takes care of selecting the correct table.
761          */
762         seg6_lookup_any_nexthop(skb, NULL, 0, true);
763
764         return dst_input(skb);
765
766 legacy_mode:
767 #endif
768         skb_set_transport_header(skb, sizeof(struct ipv6hdr));
769
770         seg6_lookup_any_nexthop(skb, NULL, slwt->table, true);
771
772         return dst_input(skb);
773
774 drop:
775         kfree_skb(skb);
776         return -EINVAL;
777 }
778
779 #ifdef CONFIG_NET_L3_MASTER_DEV
780 static int seg6_end_dt46_build(struct seg6_local_lwt *slwt, const void *cfg,
781                                struct netlink_ext_ack *extack)
782 {
783         return __seg6_end_dt_vrf_build(slwt, cfg, AF_UNSPEC, extack);
784 }
785
786 static int input_action_end_dt46(struct sk_buff *skb,
787                                  struct seg6_local_lwt *slwt)
788 {
789         unsigned int off = 0;
790         int nexthdr;
791
792         nexthdr = ipv6_find_hdr(skb, &off, -1, NULL, NULL);
793         if (unlikely(nexthdr < 0))
794                 goto drop;
795
796         switch (nexthdr) {
797         case IPPROTO_IPIP:
798                 return input_action_end_dt4(skb, slwt);
799         case IPPROTO_IPV6:
800                 return input_action_end_dt6(skb, slwt);
801         }
802
803 drop:
804         kfree_skb(skb);
805         return -EINVAL;
806 }
807 #endif
808
809 /* push an SRH on top of the current one */
810 static int input_action_end_b6(struct sk_buff *skb, struct seg6_local_lwt *slwt)
811 {
812         struct ipv6_sr_hdr *srh;
813         int err = -EINVAL;
814
815         srh = get_and_validate_srh(skb);
816         if (!srh)
817                 goto drop;
818
819         err = seg6_do_srh_inline(skb, slwt->srh);
820         if (err)
821                 goto drop;
822
823         ipv6_hdr(skb)->payload_len = htons(skb->len - sizeof(struct ipv6hdr));
824         skb_set_transport_header(skb, sizeof(struct ipv6hdr));
825
826         seg6_lookup_nexthop(skb, NULL, 0);
827
828         return dst_input(skb);
829
830 drop:
831         kfree_skb(skb);
832         return err;
833 }
834
835 /* encapsulate within an outer IPv6 header and a specified SRH */
836 static int input_action_end_b6_encap(struct sk_buff *skb,
837                                      struct seg6_local_lwt *slwt)
838 {
839         struct ipv6_sr_hdr *srh;
840         int err = -EINVAL;
841
842         srh = get_and_validate_srh(skb);
843         if (!srh)
844                 goto drop;
845
846         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
847
848         skb_reset_inner_headers(skb);
849         skb->encapsulation = 1;
850
851         err = seg6_do_srh_encap(skb, slwt->srh, IPPROTO_IPV6);
852         if (err)
853                 goto drop;
854
855         ipv6_hdr(skb)->payload_len = htons(skb->len - sizeof(struct ipv6hdr));
856         skb_set_transport_header(skb, sizeof(struct ipv6hdr));
857
858         seg6_lookup_nexthop(skb, NULL, 0);
859
860         return dst_input(skb);
861
862 drop:
863         kfree_skb(skb);
864         return err;
865 }
866
867 DEFINE_PER_CPU(struct seg6_bpf_srh_state, seg6_bpf_srh_states);
868
869 bool seg6_bpf_has_valid_srh(struct sk_buff *skb)
870 {
871         struct seg6_bpf_srh_state *srh_state =
872                 this_cpu_ptr(&seg6_bpf_srh_states);
873         struct ipv6_sr_hdr *srh = srh_state->srh;
874
875         if (unlikely(srh == NULL))
876                 return false;
877
878         if (unlikely(!srh_state->valid)) {
879                 if ((srh_state->hdrlen & 7) != 0)
880                         return false;
881
882                 srh->hdrlen = (u8)(srh_state->hdrlen >> 3);
883                 if (!seg6_validate_srh(srh, (srh->hdrlen + 1) << 3, true))
884                         return false;
885
886                 srh_state->valid = true;
887         }
888
889         return true;
890 }
891
892 static int input_action_end_bpf(struct sk_buff *skb,
893                                 struct seg6_local_lwt *slwt)
894 {
895         struct seg6_bpf_srh_state *srh_state =
896                 this_cpu_ptr(&seg6_bpf_srh_states);
897         struct ipv6_sr_hdr *srh;
898         int ret;
899
900         srh = get_and_validate_srh(skb);
901         if (!srh) {
902                 kfree_skb(skb);
903                 return -EINVAL;
904         }
905         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
906
907         /* preempt_disable is needed to protect the per-CPU buffer srh_state,
908          * which is also accessed by the bpf_lwt_seg6_* helpers
909          */
910         preempt_disable();
911         srh_state->srh = srh;
912         srh_state->hdrlen = srh->hdrlen << 3;
913         srh_state->valid = true;
914
915         rcu_read_lock();
916         bpf_compute_data_pointers(skb);
917         ret = bpf_prog_run_save_cb(slwt->bpf.prog, skb);
918         rcu_read_unlock();
919
920         switch (ret) {
921         case BPF_OK:
922         case BPF_REDIRECT:
923                 break;
924         case BPF_DROP:
925                 goto drop;
926         default:
927                 pr_warn_once("bpf-seg6local: Illegal return value %u\n", ret);
928                 goto drop;
929         }
930
931         if (srh_state->srh && !seg6_bpf_has_valid_srh(skb))
932                 goto drop;
933
934         preempt_enable();
935         if (ret != BPF_REDIRECT)
936                 seg6_lookup_nexthop(skb, NULL, 0);
937
938         return dst_input(skb);
939
940 drop:
941         preempt_enable();
942         kfree_skb(skb);
943         return -EINVAL;
944 }
945
946 static struct seg6_action_desc seg6_action_table[] = {
947         {
948                 .action         = SEG6_LOCAL_ACTION_END,
949                 .attrs          = 0,
950                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
951                 .input          = input_action_end,
952         },
953         {
954                 .action         = SEG6_LOCAL_ACTION_END_X,
955                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_NH6),
956                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
957                 .input          = input_action_end_x,
958         },
959         {
960                 .action         = SEG6_LOCAL_ACTION_END_T,
961                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_TABLE),
962                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
963                 .input          = input_action_end_t,
964         },
965         {
966                 .action         = SEG6_LOCAL_ACTION_END_DX2,
967                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_OIF),
968                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
969                 .input          = input_action_end_dx2,
970         },
971         {
972                 .action         = SEG6_LOCAL_ACTION_END_DX6,
973                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_NH6),
974                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
975                 .input          = input_action_end_dx6,
976         },
977         {
978                 .action         = SEG6_LOCAL_ACTION_END_DX4,
979                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_NH4),
980                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
981                 .input          = input_action_end_dx4,
982         },
983         {
984                 .action         = SEG6_LOCAL_ACTION_END_DT4,
985                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE),
986                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
987 #ifdef CONFIG_NET_L3_MASTER_DEV
988                 .input          = input_action_end_dt4,
989                 .slwt_ops       = {
990                                         .build_state = seg6_end_dt4_build,
991                                   },
992 #endif
993         },
994         {
995                 .action         = SEG6_LOCAL_ACTION_END_DT6,
996 #ifdef CONFIG_NET_L3_MASTER_DEV
997                 .attrs          = 0,
998                 .optattrs       = SEG6_F_LOCAL_COUNTERS         |
999                                   SEG6_F_ATTR(SEG6_LOCAL_TABLE) |
1000                                   SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE),
1001                 .slwt_ops       = {
1002                                         .build_state = seg6_end_dt6_build,
1003                                   },
1004 #else
1005                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_TABLE),
1006                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
1007 #endif
1008                 .input          = input_action_end_dt6,
1009         },
1010         {
1011                 .action         = SEG6_LOCAL_ACTION_END_DT46,
1012                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE),
1013                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
1014 #ifdef CONFIG_NET_L3_MASTER_DEV
1015                 .input          = input_action_end_dt46,
1016                 .slwt_ops       = {
1017                                         .build_state = seg6_end_dt46_build,
1018                                   },
1019 #endif
1020         },
1021         {
1022                 .action         = SEG6_LOCAL_ACTION_END_B6,
1023                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_SRH),
1024                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
1025                 .input          = input_action_end_b6,
1026         },
1027         {
1028                 .action         = SEG6_LOCAL_ACTION_END_B6_ENCAP,
1029                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_SRH),
1030                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
1031                 .input          = input_action_end_b6_encap,
1032                 .static_headroom        = sizeof(struct ipv6hdr),
1033         },
1034         {
1035                 .action         = SEG6_LOCAL_ACTION_END_BPF,
1036                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_BPF),
1037                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
1038                 .input          = input_action_end_bpf,
1039         },
1040
1041 };
1042
1043 static struct seg6_action_desc *__get_action_desc(int action)
1044 {
1045         struct seg6_action_desc *desc;
1046         int i, count;
1047
1048         count = ARRAY_SIZE(seg6_action_table);
1049         for (i = 0; i < count; i++) {
1050                 desc = &seg6_action_table[i];
1051                 if (desc->action == action)
1052                         return desc;
1053         }
1054
1055         return NULL;
1056 }
1057
1058 static bool seg6_lwtunnel_counters_enabled(struct seg6_local_lwt *slwt)
1059 {
1060         return slwt->parsed_optattrs & SEG6_F_LOCAL_COUNTERS;
1061 }
1062
1063 static void seg6_local_update_counters(struct seg6_local_lwt *slwt,
1064                                        unsigned int len, int err)
1065 {
1066         struct pcpu_seg6_local_counters *pcounters;
1067
1068         pcounters = this_cpu_ptr(slwt->pcpu_counters);
1069         u64_stats_update_begin(&pcounters->syncp);
1070
1071         if (likely(!err)) {
1072                 u64_stats_inc(&pcounters->packets);
1073                 u64_stats_add(&pcounters->bytes, len);
1074         } else {
1075                 u64_stats_inc(&pcounters->errors);
1076         }
1077
1078         u64_stats_update_end(&pcounters->syncp);
1079 }
1080
1081 static int seg6_local_input(struct sk_buff *skb)
1082 {
1083         struct dst_entry *orig_dst = skb_dst(skb);
1084         struct seg6_action_desc *desc;
1085         struct seg6_local_lwt *slwt;
1086         unsigned int len = skb->len;
1087         int rc;
1088
1089         if (skb->protocol != htons(ETH_P_IPV6)) {
1090                 kfree_skb(skb);
1091                 return -EINVAL;
1092         }
1093
1094         slwt = seg6_local_lwtunnel(orig_dst->lwtstate);
1095         desc = slwt->desc;
1096
1097         rc = desc->input(skb, slwt);
1098
1099         if (!seg6_lwtunnel_counters_enabled(slwt))
1100                 return rc;
1101
1102         seg6_local_update_counters(slwt, len, rc);
1103
1104         return rc;
1105 }
1106
1107 static const struct nla_policy seg6_local_policy[SEG6_LOCAL_MAX + 1] = {
1108         [SEG6_LOCAL_ACTION]     = { .type = NLA_U32 },
1109         [SEG6_LOCAL_SRH]        = { .type = NLA_BINARY },
1110         [SEG6_LOCAL_TABLE]      = { .type = NLA_U32 },
1111         [SEG6_LOCAL_VRFTABLE]   = { .type = NLA_U32 },
1112         [SEG6_LOCAL_NH4]        = { .type = NLA_BINARY,
1113                                     .len = sizeof(struct in_addr) },
1114         [SEG6_LOCAL_NH6]        = { .type = NLA_BINARY,
1115                                     .len = sizeof(struct in6_addr) },
1116         [SEG6_LOCAL_IIF]        = { .type = NLA_U32 },
1117         [SEG6_LOCAL_OIF]        = { .type = NLA_U32 },
1118         [SEG6_LOCAL_BPF]        = { .type = NLA_NESTED },
1119         [SEG6_LOCAL_COUNTERS]   = { .type = NLA_NESTED },
1120 };
1121
1122 static int parse_nla_srh(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1123 {
1124         struct ipv6_sr_hdr *srh;
1125         int len;
1126
1127         srh = nla_data(attrs[SEG6_LOCAL_SRH]);
1128         len = nla_len(attrs[SEG6_LOCAL_SRH]);
1129
1130         /* SRH must contain at least one segment */
1131         if (len < sizeof(*srh) + sizeof(struct in6_addr))
1132                 return -EINVAL;
1133
1134         if (!seg6_validate_srh(srh, len, false))
1135                 return -EINVAL;
1136
1137         slwt->srh = kmemdup(srh, len, GFP_KERNEL);
1138         if (!slwt->srh)
1139                 return -ENOMEM;
1140
1141         slwt->headroom += len;
1142
1143         return 0;
1144 }
1145
1146 static int put_nla_srh(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1147 {
1148         struct ipv6_sr_hdr *srh;
1149         struct nlattr *nla;
1150         int len;
1151
1152         srh = slwt->srh;
1153         len = (srh->hdrlen + 1) << 3;
1154
1155         nla = nla_reserve(skb, SEG6_LOCAL_SRH, len);
1156         if (!nla)
1157                 return -EMSGSIZE;
1158
1159         memcpy(nla_data(nla), srh, len);
1160
1161         return 0;
1162 }
1163
1164 static int cmp_nla_srh(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1165 {
1166         int len = (a->srh->hdrlen + 1) << 3;
1167
1168         if (len != ((b->srh->hdrlen + 1) << 3))
1169                 return 1;
1170
1171         return memcmp(a->srh, b->srh, len);
1172 }
1173
1174 static void destroy_attr_srh(struct seg6_local_lwt *slwt)
1175 {
1176         kfree(slwt->srh);
1177 }
1178
1179 static int parse_nla_table(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1180 {
1181         slwt->table = nla_get_u32(attrs[SEG6_LOCAL_TABLE]);
1182
1183         return 0;
1184 }
1185
1186 static int put_nla_table(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1187 {
1188         if (nla_put_u32(skb, SEG6_LOCAL_TABLE, slwt->table))
1189                 return -EMSGSIZE;
1190
1191         return 0;
1192 }
1193
1194 static int cmp_nla_table(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1195 {
1196         if (a->table != b->table)
1197                 return 1;
1198
1199         return 0;
1200 }
1201
1202 static struct
1203 seg6_end_dt_info *seg6_possible_end_dt_info(struct seg6_local_lwt *slwt)
1204 {
1205 #ifdef CONFIG_NET_L3_MASTER_DEV
1206         return &slwt->dt_info;
1207 #else
1208         return ERR_PTR(-EOPNOTSUPP);
1209 #endif
1210 }
1211
1212 static int parse_nla_vrftable(struct nlattr **attrs,
1213                               struct seg6_local_lwt *slwt)
1214 {
1215         struct seg6_end_dt_info *info = seg6_possible_end_dt_info(slwt);
1216
1217         if (IS_ERR(info))
1218                 return PTR_ERR(info);
1219
1220         info->vrf_table = nla_get_u32(attrs[SEG6_LOCAL_VRFTABLE]);
1221
1222         return 0;
1223 }
1224
1225 static int put_nla_vrftable(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1226 {
1227         struct seg6_end_dt_info *info = seg6_possible_end_dt_info(slwt);
1228
1229         if (IS_ERR(info))
1230                 return PTR_ERR(info);
1231
1232         if (nla_put_u32(skb, SEG6_LOCAL_VRFTABLE, info->vrf_table))
1233                 return -EMSGSIZE;
1234
1235         return 0;
1236 }
1237
1238 static int cmp_nla_vrftable(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1239 {
1240         struct seg6_end_dt_info *info_a = seg6_possible_end_dt_info(a);
1241         struct seg6_end_dt_info *info_b = seg6_possible_end_dt_info(b);
1242
1243         if (info_a->vrf_table != info_b->vrf_table)
1244                 return 1;
1245
1246         return 0;
1247 }
1248
1249 static int parse_nla_nh4(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1250 {
1251         memcpy(&slwt->nh4, nla_data(attrs[SEG6_LOCAL_NH4]),
1252                sizeof(struct in_addr));
1253
1254         return 0;
1255 }
1256
1257 static int put_nla_nh4(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1258 {
1259         struct nlattr *nla;
1260
1261         nla = nla_reserve(skb, SEG6_LOCAL_NH4, sizeof(struct in_addr));
1262         if (!nla)
1263                 return -EMSGSIZE;
1264
1265         memcpy(nla_data(nla), &slwt->nh4, sizeof(struct in_addr));
1266
1267         return 0;
1268 }
1269
1270 static int cmp_nla_nh4(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1271 {
1272         return memcmp(&a->nh4, &b->nh4, sizeof(struct in_addr));
1273 }
1274
1275 static int parse_nla_nh6(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1276 {
1277         memcpy(&slwt->nh6, nla_data(attrs[SEG6_LOCAL_NH6]),
1278                sizeof(struct in6_addr));
1279
1280         return 0;
1281 }
1282
1283 static int put_nla_nh6(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1284 {
1285         struct nlattr *nla;
1286
1287         nla = nla_reserve(skb, SEG6_LOCAL_NH6, sizeof(struct in6_addr));
1288         if (!nla)
1289                 return -EMSGSIZE;
1290
1291         memcpy(nla_data(nla), &slwt->nh6, sizeof(struct in6_addr));
1292
1293         return 0;
1294 }
1295
1296 static int cmp_nla_nh6(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1297 {
1298         return memcmp(&a->nh6, &b->nh6, sizeof(struct in6_addr));
1299 }
1300
1301 static int parse_nla_iif(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1302 {
1303         slwt->iif = nla_get_u32(attrs[SEG6_LOCAL_IIF]);
1304
1305         return 0;
1306 }
1307
1308 static int put_nla_iif(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1309 {
1310         if (nla_put_u32(skb, SEG6_LOCAL_IIF, slwt->iif))
1311                 return -EMSGSIZE;
1312
1313         return 0;
1314 }
1315
1316 static int cmp_nla_iif(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1317 {
1318         if (a->iif != b->iif)
1319                 return 1;
1320
1321         return 0;
1322 }
1323
1324 static int parse_nla_oif(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1325 {
1326         slwt->oif = nla_get_u32(attrs[SEG6_LOCAL_OIF]);
1327
1328         return 0;
1329 }
1330
1331 static int put_nla_oif(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1332 {
1333         if (nla_put_u32(skb, SEG6_LOCAL_OIF, slwt->oif))
1334                 return -EMSGSIZE;
1335
1336         return 0;
1337 }
1338
1339 static int cmp_nla_oif(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1340 {
1341         if (a->oif != b->oif)
1342                 return 1;
1343
1344         return 0;
1345 }
1346
1347 #define MAX_PROG_NAME 256
1348 static const struct nla_policy bpf_prog_policy[SEG6_LOCAL_BPF_PROG_MAX + 1] = {
1349         [SEG6_LOCAL_BPF_PROG]      = { .type = NLA_U32, },
1350         [SEG6_LOCAL_BPF_PROG_NAME] = { .type = NLA_NUL_STRING,
1351                                        .len = MAX_PROG_NAME },
1352 };
1353
1354 static int parse_nla_bpf(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1355 {
1356         struct nlattr *tb[SEG6_LOCAL_BPF_PROG_MAX + 1];
1357         struct bpf_prog *p;
1358         int ret;
1359         u32 fd;
1360
1361         ret = nla_parse_nested_deprecated(tb, SEG6_LOCAL_BPF_PROG_MAX,
1362                                           attrs[SEG6_LOCAL_BPF],
1363                                           bpf_prog_policy, NULL);
1364         if (ret < 0)
1365                 return ret;
1366
1367         if (!tb[SEG6_LOCAL_BPF_PROG] || !tb[SEG6_LOCAL_BPF_PROG_NAME])
1368                 return -EINVAL;
1369
1370         slwt->bpf.name = nla_memdup(tb[SEG6_LOCAL_BPF_PROG_NAME], GFP_KERNEL);
1371         if (!slwt->bpf.name)
1372                 return -ENOMEM;
1373
1374         fd = nla_get_u32(tb[SEG6_LOCAL_BPF_PROG]);
1375         p = bpf_prog_get_type(fd, BPF_PROG_TYPE_LWT_SEG6LOCAL);
1376         if (IS_ERR(p)) {
1377                 kfree(slwt->bpf.name);
1378                 return PTR_ERR(p);
1379         }
1380
1381         slwt->bpf.prog = p;
1382         return 0;
1383 }
1384
1385 static int put_nla_bpf(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1386 {
1387         struct nlattr *nest;
1388
1389         if (!slwt->bpf.prog)
1390                 return 0;
1391
1392         nest = nla_nest_start_noflag(skb, SEG6_LOCAL_BPF);
1393         if (!nest)
1394                 return -EMSGSIZE;
1395
1396         if (nla_put_u32(skb, SEG6_LOCAL_BPF_PROG, slwt->bpf.prog->aux->id))
1397                 return -EMSGSIZE;
1398
1399         if (slwt->bpf.name &&
1400             nla_put_string(skb, SEG6_LOCAL_BPF_PROG_NAME, slwt->bpf.name))
1401                 return -EMSGSIZE;
1402
1403         return nla_nest_end(skb, nest);
1404 }
1405
1406 static int cmp_nla_bpf(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1407 {
1408         if (!a->bpf.name && !b->bpf.name)
1409                 return 0;
1410
1411         if (!a->bpf.name || !b->bpf.name)
1412                 return 1;
1413
1414         return strcmp(a->bpf.name, b->bpf.name);
1415 }
1416
1417 static void destroy_attr_bpf(struct seg6_local_lwt *slwt)
1418 {
1419         kfree(slwt->bpf.name);
1420         if (slwt->bpf.prog)
1421                 bpf_prog_put(slwt->bpf.prog);
1422 }
1423
1424 static const struct
1425 nla_policy seg6_local_counters_policy[SEG6_LOCAL_CNT_MAX + 1] = {
1426         [SEG6_LOCAL_CNT_PACKETS]        = { .type = NLA_U64 },
1427         [SEG6_LOCAL_CNT_BYTES]          = { .type = NLA_U64 },
1428         [SEG6_LOCAL_CNT_ERRORS]         = { .type = NLA_U64 },
1429 };
1430
1431 static int parse_nla_counters(struct nlattr **attrs,
1432                               struct seg6_local_lwt *slwt)
1433 {
1434         struct pcpu_seg6_local_counters __percpu *pcounters;
1435         struct nlattr *tb[SEG6_LOCAL_CNT_MAX + 1];
1436         int ret;
1437
1438         ret = nla_parse_nested_deprecated(tb, SEG6_LOCAL_CNT_MAX,
1439                                           attrs[SEG6_LOCAL_COUNTERS],
1440                                           seg6_local_counters_policy, NULL);
1441         if (ret < 0)
1442                 return ret;
1443
1444         /* basic support for SRv6 Behavior counters requires at least:
1445          * packets, bytes and errors.
1446          */
1447         if (!tb[SEG6_LOCAL_CNT_PACKETS] || !tb[SEG6_LOCAL_CNT_BYTES] ||
1448             !tb[SEG6_LOCAL_CNT_ERRORS])
1449                 return -EINVAL;
1450
1451         /* counters are always zero initialized */
1452         pcounters = seg6_local_alloc_pcpu_counters(GFP_KERNEL);
1453         if (!pcounters)
1454                 return -ENOMEM;
1455
1456         slwt->pcpu_counters = pcounters;
1457
1458         return 0;
1459 }
1460
1461 static int seg6_local_fill_nla_counters(struct sk_buff *skb,
1462                                         struct seg6_local_counters *counters)
1463 {
1464         if (nla_put_u64_64bit(skb, SEG6_LOCAL_CNT_PACKETS, counters->packets,
1465                               SEG6_LOCAL_CNT_PAD))
1466                 return -EMSGSIZE;
1467
1468         if (nla_put_u64_64bit(skb, SEG6_LOCAL_CNT_BYTES, counters->bytes,
1469                               SEG6_LOCAL_CNT_PAD))
1470                 return -EMSGSIZE;
1471
1472         if (nla_put_u64_64bit(skb, SEG6_LOCAL_CNT_ERRORS, counters->errors,
1473                               SEG6_LOCAL_CNT_PAD))
1474                 return -EMSGSIZE;
1475
1476         return 0;
1477 }
1478
1479 static int put_nla_counters(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1480 {
1481         struct seg6_local_counters counters = { 0, 0, 0 };
1482         struct nlattr *nest;
1483         int rc, i;
1484
1485         nest = nla_nest_start(skb, SEG6_LOCAL_COUNTERS);
1486         if (!nest)
1487                 return -EMSGSIZE;
1488
1489         for_each_possible_cpu(i) {
1490                 struct pcpu_seg6_local_counters *pcounters;
1491                 u64 packets, bytes, errors;
1492                 unsigned int start;
1493
1494                 pcounters = per_cpu_ptr(slwt->pcpu_counters, i);
1495                 do {
1496                         start = u64_stats_fetch_begin_irq(&pcounters->syncp);
1497
1498                         packets = u64_stats_read(&pcounters->packets);
1499                         bytes = u64_stats_read(&pcounters->bytes);
1500                         errors = u64_stats_read(&pcounters->errors);
1501
1502                 } while (u64_stats_fetch_retry_irq(&pcounters->syncp, start));
1503
1504                 counters.packets += packets;
1505                 counters.bytes += bytes;
1506                 counters.errors += errors;
1507         }
1508
1509         rc = seg6_local_fill_nla_counters(skb, &counters);
1510         if (rc < 0) {
1511                 nla_nest_cancel(skb, nest);
1512                 return rc;
1513         }
1514
1515         return nla_nest_end(skb, nest);
1516 }
1517
1518 static int cmp_nla_counters(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1519 {
1520         /* a and b are equal if both have pcpu_counters set or not */
1521         return (!!((unsigned long)a->pcpu_counters)) ^
1522                 (!!((unsigned long)b->pcpu_counters));
1523 }
1524
1525 static void destroy_attr_counters(struct seg6_local_lwt *slwt)
1526 {
1527         free_percpu(slwt->pcpu_counters);
1528 }
1529
1530 struct seg6_action_param {
1531         int (*parse)(struct nlattr **attrs, struct seg6_local_lwt *slwt);
1532         int (*put)(struct sk_buff *skb, struct seg6_local_lwt *slwt);
1533         int (*cmp)(struct seg6_local_lwt *a, struct seg6_local_lwt *b);
1534
1535         /* optional destroy() callback useful for releasing resources which
1536          * have been previously acquired in the corresponding parse()
1537          * function.
1538          */
1539         void (*destroy)(struct seg6_local_lwt *slwt);
1540 };
1541
1542 static struct seg6_action_param seg6_action_params[SEG6_LOCAL_MAX + 1] = {
1543         [SEG6_LOCAL_SRH]        = { .parse = parse_nla_srh,
1544                                     .put = put_nla_srh,
1545                                     .cmp = cmp_nla_srh,
1546                                     .destroy = destroy_attr_srh },
1547
1548         [SEG6_LOCAL_TABLE]      = { .parse = parse_nla_table,
1549                                     .put = put_nla_table,
1550                                     .cmp = cmp_nla_table },
1551
1552         [SEG6_LOCAL_NH4]        = { .parse = parse_nla_nh4,
1553                                     .put = put_nla_nh4,
1554                                     .cmp = cmp_nla_nh4 },
1555
1556         [SEG6_LOCAL_NH6]        = { .parse = parse_nla_nh6,
1557                                     .put = put_nla_nh6,
1558                                     .cmp = cmp_nla_nh6 },
1559
1560         [SEG6_LOCAL_IIF]        = { .parse = parse_nla_iif,
1561                                     .put = put_nla_iif,
1562                                     .cmp = cmp_nla_iif },
1563
1564         [SEG6_LOCAL_OIF]        = { .parse = parse_nla_oif,
1565                                     .put = put_nla_oif,
1566                                     .cmp = cmp_nla_oif },
1567
1568         [SEG6_LOCAL_BPF]        = { .parse = parse_nla_bpf,
1569                                     .put = put_nla_bpf,
1570                                     .cmp = cmp_nla_bpf,
1571                                     .destroy = destroy_attr_bpf },
1572
1573         [SEG6_LOCAL_VRFTABLE]   = { .parse = parse_nla_vrftable,
1574                                     .put = put_nla_vrftable,
1575                                     .cmp = cmp_nla_vrftable },
1576
1577         [SEG6_LOCAL_COUNTERS]   = { .parse = parse_nla_counters,
1578                                     .put = put_nla_counters,
1579                                     .cmp = cmp_nla_counters,
1580                                     .destroy = destroy_attr_counters },
1581 };
1582
1583 /* call the destroy() callback (if available) for each set attribute in
1584  * @parsed_attrs, starting from the first attribute up to the @max_parsed
1585  * (excluded) attribute.
1586  */
1587 static void __destroy_attrs(unsigned long parsed_attrs, int max_parsed,
1588                             struct seg6_local_lwt *slwt)
1589 {
1590         struct seg6_action_param *param;
1591         int i;
1592
1593         /* Every required seg6local attribute is identified by an ID which is
1594          * encoded as a flag (i.e: 1 << ID) in the 'attrs' bitmask;
1595          *
1596          * We scan the 'parsed_attrs' bitmask, starting from the first attribute
1597          * up to the @max_parsed (excluded) attribute.
1598          * For each set attribute, we retrieve the corresponding destroy()
1599          * callback. If the callback is not available, then we skip to the next
1600          * attribute; otherwise, we call the destroy() callback.
1601          */
1602         for (i = 0; i < max_parsed; ++i) {
1603                 if (!(parsed_attrs & SEG6_F_ATTR(i)))
1604                         continue;
1605
1606                 param = &seg6_action_params[i];
1607
1608                 if (param->destroy)
1609                         param->destroy(slwt);
1610         }
1611 }
1612
1613 /* release all the resources that may have been acquired during parsing
1614  * operations.
1615  */
1616 static void destroy_attrs(struct seg6_local_lwt *slwt)
1617 {
1618         unsigned long attrs = slwt->desc->attrs | slwt->parsed_optattrs;
1619
1620         __destroy_attrs(attrs, SEG6_LOCAL_MAX + 1, slwt);
1621 }
1622
1623 static int parse_nla_optional_attrs(struct nlattr **attrs,
1624                                     struct seg6_local_lwt *slwt)
1625 {
1626         struct seg6_action_desc *desc = slwt->desc;
1627         unsigned long parsed_optattrs = 0;
1628         struct seg6_action_param *param;
1629         int err, i;
1630
1631         for (i = 0; i < SEG6_LOCAL_MAX + 1; ++i) {
1632                 if (!(desc->optattrs & SEG6_F_ATTR(i)) || !attrs[i])
1633                         continue;
1634
1635                 /* once here, the i-th attribute is provided by the
1636                  * userspace AND it is identified optional as well.
1637                  */
1638                 param = &seg6_action_params[i];
1639
1640                 err = param->parse(attrs, slwt);
1641                 if (err < 0)
1642                         goto parse_optattrs_err;
1643
1644                 /* current attribute has been correctly parsed */
1645                 parsed_optattrs |= SEG6_F_ATTR(i);
1646         }
1647
1648         /* store in the tunnel state all the optional attributed successfully
1649          * parsed.
1650          */
1651         slwt->parsed_optattrs = parsed_optattrs;
1652
1653         return 0;
1654
1655 parse_optattrs_err:
1656         __destroy_attrs(parsed_optattrs, i, slwt);
1657
1658         return err;
1659 }
1660
1661 /* call the custom constructor of the behavior during its initialization phase
1662  * and after that all its attributes have been parsed successfully.
1663  */
1664 static int
1665 seg6_local_lwtunnel_build_state(struct seg6_local_lwt *slwt, const void *cfg,
1666                                 struct netlink_ext_ack *extack)
1667 {
1668         struct seg6_action_desc *desc = slwt->desc;
1669         struct seg6_local_lwtunnel_ops *ops;
1670
1671         ops = &desc->slwt_ops;
1672         if (!ops->build_state)
1673                 return 0;
1674
1675         return ops->build_state(slwt, cfg, extack);
1676 }
1677
1678 /* call the custom destructor of the behavior which is invoked before the
1679  * tunnel is going to be destroyed.
1680  */
1681 static void seg6_local_lwtunnel_destroy_state(struct seg6_local_lwt *slwt)
1682 {
1683         struct seg6_action_desc *desc = slwt->desc;
1684         struct seg6_local_lwtunnel_ops *ops;
1685
1686         ops = &desc->slwt_ops;
1687         if (!ops->destroy_state)
1688                 return;
1689
1690         ops->destroy_state(slwt);
1691 }
1692
1693 static int parse_nla_action(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1694 {
1695         struct seg6_action_param *param;
1696         struct seg6_action_desc *desc;
1697         unsigned long invalid_attrs;
1698         int i, err;
1699
1700         desc = __get_action_desc(slwt->action);
1701         if (!desc)
1702                 return -EINVAL;
1703
1704         if (!desc->input)
1705                 return -EOPNOTSUPP;
1706
1707         slwt->desc = desc;
1708         slwt->headroom += desc->static_headroom;
1709
1710         /* Forcing the desc->optattrs *set* and the desc->attrs *set* to be
1711          * disjoined, this allow us to release acquired resources by optional
1712          * attributes and by required attributes independently from each other
1713          * without any interference.
1714          * In other terms, we are sure that we do not release some the acquired
1715          * resources twice.
1716          *
1717          * Note that if an attribute is configured both as required and as
1718          * optional, it means that the user has messed something up in the
1719          * seg6_action_table. Therefore, this check is required for SRv6
1720          * behaviors to work properly.
1721          */
1722         invalid_attrs = desc->attrs & desc->optattrs;
1723         if (invalid_attrs) {
1724                 WARN_ONCE(1,
1725                           "An attribute cannot be both required AND optional");
1726                 return -EINVAL;
1727         }
1728
1729         /* parse the required attributes */
1730         for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
1731                 if (desc->attrs & SEG6_F_ATTR(i)) {
1732                         if (!attrs[i])
1733                                 return -EINVAL;
1734
1735                         param = &seg6_action_params[i];
1736
1737                         err = param->parse(attrs, slwt);
1738                         if (err < 0)
1739                                 goto parse_attrs_err;
1740                 }
1741         }
1742
1743         /* parse the optional attributes, if any */
1744         err = parse_nla_optional_attrs(attrs, slwt);
1745         if (err < 0)
1746                 goto parse_attrs_err;
1747
1748         return 0;
1749
1750 parse_attrs_err:
1751         /* release any resource that may have been acquired during the i-1
1752          * parse() operations.
1753          */
1754         __destroy_attrs(desc->attrs, i, slwt);
1755
1756         return err;
1757 }
1758
1759 static int seg6_local_build_state(struct net *net, struct nlattr *nla,
1760                                   unsigned int family, const void *cfg,
1761                                   struct lwtunnel_state **ts,
1762                                   struct netlink_ext_ack *extack)
1763 {
1764         struct nlattr *tb[SEG6_LOCAL_MAX + 1];
1765         struct lwtunnel_state *newts;
1766         struct seg6_local_lwt *slwt;
1767         int err;
1768
1769         if (family != AF_INET6)
1770                 return -EINVAL;
1771
1772         err = nla_parse_nested_deprecated(tb, SEG6_LOCAL_MAX, nla,
1773                                           seg6_local_policy, extack);
1774
1775         if (err < 0)
1776                 return err;
1777
1778         if (!tb[SEG6_LOCAL_ACTION])
1779                 return -EINVAL;
1780
1781         newts = lwtunnel_state_alloc(sizeof(*slwt));
1782         if (!newts)
1783                 return -ENOMEM;
1784
1785         slwt = seg6_local_lwtunnel(newts);
1786         slwt->action = nla_get_u32(tb[SEG6_LOCAL_ACTION]);
1787
1788         err = parse_nla_action(tb, slwt);
1789         if (err < 0)
1790                 goto out_free;
1791
1792         err = seg6_local_lwtunnel_build_state(slwt, cfg, extack);
1793         if (err < 0)
1794                 goto out_destroy_attrs;
1795
1796         newts->type = LWTUNNEL_ENCAP_SEG6_LOCAL;
1797         newts->flags = LWTUNNEL_STATE_INPUT_REDIRECT;
1798         newts->headroom = slwt->headroom;
1799
1800         *ts = newts;
1801
1802         return 0;
1803
1804 out_destroy_attrs:
1805         destroy_attrs(slwt);
1806 out_free:
1807         kfree(newts);
1808         return err;
1809 }
1810
1811 static void seg6_local_destroy_state(struct lwtunnel_state *lwt)
1812 {
1813         struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
1814
1815         seg6_local_lwtunnel_destroy_state(slwt);
1816
1817         destroy_attrs(slwt);
1818
1819         return;
1820 }
1821
1822 static int seg6_local_fill_encap(struct sk_buff *skb,
1823                                  struct lwtunnel_state *lwt)
1824 {
1825         struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
1826         struct seg6_action_param *param;
1827         unsigned long attrs;
1828         int i, err;
1829
1830         if (nla_put_u32(skb, SEG6_LOCAL_ACTION, slwt->action))
1831                 return -EMSGSIZE;
1832
1833         attrs = slwt->desc->attrs | slwt->parsed_optattrs;
1834
1835         for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
1836                 if (attrs & SEG6_F_ATTR(i)) {
1837                         param = &seg6_action_params[i];
1838                         err = param->put(skb, slwt);
1839                         if (err < 0)
1840                                 return err;
1841                 }
1842         }
1843
1844         return 0;
1845 }
1846
1847 static int seg6_local_get_encap_size(struct lwtunnel_state *lwt)
1848 {
1849         struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
1850         unsigned long attrs;
1851         int nlsize;
1852
1853         nlsize = nla_total_size(4); /* action */
1854
1855         attrs = slwt->desc->attrs | slwt->parsed_optattrs;
1856
1857         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_SRH))
1858                 nlsize += nla_total_size((slwt->srh->hdrlen + 1) << 3);
1859
1860         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_TABLE))
1861                 nlsize += nla_total_size(4);
1862
1863         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_NH4))
1864                 nlsize += nla_total_size(4);
1865
1866         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_NH6))
1867                 nlsize += nla_total_size(16);
1868
1869         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_IIF))
1870                 nlsize += nla_total_size(4);
1871
1872         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_OIF))
1873                 nlsize += nla_total_size(4);
1874
1875         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_BPF))
1876                 nlsize += nla_total_size(sizeof(struct nlattr)) +
1877                        nla_total_size(MAX_PROG_NAME) +
1878                        nla_total_size(4);
1879
1880         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE))
1881                 nlsize += nla_total_size(4);
1882
1883         if (attrs & SEG6_F_LOCAL_COUNTERS)
1884                 nlsize += nla_total_size(0) + /* nest SEG6_LOCAL_COUNTERS */
1885                           /* SEG6_LOCAL_CNT_PACKETS */
1886                           nla_total_size_64bit(sizeof(__u64)) +
1887                           /* SEG6_LOCAL_CNT_BYTES */
1888                           nla_total_size_64bit(sizeof(__u64)) +
1889                           /* SEG6_LOCAL_CNT_ERRORS */
1890                           nla_total_size_64bit(sizeof(__u64));
1891
1892         return nlsize;
1893 }
1894
1895 static int seg6_local_cmp_encap(struct lwtunnel_state *a,
1896                                 struct lwtunnel_state *b)
1897 {
1898         struct seg6_local_lwt *slwt_a, *slwt_b;
1899         struct seg6_action_param *param;
1900         unsigned long attrs_a, attrs_b;
1901         int i;
1902
1903         slwt_a = seg6_local_lwtunnel(a);
1904         slwt_b = seg6_local_lwtunnel(b);
1905
1906         if (slwt_a->action != slwt_b->action)
1907                 return 1;
1908
1909         attrs_a = slwt_a->desc->attrs | slwt_a->parsed_optattrs;
1910         attrs_b = slwt_b->desc->attrs | slwt_b->parsed_optattrs;
1911
1912         if (attrs_a != attrs_b)
1913                 return 1;
1914
1915         for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
1916                 if (attrs_a & SEG6_F_ATTR(i)) {
1917                         param = &seg6_action_params[i];
1918                         if (param->cmp(slwt_a, slwt_b))
1919                                 return 1;
1920                 }
1921         }
1922
1923         return 0;
1924 }
1925
1926 static const struct lwtunnel_encap_ops seg6_local_ops = {
1927         .build_state    = seg6_local_build_state,
1928         .destroy_state  = seg6_local_destroy_state,
1929         .input          = seg6_local_input,
1930         .fill_encap     = seg6_local_fill_encap,
1931         .get_encap_size = seg6_local_get_encap_size,
1932         .cmp_encap      = seg6_local_cmp_encap,
1933         .owner          = THIS_MODULE,
1934 };
1935
1936 int __init seg6_local_init(void)
1937 {
1938         /* If the max total number of defined attributes is reached, then your
1939          * kernel build stops here.
1940          *
1941          * This check is required to avoid arithmetic overflows when processing
1942          * behavior attributes and the maximum number of defined attributes
1943          * exceeds the allowed value.
1944          */
1945         BUILD_BUG_ON(SEG6_LOCAL_MAX + 1 > BITS_PER_TYPE(unsigned long));
1946
1947         return lwtunnel_encap_add_ops(&seg6_local_ops,
1948                                       LWTUNNEL_ENCAP_SEG6_LOCAL);
1949 }
1950
1951 void seg6_local_exit(void)
1952 {
1953         lwtunnel_encap_del_ops(&seg6_local_ops, LWTUNNEL_ENCAP_SEG6_LOCAL);
1954 }