Merge tag 'pm-5.7-rc2' of git://git.kernel.org/pub/scm/linux/kernel/git/rafael/linux-pm
[linux-2.6-microblaze.git] / net / netfilter / nf_nat_proto.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* (C) 1999-2001 Paul `Rusty' Russell
3  * (C) 2002-2006 Netfilter Core Team <coreteam@netfilter.org>
4  */
5
6 #include <linux/types.h>
7 #include <linux/export.h>
8 #include <linux/init.h>
9 #include <linux/udp.h>
10 #include <linux/tcp.h>
11 #include <linux/icmp.h>
12 #include <linux/icmpv6.h>
13
14 #include <linux/dccp.h>
15 #include <linux/sctp.h>
16 #include <net/sctp/checksum.h>
17
18 #include <linux/netfilter.h>
19 #include <net/netfilter/nf_nat.h>
20
21 #include <linux/ipv6.h>
22 #include <linux/netfilter_ipv6.h>
23 #include <net/checksum.h>
24 #include <net/ip6_checksum.h>
25 #include <net/ip6_route.h>
26 #include <net/xfrm.h>
27 #include <net/ipv6.h>
28
29 #include <net/netfilter/nf_conntrack_core.h>
30 #include <net/netfilter/nf_conntrack.h>
31 #include <linux/netfilter/nfnetlink_conntrack.h>
32
33 static void nf_csum_update(struct sk_buff *skb,
34                            unsigned int iphdroff, __sum16 *check,
35                            const struct nf_conntrack_tuple *t,
36                            enum nf_nat_manip_type maniptype);
37
38 static void
39 __udp_manip_pkt(struct sk_buff *skb,
40                 unsigned int iphdroff, struct udphdr *hdr,
41                 const struct nf_conntrack_tuple *tuple,
42                 enum nf_nat_manip_type maniptype, bool do_csum)
43 {
44         __be16 *portptr, newport;
45
46         if (maniptype == NF_NAT_MANIP_SRC) {
47                 /* Get rid of src port */
48                 newport = tuple->src.u.udp.port;
49                 portptr = &hdr->source;
50         } else {
51                 /* Get rid of dst port */
52                 newport = tuple->dst.u.udp.port;
53                 portptr = &hdr->dest;
54         }
55         if (do_csum) {
56                 nf_csum_update(skb, iphdroff, &hdr->check, tuple, maniptype);
57                 inet_proto_csum_replace2(&hdr->check, skb, *portptr, newport,
58                                          false);
59                 if (!hdr->check)
60                         hdr->check = CSUM_MANGLED_0;
61         }
62         *portptr = newport;
63 }
64
65 static bool udp_manip_pkt(struct sk_buff *skb,
66                           unsigned int iphdroff, unsigned int hdroff,
67                           const struct nf_conntrack_tuple *tuple,
68                           enum nf_nat_manip_type maniptype)
69 {
70         struct udphdr *hdr;
71         bool do_csum;
72
73         if (skb_ensure_writable(skb, hdroff + sizeof(*hdr)))
74                 return false;
75
76         hdr = (struct udphdr *)(skb->data + hdroff);
77         do_csum = hdr->check || skb->ip_summed == CHECKSUM_PARTIAL;
78
79         __udp_manip_pkt(skb, iphdroff, hdr, tuple, maniptype, do_csum);
80         return true;
81 }
82
83 static bool udplite_manip_pkt(struct sk_buff *skb,
84                               unsigned int iphdroff, unsigned int hdroff,
85                               const struct nf_conntrack_tuple *tuple,
86                               enum nf_nat_manip_type maniptype)
87 {
88 #ifdef CONFIG_NF_CT_PROTO_UDPLITE
89         struct udphdr *hdr;
90
91         if (skb_ensure_writable(skb, hdroff + sizeof(*hdr)))
92                 return false;
93
94         hdr = (struct udphdr *)(skb->data + hdroff);
95         __udp_manip_pkt(skb, iphdroff, hdr, tuple, maniptype, true);
96 #endif
97         return true;
98 }
99
100 static bool
101 sctp_manip_pkt(struct sk_buff *skb,
102                unsigned int iphdroff, unsigned int hdroff,
103                const struct nf_conntrack_tuple *tuple,
104                enum nf_nat_manip_type maniptype)
105 {
106 #ifdef CONFIG_NF_CT_PROTO_SCTP
107         struct sctphdr *hdr;
108         int hdrsize = 8;
109
110         /* This could be an inner header returned in imcp packet; in such
111          * cases we cannot update the checksum field since it is outside
112          * of the 8 bytes of transport layer headers we are guaranteed.
113          */
114         if (skb->len >= hdroff + sizeof(*hdr))
115                 hdrsize = sizeof(*hdr);
116
117         if (skb_ensure_writable(skb, hdroff + hdrsize))
118                 return false;
119
120         hdr = (struct sctphdr *)(skb->data + hdroff);
121
122         if (maniptype == NF_NAT_MANIP_SRC) {
123                 /* Get rid of src port */
124                 hdr->source = tuple->src.u.sctp.port;
125         } else {
126                 /* Get rid of dst port */
127                 hdr->dest = tuple->dst.u.sctp.port;
128         }
129
130         if (hdrsize < sizeof(*hdr))
131                 return true;
132
133         if (skb->ip_summed != CHECKSUM_PARTIAL) {
134                 hdr->checksum = sctp_compute_cksum(skb, hdroff);
135                 skb->ip_summed = CHECKSUM_NONE;
136         }
137
138 #endif
139         return true;
140 }
141
142 static bool
143 tcp_manip_pkt(struct sk_buff *skb,
144               unsigned int iphdroff, unsigned int hdroff,
145               const struct nf_conntrack_tuple *tuple,
146               enum nf_nat_manip_type maniptype)
147 {
148         struct tcphdr *hdr;
149         __be16 *portptr, newport, oldport;
150         int hdrsize = 8; /* TCP connection tracking guarantees this much */
151
152         /* this could be a inner header returned in icmp packet; in such
153            cases we cannot update the checksum field since it is outside of
154            the 8 bytes of transport layer headers we are guaranteed */
155         if (skb->len >= hdroff + sizeof(struct tcphdr))
156                 hdrsize = sizeof(struct tcphdr);
157
158         if (skb_ensure_writable(skb, hdroff + hdrsize))
159                 return false;
160
161         hdr = (struct tcphdr *)(skb->data + hdroff);
162
163         if (maniptype == NF_NAT_MANIP_SRC) {
164                 /* Get rid of src port */
165                 newport = tuple->src.u.tcp.port;
166                 portptr = &hdr->source;
167         } else {
168                 /* Get rid of dst port */
169                 newport = tuple->dst.u.tcp.port;
170                 portptr = &hdr->dest;
171         }
172
173         oldport = *portptr;
174         *portptr = newport;
175
176         if (hdrsize < sizeof(*hdr))
177                 return true;
178
179         nf_csum_update(skb, iphdroff, &hdr->check, tuple, maniptype);
180         inet_proto_csum_replace2(&hdr->check, skb, oldport, newport, false);
181         return true;
182 }
183
184 static bool
185 dccp_manip_pkt(struct sk_buff *skb,
186                unsigned int iphdroff, unsigned int hdroff,
187                const struct nf_conntrack_tuple *tuple,
188                enum nf_nat_manip_type maniptype)
189 {
190 #ifdef CONFIG_NF_CT_PROTO_DCCP
191         struct dccp_hdr *hdr;
192         __be16 *portptr, oldport, newport;
193         int hdrsize = 8; /* DCCP connection tracking guarantees this much */
194
195         if (skb->len >= hdroff + sizeof(struct dccp_hdr))
196                 hdrsize = sizeof(struct dccp_hdr);
197
198         if (skb_ensure_writable(skb, hdroff + hdrsize))
199                 return false;
200
201         hdr = (struct dccp_hdr *)(skb->data + hdroff);
202
203         if (maniptype == NF_NAT_MANIP_SRC) {
204                 newport = tuple->src.u.dccp.port;
205                 portptr = &hdr->dccph_sport;
206         } else {
207                 newport = tuple->dst.u.dccp.port;
208                 portptr = &hdr->dccph_dport;
209         }
210
211         oldport = *portptr;
212         *portptr = newport;
213
214         if (hdrsize < sizeof(*hdr))
215                 return true;
216
217         nf_csum_update(skb, iphdroff, &hdr->dccph_checksum, tuple, maniptype);
218         inet_proto_csum_replace2(&hdr->dccph_checksum, skb, oldport, newport,
219                                  false);
220 #endif
221         return true;
222 }
223
224 static bool
225 icmp_manip_pkt(struct sk_buff *skb,
226                unsigned int iphdroff, unsigned int hdroff,
227                const struct nf_conntrack_tuple *tuple,
228                enum nf_nat_manip_type maniptype)
229 {
230         struct icmphdr *hdr;
231
232         if (skb_ensure_writable(skb, hdroff + sizeof(*hdr)))
233                 return false;
234
235         hdr = (struct icmphdr *)(skb->data + hdroff);
236         switch (hdr->type) {
237         case ICMP_ECHO:
238         case ICMP_ECHOREPLY:
239         case ICMP_TIMESTAMP:
240         case ICMP_TIMESTAMPREPLY:
241         case ICMP_INFO_REQUEST:
242         case ICMP_INFO_REPLY:
243         case ICMP_ADDRESS:
244         case ICMP_ADDRESSREPLY:
245                 break;
246         default:
247                 return true;
248         }
249         inet_proto_csum_replace2(&hdr->checksum, skb,
250                                  hdr->un.echo.id, tuple->src.u.icmp.id, false);
251         hdr->un.echo.id = tuple->src.u.icmp.id;
252         return true;
253 }
254
255 static bool
256 icmpv6_manip_pkt(struct sk_buff *skb,
257                  unsigned int iphdroff, unsigned int hdroff,
258                  const struct nf_conntrack_tuple *tuple,
259                  enum nf_nat_manip_type maniptype)
260 {
261         struct icmp6hdr *hdr;
262
263         if (skb_ensure_writable(skb, hdroff + sizeof(*hdr)))
264                 return false;
265
266         hdr = (struct icmp6hdr *)(skb->data + hdroff);
267         nf_csum_update(skb, iphdroff, &hdr->icmp6_cksum, tuple, maniptype);
268         if (hdr->icmp6_type == ICMPV6_ECHO_REQUEST ||
269             hdr->icmp6_type == ICMPV6_ECHO_REPLY) {
270                 inet_proto_csum_replace2(&hdr->icmp6_cksum, skb,
271                                          hdr->icmp6_identifier,
272                                          tuple->src.u.icmp.id, false);
273                 hdr->icmp6_identifier = tuple->src.u.icmp.id;
274         }
275         return true;
276 }
277
278 /* manipulate a GRE packet according to maniptype */
279 static bool
280 gre_manip_pkt(struct sk_buff *skb,
281               unsigned int iphdroff, unsigned int hdroff,
282               const struct nf_conntrack_tuple *tuple,
283               enum nf_nat_manip_type maniptype)
284 {
285 #if IS_ENABLED(CONFIG_NF_CT_PROTO_GRE)
286         const struct gre_base_hdr *greh;
287         struct pptp_gre_header *pgreh;
288
289         /* pgreh includes two optional 32bit fields which are not required
290          * to be there.  That's where the magic '8' comes from */
291         if (skb_ensure_writable(skb, hdroff + sizeof(*pgreh) - 8))
292                 return false;
293
294         greh = (void *)skb->data + hdroff;
295         pgreh = (struct pptp_gre_header *)greh;
296
297         /* we only have destination manip of a packet, since 'source key'
298          * is not present in the packet itself */
299         if (maniptype != NF_NAT_MANIP_DST)
300                 return true;
301
302         switch (greh->flags & GRE_VERSION) {
303         case GRE_VERSION_0:
304                 /* We do not currently NAT any GREv0 packets.
305                  * Try to behave like "nf_nat_proto_unknown" */
306                 break;
307         case GRE_VERSION_1:
308                 pr_debug("call_id -> 0x%04x\n", ntohs(tuple->dst.u.gre.key));
309                 pgreh->call_id = tuple->dst.u.gre.key;
310                 break;
311         default:
312                 pr_debug("can't nat unknown GRE version\n");
313                 return false;
314         }
315 #endif
316         return true;
317 }
318
319 static bool l4proto_manip_pkt(struct sk_buff *skb,
320                               unsigned int iphdroff, unsigned int hdroff,
321                               const struct nf_conntrack_tuple *tuple,
322                               enum nf_nat_manip_type maniptype)
323 {
324         switch (tuple->dst.protonum) {
325         case IPPROTO_TCP:
326                 return tcp_manip_pkt(skb, iphdroff, hdroff,
327                                      tuple, maniptype);
328         case IPPROTO_UDP:
329                 return udp_manip_pkt(skb, iphdroff, hdroff,
330                                      tuple, maniptype);
331         case IPPROTO_UDPLITE:
332                 return udplite_manip_pkt(skb, iphdroff, hdroff,
333                                          tuple, maniptype);
334         case IPPROTO_SCTP:
335                 return sctp_manip_pkt(skb, iphdroff, hdroff,
336                                       tuple, maniptype);
337         case IPPROTO_ICMP:
338                 return icmp_manip_pkt(skb, iphdroff, hdroff,
339                                       tuple, maniptype);
340         case IPPROTO_ICMPV6:
341                 return icmpv6_manip_pkt(skb, iphdroff, hdroff,
342                                         tuple, maniptype);
343         case IPPROTO_DCCP:
344                 return dccp_manip_pkt(skb, iphdroff, hdroff,
345                                       tuple, maniptype);
346         case IPPROTO_GRE:
347                 return gre_manip_pkt(skb, iphdroff, hdroff,
348                                      tuple, maniptype);
349         }
350
351         /* If we don't know protocol -- no error, pass it unmodified. */
352         return true;
353 }
354
355 static bool nf_nat_ipv4_manip_pkt(struct sk_buff *skb,
356                                   unsigned int iphdroff,
357                                   const struct nf_conntrack_tuple *target,
358                                   enum nf_nat_manip_type maniptype)
359 {
360         struct iphdr *iph;
361         unsigned int hdroff;
362
363         if (skb_ensure_writable(skb, iphdroff + sizeof(*iph)))
364                 return false;
365
366         iph = (void *)skb->data + iphdroff;
367         hdroff = iphdroff + iph->ihl * 4;
368
369         if (!l4proto_manip_pkt(skb, iphdroff, hdroff, target, maniptype))
370                 return false;
371         iph = (void *)skb->data + iphdroff;
372
373         if (maniptype == NF_NAT_MANIP_SRC) {
374                 csum_replace4(&iph->check, iph->saddr, target->src.u3.ip);
375                 iph->saddr = target->src.u3.ip;
376         } else {
377                 csum_replace4(&iph->check, iph->daddr, target->dst.u3.ip);
378                 iph->daddr = target->dst.u3.ip;
379         }
380         return true;
381 }
382
383 static bool nf_nat_ipv6_manip_pkt(struct sk_buff *skb,
384                                   unsigned int iphdroff,
385                                   const struct nf_conntrack_tuple *target,
386                                   enum nf_nat_manip_type maniptype)
387 {
388 #if IS_ENABLED(CONFIG_IPV6)
389         struct ipv6hdr *ipv6h;
390         __be16 frag_off;
391         int hdroff;
392         u8 nexthdr;
393
394         if (skb_ensure_writable(skb, iphdroff + sizeof(*ipv6h)))
395                 return false;
396
397         ipv6h = (void *)skb->data + iphdroff;
398         nexthdr = ipv6h->nexthdr;
399         hdroff = ipv6_skip_exthdr(skb, iphdroff + sizeof(*ipv6h),
400                                   &nexthdr, &frag_off);
401         if (hdroff < 0)
402                 goto manip_addr;
403
404         if ((frag_off & htons(~0x7)) == 0 &&
405             !l4proto_manip_pkt(skb, iphdroff, hdroff, target, maniptype))
406                 return false;
407
408         /* must reload, offset might have changed */
409         ipv6h = (void *)skb->data + iphdroff;
410
411 manip_addr:
412         if (maniptype == NF_NAT_MANIP_SRC)
413                 ipv6h->saddr = target->src.u3.in6;
414         else
415                 ipv6h->daddr = target->dst.u3.in6;
416
417 #endif
418         return true;
419 }
420
421 unsigned int nf_nat_manip_pkt(struct sk_buff *skb, struct nf_conn *ct,
422                               enum nf_nat_manip_type mtype,
423                               enum ip_conntrack_dir dir)
424 {
425         struct nf_conntrack_tuple target;
426
427         /* We are aiming to look like inverse of other direction. */
428         nf_ct_invert_tuple(&target, &ct->tuplehash[!dir].tuple);
429
430         switch (target.src.l3num) {
431         case NFPROTO_IPV6:
432                 if (nf_nat_ipv6_manip_pkt(skb, 0, &target, mtype))
433                         return NF_ACCEPT;
434                 break;
435         case NFPROTO_IPV4:
436                 if (nf_nat_ipv4_manip_pkt(skb, 0, &target, mtype))
437                         return NF_ACCEPT;
438                 break;
439         default:
440                 WARN_ON_ONCE(1);
441                 break;
442         }
443
444         return NF_DROP;
445 }
446
447 static void nf_nat_ipv4_csum_update(struct sk_buff *skb,
448                                     unsigned int iphdroff, __sum16 *check,
449                                     const struct nf_conntrack_tuple *t,
450                                     enum nf_nat_manip_type maniptype)
451 {
452         struct iphdr *iph = (struct iphdr *)(skb->data + iphdroff);
453         __be32 oldip, newip;
454
455         if (maniptype == NF_NAT_MANIP_SRC) {
456                 oldip = iph->saddr;
457                 newip = t->src.u3.ip;
458         } else {
459                 oldip = iph->daddr;
460                 newip = t->dst.u3.ip;
461         }
462         inet_proto_csum_replace4(check, skb, oldip, newip, true);
463 }
464
465 static void nf_nat_ipv6_csum_update(struct sk_buff *skb,
466                                     unsigned int iphdroff, __sum16 *check,
467                                     const struct nf_conntrack_tuple *t,
468                                     enum nf_nat_manip_type maniptype)
469 {
470 #if IS_ENABLED(CONFIG_IPV6)
471         const struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + iphdroff);
472         const struct in6_addr *oldip, *newip;
473
474         if (maniptype == NF_NAT_MANIP_SRC) {
475                 oldip = &ipv6h->saddr;
476                 newip = &t->src.u3.in6;
477         } else {
478                 oldip = &ipv6h->daddr;
479                 newip = &t->dst.u3.in6;
480         }
481         inet_proto_csum_replace16(check, skb, oldip->s6_addr32,
482                                   newip->s6_addr32, true);
483 #endif
484 }
485
486 static void nf_csum_update(struct sk_buff *skb,
487                            unsigned int iphdroff, __sum16 *check,
488                            const struct nf_conntrack_tuple *t,
489                            enum nf_nat_manip_type maniptype)
490 {
491         switch (t->src.l3num) {
492         case NFPROTO_IPV4:
493                 nf_nat_ipv4_csum_update(skb, iphdroff, check, t, maniptype);
494                 return;
495         case NFPROTO_IPV6:
496                 nf_nat_ipv6_csum_update(skb, iphdroff, check, t, maniptype);
497                 return;
498         }
499 }
500
501 static void nf_nat_ipv4_csum_recalc(struct sk_buff *skb,
502                                     u8 proto, void *data, __sum16 *check,
503                                     int datalen, int oldlen)
504 {
505         if (skb->ip_summed != CHECKSUM_PARTIAL) {
506                 const struct iphdr *iph = ip_hdr(skb);
507
508                 skb->ip_summed = CHECKSUM_PARTIAL;
509                 skb->csum_start = skb_headroom(skb) + skb_network_offset(skb) +
510                         ip_hdrlen(skb);
511                 skb->csum_offset = (void *)check - data;
512                 *check = ~csum_tcpudp_magic(iph->saddr, iph->daddr, datalen,
513                                             proto, 0);
514         } else {
515                 inet_proto_csum_replace2(check, skb,
516                                          htons(oldlen), htons(datalen), true);
517         }
518 }
519
520 #if IS_ENABLED(CONFIG_IPV6)
521 static void nf_nat_ipv6_csum_recalc(struct sk_buff *skb,
522                                     u8 proto, void *data, __sum16 *check,
523                                     int datalen, int oldlen)
524 {
525         if (skb->ip_summed != CHECKSUM_PARTIAL) {
526                 const struct ipv6hdr *ipv6h = ipv6_hdr(skb);
527
528                 skb->ip_summed = CHECKSUM_PARTIAL;
529                 skb->csum_start = skb_headroom(skb) + skb_network_offset(skb) +
530                         (data - (void *)skb->data);
531                 skb->csum_offset = (void *)check - data;
532                 *check = ~csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr,
533                                           datalen, proto, 0);
534         } else {
535                 inet_proto_csum_replace2(check, skb,
536                                          htons(oldlen), htons(datalen), true);
537         }
538 }
539 #endif
540
541 void nf_nat_csum_recalc(struct sk_buff *skb,
542                         u8 nfproto, u8 proto, void *data, __sum16 *check,
543                         int datalen, int oldlen)
544 {
545         switch (nfproto) {
546         case NFPROTO_IPV4:
547                 nf_nat_ipv4_csum_recalc(skb, proto, data, check,
548                                         datalen, oldlen);
549                 return;
550 #if IS_ENABLED(CONFIG_IPV6)
551         case NFPROTO_IPV6:
552                 nf_nat_ipv6_csum_recalc(skb, proto, data, check,
553                                         datalen, oldlen);
554                 return;
555 #endif
556         }
557
558         WARN_ON_ONCE(1);
559 }
560
561 int nf_nat_icmp_reply_translation(struct sk_buff *skb,
562                                   struct nf_conn *ct,
563                                   enum ip_conntrack_info ctinfo,
564                                   unsigned int hooknum)
565 {
566         struct {
567                 struct icmphdr  icmp;
568                 struct iphdr    ip;
569         } *inside;
570         enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo);
571         enum nf_nat_manip_type manip = HOOK2MANIP(hooknum);
572         unsigned int hdrlen = ip_hdrlen(skb);
573         struct nf_conntrack_tuple target;
574         unsigned long statusbit;
575
576         WARN_ON(ctinfo != IP_CT_RELATED && ctinfo != IP_CT_RELATED_REPLY);
577
578         if (skb_ensure_writable(skb, hdrlen + sizeof(*inside)))
579                 return 0;
580         if (nf_ip_checksum(skb, hooknum, hdrlen, IPPROTO_ICMP))
581                 return 0;
582
583         inside = (void *)skb->data + hdrlen;
584         if (inside->icmp.type == ICMP_REDIRECT) {
585                 if ((ct->status & IPS_NAT_DONE_MASK) != IPS_NAT_DONE_MASK)
586                         return 0;
587                 if (ct->status & IPS_NAT_MASK)
588                         return 0;
589         }
590
591         if (manip == NF_NAT_MANIP_SRC)
592                 statusbit = IPS_SRC_NAT;
593         else
594                 statusbit = IPS_DST_NAT;
595
596         /* Invert if this is reply direction */
597         if (dir == IP_CT_DIR_REPLY)
598                 statusbit ^= IPS_NAT_MASK;
599
600         if (!(ct->status & statusbit))
601                 return 1;
602
603         if (!nf_nat_ipv4_manip_pkt(skb, hdrlen + sizeof(inside->icmp),
604                                    &ct->tuplehash[!dir].tuple, !manip))
605                 return 0;
606
607         if (skb->ip_summed != CHECKSUM_PARTIAL) {
608                 /* Reloading "inside" here since manip_pkt may reallocate */
609                 inside = (void *)skb->data + hdrlen;
610                 inside->icmp.checksum = 0;
611                 inside->icmp.checksum =
612                         csum_fold(skb_checksum(skb, hdrlen,
613                                                skb->len - hdrlen, 0));
614         }
615
616         /* Change outer to look like the reply to an incoming packet */
617         nf_ct_invert_tuple(&target, &ct->tuplehash[!dir].tuple);
618         target.dst.protonum = IPPROTO_ICMP;
619         if (!nf_nat_ipv4_manip_pkt(skb, 0, &target, manip))
620                 return 0;
621
622         return 1;
623 }
624 EXPORT_SYMBOL_GPL(nf_nat_icmp_reply_translation);
625
626 static unsigned int
627 nf_nat_ipv4_fn(void *priv, struct sk_buff *skb,
628                const struct nf_hook_state *state)
629 {
630         struct nf_conn *ct;
631         enum ip_conntrack_info ctinfo;
632
633         ct = nf_ct_get(skb, &ctinfo);
634         if (!ct)
635                 return NF_ACCEPT;
636
637         if (ctinfo == IP_CT_RELATED || ctinfo == IP_CT_RELATED_REPLY) {
638                 if (ip_hdr(skb)->protocol == IPPROTO_ICMP) {
639                         if (!nf_nat_icmp_reply_translation(skb, ct, ctinfo,
640                                                            state->hook))
641                                 return NF_DROP;
642                         else
643                                 return NF_ACCEPT;
644                 }
645         }
646
647         return nf_nat_inet_fn(priv, skb, state);
648 }
649
650 static unsigned int
651 nf_nat_ipv4_in(void *priv, struct sk_buff *skb,
652                const struct nf_hook_state *state)
653 {
654         unsigned int ret;
655         __be32 daddr = ip_hdr(skb)->daddr;
656
657         ret = nf_nat_ipv4_fn(priv, skb, state);
658         if (ret == NF_ACCEPT && daddr != ip_hdr(skb)->daddr)
659                 skb_dst_drop(skb);
660
661         return ret;
662 }
663
664 static unsigned int
665 nf_nat_ipv4_out(void *priv, struct sk_buff *skb,
666                 const struct nf_hook_state *state)
667 {
668 #ifdef CONFIG_XFRM
669         const struct nf_conn *ct;
670         enum ip_conntrack_info ctinfo;
671         int err;
672 #endif
673         unsigned int ret;
674
675         ret = nf_nat_ipv4_fn(priv, skb, state);
676 #ifdef CONFIG_XFRM
677         if (ret != NF_ACCEPT)
678                 return ret;
679
680         if (IPCB(skb)->flags & IPSKB_XFRM_TRANSFORMED)
681                 return ret;
682
683         ct = nf_ct_get(skb, &ctinfo);
684         if (ct) {
685                 enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo);
686
687                 if (ct->tuplehash[dir].tuple.src.u3.ip !=
688                      ct->tuplehash[!dir].tuple.dst.u3.ip ||
689                     (ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMP &&
690                      ct->tuplehash[dir].tuple.src.u.all !=
691                      ct->tuplehash[!dir].tuple.dst.u.all)) {
692                         err = nf_xfrm_me_harder(state->net, skb, AF_INET);
693                         if (err < 0)
694                                 ret = NF_DROP_ERR(err);
695                 }
696         }
697 #endif
698         return ret;
699 }
700
701 static unsigned int
702 nf_nat_ipv4_local_fn(void *priv, struct sk_buff *skb,
703                      const struct nf_hook_state *state)
704 {
705         const struct nf_conn *ct;
706         enum ip_conntrack_info ctinfo;
707         unsigned int ret;
708         int err;
709
710         ret = nf_nat_ipv4_fn(priv, skb, state);
711         if (ret != NF_ACCEPT)
712                 return ret;
713
714         ct = nf_ct_get(skb, &ctinfo);
715         if (ct) {
716                 enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo);
717
718                 if (ct->tuplehash[dir].tuple.dst.u3.ip !=
719                     ct->tuplehash[!dir].tuple.src.u3.ip) {
720                         err = ip_route_me_harder(state->net, skb, RTN_UNSPEC);
721                         if (err < 0)
722                                 ret = NF_DROP_ERR(err);
723                 }
724 #ifdef CONFIG_XFRM
725                 else if (!(IPCB(skb)->flags & IPSKB_XFRM_TRANSFORMED) &&
726                          ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMP &&
727                          ct->tuplehash[dir].tuple.dst.u.all !=
728                          ct->tuplehash[!dir].tuple.src.u.all) {
729                         err = nf_xfrm_me_harder(state->net, skb, AF_INET);
730                         if (err < 0)
731                                 ret = NF_DROP_ERR(err);
732                 }
733 #endif
734         }
735         return ret;
736 }
737
738 static const struct nf_hook_ops nf_nat_ipv4_ops[] = {
739         /* Before packet filtering, change destination */
740         {
741                 .hook           = nf_nat_ipv4_in,
742                 .pf             = NFPROTO_IPV4,
743                 .hooknum        = NF_INET_PRE_ROUTING,
744                 .priority       = NF_IP_PRI_NAT_DST,
745         },
746         /* After packet filtering, change source */
747         {
748                 .hook           = nf_nat_ipv4_out,
749                 .pf             = NFPROTO_IPV4,
750                 .hooknum        = NF_INET_POST_ROUTING,
751                 .priority       = NF_IP_PRI_NAT_SRC,
752         },
753         /* Before packet filtering, change destination */
754         {
755                 .hook           = nf_nat_ipv4_local_fn,
756                 .pf             = NFPROTO_IPV4,
757                 .hooknum        = NF_INET_LOCAL_OUT,
758                 .priority       = NF_IP_PRI_NAT_DST,
759         },
760         /* After packet filtering, change source */
761         {
762                 .hook           = nf_nat_ipv4_fn,
763                 .pf             = NFPROTO_IPV4,
764                 .hooknum        = NF_INET_LOCAL_IN,
765                 .priority       = NF_IP_PRI_NAT_SRC,
766         },
767 };
768
769 int nf_nat_ipv4_register_fn(struct net *net, const struct nf_hook_ops *ops)
770 {
771         return nf_nat_register_fn(net, ops->pf, ops, nf_nat_ipv4_ops,
772                                   ARRAY_SIZE(nf_nat_ipv4_ops));
773 }
774 EXPORT_SYMBOL_GPL(nf_nat_ipv4_register_fn);
775
776 void nf_nat_ipv4_unregister_fn(struct net *net, const struct nf_hook_ops *ops)
777 {
778         nf_nat_unregister_fn(net, ops->pf, ops, ARRAY_SIZE(nf_nat_ipv4_ops));
779 }
780 EXPORT_SYMBOL_GPL(nf_nat_ipv4_unregister_fn);
781
782 #if IS_ENABLED(CONFIG_IPV6)
783 int nf_nat_icmpv6_reply_translation(struct sk_buff *skb,
784                                     struct nf_conn *ct,
785                                     enum ip_conntrack_info ctinfo,
786                                     unsigned int hooknum,
787                                     unsigned int hdrlen)
788 {
789         struct {
790                 struct icmp6hdr icmp6;
791                 struct ipv6hdr  ip6;
792         } *inside;
793         enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo);
794         enum nf_nat_manip_type manip = HOOK2MANIP(hooknum);
795         struct nf_conntrack_tuple target;
796         unsigned long statusbit;
797
798         WARN_ON(ctinfo != IP_CT_RELATED && ctinfo != IP_CT_RELATED_REPLY);
799
800         if (skb_ensure_writable(skb, hdrlen + sizeof(*inside)))
801                 return 0;
802         if (nf_ip6_checksum(skb, hooknum, hdrlen, IPPROTO_ICMPV6))
803                 return 0;
804
805         inside = (void *)skb->data + hdrlen;
806         if (inside->icmp6.icmp6_type == NDISC_REDIRECT) {
807                 if ((ct->status & IPS_NAT_DONE_MASK) != IPS_NAT_DONE_MASK)
808                         return 0;
809                 if (ct->status & IPS_NAT_MASK)
810                         return 0;
811         }
812
813         if (manip == NF_NAT_MANIP_SRC)
814                 statusbit = IPS_SRC_NAT;
815         else
816                 statusbit = IPS_DST_NAT;
817
818         /* Invert if this is reply direction */
819         if (dir == IP_CT_DIR_REPLY)
820                 statusbit ^= IPS_NAT_MASK;
821
822         if (!(ct->status & statusbit))
823                 return 1;
824
825         if (!nf_nat_ipv6_manip_pkt(skb, hdrlen + sizeof(inside->icmp6),
826                                    &ct->tuplehash[!dir].tuple, !manip))
827                 return 0;
828
829         if (skb->ip_summed != CHECKSUM_PARTIAL) {
830                 struct ipv6hdr *ipv6h = ipv6_hdr(skb);
831
832                 inside = (void *)skb->data + hdrlen;
833                 inside->icmp6.icmp6_cksum = 0;
834                 inside->icmp6.icmp6_cksum =
835                         csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr,
836                                         skb->len - hdrlen, IPPROTO_ICMPV6,
837                                         skb_checksum(skb, hdrlen,
838                                                      skb->len - hdrlen, 0));
839         }
840
841         nf_ct_invert_tuple(&target, &ct->tuplehash[!dir].tuple);
842         target.dst.protonum = IPPROTO_ICMPV6;
843         if (!nf_nat_ipv6_manip_pkt(skb, 0, &target, manip))
844                 return 0;
845
846         return 1;
847 }
848 EXPORT_SYMBOL_GPL(nf_nat_icmpv6_reply_translation);
849
850 static unsigned int
851 nf_nat_ipv6_fn(void *priv, struct sk_buff *skb,
852                const struct nf_hook_state *state)
853 {
854         struct nf_conn *ct;
855         enum ip_conntrack_info ctinfo;
856         __be16 frag_off;
857         int hdrlen;
858         u8 nexthdr;
859
860         ct = nf_ct_get(skb, &ctinfo);
861         /* Can't track?  It's not due to stress, or conntrack would
862          * have dropped it.  Hence it's the user's responsibilty to
863          * packet filter it out, or implement conntrack/NAT for that
864          * protocol. 8) --RR
865          */
866         if (!ct)
867                 return NF_ACCEPT;
868
869         if (ctinfo == IP_CT_RELATED || ctinfo == IP_CT_RELATED_REPLY) {
870                 nexthdr = ipv6_hdr(skb)->nexthdr;
871                 hdrlen = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr),
872                                           &nexthdr, &frag_off);
873
874                 if (hdrlen >= 0 && nexthdr == IPPROTO_ICMPV6) {
875                         if (!nf_nat_icmpv6_reply_translation(skb, ct, ctinfo,
876                                                              state->hook,
877                                                              hdrlen))
878                                 return NF_DROP;
879                         else
880                                 return NF_ACCEPT;
881                 }
882         }
883
884         return nf_nat_inet_fn(priv, skb, state);
885 }
886
887 static unsigned int
888 nf_nat_ipv6_in(void *priv, struct sk_buff *skb,
889                const struct nf_hook_state *state)
890 {
891         unsigned int ret;
892         struct in6_addr daddr = ipv6_hdr(skb)->daddr;
893
894         ret = nf_nat_ipv6_fn(priv, skb, state);
895         if (ret != NF_DROP && ret != NF_STOLEN &&
896             ipv6_addr_cmp(&daddr, &ipv6_hdr(skb)->daddr))
897                 skb_dst_drop(skb);
898
899         return ret;
900 }
901
902 static unsigned int
903 nf_nat_ipv6_out(void *priv, struct sk_buff *skb,
904                 const struct nf_hook_state *state)
905 {
906 #ifdef CONFIG_XFRM
907         const struct nf_conn *ct;
908         enum ip_conntrack_info ctinfo;
909         int err;
910 #endif
911         unsigned int ret;
912
913         ret = nf_nat_ipv6_fn(priv, skb, state);
914 #ifdef CONFIG_XFRM
915         if (ret != NF_ACCEPT)
916                 return ret;
917
918         if (IP6CB(skb)->flags & IP6SKB_XFRM_TRANSFORMED)
919                 return ret;
920         ct = nf_ct_get(skb, &ctinfo);
921         if (ct) {
922                 enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo);
923
924                 if (!nf_inet_addr_cmp(&ct->tuplehash[dir].tuple.src.u3,
925                                       &ct->tuplehash[!dir].tuple.dst.u3) ||
926                     (ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMPV6 &&
927                      ct->tuplehash[dir].tuple.src.u.all !=
928                      ct->tuplehash[!dir].tuple.dst.u.all)) {
929                         err = nf_xfrm_me_harder(state->net, skb, AF_INET6);
930                         if (err < 0)
931                                 ret = NF_DROP_ERR(err);
932                 }
933         }
934 #endif
935
936         return ret;
937 }
938
939 static unsigned int
940 nf_nat_ipv6_local_fn(void *priv, struct sk_buff *skb,
941                      const struct nf_hook_state *state)
942 {
943         const struct nf_conn *ct;
944         enum ip_conntrack_info ctinfo;
945         unsigned int ret;
946         int err;
947
948         ret = nf_nat_ipv6_fn(priv, skb, state);
949         if (ret != NF_ACCEPT)
950                 return ret;
951
952         ct = nf_ct_get(skb, &ctinfo);
953         if (ct) {
954                 enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo);
955
956                 if (!nf_inet_addr_cmp(&ct->tuplehash[dir].tuple.dst.u3,
957                                       &ct->tuplehash[!dir].tuple.src.u3)) {
958                         err = nf_ip6_route_me_harder(state->net, skb);
959                         if (err < 0)
960                                 ret = NF_DROP_ERR(err);
961                 }
962 #ifdef CONFIG_XFRM
963                 else if (!(IP6CB(skb)->flags & IP6SKB_XFRM_TRANSFORMED) &&
964                          ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMPV6 &&
965                          ct->tuplehash[dir].tuple.dst.u.all !=
966                          ct->tuplehash[!dir].tuple.src.u.all) {
967                         err = nf_xfrm_me_harder(state->net, skb, AF_INET6);
968                         if (err < 0)
969                                 ret = NF_DROP_ERR(err);
970                 }
971 #endif
972         }
973
974         return ret;
975 }
976
977 static const struct nf_hook_ops nf_nat_ipv6_ops[] = {
978         /* Before packet filtering, change destination */
979         {
980                 .hook           = nf_nat_ipv6_in,
981                 .pf             = NFPROTO_IPV6,
982                 .hooknum        = NF_INET_PRE_ROUTING,
983                 .priority       = NF_IP6_PRI_NAT_DST,
984         },
985         /* After packet filtering, change source */
986         {
987                 .hook           = nf_nat_ipv6_out,
988                 .pf             = NFPROTO_IPV6,
989                 .hooknum        = NF_INET_POST_ROUTING,
990                 .priority       = NF_IP6_PRI_NAT_SRC,
991         },
992         /* Before packet filtering, change destination */
993         {
994                 .hook           = nf_nat_ipv6_local_fn,
995                 .pf             = NFPROTO_IPV6,
996                 .hooknum        = NF_INET_LOCAL_OUT,
997                 .priority       = NF_IP6_PRI_NAT_DST,
998         },
999         /* After packet filtering, change source */
1000         {
1001                 .hook           = nf_nat_ipv6_fn,
1002                 .pf             = NFPROTO_IPV6,
1003                 .hooknum        = NF_INET_LOCAL_IN,
1004                 .priority       = NF_IP6_PRI_NAT_SRC,
1005         },
1006 };
1007
1008 int nf_nat_ipv6_register_fn(struct net *net, const struct nf_hook_ops *ops)
1009 {
1010         return nf_nat_register_fn(net, ops->pf, ops, nf_nat_ipv6_ops,
1011                                   ARRAY_SIZE(nf_nat_ipv6_ops));
1012 }
1013 EXPORT_SYMBOL_GPL(nf_nat_ipv6_register_fn);
1014
1015 void nf_nat_ipv6_unregister_fn(struct net *net, const struct nf_hook_ops *ops)
1016 {
1017         nf_nat_unregister_fn(net, ops->pf, ops, ARRAY_SIZE(nf_nat_ipv6_ops));
1018 }
1019 EXPORT_SYMBOL_GPL(nf_nat_ipv6_unregister_fn);
1020 #endif /* CONFIG_IPV6 */
1021
1022 #if defined(CONFIG_NF_TABLES_INET) && IS_ENABLED(CONFIG_NFT_NAT)
1023 int nf_nat_inet_register_fn(struct net *net, const struct nf_hook_ops *ops)
1024 {
1025         int ret;
1026
1027         if (WARN_ON_ONCE(ops->pf != NFPROTO_INET))
1028                 return -EINVAL;
1029
1030         ret = nf_nat_register_fn(net, NFPROTO_IPV6, ops, nf_nat_ipv6_ops,
1031                                  ARRAY_SIZE(nf_nat_ipv6_ops));
1032         if (ret)
1033                 return ret;
1034
1035         ret = nf_nat_register_fn(net, NFPROTO_IPV4, ops, nf_nat_ipv4_ops,
1036                                  ARRAY_SIZE(nf_nat_ipv4_ops));
1037         if (ret)
1038                 nf_nat_ipv6_unregister_fn(net, ops);
1039
1040         return ret;
1041 }
1042 EXPORT_SYMBOL_GPL(nf_nat_inet_register_fn);
1043
1044 void nf_nat_inet_unregister_fn(struct net *net, const struct nf_hook_ops *ops)
1045 {
1046         nf_nat_unregister_fn(net, NFPROTO_IPV4, ops, ARRAY_SIZE(nf_nat_ipv4_ops));
1047         nf_nat_unregister_fn(net, NFPROTO_IPV6, ops, ARRAY_SIZE(nf_nat_ipv6_ops));
1048 }
1049 EXPORT_SYMBOL_GPL(nf_nat_inet_unregister_fn);
1050 #endif /* NFT INET NAT */