tools headers UAPI: Sync openat2.h with the kernel sources
[linux-2.6-microblaze.git] / kernel / static_call.c
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/init.h>
3 #include <linux/static_call.h>
4 #include <linux/bug.h>
5 #include <linux/smp.h>
6 #include <linux/sort.h>
7 #include <linux/slab.h>
8 #include <linux/module.h>
9 #include <linux/cpu.h>
10 #include <linux/processor.h>
11 #include <asm/sections.h>
12
13 extern struct static_call_site __start_static_call_sites[],
14                                __stop_static_call_sites[];
15 extern struct static_call_tramp_key __start_static_call_tramp_key[],
16                                     __stop_static_call_tramp_key[];
17
18 static bool static_call_initialized;
19
20 /* mutex to protect key modules/sites */
21 static DEFINE_MUTEX(static_call_mutex);
22
23 static void static_call_lock(void)
24 {
25         mutex_lock(&static_call_mutex);
26 }
27
28 static void static_call_unlock(void)
29 {
30         mutex_unlock(&static_call_mutex);
31 }
32
33 static inline void *static_call_addr(struct static_call_site *site)
34 {
35         return (void *)((long)site->addr + (long)&site->addr);
36 }
37
38
39 static inline struct static_call_key *static_call_key(const struct static_call_site *site)
40 {
41         return (struct static_call_key *)
42                 (((long)site->key + (long)&site->key) & ~STATIC_CALL_SITE_FLAGS);
43 }
44
45 /* These assume the key is word-aligned. */
46 static inline bool static_call_is_init(struct static_call_site *site)
47 {
48         return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_INIT;
49 }
50
51 static inline bool static_call_is_tail(struct static_call_site *site)
52 {
53         return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_TAIL;
54 }
55
56 static inline void static_call_set_init(struct static_call_site *site)
57 {
58         site->key = ((long)static_call_key(site) | STATIC_CALL_SITE_INIT) -
59                     (long)&site->key;
60 }
61
62 static int static_call_site_cmp(const void *_a, const void *_b)
63 {
64         const struct static_call_site *a = _a;
65         const struct static_call_site *b = _b;
66         const struct static_call_key *key_a = static_call_key(a);
67         const struct static_call_key *key_b = static_call_key(b);
68
69         if (key_a < key_b)
70                 return -1;
71
72         if (key_a > key_b)
73                 return 1;
74
75         return 0;
76 }
77
78 static void static_call_site_swap(void *_a, void *_b, int size)
79 {
80         long delta = (unsigned long)_a - (unsigned long)_b;
81         struct static_call_site *a = _a;
82         struct static_call_site *b = _b;
83         struct static_call_site tmp = *a;
84
85         a->addr = b->addr  - delta;
86         a->key  = b->key   - delta;
87
88         b->addr = tmp.addr + delta;
89         b->key  = tmp.key  + delta;
90 }
91
92 static inline void static_call_sort_entries(struct static_call_site *start,
93                                             struct static_call_site *stop)
94 {
95         sort(start, stop - start, sizeof(struct static_call_site),
96              static_call_site_cmp, static_call_site_swap);
97 }
98
99 static inline bool static_call_key_has_mods(struct static_call_key *key)
100 {
101         return !(key->type & 1);
102 }
103
104 static inline struct static_call_mod *static_call_key_next(struct static_call_key *key)
105 {
106         if (!static_call_key_has_mods(key))
107                 return NULL;
108
109         return key->mods;
110 }
111
112 static inline struct static_call_site *static_call_key_sites(struct static_call_key *key)
113 {
114         if (static_call_key_has_mods(key))
115                 return NULL;
116
117         return (struct static_call_site *)(key->type & ~1);
118 }
119
120 void __static_call_update(struct static_call_key *key, void *tramp, void *func)
121 {
122         struct static_call_site *site, *stop;
123         struct static_call_mod *site_mod, first;
124
125         cpus_read_lock();
126         static_call_lock();
127
128         if (key->func == func)
129                 goto done;
130
131         key->func = func;
132
133         arch_static_call_transform(NULL, tramp, func, false);
134
135         /*
136          * If uninitialized, we'll not update the callsites, but they still
137          * point to the trampoline and we just patched that.
138          */
139         if (WARN_ON_ONCE(!static_call_initialized))
140                 goto done;
141
142         first = (struct static_call_mod){
143                 .next = static_call_key_next(key),
144                 .mod = NULL,
145                 .sites = static_call_key_sites(key),
146         };
147
148         for (site_mod = &first; site_mod; site_mod = site_mod->next) {
149                 struct module *mod = site_mod->mod;
150
151                 if (!site_mod->sites) {
152                         /*
153                          * This can happen if the static call key is defined in
154                          * a module which doesn't use it.
155                          *
156                          * It also happens in the has_mods case, where the
157                          * 'first' entry has no sites associated with it.
158                          */
159                         continue;
160                 }
161
162                 stop = __stop_static_call_sites;
163
164 #ifdef CONFIG_MODULES
165                 if (mod) {
166                         stop = mod->static_call_sites +
167                                mod->num_static_call_sites;
168                 }
169 #endif
170
171                 for (site = site_mod->sites;
172                      site < stop && static_call_key(site) == key; site++) {
173                         void *site_addr = static_call_addr(site);
174
175                         if (static_call_is_init(site)) {
176                                 /*
177                                  * Don't write to call sites which were in
178                                  * initmem and have since been freed.
179                                  */
180                                 if (!mod && system_state >= SYSTEM_RUNNING)
181                                         continue;
182                                 if (mod && !within_module_init((unsigned long)site_addr, mod))
183                                         continue;
184                         }
185
186                         if (!kernel_text_address((unsigned long)site_addr)) {
187                                 WARN_ONCE(1, "can't patch static call site at %pS",
188                                           site_addr);
189                                 continue;
190                         }
191
192                         arch_static_call_transform(site_addr, NULL, func,
193                                 static_call_is_tail(site));
194                 }
195         }
196
197 done:
198         static_call_unlock();
199         cpus_read_unlock();
200 }
201 EXPORT_SYMBOL_GPL(__static_call_update);
202
203 static int __static_call_init(struct module *mod,
204                               struct static_call_site *start,
205                               struct static_call_site *stop)
206 {
207         struct static_call_site *site;
208         struct static_call_key *key, *prev_key = NULL;
209         struct static_call_mod *site_mod;
210
211         if (start == stop)
212                 return 0;
213
214         static_call_sort_entries(start, stop);
215
216         for (site = start; site < stop; site++) {
217                 void *site_addr = static_call_addr(site);
218
219                 if ((mod && within_module_init((unsigned long)site_addr, mod)) ||
220                     (!mod && init_section_contains(site_addr, 1)))
221                         static_call_set_init(site);
222
223                 key = static_call_key(site);
224                 if (key != prev_key) {
225                         prev_key = key;
226
227                         /*
228                          * For vmlinux (!mod) avoid the allocation by storing
229                          * the sites pointer in the key itself. Also see
230                          * __static_call_update()'s @first.
231                          *
232                          * This allows architectures (eg. x86) to call
233                          * static_call_init() before memory allocation works.
234                          */
235                         if (!mod) {
236                                 key->sites = site;
237                                 key->type |= 1;
238                                 goto do_transform;
239                         }
240
241                         site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
242                         if (!site_mod)
243                                 return -ENOMEM;
244
245                         /*
246                          * When the key has a direct sites pointer, extract
247                          * that into an explicit struct static_call_mod, so we
248                          * can have a list of modules.
249                          */
250                         if (static_call_key_sites(key)) {
251                                 site_mod->mod = NULL;
252                                 site_mod->next = NULL;
253                                 site_mod->sites = static_call_key_sites(key);
254
255                                 key->mods = site_mod;
256
257                                 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
258                                 if (!site_mod)
259                                         return -ENOMEM;
260                         }
261
262                         site_mod->mod = mod;
263                         site_mod->sites = site;
264                         site_mod->next = static_call_key_next(key);
265                         key->mods = site_mod;
266                 }
267
268 do_transform:
269                 arch_static_call_transform(site_addr, NULL, key->func,
270                                 static_call_is_tail(site));
271         }
272
273         return 0;
274 }
275
276 static int addr_conflict(struct static_call_site *site, void *start, void *end)
277 {
278         unsigned long addr = (unsigned long)static_call_addr(site);
279
280         if (addr <= (unsigned long)end &&
281             addr + CALL_INSN_SIZE > (unsigned long)start)
282                 return 1;
283
284         return 0;
285 }
286
287 static int __static_call_text_reserved(struct static_call_site *iter_start,
288                                        struct static_call_site *iter_stop,
289                                        void *start, void *end)
290 {
291         struct static_call_site *iter = iter_start;
292
293         while (iter < iter_stop) {
294                 if (addr_conflict(iter, start, end))
295                         return 1;
296                 iter++;
297         }
298
299         return 0;
300 }
301
302 #ifdef CONFIG_MODULES
303
304 static int __static_call_mod_text_reserved(void *start, void *end)
305 {
306         struct module *mod;
307         int ret;
308
309         preempt_disable();
310         mod = __module_text_address((unsigned long)start);
311         WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod);
312         if (!try_module_get(mod))
313                 mod = NULL;
314         preempt_enable();
315
316         if (!mod)
317                 return 0;
318
319         ret = __static_call_text_reserved(mod->static_call_sites,
320                         mod->static_call_sites + mod->num_static_call_sites,
321                         start, end);
322
323         module_put(mod);
324
325         return ret;
326 }
327
328 static unsigned long tramp_key_lookup(unsigned long addr)
329 {
330         struct static_call_tramp_key *start = __start_static_call_tramp_key;
331         struct static_call_tramp_key *stop = __stop_static_call_tramp_key;
332         struct static_call_tramp_key *tramp_key;
333
334         for (tramp_key = start; tramp_key != stop; tramp_key++) {
335                 unsigned long tramp;
336
337                 tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp;
338                 if (tramp == addr)
339                         return (long)tramp_key->key + (long)&tramp_key->key;
340         }
341
342         return 0;
343 }
344
345 static int static_call_add_module(struct module *mod)
346 {
347         struct static_call_site *start = mod->static_call_sites;
348         struct static_call_site *stop = start + mod->num_static_call_sites;
349         struct static_call_site *site;
350
351         for (site = start; site != stop; site++) {
352                 unsigned long addr = (unsigned long)static_call_key(site);
353                 unsigned long key;
354
355                 /*
356                  * Is the key is exported, 'addr' points to the key, which
357                  * means modules are allowed to call static_call_update() on
358                  * it.
359                  *
360                  * Otherwise, the key isn't exported, and 'addr' points to the
361                  * trampoline so we need to lookup the key.
362                  *
363                  * We go through this dance to prevent crazy modules from
364                  * abusing sensitive static calls.
365                  */
366                 if (!kernel_text_address(addr))
367                         continue;
368
369                 key = tramp_key_lookup(addr);
370                 if (!key) {
371                         pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n",
372                                 static_call_addr(site));
373                         return -EINVAL;
374                 }
375
376                 site->key = (key - (long)&site->key) |
377                             (site->key & STATIC_CALL_SITE_FLAGS);
378         }
379
380         return __static_call_init(mod, start, stop);
381 }
382
383 static void static_call_del_module(struct module *mod)
384 {
385         struct static_call_site *start = mod->static_call_sites;
386         struct static_call_site *stop = mod->static_call_sites +
387                                         mod->num_static_call_sites;
388         struct static_call_key *key, *prev_key = NULL;
389         struct static_call_mod *site_mod, **prev;
390         struct static_call_site *site;
391
392         for (site = start; site < stop; site++) {
393                 key = static_call_key(site);
394                 if (key == prev_key)
395                         continue;
396
397                 prev_key = key;
398
399                 for (prev = &key->mods, site_mod = key->mods;
400                      site_mod && site_mod->mod != mod;
401                      prev = &site_mod->next, site_mod = site_mod->next)
402                         ;
403
404                 if (!site_mod)
405                         continue;
406
407                 *prev = site_mod->next;
408                 kfree(site_mod);
409         }
410 }
411
412 static int static_call_module_notify(struct notifier_block *nb,
413                                      unsigned long val, void *data)
414 {
415         struct module *mod = data;
416         int ret = 0;
417
418         cpus_read_lock();
419         static_call_lock();
420
421         switch (val) {
422         case MODULE_STATE_COMING:
423                 ret = static_call_add_module(mod);
424                 if (ret) {
425                         WARN(1, "Failed to allocate memory for static calls");
426                         static_call_del_module(mod);
427                 }
428                 break;
429         case MODULE_STATE_GOING:
430                 static_call_del_module(mod);
431                 break;
432         }
433
434         static_call_unlock();
435         cpus_read_unlock();
436
437         return notifier_from_errno(ret);
438 }
439
440 static struct notifier_block static_call_module_nb = {
441         .notifier_call = static_call_module_notify,
442 };
443
444 #else
445
446 static inline int __static_call_mod_text_reserved(void *start, void *end)
447 {
448         return 0;
449 }
450
451 #endif /* CONFIG_MODULES */
452
453 int static_call_text_reserved(void *start, void *end)
454 {
455         int ret = __static_call_text_reserved(__start_static_call_sites,
456                         __stop_static_call_sites, start, end);
457
458         if (ret)
459                 return ret;
460
461         return __static_call_mod_text_reserved(start, end);
462 }
463
464 int __init static_call_init(void)
465 {
466         int ret;
467
468         if (static_call_initialized)
469                 return 0;
470
471         cpus_read_lock();
472         static_call_lock();
473         ret = __static_call_init(NULL, __start_static_call_sites,
474                                  __stop_static_call_sites);
475         static_call_unlock();
476         cpus_read_unlock();
477
478         if (ret) {
479                 pr_err("Failed to allocate memory for static_call!\n");
480                 BUG();
481         }
482
483         static_call_initialized = true;
484
485 #ifdef CONFIG_MODULES
486         register_module_notifier(&static_call_module_nb);
487 #endif
488         return 0;
489 }
490 early_initcall(static_call_init);
491
492 long __static_call_return0(void)
493 {
494         return 0;
495 }
496
497 #ifdef CONFIG_STATIC_CALL_SELFTEST
498
499 static int func_a(int x)
500 {
501         return x+1;
502 }
503
504 static int func_b(int x)
505 {
506         return x+2;
507 }
508
509 DEFINE_STATIC_CALL(sc_selftest, func_a);
510
511 static struct static_call_data {
512       int (*func)(int);
513       int val;
514       int expect;
515 } static_call_data [] __initdata = {
516       { NULL,   2, 3 },
517       { func_b, 2, 4 },
518       { func_a, 2, 3 }
519 };
520
521 static int __init test_static_call_init(void)
522 {
523       int i;
524
525       for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) {
526               struct static_call_data *scd = &static_call_data[i];
527
528               if (scd->func)
529                       static_call_update(sc_selftest, scd->func);
530
531               WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect);
532       }
533
534       return 0;
535 }
536 early_initcall(test_static_call_init);
537
538 #endif /* CONFIG_STATIC_CALL_SELFTEST */