Merge tag 'for-5.11-rc4-tag' of git://git.kernel.org/pub/scm/linux/kernel/git/kdave...
[linux-2.6-microblaze.git] / net / xfrm / xfrm_state.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * xfrm_state.c
4  *
5  * Changes:
6  *      Mitsuru KANDA @USAGI
7  *      Kazunori MIYAZAWA @USAGI
8  *      Kunihiro Ishiguro <kunihiro@ipinfusion.com>
9  *              IPv6 support
10  *      YOSHIFUJI Hideaki @USAGI
11  *              Split up af-specific functions
12  *      Derek Atkins <derek@ihtfp.com>
13  *              Add UDP Encapsulation
14  *
15  */
16
17 #include <linux/workqueue.h>
18 #include <net/xfrm.h>
19 #include <linux/pfkeyv2.h>
20 #include <linux/ipsec.h>
21 #include <linux/module.h>
22 #include <linux/cache.h>
23 #include <linux/audit.h>
24 #include <linux/uaccess.h>
25 #include <linux/ktime.h>
26 #include <linux/slab.h>
27 #include <linux/interrupt.h>
28 #include <linux/kernel.h>
29
30 #include <crypto/aead.h>
31
32 #include "xfrm_hash.h"
33
34 #define xfrm_state_deref_prot(table, net) \
35         rcu_dereference_protected((table), lockdep_is_held(&(net)->xfrm.xfrm_state_lock))
36
37 static void xfrm_state_gc_task(struct work_struct *work);
38
39 /* Each xfrm_state may be linked to two tables:
40
41    1. Hash table by (spi,daddr,ah/esp) to find SA by SPI. (input,ctl)
42    2. Hash table by (daddr,family,reqid) to find what SAs exist for given
43       destination/tunnel endpoint. (output)
44  */
45
46 static unsigned int xfrm_state_hashmax __read_mostly = 1 * 1024 * 1024;
47 static __read_mostly seqcount_t xfrm_state_hash_generation = SEQCNT_ZERO(xfrm_state_hash_generation);
48 static struct kmem_cache *xfrm_state_cache __ro_after_init;
49
50 static DECLARE_WORK(xfrm_state_gc_work, xfrm_state_gc_task);
51 static HLIST_HEAD(xfrm_state_gc_list);
52
53 static inline bool xfrm_state_hold_rcu(struct xfrm_state __rcu *x)
54 {
55         return refcount_inc_not_zero(&x->refcnt);
56 }
57
58 static inline unsigned int xfrm_dst_hash(struct net *net,
59                                          const xfrm_address_t *daddr,
60                                          const xfrm_address_t *saddr,
61                                          u32 reqid,
62                                          unsigned short family)
63 {
64         return __xfrm_dst_hash(daddr, saddr, reqid, family, net->xfrm.state_hmask);
65 }
66
67 static inline unsigned int xfrm_src_hash(struct net *net,
68                                          const xfrm_address_t *daddr,
69                                          const xfrm_address_t *saddr,
70                                          unsigned short family)
71 {
72         return __xfrm_src_hash(daddr, saddr, family, net->xfrm.state_hmask);
73 }
74
75 static inline unsigned int
76 xfrm_spi_hash(struct net *net, const xfrm_address_t *daddr,
77               __be32 spi, u8 proto, unsigned short family)
78 {
79         return __xfrm_spi_hash(daddr, spi, proto, family, net->xfrm.state_hmask);
80 }
81
82 static void xfrm_hash_transfer(struct hlist_head *list,
83                                struct hlist_head *ndsttable,
84                                struct hlist_head *nsrctable,
85                                struct hlist_head *nspitable,
86                                unsigned int nhashmask)
87 {
88         struct hlist_node *tmp;
89         struct xfrm_state *x;
90
91         hlist_for_each_entry_safe(x, tmp, list, bydst) {
92                 unsigned int h;
93
94                 h = __xfrm_dst_hash(&x->id.daddr, &x->props.saddr,
95                                     x->props.reqid, x->props.family,
96                                     nhashmask);
97                 hlist_add_head_rcu(&x->bydst, ndsttable + h);
98
99                 h = __xfrm_src_hash(&x->id.daddr, &x->props.saddr,
100                                     x->props.family,
101                                     nhashmask);
102                 hlist_add_head_rcu(&x->bysrc, nsrctable + h);
103
104                 if (x->id.spi) {
105                         h = __xfrm_spi_hash(&x->id.daddr, x->id.spi,
106                                             x->id.proto, x->props.family,
107                                             nhashmask);
108                         hlist_add_head_rcu(&x->byspi, nspitable + h);
109                 }
110         }
111 }
112
113 static unsigned long xfrm_hash_new_size(unsigned int state_hmask)
114 {
115         return ((state_hmask + 1) << 1) * sizeof(struct hlist_head);
116 }
117
118 static void xfrm_hash_resize(struct work_struct *work)
119 {
120         struct net *net = container_of(work, struct net, xfrm.state_hash_work);
121         struct hlist_head *ndst, *nsrc, *nspi, *odst, *osrc, *ospi;
122         unsigned long nsize, osize;
123         unsigned int nhashmask, ohashmask;
124         int i;
125
126         nsize = xfrm_hash_new_size(net->xfrm.state_hmask);
127         ndst = xfrm_hash_alloc(nsize);
128         if (!ndst)
129                 return;
130         nsrc = xfrm_hash_alloc(nsize);
131         if (!nsrc) {
132                 xfrm_hash_free(ndst, nsize);
133                 return;
134         }
135         nspi = xfrm_hash_alloc(nsize);
136         if (!nspi) {
137                 xfrm_hash_free(ndst, nsize);
138                 xfrm_hash_free(nsrc, nsize);
139                 return;
140         }
141
142         spin_lock_bh(&net->xfrm.xfrm_state_lock);
143         write_seqcount_begin(&xfrm_state_hash_generation);
144
145         nhashmask = (nsize / sizeof(struct hlist_head)) - 1U;
146         odst = xfrm_state_deref_prot(net->xfrm.state_bydst, net);
147         for (i = net->xfrm.state_hmask; i >= 0; i--)
148                 xfrm_hash_transfer(odst + i, ndst, nsrc, nspi, nhashmask);
149
150         osrc = xfrm_state_deref_prot(net->xfrm.state_bysrc, net);
151         ospi = xfrm_state_deref_prot(net->xfrm.state_byspi, net);
152         ohashmask = net->xfrm.state_hmask;
153
154         rcu_assign_pointer(net->xfrm.state_bydst, ndst);
155         rcu_assign_pointer(net->xfrm.state_bysrc, nsrc);
156         rcu_assign_pointer(net->xfrm.state_byspi, nspi);
157         net->xfrm.state_hmask = nhashmask;
158
159         write_seqcount_end(&xfrm_state_hash_generation);
160         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
161
162         osize = (ohashmask + 1) * sizeof(struct hlist_head);
163
164         synchronize_rcu();
165
166         xfrm_hash_free(odst, osize);
167         xfrm_hash_free(osrc, osize);
168         xfrm_hash_free(ospi, osize);
169 }
170
171 static DEFINE_SPINLOCK(xfrm_state_afinfo_lock);
172 static struct xfrm_state_afinfo __rcu *xfrm_state_afinfo[NPROTO];
173
174 static DEFINE_SPINLOCK(xfrm_state_gc_lock);
175
176 int __xfrm_state_delete(struct xfrm_state *x);
177
178 int km_query(struct xfrm_state *x, struct xfrm_tmpl *t, struct xfrm_policy *pol);
179 static bool km_is_alive(const struct km_event *c);
180 void km_state_expired(struct xfrm_state *x, int hard, u32 portid);
181
182 int xfrm_register_type(const struct xfrm_type *type, unsigned short family)
183 {
184         struct xfrm_state_afinfo *afinfo = xfrm_state_get_afinfo(family);
185         int err = 0;
186
187         if (!afinfo)
188                 return -EAFNOSUPPORT;
189
190 #define X(afi, T, name) do {                    \
191                 WARN_ON((afi)->type_ ## name);  \
192                 (afi)->type_ ## name = (T);     \
193         } while (0)
194
195         switch (type->proto) {
196         case IPPROTO_COMP:
197                 X(afinfo, type, comp);
198                 break;
199         case IPPROTO_AH:
200                 X(afinfo, type, ah);
201                 break;
202         case IPPROTO_ESP:
203                 X(afinfo, type, esp);
204                 break;
205         case IPPROTO_IPIP:
206                 X(afinfo, type, ipip);
207                 break;
208         case IPPROTO_DSTOPTS:
209                 X(afinfo, type, dstopts);
210                 break;
211         case IPPROTO_ROUTING:
212                 X(afinfo, type, routing);
213                 break;
214         case IPPROTO_IPV6:
215                 X(afinfo, type, ipip6);
216                 break;
217         default:
218                 WARN_ON(1);
219                 err = -EPROTONOSUPPORT;
220                 break;
221         }
222 #undef X
223         rcu_read_unlock();
224         return err;
225 }
226 EXPORT_SYMBOL(xfrm_register_type);
227
228 void xfrm_unregister_type(const struct xfrm_type *type, unsigned short family)
229 {
230         struct xfrm_state_afinfo *afinfo = xfrm_state_get_afinfo(family);
231
232         if (unlikely(afinfo == NULL))
233                 return;
234
235 #define X(afi, T, name) do {                            \
236                 WARN_ON((afi)->type_ ## name != (T));   \
237                 (afi)->type_ ## name = NULL;            \
238         } while (0)
239
240         switch (type->proto) {
241         case IPPROTO_COMP:
242                 X(afinfo, type, comp);
243                 break;
244         case IPPROTO_AH:
245                 X(afinfo, type, ah);
246                 break;
247         case IPPROTO_ESP:
248                 X(afinfo, type, esp);
249                 break;
250         case IPPROTO_IPIP:
251                 X(afinfo, type, ipip);
252                 break;
253         case IPPROTO_DSTOPTS:
254                 X(afinfo, type, dstopts);
255                 break;
256         case IPPROTO_ROUTING:
257                 X(afinfo, type, routing);
258                 break;
259         case IPPROTO_IPV6:
260                 X(afinfo, type, ipip6);
261                 break;
262         default:
263                 WARN_ON(1);
264                 break;
265         }
266 #undef X
267         rcu_read_unlock();
268 }
269 EXPORT_SYMBOL(xfrm_unregister_type);
270
271 static const struct xfrm_type *xfrm_get_type(u8 proto, unsigned short family)
272 {
273         const struct xfrm_type *type = NULL;
274         struct xfrm_state_afinfo *afinfo;
275         int modload_attempted = 0;
276
277 retry:
278         afinfo = xfrm_state_get_afinfo(family);
279         if (unlikely(afinfo == NULL))
280                 return NULL;
281
282         switch (proto) {
283         case IPPROTO_COMP:
284                 type = afinfo->type_comp;
285                 break;
286         case IPPROTO_AH:
287                 type = afinfo->type_ah;
288                 break;
289         case IPPROTO_ESP:
290                 type = afinfo->type_esp;
291                 break;
292         case IPPROTO_IPIP:
293                 type = afinfo->type_ipip;
294                 break;
295         case IPPROTO_DSTOPTS:
296                 type = afinfo->type_dstopts;
297                 break;
298         case IPPROTO_ROUTING:
299                 type = afinfo->type_routing;
300                 break;
301         case IPPROTO_IPV6:
302                 type = afinfo->type_ipip6;
303                 break;
304         default:
305                 break;
306         }
307
308         if (unlikely(type && !try_module_get(type->owner)))
309                 type = NULL;
310
311         rcu_read_unlock();
312
313         if (!type && !modload_attempted) {
314                 request_module("xfrm-type-%d-%d", family, proto);
315                 modload_attempted = 1;
316                 goto retry;
317         }
318
319         return type;
320 }
321
322 static void xfrm_put_type(const struct xfrm_type *type)
323 {
324         module_put(type->owner);
325 }
326
327 int xfrm_register_type_offload(const struct xfrm_type_offload *type,
328                                unsigned short family)
329 {
330         struct xfrm_state_afinfo *afinfo = xfrm_state_get_afinfo(family);
331         int err = 0;
332
333         if (unlikely(afinfo == NULL))
334                 return -EAFNOSUPPORT;
335
336         switch (type->proto) {
337         case IPPROTO_ESP:
338                 WARN_ON(afinfo->type_offload_esp);
339                 afinfo->type_offload_esp = type;
340                 break;
341         default:
342                 WARN_ON(1);
343                 err = -EPROTONOSUPPORT;
344                 break;
345         }
346
347         rcu_read_unlock();
348         return err;
349 }
350 EXPORT_SYMBOL(xfrm_register_type_offload);
351
352 void xfrm_unregister_type_offload(const struct xfrm_type_offload *type,
353                                   unsigned short family)
354 {
355         struct xfrm_state_afinfo *afinfo = xfrm_state_get_afinfo(family);
356
357         if (unlikely(afinfo == NULL))
358                 return;
359
360         switch (type->proto) {
361         case IPPROTO_ESP:
362                 WARN_ON(afinfo->type_offload_esp != type);
363                 afinfo->type_offload_esp = NULL;
364                 break;
365         default:
366                 WARN_ON(1);
367                 break;
368         }
369         rcu_read_unlock();
370 }
371 EXPORT_SYMBOL(xfrm_unregister_type_offload);
372
373 static const struct xfrm_type_offload *
374 xfrm_get_type_offload(u8 proto, unsigned short family, bool try_load)
375 {
376         const struct xfrm_type_offload *type = NULL;
377         struct xfrm_state_afinfo *afinfo;
378
379 retry:
380         afinfo = xfrm_state_get_afinfo(family);
381         if (unlikely(afinfo == NULL))
382                 return NULL;
383
384         switch (proto) {
385         case IPPROTO_ESP:
386                 type = afinfo->type_offload_esp;
387                 break;
388         default:
389                 break;
390         }
391
392         if ((type && !try_module_get(type->owner)))
393                 type = NULL;
394
395         rcu_read_unlock();
396
397         if (!type && try_load) {
398                 request_module("xfrm-offload-%d-%d", family, proto);
399                 try_load = false;
400                 goto retry;
401         }
402
403         return type;
404 }
405
406 static void xfrm_put_type_offload(const struct xfrm_type_offload *type)
407 {
408         module_put(type->owner);
409 }
410
411 static const struct xfrm_mode xfrm4_mode_map[XFRM_MODE_MAX] = {
412         [XFRM_MODE_BEET] = {
413                 .encap = XFRM_MODE_BEET,
414                 .flags = XFRM_MODE_FLAG_TUNNEL,
415                 .family = AF_INET,
416         },
417         [XFRM_MODE_TRANSPORT] = {
418                 .encap = XFRM_MODE_TRANSPORT,
419                 .family = AF_INET,
420         },
421         [XFRM_MODE_TUNNEL] = {
422                 .encap = XFRM_MODE_TUNNEL,
423                 .flags = XFRM_MODE_FLAG_TUNNEL,
424                 .family = AF_INET,
425         },
426 };
427
428 static const struct xfrm_mode xfrm6_mode_map[XFRM_MODE_MAX] = {
429         [XFRM_MODE_BEET] = {
430                 .encap = XFRM_MODE_BEET,
431                 .flags = XFRM_MODE_FLAG_TUNNEL,
432                 .family = AF_INET6,
433         },
434         [XFRM_MODE_ROUTEOPTIMIZATION] = {
435                 .encap = XFRM_MODE_ROUTEOPTIMIZATION,
436                 .family = AF_INET6,
437         },
438         [XFRM_MODE_TRANSPORT] = {
439                 .encap = XFRM_MODE_TRANSPORT,
440                 .family = AF_INET6,
441         },
442         [XFRM_MODE_TUNNEL] = {
443                 .encap = XFRM_MODE_TUNNEL,
444                 .flags = XFRM_MODE_FLAG_TUNNEL,
445                 .family = AF_INET6,
446         },
447 };
448
449 static const struct xfrm_mode *xfrm_get_mode(unsigned int encap, int family)
450 {
451         const struct xfrm_mode *mode;
452
453         if (unlikely(encap >= XFRM_MODE_MAX))
454                 return NULL;
455
456         switch (family) {
457         case AF_INET:
458                 mode = &xfrm4_mode_map[encap];
459                 if (mode->family == family)
460                         return mode;
461                 break;
462         case AF_INET6:
463                 mode = &xfrm6_mode_map[encap];
464                 if (mode->family == family)
465                         return mode;
466                 break;
467         default:
468                 break;
469         }
470
471         return NULL;
472 }
473
474 void xfrm_state_free(struct xfrm_state *x)
475 {
476         kmem_cache_free(xfrm_state_cache, x);
477 }
478 EXPORT_SYMBOL(xfrm_state_free);
479
480 static void ___xfrm_state_destroy(struct xfrm_state *x)
481 {
482         hrtimer_cancel(&x->mtimer);
483         del_timer_sync(&x->rtimer);
484         kfree(x->aead);
485         kfree(x->aalg);
486         kfree(x->ealg);
487         kfree(x->calg);
488         kfree(x->encap);
489         kfree(x->coaddr);
490         kfree(x->replay_esn);
491         kfree(x->preplay_esn);
492         if (x->type_offload)
493                 xfrm_put_type_offload(x->type_offload);
494         if (x->type) {
495                 x->type->destructor(x);
496                 xfrm_put_type(x->type);
497         }
498         if (x->xfrag.page)
499                 put_page(x->xfrag.page);
500         xfrm_dev_state_free(x);
501         security_xfrm_state_free(x);
502         xfrm_state_free(x);
503 }
504
505 static void xfrm_state_gc_task(struct work_struct *work)
506 {
507         struct xfrm_state *x;
508         struct hlist_node *tmp;
509         struct hlist_head gc_list;
510
511         spin_lock_bh(&xfrm_state_gc_lock);
512         hlist_move_list(&xfrm_state_gc_list, &gc_list);
513         spin_unlock_bh(&xfrm_state_gc_lock);
514
515         synchronize_rcu();
516
517         hlist_for_each_entry_safe(x, tmp, &gc_list, gclist)
518                 ___xfrm_state_destroy(x);
519 }
520
521 static enum hrtimer_restart xfrm_timer_handler(struct hrtimer *me)
522 {
523         struct xfrm_state *x = container_of(me, struct xfrm_state, mtimer);
524         enum hrtimer_restart ret = HRTIMER_NORESTART;
525         time64_t now = ktime_get_real_seconds();
526         time64_t next = TIME64_MAX;
527         int warn = 0;
528         int err = 0;
529
530         spin_lock(&x->lock);
531         if (x->km.state == XFRM_STATE_DEAD)
532                 goto out;
533         if (x->km.state == XFRM_STATE_EXPIRED)
534                 goto expired;
535         if (x->lft.hard_add_expires_seconds) {
536                 long tmo = x->lft.hard_add_expires_seconds +
537                         x->curlft.add_time - now;
538                 if (tmo <= 0) {
539                         if (x->xflags & XFRM_SOFT_EXPIRE) {
540                                 /* enter hard expire without soft expire first?!
541                                  * setting a new date could trigger this.
542                                  * workaround: fix x->curflt.add_time by below:
543                                  */
544                                 x->curlft.add_time = now - x->saved_tmo - 1;
545                                 tmo = x->lft.hard_add_expires_seconds - x->saved_tmo;
546                         } else
547                                 goto expired;
548                 }
549                 if (tmo < next)
550                         next = tmo;
551         }
552         if (x->lft.hard_use_expires_seconds) {
553                 long tmo = x->lft.hard_use_expires_seconds +
554                         (x->curlft.use_time ? : now) - now;
555                 if (tmo <= 0)
556                         goto expired;
557                 if (tmo < next)
558                         next = tmo;
559         }
560         if (x->km.dying)
561                 goto resched;
562         if (x->lft.soft_add_expires_seconds) {
563                 long tmo = x->lft.soft_add_expires_seconds +
564                         x->curlft.add_time - now;
565                 if (tmo <= 0) {
566                         warn = 1;
567                         x->xflags &= ~XFRM_SOFT_EXPIRE;
568                 } else if (tmo < next) {
569                         next = tmo;
570                         x->xflags |= XFRM_SOFT_EXPIRE;
571                         x->saved_tmo = tmo;
572                 }
573         }
574         if (x->lft.soft_use_expires_seconds) {
575                 long tmo = x->lft.soft_use_expires_seconds +
576                         (x->curlft.use_time ? : now) - now;
577                 if (tmo <= 0)
578                         warn = 1;
579                 else if (tmo < next)
580                         next = tmo;
581         }
582
583         x->km.dying = warn;
584         if (warn)
585                 km_state_expired(x, 0, 0);
586 resched:
587         if (next != TIME64_MAX) {
588                 hrtimer_forward_now(&x->mtimer, ktime_set(next, 0));
589                 ret = HRTIMER_RESTART;
590         }
591
592         goto out;
593
594 expired:
595         if (x->km.state == XFRM_STATE_ACQ && x->id.spi == 0)
596                 x->km.state = XFRM_STATE_EXPIRED;
597
598         err = __xfrm_state_delete(x);
599         if (!err)
600                 km_state_expired(x, 1, 0);
601
602         xfrm_audit_state_delete(x, err ? 0 : 1, true);
603
604 out:
605         spin_unlock(&x->lock);
606         return ret;
607 }
608
609 static void xfrm_replay_timer_handler(struct timer_list *t);
610
611 struct xfrm_state *xfrm_state_alloc(struct net *net)
612 {
613         struct xfrm_state *x;
614
615         x = kmem_cache_zalloc(xfrm_state_cache, GFP_ATOMIC);
616
617         if (x) {
618                 write_pnet(&x->xs_net, net);
619                 refcount_set(&x->refcnt, 1);
620                 atomic_set(&x->tunnel_users, 0);
621                 INIT_LIST_HEAD(&x->km.all);
622                 INIT_HLIST_NODE(&x->bydst);
623                 INIT_HLIST_NODE(&x->bysrc);
624                 INIT_HLIST_NODE(&x->byspi);
625                 hrtimer_init(&x->mtimer, CLOCK_BOOTTIME, HRTIMER_MODE_ABS_SOFT);
626                 x->mtimer.function = xfrm_timer_handler;
627                 timer_setup(&x->rtimer, xfrm_replay_timer_handler, 0);
628                 x->curlft.add_time = ktime_get_real_seconds();
629                 x->lft.soft_byte_limit = XFRM_INF;
630                 x->lft.soft_packet_limit = XFRM_INF;
631                 x->lft.hard_byte_limit = XFRM_INF;
632                 x->lft.hard_packet_limit = XFRM_INF;
633                 x->replay_maxage = 0;
634                 x->replay_maxdiff = 0;
635                 spin_lock_init(&x->lock);
636         }
637         return x;
638 }
639 EXPORT_SYMBOL(xfrm_state_alloc);
640
641 void __xfrm_state_destroy(struct xfrm_state *x, bool sync)
642 {
643         WARN_ON(x->km.state != XFRM_STATE_DEAD);
644
645         if (sync) {
646                 synchronize_rcu();
647                 ___xfrm_state_destroy(x);
648         } else {
649                 spin_lock_bh(&xfrm_state_gc_lock);
650                 hlist_add_head(&x->gclist, &xfrm_state_gc_list);
651                 spin_unlock_bh(&xfrm_state_gc_lock);
652                 schedule_work(&xfrm_state_gc_work);
653         }
654 }
655 EXPORT_SYMBOL(__xfrm_state_destroy);
656
657 int __xfrm_state_delete(struct xfrm_state *x)
658 {
659         struct net *net = xs_net(x);
660         int err = -ESRCH;
661
662         if (x->km.state != XFRM_STATE_DEAD) {
663                 x->km.state = XFRM_STATE_DEAD;
664                 spin_lock(&net->xfrm.xfrm_state_lock);
665                 list_del(&x->km.all);
666                 hlist_del_rcu(&x->bydst);
667                 hlist_del_rcu(&x->bysrc);
668                 if (x->id.spi)
669                         hlist_del_rcu(&x->byspi);
670                 net->xfrm.state_num--;
671                 spin_unlock(&net->xfrm.xfrm_state_lock);
672
673                 if (x->encap_sk)
674                         sock_put(rcu_dereference_raw(x->encap_sk));
675
676                 xfrm_dev_state_delete(x);
677
678                 /* All xfrm_state objects are created by xfrm_state_alloc.
679                  * The xfrm_state_alloc call gives a reference, and that
680                  * is what we are dropping here.
681                  */
682                 xfrm_state_put(x);
683                 err = 0;
684         }
685
686         return err;
687 }
688 EXPORT_SYMBOL(__xfrm_state_delete);
689
690 int xfrm_state_delete(struct xfrm_state *x)
691 {
692         int err;
693
694         spin_lock_bh(&x->lock);
695         err = __xfrm_state_delete(x);
696         spin_unlock_bh(&x->lock);
697
698         return err;
699 }
700 EXPORT_SYMBOL(xfrm_state_delete);
701
702 #ifdef CONFIG_SECURITY_NETWORK_XFRM
703 static inline int
704 xfrm_state_flush_secctx_check(struct net *net, u8 proto, bool task_valid)
705 {
706         int i, err = 0;
707
708         for (i = 0; i <= net->xfrm.state_hmask; i++) {
709                 struct xfrm_state *x;
710
711                 hlist_for_each_entry(x, net->xfrm.state_bydst+i, bydst) {
712                         if (xfrm_id_proto_match(x->id.proto, proto) &&
713                            (err = security_xfrm_state_delete(x)) != 0) {
714                                 xfrm_audit_state_delete(x, 0, task_valid);
715                                 return err;
716                         }
717                 }
718         }
719
720         return err;
721 }
722
723 static inline int
724 xfrm_dev_state_flush_secctx_check(struct net *net, struct net_device *dev, bool task_valid)
725 {
726         int i, err = 0;
727
728         for (i = 0; i <= net->xfrm.state_hmask; i++) {
729                 struct xfrm_state *x;
730                 struct xfrm_state_offload *xso;
731
732                 hlist_for_each_entry(x, net->xfrm.state_bydst+i, bydst) {
733                         xso = &x->xso;
734
735                         if (xso->dev == dev &&
736                            (err = security_xfrm_state_delete(x)) != 0) {
737                                 xfrm_audit_state_delete(x, 0, task_valid);
738                                 return err;
739                         }
740                 }
741         }
742
743         return err;
744 }
745 #else
746 static inline int
747 xfrm_state_flush_secctx_check(struct net *net, u8 proto, bool task_valid)
748 {
749         return 0;
750 }
751
752 static inline int
753 xfrm_dev_state_flush_secctx_check(struct net *net, struct net_device *dev, bool task_valid)
754 {
755         return 0;
756 }
757 #endif
758
759 int xfrm_state_flush(struct net *net, u8 proto, bool task_valid, bool sync)
760 {
761         int i, err = 0, cnt = 0;
762
763         spin_lock_bh(&net->xfrm.xfrm_state_lock);
764         err = xfrm_state_flush_secctx_check(net, proto, task_valid);
765         if (err)
766                 goto out;
767
768         err = -ESRCH;
769         for (i = 0; i <= net->xfrm.state_hmask; i++) {
770                 struct xfrm_state *x;
771 restart:
772                 hlist_for_each_entry(x, net->xfrm.state_bydst+i, bydst) {
773                         if (!xfrm_state_kern(x) &&
774                             xfrm_id_proto_match(x->id.proto, proto)) {
775                                 xfrm_state_hold(x);
776                                 spin_unlock_bh(&net->xfrm.xfrm_state_lock);
777
778                                 err = xfrm_state_delete(x);
779                                 xfrm_audit_state_delete(x, err ? 0 : 1,
780                                                         task_valid);
781                                 if (sync)
782                                         xfrm_state_put_sync(x);
783                                 else
784                                         xfrm_state_put(x);
785                                 if (!err)
786                                         cnt++;
787
788                                 spin_lock_bh(&net->xfrm.xfrm_state_lock);
789                                 goto restart;
790                         }
791                 }
792         }
793 out:
794         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
795         if (cnt)
796                 err = 0;
797
798         return err;
799 }
800 EXPORT_SYMBOL(xfrm_state_flush);
801
802 int xfrm_dev_state_flush(struct net *net, struct net_device *dev, bool task_valid)
803 {
804         int i, err = 0, cnt = 0;
805
806         spin_lock_bh(&net->xfrm.xfrm_state_lock);
807         err = xfrm_dev_state_flush_secctx_check(net, dev, task_valid);
808         if (err)
809                 goto out;
810
811         err = -ESRCH;
812         for (i = 0; i <= net->xfrm.state_hmask; i++) {
813                 struct xfrm_state *x;
814                 struct xfrm_state_offload *xso;
815 restart:
816                 hlist_for_each_entry(x, net->xfrm.state_bydst+i, bydst) {
817                         xso = &x->xso;
818
819                         if (!xfrm_state_kern(x) && xso->dev == dev) {
820                                 xfrm_state_hold(x);
821                                 spin_unlock_bh(&net->xfrm.xfrm_state_lock);
822
823                                 err = xfrm_state_delete(x);
824                                 xfrm_audit_state_delete(x, err ? 0 : 1,
825                                                         task_valid);
826                                 xfrm_state_put(x);
827                                 if (!err)
828                                         cnt++;
829
830                                 spin_lock_bh(&net->xfrm.xfrm_state_lock);
831                                 goto restart;
832                         }
833                 }
834         }
835         if (cnt)
836                 err = 0;
837
838 out:
839         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
840         return err;
841 }
842 EXPORT_SYMBOL(xfrm_dev_state_flush);
843
844 void xfrm_sad_getinfo(struct net *net, struct xfrmk_sadinfo *si)
845 {
846         spin_lock_bh(&net->xfrm.xfrm_state_lock);
847         si->sadcnt = net->xfrm.state_num;
848         si->sadhcnt = net->xfrm.state_hmask + 1;
849         si->sadhmcnt = xfrm_state_hashmax;
850         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
851 }
852 EXPORT_SYMBOL(xfrm_sad_getinfo);
853
854 static void
855 __xfrm4_init_tempsel(struct xfrm_selector *sel, const struct flowi *fl)
856 {
857         const struct flowi4 *fl4 = &fl->u.ip4;
858
859         sel->daddr.a4 = fl4->daddr;
860         sel->saddr.a4 = fl4->saddr;
861         sel->dport = xfrm_flowi_dport(fl, &fl4->uli);
862         sel->dport_mask = htons(0xffff);
863         sel->sport = xfrm_flowi_sport(fl, &fl4->uli);
864         sel->sport_mask = htons(0xffff);
865         sel->family = AF_INET;
866         sel->prefixlen_d = 32;
867         sel->prefixlen_s = 32;
868         sel->proto = fl4->flowi4_proto;
869         sel->ifindex = fl4->flowi4_oif;
870 }
871
872 static void
873 __xfrm6_init_tempsel(struct xfrm_selector *sel, const struct flowi *fl)
874 {
875         const struct flowi6 *fl6 = &fl->u.ip6;
876
877         /* Initialize temporary selector matching only to current session. */
878         *(struct in6_addr *)&sel->daddr = fl6->daddr;
879         *(struct in6_addr *)&sel->saddr = fl6->saddr;
880         sel->dport = xfrm_flowi_dport(fl, &fl6->uli);
881         sel->dport_mask = htons(0xffff);
882         sel->sport = xfrm_flowi_sport(fl, &fl6->uli);
883         sel->sport_mask = htons(0xffff);
884         sel->family = AF_INET6;
885         sel->prefixlen_d = 128;
886         sel->prefixlen_s = 128;
887         sel->proto = fl6->flowi6_proto;
888         sel->ifindex = fl6->flowi6_oif;
889 }
890
891 static void
892 xfrm_init_tempstate(struct xfrm_state *x, const struct flowi *fl,
893                     const struct xfrm_tmpl *tmpl,
894                     const xfrm_address_t *daddr, const xfrm_address_t *saddr,
895                     unsigned short family)
896 {
897         switch (family) {
898         case AF_INET:
899                 __xfrm4_init_tempsel(&x->sel, fl);
900                 break;
901         case AF_INET6:
902                 __xfrm6_init_tempsel(&x->sel, fl);
903                 break;
904         }
905
906         x->id = tmpl->id;
907
908         switch (tmpl->encap_family) {
909         case AF_INET:
910                 if (x->id.daddr.a4 == 0)
911                         x->id.daddr.a4 = daddr->a4;
912                 x->props.saddr = tmpl->saddr;
913                 if (x->props.saddr.a4 == 0)
914                         x->props.saddr.a4 = saddr->a4;
915                 break;
916         case AF_INET6:
917                 if (ipv6_addr_any((struct in6_addr *)&x->id.daddr))
918                         memcpy(&x->id.daddr, daddr, sizeof(x->sel.daddr));
919                 memcpy(&x->props.saddr, &tmpl->saddr, sizeof(x->props.saddr));
920                 if (ipv6_addr_any((struct in6_addr *)&x->props.saddr))
921                         memcpy(&x->props.saddr, saddr, sizeof(x->props.saddr));
922                 break;
923         }
924
925         x->props.mode = tmpl->mode;
926         x->props.reqid = tmpl->reqid;
927         x->props.family = tmpl->encap_family;
928 }
929
930 static struct xfrm_state *__xfrm_state_lookup(struct net *net, u32 mark,
931                                               const xfrm_address_t *daddr,
932                                               __be32 spi, u8 proto,
933                                               unsigned short family)
934 {
935         unsigned int h = xfrm_spi_hash(net, daddr, spi, proto, family);
936         struct xfrm_state *x;
937
938         hlist_for_each_entry_rcu(x, net->xfrm.state_byspi + h, byspi) {
939                 if (x->props.family != family ||
940                     x->id.spi       != spi ||
941                     x->id.proto     != proto ||
942                     !xfrm_addr_equal(&x->id.daddr, daddr, family))
943                         continue;
944
945                 if ((mark & x->mark.m) != x->mark.v)
946                         continue;
947                 if (!xfrm_state_hold_rcu(x))
948                         continue;
949                 return x;
950         }
951
952         return NULL;
953 }
954
955 static struct xfrm_state *__xfrm_state_lookup_byaddr(struct net *net, u32 mark,
956                                                      const xfrm_address_t *daddr,
957                                                      const xfrm_address_t *saddr,
958                                                      u8 proto, unsigned short family)
959 {
960         unsigned int h = xfrm_src_hash(net, daddr, saddr, family);
961         struct xfrm_state *x;
962
963         hlist_for_each_entry_rcu(x, net->xfrm.state_bysrc + h, bysrc) {
964                 if (x->props.family != family ||
965                     x->id.proto     != proto ||
966                     !xfrm_addr_equal(&x->id.daddr, daddr, family) ||
967                     !xfrm_addr_equal(&x->props.saddr, saddr, family))
968                         continue;
969
970                 if ((mark & x->mark.m) != x->mark.v)
971                         continue;
972                 if (!xfrm_state_hold_rcu(x))
973                         continue;
974                 return x;
975         }
976
977         return NULL;
978 }
979
980 static inline struct xfrm_state *
981 __xfrm_state_locate(struct xfrm_state *x, int use_spi, int family)
982 {
983         struct net *net = xs_net(x);
984         u32 mark = x->mark.v & x->mark.m;
985
986         if (use_spi)
987                 return __xfrm_state_lookup(net, mark, &x->id.daddr,
988                                            x->id.spi, x->id.proto, family);
989         else
990                 return __xfrm_state_lookup_byaddr(net, mark,
991                                                   &x->id.daddr,
992                                                   &x->props.saddr,
993                                                   x->id.proto, family);
994 }
995
996 static void xfrm_hash_grow_check(struct net *net, int have_hash_collision)
997 {
998         if (have_hash_collision &&
999             (net->xfrm.state_hmask + 1) < xfrm_state_hashmax &&
1000             net->xfrm.state_num > net->xfrm.state_hmask)
1001                 schedule_work(&net->xfrm.state_hash_work);
1002 }
1003
1004 static void xfrm_state_look_at(struct xfrm_policy *pol, struct xfrm_state *x,
1005                                const struct flowi *fl, unsigned short family,
1006                                struct xfrm_state **best, int *acq_in_progress,
1007                                int *error)
1008 {
1009         /* Resolution logic:
1010          * 1. There is a valid state with matching selector. Done.
1011          * 2. Valid state with inappropriate selector. Skip.
1012          *
1013          * Entering area of "sysdeps".
1014          *
1015          * 3. If state is not valid, selector is temporary, it selects
1016          *    only session which triggered previous resolution. Key
1017          *    manager will do something to install a state with proper
1018          *    selector.
1019          */
1020         if (x->km.state == XFRM_STATE_VALID) {
1021                 if ((x->sel.family &&
1022                      (x->sel.family != family ||
1023                       !xfrm_selector_match(&x->sel, fl, family))) ||
1024                     !security_xfrm_state_pol_flow_match(x, pol,
1025                                                         &fl->u.__fl_common))
1026                         return;
1027
1028                 if (!*best ||
1029                     (*best)->km.dying > x->km.dying ||
1030                     ((*best)->km.dying == x->km.dying &&
1031                      (*best)->curlft.add_time < x->curlft.add_time))
1032                         *best = x;
1033         } else if (x->km.state == XFRM_STATE_ACQ) {
1034                 *acq_in_progress = 1;
1035         } else if (x->km.state == XFRM_STATE_ERROR ||
1036                    x->km.state == XFRM_STATE_EXPIRED) {
1037                 if ((!x->sel.family ||
1038                      (x->sel.family == family &&
1039                       xfrm_selector_match(&x->sel, fl, family))) &&
1040                     security_xfrm_state_pol_flow_match(x, pol,
1041                                                        &fl->u.__fl_common))
1042                         *error = -ESRCH;
1043         }
1044 }
1045
1046 struct xfrm_state *
1047 xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr,
1048                 const struct flowi *fl, struct xfrm_tmpl *tmpl,
1049                 struct xfrm_policy *pol, int *err,
1050                 unsigned short family, u32 if_id)
1051 {
1052         static xfrm_address_t saddr_wildcard = { };
1053         struct net *net = xp_net(pol);
1054         unsigned int h, h_wildcard;
1055         struct xfrm_state *x, *x0, *to_put;
1056         int acquire_in_progress = 0;
1057         int error = 0;
1058         struct xfrm_state *best = NULL;
1059         u32 mark = pol->mark.v & pol->mark.m;
1060         unsigned short encap_family = tmpl->encap_family;
1061         unsigned int sequence;
1062         struct km_event c;
1063
1064         to_put = NULL;
1065
1066         sequence = read_seqcount_begin(&xfrm_state_hash_generation);
1067
1068         rcu_read_lock();
1069         h = xfrm_dst_hash(net, daddr, saddr, tmpl->reqid, encap_family);
1070         hlist_for_each_entry_rcu(x, net->xfrm.state_bydst + h, bydst) {
1071                 if (x->props.family == encap_family &&
1072                     x->props.reqid == tmpl->reqid &&
1073                     (mark & x->mark.m) == x->mark.v &&
1074                     x->if_id == if_id &&
1075                     !(x->props.flags & XFRM_STATE_WILDRECV) &&
1076                     xfrm_state_addr_check(x, daddr, saddr, encap_family) &&
1077                     tmpl->mode == x->props.mode &&
1078                     tmpl->id.proto == x->id.proto &&
1079                     (tmpl->id.spi == x->id.spi || !tmpl->id.spi))
1080                         xfrm_state_look_at(pol, x, fl, family,
1081                                            &best, &acquire_in_progress, &error);
1082         }
1083         if (best || acquire_in_progress)
1084                 goto found;
1085
1086         h_wildcard = xfrm_dst_hash(net, daddr, &saddr_wildcard, tmpl->reqid, encap_family);
1087         hlist_for_each_entry_rcu(x, net->xfrm.state_bydst + h_wildcard, bydst) {
1088                 if (x->props.family == encap_family &&
1089                     x->props.reqid == tmpl->reqid &&
1090                     (mark & x->mark.m) == x->mark.v &&
1091                     x->if_id == if_id &&
1092                     !(x->props.flags & XFRM_STATE_WILDRECV) &&
1093                     xfrm_addr_equal(&x->id.daddr, daddr, encap_family) &&
1094                     tmpl->mode == x->props.mode &&
1095                     tmpl->id.proto == x->id.proto &&
1096                     (tmpl->id.spi == x->id.spi || !tmpl->id.spi))
1097                         xfrm_state_look_at(pol, x, fl, family,
1098                                            &best, &acquire_in_progress, &error);
1099         }
1100
1101 found:
1102         x = best;
1103         if (!x && !error && !acquire_in_progress) {
1104                 if (tmpl->id.spi &&
1105                     (x0 = __xfrm_state_lookup(net, mark, daddr, tmpl->id.spi,
1106                                               tmpl->id.proto, encap_family)) != NULL) {
1107                         to_put = x0;
1108                         error = -EEXIST;
1109                         goto out;
1110                 }
1111
1112                 c.net = net;
1113                 /* If the KMs have no listeners (yet...), avoid allocating an SA
1114                  * for each and every packet - garbage collection might not
1115                  * handle the flood.
1116                  */
1117                 if (!km_is_alive(&c)) {
1118                         error = -ESRCH;
1119                         goto out;
1120                 }
1121
1122                 x = xfrm_state_alloc(net);
1123                 if (x == NULL) {
1124                         error = -ENOMEM;
1125                         goto out;
1126                 }
1127                 /* Initialize temporary state matching only
1128                  * to current session. */
1129                 xfrm_init_tempstate(x, fl, tmpl, daddr, saddr, family);
1130                 memcpy(&x->mark, &pol->mark, sizeof(x->mark));
1131                 x->if_id = if_id;
1132
1133                 error = security_xfrm_state_alloc_acquire(x, pol->security, fl->flowi_secid);
1134                 if (error) {
1135                         x->km.state = XFRM_STATE_DEAD;
1136                         to_put = x;
1137                         x = NULL;
1138                         goto out;
1139                 }
1140
1141                 if (km_query(x, tmpl, pol) == 0) {
1142                         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1143                         x->km.state = XFRM_STATE_ACQ;
1144                         list_add(&x->km.all, &net->xfrm.state_all);
1145                         hlist_add_head_rcu(&x->bydst, net->xfrm.state_bydst + h);
1146                         h = xfrm_src_hash(net, daddr, saddr, encap_family);
1147                         hlist_add_head_rcu(&x->bysrc, net->xfrm.state_bysrc + h);
1148                         if (x->id.spi) {
1149                                 h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, encap_family);
1150                                 hlist_add_head_rcu(&x->byspi, net->xfrm.state_byspi + h);
1151                         }
1152                         x->lft.hard_add_expires_seconds = net->xfrm.sysctl_acq_expires;
1153                         hrtimer_start(&x->mtimer,
1154                                       ktime_set(net->xfrm.sysctl_acq_expires, 0),
1155                                       HRTIMER_MODE_REL_SOFT);
1156                         net->xfrm.state_num++;
1157                         xfrm_hash_grow_check(net, x->bydst.next != NULL);
1158                         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1159                 } else {
1160                         x->km.state = XFRM_STATE_DEAD;
1161                         to_put = x;
1162                         x = NULL;
1163                         error = -ESRCH;
1164                 }
1165         }
1166 out:
1167         if (x) {
1168                 if (!xfrm_state_hold_rcu(x)) {
1169                         *err = -EAGAIN;
1170                         x = NULL;
1171                 }
1172         } else {
1173                 *err = acquire_in_progress ? -EAGAIN : error;
1174         }
1175         rcu_read_unlock();
1176         if (to_put)
1177                 xfrm_state_put(to_put);
1178
1179         if (read_seqcount_retry(&xfrm_state_hash_generation, sequence)) {
1180                 *err = -EAGAIN;
1181                 if (x) {
1182                         xfrm_state_put(x);
1183                         x = NULL;
1184                 }
1185         }
1186
1187         return x;
1188 }
1189
1190 struct xfrm_state *
1191 xfrm_stateonly_find(struct net *net, u32 mark, u32 if_id,
1192                     xfrm_address_t *daddr, xfrm_address_t *saddr,
1193                     unsigned short family, u8 mode, u8 proto, u32 reqid)
1194 {
1195         unsigned int h;
1196         struct xfrm_state *rx = NULL, *x = NULL;
1197
1198         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1199         h = xfrm_dst_hash(net, daddr, saddr, reqid, family);
1200         hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) {
1201                 if (x->props.family == family &&
1202                     x->props.reqid == reqid &&
1203                     (mark & x->mark.m) == x->mark.v &&
1204                     x->if_id == if_id &&
1205                     !(x->props.flags & XFRM_STATE_WILDRECV) &&
1206                     xfrm_state_addr_check(x, daddr, saddr, family) &&
1207                     mode == x->props.mode &&
1208                     proto == x->id.proto &&
1209                     x->km.state == XFRM_STATE_VALID) {
1210                         rx = x;
1211                         break;
1212                 }
1213         }
1214
1215         if (rx)
1216                 xfrm_state_hold(rx);
1217         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1218
1219
1220         return rx;
1221 }
1222 EXPORT_SYMBOL(xfrm_stateonly_find);
1223
1224 struct xfrm_state *xfrm_state_lookup_byspi(struct net *net, __be32 spi,
1225                                               unsigned short family)
1226 {
1227         struct xfrm_state *x;
1228         struct xfrm_state_walk *w;
1229
1230         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1231         list_for_each_entry(w, &net->xfrm.state_all, all) {
1232                 x = container_of(w, struct xfrm_state, km);
1233                 if (x->props.family != family ||
1234                         x->id.spi != spi)
1235                         continue;
1236
1237                 xfrm_state_hold(x);
1238                 spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1239                 return x;
1240         }
1241         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1242         return NULL;
1243 }
1244 EXPORT_SYMBOL(xfrm_state_lookup_byspi);
1245
1246 static void __xfrm_state_insert(struct xfrm_state *x)
1247 {
1248         struct net *net = xs_net(x);
1249         unsigned int h;
1250
1251         list_add(&x->km.all, &net->xfrm.state_all);
1252
1253         h = xfrm_dst_hash(net, &x->id.daddr, &x->props.saddr,
1254                           x->props.reqid, x->props.family);
1255         hlist_add_head_rcu(&x->bydst, net->xfrm.state_bydst + h);
1256
1257         h = xfrm_src_hash(net, &x->id.daddr, &x->props.saddr, x->props.family);
1258         hlist_add_head_rcu(&x->bysrc, net->xfrm.state_bysrc + h);
1259
1260         if (x->id.spi) {
1261                 h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto,
1262                                   x->props.family);
1263
1264                 hlist_add_head_rcu(&x->byspi, net->xfrm.state_byspi + h);
1265         }
1266
1267         hrtimer_start(&x->mtimer, ktime_set(1, 0), HRTIMER_MODE_REL_SOFT);
1268         if (x->replay_maxage)
1269                 mod_timer(&x->rtimer, jiffies + x->replay_maxage);
1270
1271         net->xfrm.state_num++;
1272
1273         xfrm_hash_grow_check(net, x->bydst.next != NULL);
1274 }
1275
1276 /* net->xfrm.xfrm_state_lock is held */
1277 static void __xfrm_state_bump_genids(struct xfrm_state *xnew)
1278 {
1279         struct net *net = xs_net(xnew);
1280         unsigned short family = xnew->props.family;
1281         u32 reqid = xnew->props.reqid;
1282         struct xfrm_state *x;
1283         unsigned int h;
1284         u32 mark = xnew->mark.v & xnew->mark.m;
1285         u32 if_id = xnew->if_id;
1286
1287         h = xfrm_dst_hash(net, &xnew->id.daddr, &xnew->props.saddr, reqid, family);
1288         hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) {
1289                 if (x->props.family     == family &&
1290                     x->props.reqid      == reqid &&
1291                     x->if_id            == if_id &&
1292                     (mark & x->mark.m) == x->mark.v &&
1293                     xfrm_addr_equal(&x->id.daddr, &xnew->id.daddr, family) &&
1294                     xfrm_addr_equal(&x->props.saddr, &xnew->props.saddr, family))
1295                         x->genid++;
1296         }
1297 }
1298
1299 void xfrm_state_insert(struct xfrm_state *x)
1300 {
1301         struct net *net = xs_net(x);
1302
1303         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1304         __xfrm_state_bump_genids(x);
1305         __xfrm_state_insert(x);
1306         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1307 }
1308 EXPORT_SYMBOL(xfrm_state_insert);
1309
1310 /* net->xfrm.xfrm_state_lock is held */
1311 static struct xfrm_state *__find_acq_core(struct net *net,
1312                                           const struct xfrm_mark *m,
1313                                           unsigned short family, u8 mode,
1314                                           u32 reqid, u32 if_id, u8 proto,
1315                                           const xfrm_address_t *daddr,
1316                                           const xfrm_address_t *saddr,
1317                                           int create)
1318 {
1319         unsigned int h = xfrm_dst_hash(net, daddr, saddr, reqid, family);
1320         struct xfrm_state *x;
1321         u32 mark = m->v & m->m;
1322
1323         hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) {
1324                 if (x->props.reqid  != reqid ||
1325                     x->props.mode   != mode ||
1326                     x->props.family != family ||
1327                     x->km.state     != XFRM_STATE_ACQ ||
1328                     x->id.spi       != 0 ||
1329                     x->id.proto     != proto ||
1330                     (mark & x->mark.m) != x->mark.v ||
1331                     !xfrm_addr_equal(&x->id.daddr, daddr, family) ||
1332                     !xfrm_addr_equal(&x->props.saddr, saddr, family))
1333                         continue;
1334
1335                 xfrm_state_hold(x);
1336                 return x;
1337         }
1338
1339         if (!create)
1340                 return NULL;
1341
1342         x = xfrm_state_alloc(net);
1343         if (likely(x)) {
1344                 switch (family) {
1345                 case AF_INET:
1346                         x->sel.daddr.a4 = daddr->a4;
1347                         x->sel.saddr.a4 = saddr->a4;
1348                         x->sel.prefixlen_d = 32;
1349                         x->sel.prefixlen_s = 32;
1350                         x->props.saddr.a4 = saddr->a4;
1351                         x->id.daddr.a4 = daddr->a4;
1352                         break;
1353
1354                 case AF_INET6:
1355                         x->sel.daddr.in6 = daddr->in6;
1356                         x->sel.saddr.in6 = saddr->in6;
1357                         x->sel.prefixlen_d = 128;
1358                         x->sel.prefixlen_s = 128;
1359                         x->props.saddr.in6 = saddr->in6;
1360                         x->id.daddr.in6 = daddr->in6;
1361                         break;
1362                 }
1363
1364                 x->km.state = XFRM_STATE_ACQ;
1365                 x->id.proto = proto;
1366                 x->props.family = family;
1367                 x->props.mode = mode;
1368                 x->props.reqid = reqid;
1369                 x->if_id = if_id;
1370                 x->mark.v = m->v;
1371                 x->mark.m = m->m;
1372                 x->lft.hard_add_expires_seconds = net->xfrm.sysctl_acq_expires;
1373                 xfrm_state_hold(x);
1374                 hrtimer_start(&x->mtimer,
1375                               ktime_set(net->xfrm.sysctl_acq_expires, 0),
1376                               HRTIMER_MODE_REL_SOFT);
1377                 list_add(&x->km.all, &net->xfrm.state_all);
1378                 hlist_add_head_rcu(&x->bydst, net->xfrm.state_bydst + h);
1379                 h = xfrm_src_hash(net, daddr, saddr, family);
1380                 hlist_add_head_rcu(&x->bysrc, net->xfrm.state_bysrc + h);
1381
1382                 net->xfrm.state_num++;
1383
1384                 xfrm_hash_grow_check(net, x->bydst.next != NULL);
1385         }
1386
1387         return x;
1388 }
1389
1390 static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq);
1391
1392 int xfrm_state_add(struct xfrm_state *x)
1393 {
1394         struct net *net = xs_net(x);
1395         struct xfrm_state *x1, *to_put;
1396         int family;
1397         int err;
1398         u32 mark = x->mark.v & x->mark.m;
1399         int use_spi = xfrm_id_proto_match(x->id.proto, IPSEC_PROTO_ANY);
1400
1401         family = x->props.family;
1402
1403         to_put = NULL;
1404
1405         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1406
1407         x1 = __xfrm_state_locate(x, use_spi, family);
1408         if (x1) {
1409                 to_put = x1;
1410                 x1 = NULL;
1411                 err = -EEXIST;
1412                 goto out;
1413         }
1414
1415         if (use_spi && x->km.seq) {
1416                 x1 = __xfrm_find_acq_byseq(net, mark, x->km.seq);
1417                 if (x1 && ((x1->id.proto != x->id.proto) ||
1418                     !xfrm_addr_equal(&x1->id.daddr, &x->id.daddr, family))) {
1419                         to_put = x1;
1420                         x1 = NULL;
1421                 }
1422         }
1423
1424         if (use_spi && !x1)
1425                 x1 = __find_acq_core(net, &x->mark, family, x->props.mode,
1426                                      x->props.reqid, x->if_id, x->id.proto,
1427                                      &x->id.daddr, &x->props.saddr, 0);
1428
1429         __xfrm_state_bump_genids(x);
1430         __xfrm_state_insert(x);
1431         err = 0;
1432
1433 out:
1434         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1435
1436         if (x1) {
1437                 xfrm_state_delete(x1);
1438                 xfrm_state_put(x1);
1439         }
1440
1441         if (to_put)
1442                 xfrm_state_put(to_put);
1443
1444         return err;
1445 }
1446 EXPORT_SYMBOL(xfrm_state_add);
1447
1448 #ifdef CONFIG_XFRM_MIGRATE
1449 static inline int clone_security(struct xfrm_state *x, struct xfrm_sec_ctx *security)
1450 {
1451         struct xfrm_user_sec_ctx *uctx;
1452         int size = sizeof(*uctx) + security->ctx_len;
1453         int err;
1454
1455         uctx = kmalloc(size, GFP_KERNEL);
1456         if (!uctx)
1457                 return -ENOMEM;
1458
1459         uctx->exttype = XFRMA_SEC_CTX;
1460         uctx->len = size;
1461         uctx->ctx_doi = security->ctx_doi;
1462         uctx->ctx_alg = security->ctx_alg;
1463         uctx->ctx_len = security->ctx_len;
1464         memcpy(uctx + 1, security->ctx_str, security->ctx_len);
1465         err = security_xfrm_state_alloc(x, uctx);
1466         kfree(uctx);
1467         if (err)
1468                 return err;
1469
1470         return 0;
1471 }
1472
1473 static struct xfrm_state *xfrm_state_clone(struct xfrm_state *orig,
1474                                            struct xfrm_encap_tmpl *encap)
1475 {
1476         struct net *net = xs_net(orig);
1477         struct xfrm_state *x = xfrm_state_alloc(net);
1478         if (!x)
1479                 goto out;
1480
1481         memcpy(&x->id, &orig->id, sizeof(x->id));
1482         memcpy(&x->sel, &orig->sel, sizeof(x->sel));
1483         memcpy(&x->lft, &orig->lft, sizeof(x->lft));
1484         x->props.mode = orig->props.mode;
1485         x->props.replay_window = orig->props.replay_window;
1486         x->props.reqid = orig->props.reqid;
1487         x->props.family = orig->props.family;
1488         x->props.saddr = orig->props.saddr;
1489
1490         if (orig->aalg) {
1491                 x->aalg = xfrm_algo_auth_clone(orig->aalg);
1492                 if (!x->aalg)
1493                         goto error;
1494         }
1495         x->props.aalgo = orig->props.aalgo;
1496
1497         if (orig->aead) {
1498                 x->aead = xfrm_algo_aead_clone(orig->aead);
1499                 x->geniv = orig->geniv;
1500                 if (!x->aead)
1501                         goto error;
1502         }
1503         if (orig->ealg) {
1504                 x->ealg = xfrm_algo_clone(orig->ealg);
1505                 if (!x->ealg)
1506                         goto error;
1507         }
1508         x->props.ealgo = orig->props.ealgo;
1509
1510         if (orig->calg) {
1511                 x->calg = xfrm_algo_clone(orig->calg);
1512                 if (!x->calg)
1513                         goto error;
1514         }
1515         x->props.calgo = orig->props.calgo;
1516
1517         if (encap || orig->encap) {
1518                 if (encap)
1519                         x->encap = kmemdup(encap, sizeof(*x->encap),
1520                                         GFP_KERNEL);
1521                 else
1522                         x->encap = kmemdup(orig->encap, sizeof(*x->encap),
1523                                         GFP_KERNEL);
1524
1525                 if (!x->encap)
1526                         goto error;
1527         }
1528
1529         if (orig->security)
1530                 if (clone_security(x, orig->security))
1531                         goto error;
1532
1533         if (orig->coaddr) {
1534                 x->coaddr = kmemdup(orig->coaddr, sizeof(*x->coaddr),
1535                                     GFP_KERNEL);
1536                 if (!x->coaddr)
1537                         goto error;
1538         }
1539
1540         if (orig->replay_esn) {
1541                 if (xfrm_replay_clone(x, orig))
1542                         goto error;
1543         }
1544
1545         memcpy(&x->mark, &orig->mark, sizeof(x->mark));
1546         memcpy(&x->props.smark, &orig->props.smark, sizeof(x->props.smark));
1547
1548         if (xfrm_init_state(x) < 0)
1549                 goto error;
1550
1551         x->props.flags = orig->props.flags;
1552         x->props.extra_flags = orig->props.extra_flags;
1553
1554         x->if_id = orig->if_id;
1555         x->tfcpad = orig->tfcpad;
1556         x->replay_maxdiff = orig->replay_maxdiff;
1557         x->replay_maxage = orig->replay_maxage;
1558         memcpy(&x->curlft, &orig->curlft, sizeof(x->curlft));
1559         x->km.state = orig->km.state;
1560         x->km.seq = orig->km.seq;
1561         x->replay = orig->replay;
1562         x->preplay = orig->preplay;
1563
1564         return x;
1565
1566  error:
1567         xfrm_state_put(x);
1568 out:
1569         return NULL;
1570 }
1571
1572 struct xfrm_state *xfrm_migrate_state_find(struct xfrm_migrate *m, struct net *net)
1573 {
1574         unsigned int h;
1575         struct xfrm_state *x = NULL;
1576
1577         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1578
1579         if (m->reqid) {
1580                 h = xfrm_dst_hash(net, &m->old_daddr, &m->old_saddr,
1581                                   m->reqid, m->old_family);
1582                 hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) {
1583                         if (x->props.mode != m->mode ||
1584                             x->id.proto != m->proto)
1585                                 continue;
1586                         if (m->reqid && x->props.reqid != m->reqid)
1587                                 continue;
1588                         if (!xfrm_addr_equal(&x->id.daddr, &m->old_daddr,
1589                                              m->old_family) ||
1590                             !xfrm_addr_equal(&x->props.saddr, &m->old_saddr,
1591                                              m->old_family))
1592                                 continue;
1593                         xfrm_state_hold(x);
1594                         break;
1595                 }
1596         } else {
1597                 h = xfrm_src_hash(net, &m->old_daddr, &m->old_saddr,
1598                                   m->old_family);
1599                 hlist_for_each_entry(x, net->xfrm.state_bysrc+h, bysrc) {
1600                         if (x->props.mode != m->mode ||
1601                             x->id.proto != m->proto)
1602                                 continue;
1603                         if (!xfrm_addr_equal(&x->id.daddr, &m->old_daddr,
1604                                              m->old_family) ||
1605                             !xfrm_addr_equal(&x->props.saddr, &m->old_saddr,
1606                                              m->old_family))
1607                                 continue;
1608                         xfrm_state_hold(x);
1609                         break;
1610                 }
1611         }
1612
1613         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1614
1615         return x;
1616 }
1617 EXPORT_SYMBOL(xfrm_migrate_state_find);
1618
1619 struct xfrm_state *xfrm_state_migrate(struct xfrm_state *x,
1620                                       struct xfrm_migrate *m,
1621                                       struct xfrm_encap_tmpl *encap)
1622 {
1623         struct xfrm_state *xc;
1624
1625         xc = xfrm_state_clone(x, encap);
1626         if (!xc)
1627                 return NULL;
1628
1629         memcpy(&xc->id.daddr, &m->new_daddr, sizeof(xc->id.daddr));
1630         memcpy(&xc->props.saddr, &m->new_saddr, sizeof(xc->props.saddr));
1631
1632         /* add state */
1633         if (xfrm_addr_equal(&x->id.daddr, &m->new_daddr, m->new_family)) {
1634                 /* a care is needed when the destination address of the
1635                    state is to be updated as it is a part of triplet */
1636                 xfrm_state_insert(xc);
1637         } else {
1638                 if (xfrm_state_add(xc) < 0)
1639                         goto error;
1640         }
1641
1642         return xc;
1643 error:
1644         xfrm_state_put(xc);
1645         return NULL;
1646 }
1647 EXPORT_SYMBOL(xfrm_state_migrate);
1648 #endif
1649
1650 int xfrm_state_update(struct xfrm_state *x)
1651 {
1652         struct xfrm_state *x1, *to_put;
1653         int err;
1654         int use_spi = xfrm_id_proto_match(x->id.proto, IPSEC_PROTO_ANY);
1655         struct net *net = xs_net(x);
1656
1657         to_put = NULL;
1658
1659         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1660         x1 = __xfrm_state_locate(x, use_spi, x->props.family);
1661
1662         err = -ESRCH;
1663         if (!x1)
1664                 goto out;
1665
1666         if (xfrm_state_kern(x1)) {
1667                 to_put = x1;
1668                 err = -EEXIST;
1669                 goto out;
1670         }
1671
1672         if (x1->km.state == XFRM_STATE_ACQ) {
1673                 __xfrm_state_insert(x);
1674                 x = NULL;
1675         }
1676         err = 0;
1677
1678 out:
1679         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1680
1681         if (to_put)
1682                 xfrm_state_put(to_put);
1683
1684         if (err)
1685                 return err;
1686
1687         if (!x) {
1688                 xfrm_state_delete(x1);
1689                 xfrm_state_put(x1);
1690                 return 0;
1691         }
1692
1693         err = -EINVAL;
1694         spin_lock_bh(&x1->lock);
1695         if (likely(x1->km.state == XFRM_STATE_VALID)) {
1696                 if (x->encap && x1->encap &&
1697                     x->encap->encap_type == x1->encap->encap_type)
1698                         memcpy(x1->encap, x->encap, sizeof(*x1->encap));
1699                 else if (x->encap || x1->encap)
1700                         goto fail;
1701
1702                 if (x->coaddr && x1->coaddr) {
1703                         memcpy(x1->coaddr, x->coaddr, sizeof(*x1->coaddr));
1704                 }
1705                 if (!use_spi && memcmp(&x1->sel, &x->sel, sizeof(x1->sel)))
1706                         memcpy(&x1->sel, &x->sel, sizeof(x1->sel));
1707                 memcpy(&x1->lft, &x->lft, sizeof(x1->lft));
1708                 x1->km.dying = 0;
1709
1710                 hrtimer_start(&x1->mtimer, ktime_set(1, 0),
1711                               HRTIMER_MODE_REL_SOFT);
1712                 if (x1->curlft.use_time)
1713                         xfrm_state_check_expire(x1);
1714
1715                 if (x->props.smark.m || x->props.smark.v || x->if_id) {
1716                         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1717
1718                         if (x->props.smark.m || x->props.smark.v)
1719                                 x1->props.smark = x->props.smark;
1720
1721                         if (x->if_id)
1722                                 x1->if_id = x->if_id;
1723
1724                         __xfrm_state_bump_genids(x1);
1725                         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1726                 }
1727
1728                 err = 0;
1729                 x->km.state = XFRM_STATE_DEAD;
1730                 __xfrm_state_put(x);
1731         }
1732
1733 fail:
1734         spin_unlock_bh(&x1->lock);
1735
1736         xfrm_state_put(x1);
1737
1738         return err;
1739 }
1740 EXPORT_SYMBOL(xfrm_state_update);
1741
1742 int xfrm_state_check_expire(struct xfrm_state *x)
1743 {
1744         if (!x->curlft.use_time)
1745                 x->curlft.use_time = ktime_get_real_seconds();
1746
1747         if (x->curlft.bytes >= x->lft.hard_byte_limit ||
1748             x->curlft.packets >= x->lft.hard_packet_limit) {
1749                 x->km.state = XFRM_STATE_EXPIRED;
1750                 hrtimer_start(&x->mtimer, 0, HRTIMER_MODE_REL_SOFT);
1751                 return -EINVAL;
1752         }
1753
1754         if (!x->km.dying &&
1755             (x->curlft.bytes >= x->lft.soft_byte_limit ||
1756              x->curlft.packets >= x->lft.soft_packet_limit)) {
1757                 x->km.dying = 1;
1758                 km_state_expired(x, 0, 0);
1759         }
1760         return 0;
1761 }
1762 EXPORT_SYMBOL(xfrm_state_check_expire);
1763
1764 struct xfrm_state *
1765 xfrm_state_lookup(struct net *net, u32 mark, const xfrm_address_t *daddr, __be32 spi,
1766                   u8 proto, unsigned short family)
1767 {
1768         struct xfrm_state *x;
1769
1770         rcu_read_lock();
1771         x = __xfrm_state_lookup(net, mark, daddr, spi, proto, family);
1772         rcu_read_unlock();
1773         return x;
1774 }
1775 EXPORT_SYMBOL(xfrm_state_lookup);
1776
1777 struct xfrm_state *
1778 xfrm_state_lookup_byaddr(struct net *net, u32 mark,
1779                          const xfrm_address_t *daddr, const xfrm_address_t *saddr,
1780                          u8 proto, unsigned short family)
1781 {
1782         struct xfrm_state *x;
1783
1784         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1785         x = __xfrm_state_lookup_byaddr(net, mark, daddr, saddr, proto, family);
1786         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1787         return x;
1788 }
1789 EXPORT_SYMBOL(xfrm_state_lookup_byaddr);
1790
1791 struct xfrm_state *
1792 xfrm_find_acq(struct net *net, const struct xfrm_mark *mark, u8 mode, u32 reqid,
1793               u32 if_id, u8 proto, const xfrm_address_t *daddr,
1794               const xfrm_address_t *saddr, int create, unsigned short family)
1795 {
1796         struct xfrm_state *x;
1797
1798         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1799         x = __find_acq_core(net, mark, family, mode, reqid, if_id, proto, daddr, saddr, create);
1800         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1801
1802         return x;
1803 }
1804 EXPORT_SYMBOL(xfrm_find_acq);
1805
1806 #ifdef CONFIG_XFRM_SUB_POLICY
1807 #if IS_ENABLED(CONFIG_IPV6)
1808 /* distribution counting sort function for xfrm_state and xfrm_tmpl */
1809 static void
1810 __xfrm6_sort(void **dst, void **src, int n,
1811              int (*cmp)(const void *p), int maxclass)
1812 {
1813         int count[XFRM_MAX_DEPTH] = { };
1814         int class[XFRM_MAX_DEPTH];
1815         int i;
1816
1817         for (i = 0; i < n; i++) {
1818                 int c = cmp(src[i]);
1819
1820                 class[i] = c;
1821                 count[c]++;
1822         }
1823
1824         for (i = 2; i < maxclass; i++)
1825                 count[i] += count[i - 1];
1826
1827         for (i = 0; i < n; i++) {
1828                 dst[count[class[i] - 1]++] = src[i];
1829                 src[i] = NULL;
1830         }
1831 }
1832
1833 /* Rule for xfrm_state:
1834  *
1835  * rule 1: select IPsec transport except AH
1836  * rule 2: select MIPv6 RO or inbound trigger
1837  * rule 3: select IPsec transport AH
1838  * rule 4: select IPsec tunnel
1839  * rule 5: others
1840  */
1841 static int __xfrm6_state_sort_cmp(const void *p)
1842 {
1843         const struct xfrm_state *v = p;
1844
1845         switch (v->props.mode) {
1846         case XFRM_MODE_TRANSPORT:
1847                 if (v->id.proto != IPPROTO_AH)
1848                         return 1;
1849                 else
1850                         return 3;
1851 #if IS_ENABLED(CONFIG_IPV6_MIP6)
1852         case XFRM_MODE_ROUTEOPTIMIZATION:
1853         case XFRM_MODE_IN_TRIGGER:
1854                 return 2;
1855 #endif
1856         case XFRM_MODE_TUNNEL:
1857         case XFRM_MODE_BEET:
1858                 return 4;
1859         }
1860         return 5;
1861 }
1862
1863 /* Rule for xfrm_tmpl:
1864  *
1865  * rule 1: select IPsec transport
1866  * rule 2: select MIPv6 RO or inbound trigger
1867  * rule 3: select IPsec tunnel
1868  * rule 4: others
1869  */
1870 static int __xfrm6_tmpl_sort_cmp(const void *p)
1871 {
1872         const struct xfrm_tmpl *v = p;
1873
1874         switch (v->mode) {
1875         case XFRM_MODE_TRANSPORT:
1876                 return 1;
1877 #if IS_ENABLED(CONFIG_IPV6_MIP6)
1878         case XFRM_MODE_ROUTEOPTIMIZATION:
1879         case XFRM_MODE_IN_TRIGGER:
1880                 return 2;
1881 #endif
1882         case XFRM_MODE_TUNNEL:
1883         case XFRM_MODE_BEET:
1884                 return 3;
1885         }
1886         return 4;
1887 }
1888 #else
1889 static inline int __xfrm6_state_sort_cmp(const void *p) { return 5; }
1890 static inline int __xfrm6_tmpl_sort_cmp(const void *p) { return 4; }
1891
1892 static inline void
1893 __xfrm6_sort(void **dst, void **src, int n,
1894              int (*cmp)(const void *p), int maxclass)
1895 {
1896         int i;
1897
1898         for (i = 0; i < n; i++)
1899                 dst[i] = src[i];
1900 }
1901 #endif /* CONFIG_IPV6 */
1902
1903 void
1904 xfrm_tmpl_sort(struct xfrm_tmpl **dst, struct xfrm_tmpl **src, int n,
1905                unsigned short family)
1906 {
1907         int i;
1908
1909         if (family == AF_INET6)
1910                 __xfrm6_sort((void **)dst, (void **)src, n,
1911                              __xfrm6_tmpl_sort_cmp, 5);
1912         else
1913                 for (i = 0; i < n; i++)
1914                         dst[i] = src[i];
1915 }
1916
1917 void
1918 xfrm_state_sort(struct xfrm_state **dst, struct xfrm_state **src, int n,
1919                 unsigned short family)
1920 {
1921         int i;
1922
1923         if (family == AF_INET6)
1924                 __xfrm6_sort((void **)dst, (void **)src, n,
1925                              __xfrm6_state_sort_cmp, 6);
1926         else
1927                 for (i = 0; i < n; i++)
1928                         dst[i] = src[i];
1929 }
1930 #endif
1931
1932 /* Silly enough, but I'm lazy to build resolution list */
1933
1934 static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq)
1935 {
1936         int i;
1937
1938         for (i = 0; i <= net->xfrm.state_hmask; i++) {
1939                 struct xfrm_state *x;
1940
1941                 hlist_for_each_entry(x, net->xfrm.state_bydst+i, bydst) {
1942                         if (x->km.seq == seq &&
1943                             (mark & x->mark.m) == x->mark.v &&
1944                             x->km.state == XFRM_STATE_ACQ) {
1945                                 xfrm_state_hold(x);
1946                                 return x;
1947                         }
1948                 }
1949         }
1950         return NULL;
1951 }
1952
1953 struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq)
1954 {
1955         struct xfrm_state *x;
1956
1957         spin_lock_bh(&net->xfrm.xfrm_state_lock);
1958         x = __xfrm_find_acq_byseq(net, mark, seq);
1959         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
1960         return x;
1961 }
1962 EXPORT_SYMBOL(xfrm_find_acq_byseq);
1963
1964 u32 xfrm_get_acqseq(void)
1965 {
1966         u32 res;
1967         static atomic_t acqseq;
1968
1969         do {
1970                 res = atomic_inc_return(&acqseq);
1971         } while (!res);
1972
1973         return res;
1974 }
1975 EXPORT_SYMBOL(xfrm_get_acqseq);
1976
1977 int verify_spi_info(u8 proto, u32 min, u32 max)
1978 {
1979         switch (proto) {
1980         case IPPROTO_AH:
1981         case IPPROTO_ESP:
1982                 break;
1983
1984         case IPPROTO_COMP:
1985                 /* IPCOMP spi is 16-bits. */
1986                 if (max >= 0x10000)
1987                         return -EINVAL;
1988                 break;
1989
1990         default:
1991                 return -EINVAL;
1992         }
1993
1994         if (min > max)
1995                 return -EINVAL;
1996
1997         return 0;
1998 }
1999 EXPORT_SYMBOL(verify_spi_info);
2000
2001 int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high)
2002 {
2003         struct net *net = xs_net(x);
2004         unsigned int h;
2005         struct xfrm_state *x0;
2006         int err = -ENOENT;
2007         __be32 minspi = htonl(low);
2008         __be32 maxspi = htonl(high);
2009         __be32 newspi = 0;
2010         u32 mark = x->mark.v & x->mark.m;
2011
2012         spin_lock_bh(&x->lock);
2013         if (x->km.state == XFRM_STATE_DEAD)
2014                 goto unlock;
2015
2016         err = 0;
2017         if (x->id.spi)
2018                 goto unlock;
2019
2020         err = -ENOENT;
2021
2022         if (minspi == maxspi) {
2023                 x0 = xfrm_state_lookup(net, mark, &x->id.daddr, minspi, x->id.proto, x->props.family);
2024                 if (x0) {
2025                         xfrm_state_put(x0);
2026                         goto unlock;
2027                 }
2028                 newspi = minspi;
2029         } else {
2030                 u32 spi = 0;
2031                 for (h = 0; h < high-low+1; h++) {
2032                         spi = low + prandom_u32()%(high-low+1);
2033                         x0 = xfrm_state_lookup(net, mark, &x->id.daddr, htonl(spi), x->id.proto, x->props.family);
2034                         if (x0 == NULL) {
2035                                 newspi = htonl(spi);
2036                                 break;
2037                         }
2038                         xfrm_state_put(x0);
2039                 }
2040         }
2041         if (newspi) {
2042                 spin_lock_bh(&net->xfrm.xfrm_state_lock);
2043                 x->id.spi = newspi;
2044                 h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, x->props.family);
2045                 hlist_add_head_rcu(&x->byspi, net->xfrm.state_byspi + h);
2046                 spin_unlock_bh(&net->xfrm.xfrm_state_lock);
2047
2048                 err = 0;
2049         }
2050
2051 unlock:
2052         spin_unlock_bh(&x->lock);
2053
2054         return err;
2055 }
2056 EXPORT_SYMBOL(xfrm_alloc_spi);
2057
2058 static bool __xfrm_state_filter_match(struct xfrm_state *x,
2059                                       struct xfrm_address_filter *filter)
2060 {
2061         if (filter) {
2062                 if ((filter->family == AF_INET ||
2063                      filter->family == AF_INET6) &&
2064                     x->props.family != filter->family)
2065                         return false;
2066
2067                 return addr_match(&x->props.saddr, &filter->saddr,
2068                                   filter->splen) &&
2069                        addr_match(&x->id.daddr, &filter->daddr,
2070                                   filter->dplen);
2071         }
2072         return true;
2073 }
2074
2075 int xfrm_state_walk(struct net *net, struct xfrm_state_walk *walk,
2076                     int (*func)(struct xfrm_state *, int, void*),
2077                     void *data)
2078 {
2079         struct xfrm_state *state;
2080         struct xfrm_state_walk *x;
2081         int err = 0;
2082
2083         if (walk->seq != 0 && list_empty(&walk->all))
2084                 return 0;
2085
2086         spin_lock_bh(&net->xfrm.xfrm_state_lock);
2087         if (list_empty(&walk->all))
2088                 x = list_first_entry(&net->xfrm.state_all, struct xfrm_state_walk, all);
2089         else
2090                 x = list_first_entry(&walk->all, struct xfrm_state_walk, all);
2091         list_for_each_entry_from(x, &net->xfrm.state_all, all) {
2092                 if (x->state == XFRM_STATE_DEAD)
2093                         continue;
2094                 state = container_of(x, struct xfrm_state, km);
2095                 if (!xfrm_id_proto_match(state->id.proto, walk->proto))
2096                         continue;
2097                 if (!__xfrm_state_filter_match(state, walk->filter))
2098                         continue;
2099                 err = func(state, walk->seq, data);
2100                 if (err) {
2101                         list_move_tail(&walk->all, &x->all);
2102                         goto out;
2103                 }
2104                 walk->seq++;
2105         }
2106         if (walk->seq == 0) {
2107                 err = -ENOENT;
2108                 goto out;
2109         }
2110         list_del_init(&walk->all);
2111 out:
2112         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
2113         return err;
2114 }
2115 EXPORT_SYMBOL(xfrm_state_walk);
2116
2117 void xfrm_state_walk_init(struct xfrm_state_walk *walk, u8 proto,
2118                           struct xfrm_address_filter *filter)
2119 {
2120         INIT_LIST_HEAD(&walk->all);
2121         walk->proto = proto;
2122         walk->state = XFRM_STATE_DEAD;
2123         walk->seq = 0;
2124         walk->filter = filter;
2125 }
2126 EXPORT_SYMBOL(xfrm_state_walk_init);
2127
2128 void xfrm_state_walk_done(struct xfrm_state_walk *walk, struct net *net)
2129 {
2130         kfree(walk->filter);
2131
2132         if (list_empty(&walk->all))
2133                 return;
2134
2135         spin_lock_bh(&net->xfrm.xfrm_state_lock);
2136         list_del(&walk->all);
2137         spin_unlock_bh(&net->xfrm.xfrm_state_lock);
2138 }
2139 EXPORT_SYMBOL(xfrm_state_walk_done);
2140
2141 static void xfrm_replay_timer_handler(struct timer_list *t)
2142 {
2143         struct xfrm_state *x = from_timer(x, t, rtimer);
2144
2145         spin_lock(&x->lock);
2146
2147         if (x->km.state == XFRM_STATE_VALID) {
2148                 if (xfrm_aevent_is_on(xs_net(x)))
2149                         x->repl->notify(x, XFRM_REPLAY_TIMEOUT);
2150                 else
2151                         x->xflags |= XFRM_TIME_DEFER;
2152         }
2153
2154         spin_unlock(&x->lock);
2155 }
2156
2157 static LIST_HEAD(xfrm_km_list);
2158
2159 void km_policy_notify(struct xfrm_policy *xp, int dir, const struct km_event *c)
2160 {
2161         struct xfrm_mgr *km;
2162
2163         rcu_read_lock();
2164         list_for_each_entry_rcu(km, &xfrm_km_list, list)
2165                 if (km->notify_policy)
2166                         km->notify_policy(xp, dir, c);
2167         rcu_read_unlock();
2168 }
2169
2170 void km_state_notify(struct xfrm_state *x, const struct km_event *c)
2171 {
2172         struct xfrm_mgr *km;
2173         rcu_read_lock();
2174         list_for_each_entry_rcu(km, &xfrm_km_list, list)
2175                 if (km->notify)
2176                         km->notify(x, c);
2177         rcu_read_unlock();
2178 }
2179
2180 EXPORT_SYMBOL(km_policy_notify);
2181 EXPORT_SYMBOL(km_state_notify);
2182
2183 void km_state_expired(struct xfrm_state *x, int hard, u32 portid)
2184 {
2185         struct km_event c;
2186
2187         c.data.hard = hard;
2188         c.portid = portid;
2189         c.event = XFRM_MSG_EXPIRE;
2190         km_state_notify(x, &c);
2191 }
2192
2193 EXPORT_SYMBOL(km_state_expired);
2194 /*
2195  * We send to all registered managers regardless of failure
2196  * We are happy with one success
2197 */
2198 int km_query(struct xfrm_state *x, struct xfrm_tmpl *t, struct xfrm_policy *pol)
2199 {
2200         int err = -EINVAL, acqret;
2201         struct xfrm_mgr *km;
2202
2203         rcu_read_lock();
2204         list_for_each_entry_rcu(km, &xfrm_km_list, list) {
2205                 acqret = km->acquire(x, t, pol);
2206                 if (!acqret)
2207                         err = acqret;
2208         }
2209         rcu_read_unlock();
2210         return err;
2211 }
2212 EXPORT_SYMBOL(km_query);
2213
2214 int km_new_mapping(struct xfrm_state *x, xfrm_address_t *ipaddr, __be16 sport)
2215 {
2216         int err = -EINVAL;
2217         struct xfrm_mgr *km;
2218
2219         rcu_read_lock();
2220         list_for_each_entry_rcu(km, &xfrm_km_list, list) {
2221                 if (km->new_mapping)
2222                         err = km->new_mapping(x, ipaddr, sport);
2223                 if (!err)
2224                         break;
2225         }
2226         rcu_read_unlock();
2227         return err;
2228 }
2229 EXPORT_SYMBOL(km_new_mapping);
2230
2231 void km_policy_expired(struct xfrm_policy *pol, int dir, int hard, u32 portid)
2232 {
2233         struct km_event c;
2234
2235         c.data.hard = hard;
2236         c.portid = portid;
2237         c.event = XFRM_MSG_POLEXPIRE;
2238         km_policy_notify(pol, dir, &c);
2239 }
2240 EXPORT_SYMBOL(km_policy_expired);
2241
2242 #ifdef CONFIG_XFRM_MIGRATE
2243 int km_migrate(const struct xfrm_selector *sel, u8 dir, u8 type,
2244                const struct xfrm_migrate *m, int num_migrate,
2245                const struct xfrm_kmaddress *k,
2246                const struct xfrm_encap_tmpl *encap)
2247 {
2248         int err = -EINVAL;
2249         int ret;
2250         struct xfrm_mgr *km;
2251
2252         rcu_read_lock();
2253         list_for_each_entry_rcu(km, &xfrm_km_list, list) {
2254                 if (km->migrate) {
2255                         ret = km->migrate(sel, dir, type, m, num_migrate, k,
2256                                           encap);
2257                         if (!ret)
2258                                 err = ret;
2259                 }
2260         }
2261         rcu_read_unlock();
2262         return err;
2263 }
2264 EXPORT_SYMBOL(km_migrate);
2265 #endif
2266
2267 int km_report(struct net *net, u8 proto, struct xfrm_selector *sel, xfrm_address_t *addr)
2268 {
2269         int err = -EINVAL;
2270         int ret;
2271         struct xfrm_mgr *km;
2272
2273         rcu_read_lock();
2274         list_for_each_entry_rcu(km, &xfrm_km_list, list) {
2275                 if (km->report) {
2276                         ret = km->report(net, proto, sel, addr);
2277                         if (!ret)
2278                                 err = ret;
2279                 }
2280         }
2281         rcu_read_unlock();
2282         return err;
2283 }
2284 EXPORT_SYMBOL(km_report);
2285
2286 static bool km_is_alive(const struct km_event *c)
2287 {
2288         struct xfrm_mgr *km;
2289         bool is_alive = false;
2290
2291         rcu_read_lock();
2292         list_for_each_entry_rcu(km, &xfrm_km_list, list) {
2293                 if (km->is_alive && km->is_alive(c)) {
2294                         is_alive = true;
2295                         break;
2296                 }
2297         }
2298         rcu_read_unlock();
2299
2300         return is_alive;
2301 }
2302
2303 #if IS_ENABLED(CONFIG_XFRM_USER_COMPAT)
2304 static DEFINE_SPINLOCK(xfrm_translator_lock);
2305 static struct xfrm_translator __rcu *xfrm_translator;
2306
2307 struct xfrm_translator *xfrm_get_translator(void)
2308 {
2309         struct xfrm_translator *xtr;
2310
2311         rcu_read_lock();
2312         xtr = rcu_dereference(xfrm_translator);
2313         if (unlikely(!xtr))
2314                 goto out;
2315         if (!try_module_get(xtr->owner))
2316                 xtr = NULL;
2317 out:
2318         rcu_read_unlock();
2319         return xtr;
2320 }
2321 EXPORT_SYMBOL_GPL(xfrm_get_translator);
2322
2323 void xfrm_put_translator(struct xfrm_translator *xtr)
2324 {
2325         module_put(xtr->owner);
2326 }
2327 EXPORT_SYMBOL_GPL(xfrm_put_translator);
2328
2329 int xfrm_register_translator(struct xfrm_translator *xtr)
2330 {
2331         int err = 0;
2332
2333         spin_lock_bh(&xfrm_translator_lock);
2334         if (unlikely(xfrm_translator != NULL))
2335                 err = -EEXIST;
2336         else
2337                 rcu_assign_pointer(xfrm_translator, xtr);
2338         spin_unlock_bh(&xfrm_translator_lock);
2339
2340         return err;
2341 }
2342 EXPORT_SYMBOL_GPL(xfrm_register_translator);
2343
2344 int xfrm_unregister_translator(struct xfrm_translator *xtr)
2345 {
2346         int err = 0;
2347
2348         spin_lock_bh(&xfrm_translator_lock);
2349         if (likely(xfrm_translator != NULL)) {
2350                 if (rcu_access_pointer(xfrm_translator) != xtr)
2351                         err = -EINVAL;
2352                 else
2353                         RCU_INIT_POINTER(xfrm_translator, NULL);
2354         }
2355         spin_unlock_bh(&xfrm_translator_lock);
2356         synchronize_rcu();
2357
2358         return err;
2359 }
2360 EXPORT_SYMBOL_GPL(xfrm_unregister_translator);
2361 #endif
2362
2363 int xfrm_user_policy(struct sock *sk, int optname, sockptr_t optval, int optlen)
2364 {
2365         int err;
2366         u8 *data;
2367         struct xfrm_mgr *km;
2368         struct xfrm_policy *pol = NULL;
2369
2370         if (sockptr_is_null(optval) && !optlen) {
2371                 xfrm_sk_policy_insert(sk, XFRM_POLICY_IN, NULL);
2372                 xfrm_sk_policy_insert(sk, XFRM_POLICY_OUT, NULL);
2373                 __sk_dst_reset(sk);
2374                 return 0;
2375         }
2376
2377         if (optlen <= 0 || optlen > PAGE_SIZE)
2378                 return -EMSGSIZE;
2379
2380         data = memdup_sockptr(optval, optlen);
2381         if (IS_ERR(data))
2382                 return PTR_ERR(data);
2383
2384         if (in_compat_syscall()) {
2385                 struct xfrm_translator *xtr = xfrm_get_translator();
2386
2387                 if (!xtr) {
2388                         kfree(data);
2389                         return -EOPNOTSUPP;
2390                 }
2391
2392                 err = xtr->xlate_user_policy_sockptr(&data, optlen);
2393                 xfrm_put_translator(xtr);
2394                 if (err) {
2395                         kfree(data);
2396                         return err;
2397                 }
2398         }
2399
2400         err = -EINVAL;
2401         rcu_read_lock();
2402         list_for_each_entry_rcu(km, &xfrm_km_list, list) {
2403                 pol = km->compile_policy(sk, optname, data,
2404                                          optlen, &err);
2405                 if (err >= 0)
2406                         break;
2407         }
2408         rcu_read_unlock();
2409
2410         if (err >= 0) {
2411                 xfrm_sk_policy_insert(sk, err, pol);
2412                 xfrm_pol_put(pol);
2413                 __sk_dst_reset(sk);
2414                 err = 0;
2415         }
2416
2417         kfree(data);
2418         return err;
2419 }
2420 EXPORT_SYMBOL(xfrm_user_policy);
2421
2422 static DEFINE_SPINLOCK(xfrm_km_lock);
2423
2424 int xfrm_register_km(struct xfrm_mgr *km)
2425 {
2426         spin_lock_bh(&xfrm_km_lock);
2427         list_add_tail_rcu(&km->list, &xfrm_km_list);
2428         spin_unlock_bh(&xfrm_km_lock);
2429         return 0;
2430 }
2431 EXPORT_SYMBOL(xfrm_register_km);
2432
2433 int xfrm_unregister_km(struct xfrm_mgr *km)
2434 {
2435         spin_lock_bh(&xfrm_km_lock);
2436         list_del_rcu(&km->list);
2437         spin_unlock_bh(&xfrm_km_lock);
2438         synchronize_rcu();
2439         return 0;
2440 }
2441 EXPORT_SYMBOL(xfrm_unregister_km);
2442
2443 int xfrm_state_register_afinfo(struct xfrm_state_afinfo *afinfo)
2444 {
2445         int err = 0;
2446
2447         if (WARN_ON(afinfo->family >= NPROTO))
2448                 return -EAFNOSUPPORT;
2449
2450         spin_lock_bh(&xfrm_state_afinfo_lock);
2451         if (unlikely(xfrm_state_afinfo[afinfo->family] != NULL))
2452                 err = -EEXIST;
2453         else
2454                 rcu_assign_pointer(xfrm_state_afinfo[afinfo->family], afinfo);
2455         spin_unlock_bh(&xfrm_state_afinfo_lock);
2456         return err;
2457 }
2458 EXPORT_SYMBOL(xfrm_state_register_afinfo);
2459
2460 int xfrm_state_unregister_afinfo(struct xfrm_state_afinfo *afinfo)
2461 {
2462         int err = 0, family = afinfo->family;
2463
2464         if (WARN_ON(family >= NPROTO))
2465                 return -EAFNOSUPPORT;
2466
2467         spin_lock_bh(&xfrm_state_afinfo_lock);
2468         if (likely(xfrm_state_afinfo[afinfo->family] != NULL)) {
2469                 if (rcu_access_pointer(xfrm_state_afinfo[family]) != afinfo)
2470                         err = -EINVAL;
2471                 else
2472                         RCU_INIT_POINTER(xfrm_state_afinfo[afinfo->family], NULL);
2473         }
2474         spin_unlock_bh(&xfrm_state_afinfo_lock);
2475         synchronize_rcu();
2476         return err;
2477 }
2478 EXPORT_SYMBOL(xfrm_state_unregister_afinfo);
2479
2480 struct xfrm_state_afinfo *xfrm_state_afinfo_get_rcu(unsigned int family)
2481 {
2482         if (unlikely(family >= NPROTO))
2483                 return NULL;
2484
2485         return rcu_dereference(xfrm_state_afinfo[family]);
2486 }
2487 EXPORT_SYMBOL_GPL(xfrm_state_afinfo_get_rcu);
2488
2489 struct xfrm_state_afinfo *xfrm_state_get_afinfo(unsigned int family)
2490 {
2491         struct xfrm_state_afinfo *afinfo;
2492         if (unlikely(family >= NPROTO))
2493                 return NULL;
2494         rcu_read_lock();
2495         afinfo = rcu_dereference(xfrm_state_afinfo[family]);
2496         if (unlikely(!afinfo))
2497                 rcu_read_unlock();
2498         return afinfo;
2499 }
2500
2501 void xfrm_flush_gc(void)
2502 {
2503         flush_work(&xfrm_state_gc_work);
2504 }
2505 EXPORT_SYMBOL(xfrm_flush_gc);
2506
2507 /* Temporarily located here until net/xfrm/xfrm_tunnel.c is created */
2508 void xfrm_state_delete_tunnel(struct xfrm_state *x)
2509 {
2510         if (x->tunnel) {
2511                 struct xfrm_state *t = x->tunnel;
2512
2513                 if (atomic_read(&t->tunnel_users) == 2)
2514                         xfrm_state_delete(t);
2515                 atomic_dec(&t->tunnel_users);
2516                 xfrm_state_put_sync(t);
2517                 x->tunnel = NULL;
2518         }
2519 }
2520 EXPORT_SYMBOL(xfrm_state_delete_tunnel);
2521
2522 u32 xfrm_state_mtu(struct xfrm_state *x, int mtu)
2523 {
2524         const struct xfrm_type *type = READ_ONCE(x->type);
2525         struct crypto_aead *aead;
2526         u32 blksize, net_adj = 0;
2527
2528         if (x->km.state != XFRM_STATE_VALID ||
2529             !type || type->proto != IPPROTO_ESP)
2530                 return mtu - x->props.header_len;
2531
2532         aead = x->data;
2533         blksize = ALIGN(crypto_aead_blocksize(aead), 4);
2534
2535         switch (x->props.mode) {
2536         case XFRM_MODE_TRANSPORT:
2537         case XFRM_MODE_BEET:
2538                 if (x->props.family == AF_INET)
2539                         net_adj = sizeof(struct iphdr);
2540                 else if (x->props.family == AF_INET6)
2541                         net_adj = sizeof(struct ipv6hdr);
2542                 break;
2543         case XFRM_MODE_TUNNEL:
2544                 break;
2545         default:
2546                 WARN_ON_ONCE(1);
2547                 break;
2548         }
2549
2550         return ((mtu - x->props.header_len - crypto_aead_authsize(aead) -
2551                  net_adj) & ~(blksize - 1)) + net_adj - 2;
2552 }
2553 EXPORT_SYMBOL_GPL(xfrm_state_mtu);
2554
2555 int __xfrm_init_state(struct xfrm_state *x, bool init_replay, bool offload)
2556 {
2557         const struct xfrm_mode *inner_mode;
2558         const struct xfrm_mode *outer_mode;
2559         int family = x->props.family;
2560         int err;
2561
2562         if (family == AF_INET &&
2563             xs_net(x)->ipv4.sysctl_ip_no_pmtu_disc)
2564                 x->props.flags |= XFRM_STATE_NOPMTUDISC;
2565
2566         err = -EPROTONOSUPPORT;
2567
2568         if (x->sel.family != AF_UNSPEC) {
2569                 inner_mode = xfrm_get_mode(x->props.mode, x->sel.family);
2570                 if (inner_mode == NULL)
2571                         goto error;
2572
2573                 if (!(inner_mode->flags & XFRM_MODE_FLAG_TUNNEL) &&
2574                     family != x->sel.family)
2575                         goto error;
2576
2577                 x->inner_mode = *inner_mode;
2578         } else {
2579                 const struct xfrm_mode *inner_mode_iaf;
2580                 int iafamily = AF_INET;
2581
2582                 inner_mode = xfrm_get_mode(x->props.mode, x->props.family);
2583                 if (inner_mode == NULL)
2584                         goto error;
2585
2586                 if (!(inner_mode->flags & XFRM_MODE_FLAG_TUNNEL))
2587                         goto error;
2588
2589                 x->inner_mode = *inner_mode;
2590
2591                 if (x->props.family == AF_INET)
2592                         iafamily = AF_INET6;
2593
2594                 inner_mode_iaf = xfrm_get_mode(x->props.mode, iafamily);
2595                 if (inner_mode_iaf) {
2596                         if (inner_mode_iaf->flags & XFRM_MODE_FLAG_TUNNEL)
2597                                 x->inner_mode_iaf = *inner_mode_iaf;
2598                 }
2599         }
2600
2601         x->type = xfrm_get_type(x->id.proto, family);
2602         if (x->type == NULL)
2603                 goto error;
2604
2605         x->type_offload = xfrm_get_type_offload(x->id.proto, family, offload);
2606
2607         err = x->type->init_state(x);
2608         if (err)
2609                 goto error;
2610
2611         outer_mode = xfrm_get_mode(x->props.mode, family);
2612         if (!outer_mode) {
2613                 err = -EPROTONOSUPPORT;
2614                 goto error;
2615         }
2616
2617         x->outer_mode = *outer_mode;
2618         if (init_replay) {
2619                 err = xfrm_init_replay(x);
2620                 if (err)
2621                         goto error;
2622         }
2623
2624 error:
2625         return err;
2626 }
2627
2628 EXPORT_SYMBOL(__xfrm_init_state);
2629
2630 int xfrm_init_state(struct xfrm_state *x)
2631 {
2632         int err;
2633
2634         err = __xfrm_init_state(x, true, false);
2635         if (!err)
2636                 x->km.state = XFRM_STATE_VALID;
2637
2638         return err;
2639 }
2640
2641 EXPORT_SYMBOL(xfrm_init_state);
2642
2643 int __net_init xfrm_state_init(struct net *net)
2644 {
2645         unsigned int sz;
2646
2647         if (net_eq(net, &init_net))
2648                 xfrm_state_cache = KMEM_CACHE(xfrm_state,
2649                                               SLAB_HWCACHE_ALIGN | SLAB_PANIC);
2650
2651         INIT_LIST_HEAD(&net->xfrm.state_all);
2652
2653         sz = sizeof(struct hlist_head) * 8;
2654
2655         net->xfrm.state_bydst = xfrm_hash_alloc(sz);
2656         if (!net->xfrm.state_bydst)
2657                 goto out_bydst;
2658         net->xfrm.state_bysrc = xfrm_hash_alloc(sz);
2659         if (!net->xfrm.state_bysrc)
2660                 goto out_bysrc;
2661         net->xfrm.state_byspi = xfrm_hash_alloc(sz);
2662         if (!net->xfrm.state_byspi)
2663                 goto out_byspi;
2664         net->xfrm.state_hmask = ((sz / sizeof(struct hlist_head)) - 1);
2665
2666         net->xfrm.state_num = 0;
2667         INIT_WORK(&net->xfrm.state_hash_work, xfrm_hash_resize);
2668         spin_lock_init(&net->xfrm.xfrm_state_lock);
2669         return 0;
2670
2671 out_byspi:
2672         xfrm_hash_free(net->xfrm.state_bysrc, sz);
2673 out_bysrc:
2674         xfrm_hash_free(net->xfrm.state_bydst, sz);
2675 out_bydst:
2676         return -ENOMEM;
2677 }
2678
2679 void xfrm_state_fini(struct net *net)
2680 {
2681         unsigned int sz;
2682
2683         flush_work(&net->xfrm.state_hash_work);
2684         flush_work(&xfrm_state_gc_work);
2685         xfrm_state_flush(net, 0, false, true);
2686
2687         WARN_ON(!list_empty(&net->xfrm.state_all));
2688
2689         sz = (net->xfrm.state_hmask + 1) * sizeof(struct hlist_head);
2690         WARN_ON(!hlist_empty(net->xfrm.state_byspi));
2691         xfrm_hash_free(net->xfrm.state_byspi, sz);
2692         WARN_ON(!hlist_empty(net->xfrm.state_bysrc));
2693         xfrm_hash_free(net->xfrm.state_bysrc, sz);
2694         WARN_ON(!hlist_empty(net->xfrm.state_bydst));
2695         xfrm_hash_free(net->xfrm.state_bydst, sz);
2696 }
2697
2698 #ifdef CONFIG_AUDITSYSCALL
2699 static void xfrm_audit_helper_sainfo(struct xfrm_state *x,
2700                                      struct audit_buffer *audit_buf)
2701 {
2702         struct xfrm_sec_ctx *ctx = x->security;
2703         u32 spi = ntohl(x->id.spi);
2704
2705         if (ctx)
2706                 audit_log_format(audit_buf, " sec_alg=%u sec_doi=%u sec_obj=%s",
2707                                  ctx->ctx_alg, ctx->ctx_doi, ctx->ctx_str);
2708
2709         switch (x->props.family) {
2710         case AF_INET:
2711                 audit_log_format(audit_buf, " src=%pI4 dst=%pI4",
2712                                  &x->props.saddr.a4, &x->id.daddr.a4);
2713                 break;
2714         case AF_INET6:
2715                 audit_log_format(audit_buf, " src=%pI6 dst=%pI6",
2716                                  x->props.saddr.a6, x->id.daddr.a6);
2717                 break;
2718         }
2719
2720         audit_log_format(audit_buf, " spi=%u(0x%x)", spi, spi);
2721 }
2722
2723 static void xfrm_audit_helper_pktinfo(struct sk_buff *skb, u16 family,
2724                                       struct audit_buffer *audit_buf)
2725 {
2726         const struct iphdr *iph4;
2727         const struct ipv6hdr *iph6;
2728
2729         switch (family) {
2730         case AF_INET:
2731                 iph4 = ip_hdr(skb);
2732                 audit_log_format(audit_buf, " src=%pI4 dst=%pI4",
2733                                  &iph4->saddr, &iph4->daddr);
2734                 break;
2735         case AF_INET6:
2736                 iph6 = ipv6_hdr(skb);
2737                 audit_log_format(audit_buf,
2738                                  " src=%pI6 dst=%pI6 flowlbl=0x%x%02x%02x",
2739                                  &iph6->saddr, &iph6->daddr,
2740                                  iph6->flow_lbl[0] & 0x0f,
2741                                  iph6->flow_lbl[1],
2742                                  iph6->flow_lbl[2]);
2743                 break;
2744         }
2745 }
2746
2747 void xfrm_audit_state_add(struct xfrm_state *x, int result, bool task_valid)
2748 {
2749         struct audit_buffer *audit_buf;
2750
2751         audit_buf = xfrm_audit_start("SAD-add");
2752         if (audit_buf == NULL)
2753                 return;
2754         xfrm_audit_helper_usrinfo(task_valid, audit_buf);
2755         xfrm_audit_helper_sainfo(x, audit_buf);
2756         audit_log_format(audit_buf, " res=%u", result);
2757         audit_log_end(audit_buf);
2758 }
2759 EXPORT_SYMBOL_GPL(xfrm_audit_state_add);
2760
2761 void xfrm_audit_state_delete(struct xfrm_state *x, int result, bool task_valid)
2762 {
2763         struct audit_buffer *audit_buf;
2764
2765         audit_buf = xfrm_audit_start("SAD-delete");
2766         if (audit_buf == NULL)
2767                 return;
2768         xfrm_audit_helper_usrinfo(task_valid, audit_buf);
2769         xfrm_audit_helper_sainfo(x, audit_buf);
2770         audit_log_format(audit_buf, " res=%u", result);
2771         audit_log_end(audit_buf);
2772 }
2773 EXPORT_SYMBOL_GPL(xfrm_audit_state_delete);
2774
2775 void xfrm_audit_state_replay_overflow(struct xfrm_state *x,
2776                                       struct sk_buff *skb)
2777 {
2778         struct audit_buffer *audit_buf;
2779         u32 spi;
2780
2781         audit_buf = xfrm_audit_start("SA-replay-overflow");
2782         if (audit_buf == NULL)
2783                 return;
2784         xfrm_audit_helper_pktinfo(skb, x->props.family, audit_buf);
2785         /* don't record the sequence number because it's inherent in this kind
2786          * of audit message */
2787         spi = ntohl(x->id.spi);
2788         audit_log_format(audit_buf, " spi=%u(0x%x)", spi, spi);
2789         audit_log_end(audit_buf);
2790 }
2791 EXPORT_SYMBOL_GPL(xfrm_audit_state_replay_overflow);
2792
2793 void xfrm_audit_state_replay(struct xfrm_state *x,
2794                              struct sk_buff *skb, __be32 net_seq)
2795 {
2796         struct audit_buffer *audit_buf;
2797         u32 spi;
2798
2799         audit_buf = xfrm_audit_start("SA-replayed-pkt");
2800         if (audit_buf == NULL)
2801                 return;
2802         xfrm_audit_helper_pktinfo(skb, x->props.family, audit_buf);
2803         spi = ntohl(x->id.spi);
2804         audit_log_format(audit_buf, " spi=%u(0x%x) seqno=%u",
2805                          spi, spi, ntohl(net_seq));
2806         audit_log_end(audit_buf);
2807 }
2808 EXPORT_SYMBOL_GPL(xfrm_audit_state_replay);
2809
2810 void xfrm_audit_state_notfound_simple(struct sk_buff *skb, u16 family)
2811 {
2812         struct audit_buffer *audit_buf;
2813
2814         audit_buf = xfrm_audit_start("SA-notfound");
2815         if (audit_buf == NULL)
2816                 return;
2817         xfrm_audit_helper_pktinfo(skb, family, audit_buf);
2818         audit_log_end(audit_buf);
2819 }
2820 EXPORT_SYMBOL_GPL(xfrm_audit_state_notfound_simple);
2821
2822 void xfrm_audit_state_notfound(struct sk_buff *skb, u16 family,
2823                                __be32 net_spi, __be32 net_seq)
2824 {
2825         struct audit_buffer *audit_buf;
2826         u32 spi;
2827
2828         audit_buf = xfrm_audit_start("SA-notfound");
2829         if (audit_buf == NULL)
2830                 return;
2831         xfrm_audit_helper_pktinfo(skb, family, audit_buf);
2832         spi = ntohl(net_spi);
2833         audit_log_format(audit_buf, " spi=%u(0x%x) seqno=%u",
2834                          spi, spi, ntohl(net_seq));
2835         audit_log_end(audit_buf);
2836 }
2837 EXPORT_SYMBOL_GPL(xfrm_audit_state_notfound);
2838
2839 void xfrm_audit_state_icvfail(struct xfrm_state *x,
2840                               struct sk_buff *skb, u8 proto)
2841 {
2842         struct audit_buffer *audit_buf;
2843         __be32 net_spi;
2844         __be32 net_seq;
2845
2846         audit_buf = xfrm_audit_start("SA-icv-failure");
2847         if (audit_buf == NULL)
2848                 return;
2849         xfrm_audit_helper_pktinfo(skb, x->props.family, audit_buf);
2850         if (xfrm_parse_spi(skb, proto, &net_spi, &net_seq) == 0) {
2851                 u32 spi = ntohl(net_spi);
2852                 audit_log_format(audit_buf, " spi=%u(0x%x) seqno=%u",
2853                                  spi, spi, ntohl(net_seq));
2854         }
2855         audit_log_end(audit_buf);
2856 }
2857 EXPORT_SYMBOL_GPL(xfrm_audit_state_icvfail);
2858 #endif /* CONFIG_AUDITSYSCALL */