Merge tag 'pci-v4.18-fixes-2' of git://git.kernel.org/pub/scm/linux/kernel/git/helgaa...
[linux-2.6-microblaze.git] / net / netfilter / nft_ct.c
1 /*
2  * Copyright (c) 2008-2009 Patrick McHardy <kaber@trash.net>
3  * Copyright (c) 2016 Pablo Neira Ayuso <pablo@netfilter.org>
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License version 2 as
7  * published by the Free Software Foundation.
8  *
9  * Development of this code funded by Astaro AG (http://www.astaro.com/)
10  */
11
12 #include <linux/kernel.h>
13 #include <linux/init.h>
14 #include <linux/module.h>
15 #include <linux/netlink.h>
16 #include <linux/netfilter.h>
17 #include <linux/netfilter/nf_tables.h>
18 #include <net/netfilter/nf_tables.h>
19 #include <net/netfilter/nf_conntrack.h>
20 #include <net/netfilter/nf_conntrack_acct.h>
21 #include <net/netfilter/nf_conntrack_tuple.h>
22 #include <net/netfilter/nf_conntrack_helper.h>
23 #include <net/netfilter/nf_conntrack_ecache.h>
24 #include <net/netfilter/nf_conntrack_labels.h>
25
26 struct nft_ct {
27         enum nft_ct_keys        key:8;
28         enum ip_conntrack_dir   dir:8;
29         union {
30                 enum nft_registers      dreg:8;
31                 enum nft_registers      sreg:8;
32         };
33 };
34
35 struct nft_ct_helper_obj  {
36         struct nf_conntrack_helper *helper4;
37         struct nf_conntrack_helper *helper6;
38         u8 l4proto;
39 };
40
41 #ifdef CONFIG_NF_CONNTRACK_ZONES
42 static DEFINE_PER_CPU(struct nf_conn *, nft_ct_pcpu_template);
43 static unsigned int nft_ct_pcpu_template_refcnt __read_mostly;
44 #endif
45
46 static u64 nft_ct_get_eval_counter(const struct nf_conn_counter *c,
47                                    enum nft_ct_keys k,
48                                    enum ip_conntrack_dir d)
49 {
50         if (d < IP_CT_DIR_MAX)
51                 return k == NFT_CT_BYTES ? atomic64_read(&c[d].bytes) :
52                                            atomic64_read(&c[d].packets);
53
54         return nft_ct_get_eval_counter(c, k, IP_CT_DIR_ORIGINAL) +
55                nft_ct_get_eval_counter(c, k, IP_CT_DIR_REPLY);
56 }
57
58 static void nft_ct_get_eval(const struct nft_expr *expr,
59                             struct nft_regs *regs,
60                             const struct nft_pktinfo *pkt)
61 {
62         const struct nft_ct *priv = nft_expr_priv(expr);
63         u32 *dest = &regs->data[priv->dreg];
64         enum ip_conntrack_info ctinfo;
65         const struct nf_conn *ct;
66         const struct nf_conn_help *help;
67         const struct nf_conntrack_tuple *tuple;
68         const struct nf_conntrack_helper *helper;
69         unsigned int state;
70
71         ct = nf_ct_get(pkt->skb, &ctinfo);
72
73         switch (priv->key) {
74         case NFT_CT_STATE:
75                 if (ct)
76                         state = NF_CT_STATE_BIT(ctinfo);
77                 else if (ctinfo == IP_CT_UNTRACKED)
78                         state = NF_CT_STATE_UNTRACKED_BIT;
79                 else
80                         state = NF_CT_STATE_INVALID_BIT;
81                 *dest = state;
82                 return;
83         default:
84                 break;
85         }
86
87         if (ct == NULL)
88                 goto err;
89
90         switch (priv->key) {
91         case NFT_CT_DIRECTION:
92                 nft_reg_store8(dest, CTINFO2DIR(ctinfo));
93                 return;
94         case NFT_CT_STATUS:
95                 *dest = ct->status;
96                 return;
97 #ifdef CONFIG_NF_CONNTRACK_MARK
98         case NFT_CT_MARK:
99                 *dest = ct->mark;
100                 return;
101 #endif
102 #ifdef CONFIG_NF_CONNTRACK_SECMARK
103         case NFT_CT_SECMARK:
104                 *dest = ct->secmark;
105                 return;
106 #endif
107         case NFT_CT_EXPIRATION:
108                 *dest = jiffies_to_msecs(nf_ct_expires(ct));
109                 return;
110         case NFT_CT_HELPER:
111                 if (ct->master == NULL)
112                         goto err;
113                 help = nfct_help(ct->master);
114                 if (help == NULL)
115                         goto err;
116                 helper = rcu_dereference(help->helper);
117                 if (helper == NULL)
118                         goto err;
119                 strncpy((char *)dest, helper->name, NF_CT_HELPER_NAME_LEN);
120                 return;
121 #ifdef CONFIG_NF_CONNTRACK_LABELS
122         case NFT_CT_LABELS: {
123                 struct nf_conn_labels *labels = nf_ct_labels_find(ct);
124
125                 if (labels)
126                         memcpy(dest, labels->bits, NF_CT_LABELS_MAX_SIZE);
127                 else
128                         memset(dest, 0, NF_CT_LABELS_MAX_SIZE);
129                 return;
130         }
131 #endif
132         case NFT_CT_BYTES: /* fallthrough */
133         case NFT_CT_PKTS: {
134                 const struct nf_conn_acct *acct = nf_conn_acct_find(ct);
135                 u64 count = 0;
136
137                 if (acct)
138                         count = nft_ct_get_eval_counter(acct->counter,
139                                                         priv->key, priv->dir);
140                 memcpy(dest, &count, sizeof(count));
141                 return;
142         }
143         case NFT_CT_AVGPKT: {
144                 const struct nf_conn_acct *acct = nf_conn_acct_find(ct);
145                 u64 avgcnt = 0, bcnt = 0, pcnt = 0;
146
147                 if (acct) {
148                         pcnt = nft_ct_get_eval_counter(acct->counter,
149                                                        NFT_CT_PKTS, priv->dir);
150                         bcnt = nft_ct_get_eval_counter(acct->counter,
151                                                        NFT_CT_BYTES, priv->dir);
152                         if (pcnt != 0)
153                                 avgcnt = div64_u64(bcnt, pcnt);
154                 }
155
156                 memcpy(dest, &avgcnt, sizeof(avgcnt));
157                 return;
158         }
159         case NFT_CT_L3PROTOCOL:
160                 nft_reg_store8(dest, nf_ct_l3num(ct));
161                 return;
162         case NFT_CT_PROTOCOL:
163                 nft_reg_store8(dest, nf_ct_protonum(ct));
164                 return;
165 #ifdef CONFIG_NF_CONNTRACK_ZONES
166         case NFT_CT_ZONE: {
167                 const struct nf_conntrack_zone *zone = nf_ct_zone(ct);
168                 u16 zoneid;
169
170                 if (priv->dir < IP_CT_DIR_MAX)
171                         zoneid = nf_ct_zone_id(zone, priv->dir);
172                 else
173                         zoneid = zone->id;
174
175                 nft_reg_store16(dest, zoneid);
176                 return;
177         }
178 #endif
179         default:
180                 break;
181         }
182
183         tuple = &ct->tuplehash[priv->dir].tuple;
184         switch (priv->key) {
185         case NFT_CT_SRC:
186                 memcpy(dest, tuple->src.u3.all,
187                        nf_ct_l3num(ct) == NFPROTO_IPV4 ? 4 : 16);
188                 return;
189         case NFT_CT_DST:
190                 memcpy(dest, tuple->dst.u3.all,
191                        nf_ct_l3num(ct) == NFPROTO_IPV4 ? 4 : 16);
192                 return;
193         case NFT_CT_PROTO_SRC:
194                 nft_reg_store16(dest, (__force u16)tuple->src.u.all);
195                 return;
196         case NFT_CT_PROTO_DST:
197                 nft_reg_store16(dest, (__force u16)tuple->dst.u.all);
198                 return;
199         case NFT_CT_SRC_IP:
200                 if (nf_ct_l3num(ct) != NFPROTO_IPV4)
201                         goto err;
202                 *dest = tuple->src.u3.ip;
203                 return;
204         case NFT_CT_DST_IP:
205                 if (nf_ct_l3num(ct) != NFPROTO_IPV4)
206                         goto err;
207                 *dest = tuple->dst.u3.ip;
208                 return;
209         case NFT_CT_SRC_IP6:
210                 if (nf_ct_l3num(ct) != NFPROTO_IPV6)
211                         goto err;
212                 memcpy(dest, tuple->src.u3.ip6, sizeof(struct in6_addr));
213                 return;
214         case NFT_CT_DST_IP6:
215                 if (nf_ct_l3num(ct) != NFPROTO_IPV6)
216                         goto err;
217                 memcpy(dest, tuple->dst.u3.ip6, sizeof(struct in6_addr));
218                 return;
219         default:
220                 break;
221         }
222         return;
223 err:
224         regs->verdict.code = NFT_BREAK;
225 }
226
227 #ifdef CONFIG_NF_CONNTRACK_ZONES
228 static void nft_ct_set_zone_eval(const struct nft_expr *expr,
229                                  struct nft_regs *regs,
230                                  const struct nft_pktinfo *pkt)
231 {
232         struct nf_conntrack_zone zone = { .dir = NF_CT_DEFAULT_ZONE_DIR };
233         const struct nft_ct *priv = nft_expr_priv(expr);
234         struct sk_buff *skb = pkt->skb;
235         enum ip_conntrack_info ctinfo;
236         u16 value = nft_reg_load16(&regs->data[priv->sreg]);
237         struct nf_conn *ct;
238
239         ct = nf_ct_get(skb, &ctinfo);
240         if (ct) /* already tracked */
241                 return;
242
243         zone.id = value;
244
245         switch (priv->dir) {
246         case IP_CT_DIR_ORIGINAL:
247                 zone.dir = NF_CT_ZONE_DIR_ORIG;
248                 break;
249         case IP_CT_DIR_REPLY:
250                 zone.dir = NF_CT_ZONE_DIR_REPL;
251                 break;
252         default:
253                 break;
254         }
255
256         ct = this_cpu_read(nft_ct_pcpu_template);
257
258         if (likely(atomic_read(&ct->ct_general.use) == 1)) {
259                 nf_ct_zone_add(ct, &zone);
260         } else {
261                 /* previous skb got queued to userspace */
262                 ct = nf_ct_tmpl_alloc(nft_net(pkt), &zone, GFP_ATOMIC);
263                 if (!ct) {
264                         regs->verdict.code = NF_DROP;
265                         return;
266                 }
267         }
268
269         atomic_inc(&ct->ct_general.use);
270         nf_ct_set(skb, ct, IP_CT_NEW);
271 }
272 #endif
273
274 static void nft_ct_set_eval(const struct nft_expr *expr,
275                             struct nft_regs *regs,
276                             const struct nft_pktinfo *pkt)
277 {
278         const struct nft_ct *priv = nft_expr_priv(expr);
279         struct sk_buff *skb = pkt->skb;
280 #ifdef CONFIG_NF_CONNTRACK_MARK
281         u32 value = regs->data[priv->sreg];
282 #endif
283         enum ip_conntrack_info ctinfo;
284         struct nf_conn *ct;
285
286         ct = nf_ct_get(skb, &ctinfo);
287         if (ct == NULL || nf_ct_is_template(ct))
288                 return;
289
290         switch (priv->key) {
291 #ifdef CONFIG_NF_CONNTRACK_MARK
292         case NFT_CT_MARK:
293                 if (ct->mark != value) {
294                         ct->mark = value;
295                         nf_conntrack_event_cache(IPCT_MARK, ct);
296                 }
297                 break;
298 #endif
299 #ifdef CONFIG_NF_CONNTRACK_LABELS
300         case NFT_CT_LABELS:
301                 nf_connlabels_replace(ct,
302                                       &regs->data[priv->sreg],
303                                       &regs->data[priv->sreg],
304                                       NF_CT_LABELS_MAX_SIZE / sizeof(u32));
305                 break;
306 #endif
307 #ifdef CONFIG_NF_CONNTRACK_EVENTS
308         case NFT_CT_EVENTMASK: {
309                 struct nf_conntrack_ecache *e = nf_ct_ecache_find(ct);
310                 u32 ctmask = regs->data[priv->sreg];
311
312                 if (e) {
313                         if (e->ctmask != ctmask)
314                                 e->ctmask = ctmask;
315                         break;
316                 }
317
318                 if (ctmask && !nf_ct_is_confirmed(ct))
319                         nf_ct_ecache_ext_add(ct, ctmask, 0, GFP_ATOMIC);
320                 break;
321         }
322 #endif
323         default:
324                 break;
325         }
326 }
327
328 static const struct nla_policy nft_ct_policy[NFTA_CT_MAX + 1] = {
329         [NFTA_CT_DREG]          = { .type = NLA_U32 },
330         [NFTA_CT_KEY]           = { .type = NLA_U32 },
331         [NFTA_CT_DIRECTION]     = { .type = NLA_U8 },
332         [NFTA_CT_SREG]          = { .type = NLA_U32 },
333 };
334
335 #ifdef CONFIG_NF_CONNTRACK_ZONES
336 static void nft_ct_tmpl_put_pcpu(void)
337 {
338         struct nf_conn *ct;
339         int cpu;
340
341         for_each_possible_cpu(cpu) {
342                 ct = per_cpu(nft_ct_pcpu_template, cpu);
343                 if (!ct)
344                         break;
345                 nf_ct_put(ct);
346                 per_cpu(nft_ct_pcpu_template, cpu) = NULL;
347         }
348 }
349
350 static bool nft_ct_tmpl_alloc_pcpu(void)
351 {
352         struct nf_conntrack_zone zone = { .id = 0 };
353         struct nf_conn *tmp;
354         int cpu;
355
356         if (nft_ct_pcpu_template_refcnt)
357                 return true;
358
359         for_each_possible_cpu(cpu) {
360                 tmp = nf_ct_tmpl_alloc(&init_net, &zone, GFP_KERNEL);
361                 if (!tmp) {
362                         nft_ct_tmpl_put_pcpu();
363                         return false;
364                 }
365
366                 atomic_set(&tmp->ct_general.use, 1);
367                 per_cpu(nft_ct_pcpu_template, cpu) = tmp;
368         }
369
370         return true;
371 }
372 #endif
373
374 static int nft_ct_get_init(const struct nft_ctx *ctx,
375                            const struct nft_expr *expr,
376                            const struct nlattr * const tb[])
377 {
378         struct nft_ct *priv = nft_expr_priv(expr);
379         unsigned int len;
380         int err;
381
382         priv->key = ntohl(nla_get_be32(tb[NFTA_CT_KEY]));
383         priv->dir = IP_CT_DIR_MAX;
384         switch (priv->key) {
385         case NFT_CT_DIRECTION:
386                 if (tb[NFTA_CT_DIRECTION] != NULL)
387                         return -EINVAL;
388                 len = sizeof(u8);
389                 break;
390         case NFT_CT_STATE:
391         case NFT_CT_STATUS:
392 #ifdef CONFIG_NF_CONNTRACK_MARK
393         case NFT_CT_MARK:
394 #endif
395 #ifdef CONFIG_NF_CONNTRACK_SECMARK
396         case NFT_CT_SECMARK:
397 #endif
398         case NFT_CT_EXPIRATION:
399                 if (tb[NFTA_CT_DIRECTION] != NULL)
400                         return -EINVAL;
401                 len = sizeof(u32);
402                 break;
403 #ifdef CONFIG_NF_CONNTRACK_LABELS
404         case NFT_CT_LABELS:
405                 if (tb[NFTA_CT_DIRECTION] != NULL)
406                         return -EINVAL;
407                 len = NF_CT_LABELS_MAX_SIZE;
408                 break;
409 #endif
410         case NFT_CT_HELPER:
411                 if (tb[NFTA_CT_DIRECTION] != NULL)
412                         return -EINVAL;
413                 len = NF_CT_HELPER_NAME_LEN;
414                 break;
415
416         case NFT_CT_L3PROTOCOL:
417         case NFT_CT_PROTOCOL:
418                 /* For compatibility, do not report error if NFTA_CT_DIRECTION
419                  * attribute is specified.
420                  */
421                 len = sizeof(u8);
422                 break;
423         case NFT_CT_SRC:
424         case NFT_CT_DST:
425                 if (tb[NFTA_CT_DIRECTION] == NULL)
426                         return -EINVAL;
427
428                 switch (ctx->family) {
429                 case NFPROTO_IPV4:
430                         len = FIELD_SIZEOF(struct nf_conntrack_tuple,
431                                            src.u3.ip);
432                         break;
433                 case NFPROTO_IPV6:
434                 case NFPROTO_INET:
435                         len = FIELD_SIZEOF(struct nf_conntrack_tuple,
436                                            src.u3.ip6);
437                         break;
438                 default:
439                         return -EAFNOSUPPORT;
440                 }
441                 break;
442         case NFT_CT_SRC_IP:
443         case NFT_CT_DST_IP:
444                 if (tb[NFTA_CT_DIRECTION] == NULL)
445                         return -EINVAL;
446
447                 len = FIELD_SIZEOF(struct nf_conntrack_tuple, src.u3.ip);
448                 break;
449         case NFT_CT_SRC_IP6:
450         case NFT_CT_DST_IP6:
451                 if (tb[NFTA_CT_DIRECTION] == NULL)
452                         return -EINVAL;
453
454                 len = FIELD_SIZEOF(struct nf_conntrack_tuple, src.u3.ip6);
455                 break;
456         case NFT_CT_PROTO_SRC:
457         case NFT_CT_PROTO_DST:
458                 if (tb[NFTA_CT_DIRECTION] == NULL)
459                         return -EINVAL;
460                 len = FIELD_SIZEOF(struct nf_conntrack_tuple, src.u.all);
461                 break;
462         case NFT_CT_BYTES:
463         case NFT_CT_PKTS:
464         case NFT_CT_AVGPKT:
465                 len = sizeof(u64);
466                 break;
467 #ifdef CONFIG_NF_CONNTRACK_ZONES
468         case NFT_CT_ZONE:
469                 len = sizeof(u16);
470                 break;
471 #endif
472         default:
473                 return -EOPNOTSUPP;
474         }
475
476         if (tb[NFTA_CT_DIRECTION] != NULL) {
477                 priv->dir = nla_get_u8(tb[NFTA_CT_DIRECTION]);
478                 switch (priv->dir) {
479                 case IP_CT_DIR_ORIGINAL:
480                 case IP_CT_DIR_REPLY:
481                         break;
482                 default:
483                         return -EINVAL;
484                 }
485         }
486
487         priv->dreg = nft_parse_register(tb[NFTA_CT_DREG]);
488         err = nft_validate_register_store(ctx, priv->dreg, NULL,
489                                           NFT_DATA_VALUE, len);
490         if (err < 0)
491                 return err;
492
493         err = nf_ct_netns_get(ctx->net, ctx->family);
494         if (err < 0)
495                 return err;
496
497         if (priv->key == NFT_CT_BYTES ||
498             priv->key == NFT_CT_PKTS  ||
499             priv->key == NFT_CT_AVGPKT)
500                 nf_ct_set_acct(ctx->net, true);
501
502         return 0;
503 }
504
505 static void __nft_ct_set_destroy(const struct nft_ctx *ctx, struct nft_ct *priv)
506 {
507         switch (priv->key) {
508 #ifdef CONFIG_NF_CONNTRACK_LABELS
509         case NFT_CT_LABELS:
510                 nf_connlabels_put(ctx->net);
511                 break;
512 #endif
513 #ifdef CONFIG_NF_CONNTRACK_ZONES
514         case NFT_CT_ZONE:
515                 if (--nft_ct_pcpu_template_refcnt == 0)
516                         nft_ct_tmpl_put_pcpu();
517 #endif
518         default:
519                 break;
520         }
521 }
522
523 static int nft_ct_set_init(const struct nft_ctx *ctx,
524                            const struct nft_expr *expr,
525                            const struct nlattr * const tb[])
526 {
527         struct nft_ct *priv = nft_expr_priv(expr);
528         unsigned int len;
529         int err;
530
531         priv->dir = IP_CT_DIR_MAX;
532         priv->key = ntohl(nla_get_be32(tb[NFTA_CT_KEY]));
533         switch (priv->key) {
534 #ifdef CONFIG_NF_CONNTRACK_MARK
535         case NFT_CT_MARK:
536                 if (tb[NFTA_CT_DIRECTION])
537                         return -EINVAL;
538                 len = FIELD_SIZEOF(struct nf_conn, mark);
539                 break;
540 #endif
541 #ifdef CONFIG_NF_CONNTRACK_LABELS
542         case NFT_CT_LABELS:
543                 if (tb[NFTA_CT_DIRECTION])
544                         return -EINVAL;
545                 len = NF_CT_LABELS_MAX_SIZE;
546                 err = nf_connlabels_get(ctx->net, (len * BITS_PER_BYTE) - 1);
547                 if (err)
548                         return err;
549                 break;
550 #endif
551 #ifdef CONFIG_NF_CONNTRACK_ZONES
552         case NFT_CT_ZONE:
553                 if (!nft_ct_tmpl_alloc_pcpu())
554                         return -ENOMEM;
555                 nft_ct_pcpu_template_refcnt++;
556                 len = sizeof(u16);
557                 break;
558 #endif
559 #ifdef CONFIG_NF_CONNTRACK_EVENTS
560         case NFT_CT_EVENTMASK:
561                 if (tb[NFTA_CT_DIRECTION])
562                         return -EINVAL;
563                 len = sizeof(u32);
564                 break;
565 #endif
566         default:
567                 return -EOPNOTSUPP;
568         }
569
570         if (tb[NFTA_CT_DIRECTION]) {
571                 priv->dir = nla_get_u8(tb[NFTA_CT_DIRECTION]);
572                 switch (priv->dir) {
573                 case IP_CT_DIR_ORIGINAL:
574                 case IP_CT_DIR_REPLY:
575                         break;
576                 default:
577                         err = -EINVAL;
578                         goto err1;
579                 }
580         }
581
582         priv->sreg = nft_parse_register(tb[NFTA_CT_SREG]);
583         err = nft_validate_register_load(priv->sreg, len);
584         if (err < 0)
585                 goto err1;
586
587         err = nf_ct_netns_get(ctx->net, ctx->family);
588         if (err < 0)
589                 goto err1;
590
591         return 0;
592
593 err1:
594         __nft_ct_set_destroy(ctx, priv);
595         return err;
596 }
597
598 static void nft_ct_get_destroy(const struct nft_ctx *ctx,
599                                const struct nft_expr *expr)
600 {
601         nf_ct_netns_put(ctx->net, ctx->family);
602 }
603
604 static void nft_ct_set_destroy(const struct nft_ctx *ctx,
605                                const struct nft_expr *expr)
606 {
607         struct nft_ct *priv = nft_expr_priv(expr);
608
609         __nft_ct_set_destroy(ctx, priv);
610         nf_ct_netns_put(ctx->net, ctx->family);
611 }
612
613 static int nft_ct_get_dump(struct sk_buff *skb, const struct nft_expr *expr)
614 {
615         const struct nft_ct *priv = nft_expr_priv(expr);
616
617         if (nft_dump_register(skb, NFTA_CT_DREG, priv->dreg))
618                 goto nla_put_failure;
619         if (nla_put_be32(skb, NFTA_CT_KEY, htonl(priv->key)))
620                 goto nla_put_failure;
621
622         switch (priv->key) {
623         case NFT_CT_SRC:
624         case NFT_CT_DST:
625         case NFT_CT_SRC_IP:
626         case NFT_CT_DST_IP:
627         case NFT_CT_SRC_IP6:
628         case NFT_CT_DST_IP6:
629         case NFT_CT_PROTO_SRC:
630         case NFT_CT_PROTO_DST:
631                 if (nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
632                         goto nla_put_failure;
633                 break;
634         case NFT_CT_BYTES:
635         case NFT_CT_PKTS:
636         case NFT_CT_AVGPKT:
637         case NFT_CT_ZONE:
638                 if (priv->dir < IP_CT_DIR_MAX &&
639                     nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
640                         goto nla_put_failure;
641                 break;
642         default:
643                 break;
644         }
645
646         return 0;
647
648 nla_put_failure:
649         return -1;
650 }
651
652 static int nft_ct_set_dump(struct sk_buff *skb, const struct nft_expr *expr)
653 {
654         const struct nft_ct *priv = nft_expr_priv(expr);
655
656         if (nft_dump_register(skb, NFTA_CT_SREG, priv->sreg))
657                 goto nla_put_failure;
658         if (nla_put_be32(skb, NFTA_CT_KEY, htonl(priv->key)))
659                 goto nla_put_failure;
660
661         switch (priv->key) {
662         case NFT_CT_ZONE:
663                 if (priv->dir < IP_CT_DIR_MAX &&
664                     nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
665                         goto nla_put_failure;
666                 break;
667         default:
668                 break;
669         }
670
671         return 0;
672
673 nla_put_failure:
674         return -1;
675 }
676
677 static struct nft_expr_type nft_ct_type;
678 static const struct nft_expr_ops nft_ct_get_ops = {
679         .type           = &nft_ct_type,
680         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
681         .eval           = nft_ct_get_eval,
682         .init           = nft_ct_get_init,
683         .destroy        = nft_ct_get_destroy,
684         .dump           = nft_ct_get_dump,
685 };
686
687 static const struct nft_expr_ops nft_ct_set_ops = {
688         .type           = &nft_ct_type,
689         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
690         .eval           = nft_ct_set_eval,
691         .init           = nft_ct_set_init,
692         .destroy        = nft_ct_set_destroy,
693         .dump           = nft_ct_set_dump,
694 };
695
696 #ifdef CONFIG_NF_CONNTRACK_ZONES
697 static const struct nft_expr_ops nft_ct_set_zone_ops = {
698         .type           = &nft_ct_type,
699         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
700         .eval           = nft_ct_set_zone_eval,
701         .init           = nft_ct_set_init,
702         .destroy        = nft_ct_set_destroy,
703         .dump           = nft_ct_set_dump,
704 };
705 #endif
706
707 static const struct nft_expr_ops *
708 nft_ct_select_ops(const struct nft_ctx *ctx,
709                     const struct nlattr * const tb[])
710 {
711         if (tb[NFTA_CT_KEY] == NULL)
712                 return ERR_PTR(-EINVAL);
713
714         if (tb[NFTA_CT_DREG] && tb[NFTA_CT_SREG])
715                 return ERR_PTR(-EINVAL);
716
717         if (tb[NFTA_CT_DREG])
718                 return &nft_ct_get_ops;
719
720         if (tb[NFTA_CT_SREG]) {
721 #ifdef CONFIG_NF_CONNTRACK_ZONES
722                 if (nla_get_be32(tb[NFTA_CT_KEY]) == htonl(NFT_CT_ZONE))
723                         return &nft_ct_set_zone_ops;
724 #endif
725                 return &nft_ct_set_ops;
726         }
727
728         return ERR_PTR(-EINVAL);
729 }
730
731 static struct nft_expr_type nft_ct_type __read_mostly = {
732         .name           = "ct",
733         .select_ops     = nft_ct_select_ops,
734         .policy         = nft_ct_policy,
735         .maxattr        = NFTA_CT_MAX,
736         .owner          = THIS_MODULE,
737 };
738
739 static void nft_notrack_eval(const struct nft_expr *expr,
740                              struct nft_regs *regs,
741                              const struct nft_pktinfo *pkt)
742 {
743         struct sk_buff *skb = pkt->skb;
744         enum ip_conntrack_info ctinfo;
745         struct nf_conn *ct;
746
747         ct = nf_ct_get(pkt->skb, &ctinfo);
748         /* Previously seen (loopback or untracked)?  Ignore. */
749         if (ct || ctinfo == IP_CT_UNTRACKED)
750                 return;
751
752         nf_ct_set(skb, ct, IP_CT_UNTRACKED);
753 }
754
755 static struct nft_expr_type nft_notrack_type;
756 static const struct nft_expr_ops nft_notrack_ops = {
757         .type           = &nft_notrack_type,
758         .size           = NFT_EXPR_SIZE(0),
759         .eval           = nft_notrack_eval,
760 };
761
762 static struct nft_expr_type nft_notrack_type __read_mostly = {
763         .name           = "notrack",
764         .ops            = &nft_notrack_ops,
765         .owner          = THIS_MODULE,
766 };
767
768 static int nft_ct_helper_obj_init(const struct nft_ctx *ctx,
769                                   const struct nlattr * const tb[],
770                                   struct nft_object *obj)
771 {
772         struct nft_ct_helper_obj *priv = nft_obj_data(obj);
773         struct nf_conntrack_helper *help4, *help6;
774         char name[NF_CT_HELPER_NAME_LEN];
775         int family = ctx->family;
776
777         if (!tb[NFTA_CT_HELPER_NAME] || !tb[NFTA_CT_HELPER_L4PROTO])
778                 return -EINVAL;
779
780         priv->l4proto = nla_get_u8(tb[NFTA_CT_HELPER_L4PROTO]);
781         if (!priv->l4proto)
782                 return -ENOENT;
783
784         nla_strlcpy(name, tb[NFTA_CT_HELPER_NAME], sizeof(name));
785
786         if (tb[NFTA_CT_HELPER_L3PROTO])
787                 family = ntohs(nla_get_be16(tb[NFTA_CT_HELPER_L3PROTO]));
788
789         help4 = NULL;
790         help6 = NULL;
791
792         switch (family) {
793         case NFPROTO_IPV4:
794                 if (ctx->family == NFPROTO_IPV6)
795                         return -EINVAL;
796
797                 help4 = nf_conntrack_helper_try_module_get(name, family,
798                                                            priv->l4proto);
799                 break;
800         case NFPROTO_IPV6:
801                 if (ctx->family == NFPROTO_IPV4)
802                         return -EINVAL;
803
804                 help6 = nf_conntrack_helper_try_module_get(name, family,
805                                                            priv->l4proto);
806                 break;
807         case NFPROTO_NETDEV: /* fallthrough */
808         case NFPROTO_BRIDGE: /* same */
809         case NFPROTO_INET:
810                 help4 = nf_conntrack_helper_try_module_get(name, NFPROTO_IPV4,
811                                                            priv->l4proto);
812                 help6 = nf_conntrack_helper_try_module_get(name, NFPROTO_IPV6,
813                                                            priv->l4proto);
814                 break;
815         default:
816                 return -EAFNOSUPPORT;
817         }
818
819         /* && is intentional; only error if INET found neither ipv4 or ipv6 */
820         if (!help4 && !help6)
821                 return -ENOENT;
822
823         priv->helper4 = help4;
824         priv->helper6 = help6;
825
826         return 0;
827 }
828
829 static void nft_ct_helper_obj_destroy(const struct nft_ctx *ctx,
830                                       struct nft_object *obj)
831 {
832         struct nft_ct_helper_obj *priv = nft_obj_data(obj);
833
834         if (priv->helper4)
835                 nf_conntrack_helper_put(priv->helper4);
836         if (priv->helper6)
837                 nf_conntrack_helper_put(priv->helper6);
838 }
839
840 static void nft_ct_helper_obj_eval(struct nft_object *obj,
841                                    struct nft_regs *regs,
842                                    const struct nft_pktinfo *pkt)
843 {
844         const struct nft_ct_helper_obj *priv = nft_obj_data(obj);
845         struct nf_conn *ct = (struct nf_conn *)skb_nfct(pkt->skb);
846         struct nf_conntrack_helper *to_assign = NULL;
847         struct nf_conn_help *help;
848
849         if (!ct ||
850             nf_ct_is_confirmed(ct) ||
851             nf_ct_is_template(ct) ||
852             priv->l4proto != nf_ct_protonum(ct))
853                 return;
854
855         switch (nf_ct_l3num(ct)) {
856         case NFPROTO_IPV4:
857                 to_assign = priv->helper4;
858                 break;
859         case NFPROTO_IPV6:
860                 to_assign = priv->helper6;
861                 break;
862         default:
863                 WARN_ON_ONCE(1);
864                 return;
865         }
866
867         if (!to_assign)
868                 return;
869
870         if (test_bit(IPS_HELPER_BIT, &ct->status))
871                 return;
872
873         help = nf_ct_helper_ext_add(ct, to_assign, GFP_ATOMIC);
874         if (help) {
875                 rcu_assign_pointer(help->helper, to_assign);
876                 set_bit(IPS_HELPER_BIT, &ct->status);
877         }
878 }
879
880 static int nft_ct_helper_obj_dump(struct sk_buff *skb,
881                                   struct nft_object *obj, bool reset)
882 {
883         const struct nft_ct_helper_obj *priv = nft_obj_data(obj);
884         const struct nf_conntrack_helper *helper;
885         u16 family;
886
887         if (priv->helper4 && priv->helper6) {
888                 family = NFPROTO_INET;
889                 helper = priv->helper4;
890         } else if (priv->helper6) {
891                 family = NFPROTO_IPV6;
892                 helper = priv->helper6;
893         } else {
894                 family = NFPROTO_IPV4;
895                 helper = priv->helper4;
896         }
897
898         if (nla_put_string(skb, NFTA_CT_HELPER_NAME, helper->name))
899                 return -1;
900
901         if (nla_put_u8(skb, NFTA_CT_HELPER_L4PROTO, priv->l4proto))
902                 return -1;
903
904         if (nla_put_be16(skb, NFTA_CT_HELPER_L3PROTO, htons(family)))
905                 return -1;
906
907         return 0;
908 }
909
910 static const struct nla_policy nft_ct_helper_policy[NFTA_CT_HELPER_MAX + 1] = {
911         [NFTA_CT_HELPER_NAME] = { .type = NLA_STRING,
912                                   .len = NF_CT_HELPER_NAME_LEN - 1 },
913         [NFTA_CT_HELPER_L3PROTO] = { .type = NLA_U16 },
914         [NFTA_CT_HELPER_L4PROTO] = { .type = NLA_U8 },
915 };
916
917 static struct nft_object_type nft_ct_helper_obj_type;
918 static const struct nft_object_ops nft_ct_helper_obj_ops = {
919         .type           = &nft_ct_helper_obj_type,
920         .size           = sizeof(struct nft_ct_helper_obj),
921         .eval           = nft_ct_helper_obj_eval,
922         .init           = nft_ct_helper_obj_init,
923         .destroy        = nft_ct_helper_obj_destroy,
924         .dump           = nft_ct_helper_obj_dump,
925 };
926
927 static struct nft_object_type nft_ct_helper_obj_type __read_mostly = {
928         .type           = NFT_OBJECT_CT_HELPER,
929         .ops            = &nft_ct_helper_obj_ops,
930         .maxattr        = NFTA_CT_HELPER_MAX,
931         .policy         = nft_ct_helper_policy,
932         .owner          = THIS_MODULE,
933 };
934
935 static int __init nft_ct_module_init(void)
936 {
937         int err;
938
939         BUILD_BUG_ON(NF_CT_LABELS_MAX_SIZE > NFT_REG_SIZE);
940
941         err = nft_register_expr(&nft_ct_type);
942         if (err < 0)
943                 return err;
944
945         err = nft_register_expr(&nft_notrack_type);
946         if (err < 0)
947                 goto err1;
948
949         err = nft_register_obj(&nft_ct_helper_obj_type);
950         if (err < 0)
951                 goto err2;
952
953         return 0;
954
955 err2:
956         nft_unregister_expr(&nft_notrack_type);
957 err1:
958         nft_unregister_expr(&nft_ct_type);
959         return err;
960 }
961
962 static void __exit nft_ct_module_exit(void)
963 {
964         nft_unregister_obj(&nft_ct_helper_obj_type);
965         nft_unregister_expr(&nft_notrack_type);
966         nft_unregister_expr(&nft_ct_type);
967 }
968
969 module_init(nft_ct_module_init);
970 module_exit(nft_ct_module_exit);
971
972 MODULE_LICENSE("GPL");
973 MODULE_AUTHOR("Patrick McHardy <kaber@trash.net>");
974 MODULE_ALIAS_NFT_EXPR("ct");
975 MODULE_ALIAS_NFT_EXPR("notrack");
976 MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_CT_HELPER);