Merge tag 'powerpc-5.13-1' of git://git.kernel.org/pub/scm/linux/kernel/git/powerpc...
[linux-2.6-microblaze.git] / net / netfilter / xt_conntrack.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  *      xt_conntrack - Netfilter module to match connection tracking
4  *      information. (Superset of Rusty's minimalistic state match.)
5  *
6  *      (C) 2001  Marc Boucher (marc@mbsi.ca).
7  *      (C) 2006-2012 Patrick McHardy <kaber@trash.net>
8  *      Copyright © CC Computer Consultants GmbH, 2007 - 2008
9  */
10 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
11 #include <linux/module.h>
12 #include <linux/skbuff.h>
13 #include <net/ipv6.h>
14 #include <linux/netfilter/x_tables.h>
15 #include <linux/netfilter/xt_conntrack.h>
16 #include <net/netfilter/nf_conntrack.h>
17
18 MODULE_LICENSE("GPL");
19 MODULE_AUTHOR("Marc Boucher <marc@mbsi.ca>");
20 MODULE_AUTHOR("Jan Engelhardt <jengelh@medozas.de>");
21 MODULE_DESCRIPTION("Xtables: connection tracking state match");
22 MODULE_ALIAS("ipt_conntrack");
23 MODULE_ALIAS("ip6t_conntrack");
24
25 static bool
26 conntrack_addrcmp(const union nf_inet_addr *kaddr,
27                   const union nf_inet_addr *uaddr,
28                   const union nf_inet_addr *umask, unsigned int l3proto)
29 {
30         if (l3proto == NFPROTO_IPV4)
31                 return ((kaddr->ip ^ uaddr->ip) & umask->ip) == 0;
32         else if (l3proto == NFPROTO_IPV6)
33                 return ipv6_masked_addr_cmp(&kaddr->in6, &umask->in6,
34                        &uaddr->in6) == 0;
35         else
36                 return false;
37 }
38
39 static inline bool
40 conntrack_mt_origsrc(const struct nf_conn *ct,
41                      const struct xt_conntrack_mtinfo2 *info,
42                      u_int8_t family)
43 {
44         return conntrack_addrcmp(&ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.u3,
45                &info->origsrc_addr, &info->origsrc_mask, family);
46 }
47
48 static inline bool
49 conntrack_mt_origdst(const struct nf_conn *ct,
50                      const struct xt_conntrack_mtinfo2 *info,
51                      u_int8_t family)
52 {
53         return conntrack_addrcmp(&ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.u3,
54                &info->origdst_addr, &info->origdst_mask, family);
55 }
56
57 static inline bool
58 conntrack_mt_replsrc(const struct nf_conn *ct,
59                      const struct xt_conntrack_mtinfo2 *info,
60                      u_int8_t family)
61 {
62         return conntrack_addrcmp(&ct->tuplehash[IP_CT_DIR_REPLY].tuple.src.u3,
63                &info->replsrc_addr, &info->replsrc_mask, family);
64 }
65
66 static inline bool
67 conntrack_mt_repldst(const struct nf_conn *ct,
68                      const struct xt_conntrack_mtinfo2 *info,
69                      u_int8_t family)
70 {
71         return conntrack_addrcmp(&ct->tuplehash[IP_CT_DIR_REPLY].tuple.dst.u3,
72                &info->repldst_addr, &info->repldst_mask, family);
73 }
74
75 static inline bool
76 ct_proto_port_check(const struct xt_conntrack_mtinfo2 *info,
77                     const struct nf_conn *ct)
78 {
79         const struct nf_conntrack_tuple *tuple;
80
81         tuple = &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple;
82         if ((info->match_flags & XT_CONNTRACK_PROTO) &&
83             (nf_ct_protonum(ct) == info->l4proto) ^
84             !(info->invert_flags & XT_CONNTRACK_PROTO))
85                 return false;
86
87         /* Shortcut to match all recognized protocols by using ->src.all. */
88         if ((info->match_flags & XT_CONNTRACK_ORIGSRC_PORT) &&
89             (tuple->src.u.all == info->origsrc_port) ^
90             !(info->invert_flags & XT_CONNTRACK_ORIGSRC_PORT))
91                 return false;
92
93         if ((info->match_flags & XT_CONNTRACK_ORIGDST_PORT) &&
94             (tuple->dst.u.all == info->origdst_port) ^
95             !(info->invert_flags & XT_CONNTRACK_ORIGDST_PORT))
96                 return false;
97
98         tuple = &ct->tuplehash[IP_CT_DIR_REPLY].tuple;
99
100         if ((info->match_flags & XT_CONNTRACK_REPLSRC_PORT) &&
101             (tuple->src.u.all == info->replsrc_port) ^
102             !(info->invert_flags & XT_CONNTRACK_REPLSRC_PORT))
103                 return false;
104
105         if ((info->match_flags & XT_CONNTRACK_REPLDST_PORT) &&
106             (tuple->dst.u.all == info->repldst_port) ^
107             !(info->invert_flags & XT_CONNTRACK_REPLDST_PORT))
108                 return false;
109
110         return true;
111 }
112
113 static inline bool
114 port_match(u16 min, u16 max, u16 port, bool invert)
115 {
116         return (port >= min && port <= max) ^ invert;
117 }
118
119 static inline bool
120 ct_proto_port_check_v3(const struct xt_conntrack_mtinfo3 *info,
121                        const struct nf_conn *ct)
122 {
123         const struct nf_conntrack_tuple *tuple;
124
125         tuple = &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple;
126         if ((info->match_flags & XT_CONNTRACK_PROTO) &&
127             (nf_ct_protonum(ct) == info->l4proto) ^
128             !(info->invert_flags & XT_CONNTRACK_PROTO))
129                 return false;
130
131         /* Shortcut to match all recognized protocols by using ->src.all. */
132         if ((info->match_flags & XT_CONNTRACK_ORIGSRC_PORT) &&
133             !port_match(info->origsrc_port, info->origsrc_port_high,
134                         ntohs(tuple->src.u.all),
135                         info->invert_flags & XT_CONNTRACK_ORIGSRC_PORT))
136                 return false;
137
138         if ((info->match_flags & XT_CONNTRACK_ORIGDST_PORT) &&
139             !port_match(info->origdst_port, info->origdst_port_high,
140                         ntohs(tuple->dst.u.all),
141                         info->invert_flags & XT_CONNTRACK_ORIGDST_PORT))
142                 return false;
143
144         tuple = &ct->tuplehash[IP_CT_DIR_REPLY].tuple;
145
146         if ((info->match_flags & XT_CONNTRACK_REPLSRC_PORT) &&
147             !port_match(info->replsrc_port, info->replsrc_port_high,
148                         ntohs(tuple->src.u.all),
149                         info->invert_flags & XT_CONNTRACK_REPLSRC_PORT))
150                 return false;
151
152         if ((info->match_flags & XT_CONNTRACK_REPLDST_PORT) &&
153             !port_match(info->repldst_port, info->repldst_port_high,
154                         ntohs(tuple->dst.u.all),
155                         info->invert_flags & XT_CONNTRACK_REPLDST_PORT))
156                 return false;
157
158         return true;
159 }
160
161 static bool
162 conntrack_mt(const struct sk_buff *skb, struct xt_action_param *par,
163              u16 state_mask, u16 status_mask)
164 {
165         const struct xt_conntrack_mtinfo2 *info = par->matchinfo;
166         enum ip_conntrack_info ctinfo;
167         const struct nf_conn *ct;
168         unsigned int statebit;
169
170         ct = nf_ct_get(skb, &ctinfo);
171
172         if (ct)
173                 statebit = XT_CONNTRACK_STATE_BIT(ctinfo);
174         else if (ctinfo == IP_CT_UNTRACKED)
175                 statebit = XT_CONNTRACK_STATE_UNTRACKED;
176         else
177                 statebit = XT_CONNTRACK_STATE_INVALID;
178
179         if (info->match_flags & XT_CONNTRACK_STATE) {
180                 if (ct != NULL) {
181                         if (test_bit(IPS_SRC_NAT_BIT, &ct->status))
182                                 statebit |= XT_CONNTRACK_STATE_SNAT;
183                         if (test_bit(IPS_DST_NAT_BIT, &ct->status))
184                                 statebit |= XT_CONNTRACK_STATE_DNAT;
185                 }
186                 if (!!(state_mask & statebit) ^
187                     !(info->invert_flags & XT_CONNTRACK_STATE))
188                         return false;
189         }
190
191         if (ct == NULL)
192                 return info->match_flags & XT_CONNTRACK_STATE;
193         if ((info->match_flags & XT_CONNTRACK_DIRECTION) &&
194             (CTINFO2DIR(ctinfo) == IP_CT_DIR_ORIGINAL) ^
195             !(info->invert_flags & XT_CONNTRACK_DIRECTION))
196                 return false;
197
198         if (info->match_flags & XT_CONNTRACK_ORIGSRC)
199                 if (conntrack_mt_origsrc(ct, info, xt_family(par)) ^
200                     !(info->invert_flags & XT_CONNTRACK_ORIGSRC))
201                         return false;
202
203         if (info->match_flags & XT_CONNTRACK_ORIGDST)
204                 if (conntrack_mt_origdst(ct, info, xt_family(par)) ^
205                     !(info->invert_flags & XT_CONNTRACK_ORIGDST))
206                         return false;
207
208         if (info->match_flags & XT_CONNTRACK_REPLSRC)
209                 if (conntrack_mt_replsrc(ct, info, xt_family(par)) ^
210                     !(info->invert_flags & XT_CONNTRACK_REPLSRC))
211                         return false;
212
213         if (info->match_flags & XT_CONNTRACK_REPLDST)
214                 if (conntrack_mt_repldst(ct, info, xt_family(par)) ^
215                     !(info->invert_flags & XT_CONNTRACK_REPLDST))
216                         return false;
217
218         if (par->match->revision != 3) {
219                 if (!ct_proto_port_check(info, ct))
220                         return false;
221         } else {
222                 if (!ct_proto_port_check_v3(par->matchinfo, ct))
223                         return false;
224         }
225
226         if ((info->match_flags & XT_CONNTRACK_STATUS) &&
227             (!!(status_mask & ct->status) ^
228             !(info->invert_flags & XT_CONNTRACK_STATUS)))
229                 return false;
230
231         if (info->match_flags & XT_CONNTRACK_EXPIRES) {
232                 unsigned long expires = nf_ct_expires(ct) / HZ;
233
234                 if ((expires >= info->expires_min &&
235                     expires <= info->expires_max) ^
236                     !(info->invert_flags & XT_CONNTRACK_EXPIRES))
237                         return false;
238         }
239         return true;
240 }
241
242 static bool
243 conntrack_mt_v1(const struct sk_buff *skb, struct xt_action_param *par)
244 {
245         const struct xt_conntrack_mtinfo1 *info = par->matchinfo;
246
247         return conntrack_mt(skb, par, info->state_mask, info->status_mask);
248 }
249
250 static bool
251 conntrack_mt_v2(const struct sk_buff *skb, struct xt_action_param *par)
252 {
253         const struct xt_conntrack_mtinfo2 *info = par->matchinfo;
254
255         return conntrack_mt(skb, par, info->state_mask, info->status_mask);
256 }
257
258 static bool
259 conntrack_mt_v3(const struct sk_buff *skb, struct xt_action_param *par)
260 {
261         const struct xt_conntrack_mtinfo3 *info = par->matchinfo;
262
263         return conntrack_mt(skb, par, info->state_mask, info->status_mask);
264 }
265
266 static int conntrack_mt_check(const struct xt_mtchk_param *par)
267 {
268         int ret;
269
270         ret = nf_ct_netns_get(par->net, par->family);
271         if (ret < 0)
272                 pr_info_ratelimited("cannot load conntrack support for proto=%u\n",
273                                     par->family);
274         return ret;
275 }
276
277 static void conntrack_mt_destroy(const struct xt_mtdtor_param *par)
278 {
279         nf_ct_netns_put(par->net, par->family);
280 }
281
282 static struct xt_match conntrack_mt_reg[] __read_mostly = {
283         {
284                 .name       = "conntrack",
285                 .revision   = 1,
286                 .family     = NFPROTO_UNSPEC,
287                 .matchsize  = sizeof(struct xt_conntrack_mtinfo1),
288                 .match      = conntrack_mt_v1,
289                 .checkentry = conntrack_mt_check,
290                 .destroy    = conntrack_mt_destroy,
291                 .me         = THIS_MODULE,
292         },
293         {
294                 .name       = "conntrack",
295                 .revision   = 2,
296                 .family     = NFPROTO_UNSPEC,
297                 .matchsize  = sizeof(struct xt_conntrack_mtinfo2),
298                 .match      = conntrack_mt_v2,
299                 .checkentry = conntrack_mt_check,
300                 .destroy    = conntrack_mt_destroy,
301                 .me         = THIS_MODULE,
302         },
303         {
304                 .name       = "conntrack",
305                 .revision   = 3,
306                 .family     = NFPROTO_UNSPEC,
307                 .matchsize  = sizeof(struct xt_conntrack_mtinfo3),
308                 .match      = conntrack_mt_v3,
309                 .checkentry = conntrack_mt_check,
310                 .destroy    = conntrack_mt_destroy,
311                 .me         = THIS_MODULE,
312         },
313 };
314
315 static int __init conntrack_mt_init(void)
316 {
317         return xt_register_matches(conntrack_mt_reg,
318                ARRAY_SIZE(conntrack_mt_reg));
319 }
320
321 static void __exit conntrack_mt_exit(void)
322 {
323         xt_unregister_matches(conntrack_mt_reg, ARRAY_SIZE(conntrack_mt_reg));
324 }
325
326 module_init(conntrack_mt_init);
327 module_exit(conntrack_mt_exit);