drm/i915: Create stolen memory region from local memory
[linux-2.6-microblaze.git] / kernel / static_call.c
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/init.h>
3 #include <linux/static_call.h>
4 #include <linux/bug.h>
5 #include <linux/smp.h>
6 #include <linux/sort.h>
7 #include <linux/slab.h>
8 #include <linux/module.h>
9 #include <linux/cpu.h>
10 #include <linux/processor.h>
11 #include <asm/sections.h>
12
13 extern struct static_call_site __start_static_call_sites[],
14                                __stop_static_call_sites[];
15 extern struct static_call_tramp_key __start_static_call_tramp_key[],
16                                     __stop_static_call_tramp_key[];
17
18 static bool static_call_initialized;
19
20 /* mutex to protect key modules/sites */
21 static DEFINE_MUTEX(static_call_mutex);
22
23 static void static_call_lock(void)
24 {
25         mutex_lock(&static_call_mutex);
26 }
27
28 static void static_call_unlock(void)
29 {
30         mutex_unlock(&static_call_mutex);
31 }
32
33 static inline void *static_call_addr(struct static_call_site *site)
34 {
35         return (void *)((long)site->addr + (long)&site->addr);
36 }
37
38
39 static inline struct static_call_key *static_call_key(const struct static_call_site *site)
40 {
41         return (struct static_call_key *)
42                 (((long)site->key + (long)&site->key) & ~STATIC_CALL_SITE_FLAGS);
43 }
44
45 /* These assume the key is word-aligned. */
46 static inline bool static_call_is_init(struct static_call_site *site)
47 {
48         return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_INIT;
49 }
50
51 static inline bool static_call_is_tail(struct static_call_site *site)
52 {
53         return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_TAIL;
54 }
55
56 static inline void static_call_set_init(struct static_call_site *site)
57 {
58         site->key = ((long)static_call_key(site) | STATIC_CALL_SITE_INIT) -
59                     (long)&site->key;
60 }
61
62 static int static_call_site_cmp(const void *_a, const void *_b)
63 {
64         const struct static_call_site *a = _a;
65         const struct static_call_site *b = _b;
66         const struct static_call_key *key_a = static_call_key(a);
67         const struct static_call_key *key_b = static_call_key(b);
68
69         if (key_a < key_b)
70                 return -1;
71
72         if (key_a > key_b)
73                 return 1;
74
75         return 0;
76 }
77
78 static void static_call_site_swap(void *_a, void *_b, int size)
79 {
80         long delta = (unsigned long)_a - (unsigned long)_b;
81         struct static_call_site *a = _a;
82         struct static_call_site *b = _b;
83         struct static_call_site tmp = *a;
84
85         a->addr = b->addr  - delta;
86         a->key  = b->key   - delta;
87
88         b->addr = tmp.addr + delta;
89         b->key  = tmp.key  + delta;
90 }
91
92 static inline void static_call_sort_entries(struct static_call_site *start,
93                                             struct static_call_site *stop)
94 {
95         sort(start, stop - start, sizeof(struct static_call_site),
96              static_call_site_cmp, static_call_site_swap);
97 }
98
99 static inline bool static_call_key_has_mods(struct static_call_key *key)
100 {
101         return !(key->type & 1);
102 }
103
104 static inline struct static_call_mod *static_call_key_next(struct static_call_key *key)
105 {
106         if (!static_call_key_has_mods(key))
107                 return NULL;
108
109         return key->mods;
110 }
111
112 static inline struct static_call_site *static_call_key_sites(struct static_call_key *key)
113 {
114         if (static_call_key_has_mods(key))
115                 return NULL;
116
117         return (struct static_call_site *)(key->type & ~1);
118 }
119
120 void __static_call_update(struct static_call_key *key, void *tramp, void *func)
121 {
122         struct static_call_site *site, *stop;
123         struct static_call_mod *site_mod, first;
124
125         cpus_read_lock();
126         static_call_lock();
127
128         if (key->func == func)
129                 goto done;
130
131         key->func = func;
132
133         arch_static_call_transform(NULL, tramp, func, false);
134
135         /*
136          * If uninitialized, we'll not update the callsites, but they still
137          * point to the trampoline and we just patched that.
138          */
139         if (WARN_ON_ONCE(!static_call_initialized))
140                 goto done;
141
142         first = (struct static_call_mod){
143                 .next = static_call_key_next(key),
144                 .mod = NULL,
145                 .sites = static_call_key_sites(key),
146         };
147
148         for (site_mod = &first; site_mod; site_mod = site_mod->next) {
149                 struct module *mod = site_mod->mod;
150
151                 if (!site_mod->sites) {
152                         /*
153                          * This can happen if the static call key is defined in
154                          * a module which doesn't use it.
155                          *
156                          * It also happens in the has_mods case, where the
157                          * 'first' entry has no sites associated with it.
158                          */
159                         continue;
160                 }
161
162                 stop = __stop_static_call_sites;
163
164 #ifdef CONFIG_MODULES
165                 if (mod) {
166                         stop = mod->static_call_sites +
167                                mod->num_static_call_sites;
168                 }
169 #endif
170
171                 for (site = site_mod->sites;
172                      site < stop && static_call_key(site) == key; site++) {
173                         void *site_addr = static_call_addr(site);
174
175                         if (static_call_is_init(site)) {
176                                 /*
177                                  * Don't write to call sites which were in
178                                  * initmem and have since been freed.
179                                  */
180                                 if (!mod && system_state >= SYSTEM_RUNNING)
181                                         continue;
182                                 if (mod && !within_module_init((unsigned long)site_addr, mod))
183                                         continue;
184                         }
185
186                         if (!kernel_text_address((unsigned long)site_addr)) {
187                                 WARN_ONCE(1, "can't patch static call site at %pS",
188                                           site_addr);
189                                 continue;
190                         }
191
192                         arch_static_call_transform(site_addr, NULL, func,
193                                 static_call_is_tail(site));
194                 }
195         }
196
197 done:
198         static_call_unlock();
199         cpus_read_unlock();
200 }
201 EXPORT_SYMBOL_GPL(__static_call_update);
202
203 static int __static_call_init(struct module *mod,
204                               struct static_call_site *start,
205                               struct static_call_site *stop)
206 {
207         struct static_call_site *site;
208         struct static_call_key *key, *prev_key = NULL;
209         struct static_call_mod *site_mod;
210
211         if (start == stop)
212                 return 0;
213
214         static_call_sort_entries(start, stop);
215
216         for (site = start; site < stop; site++) {
217                 void *site_addr = static_call_addr(site);
218
219                 if ((mod && within_module_init((unsigned long)site_addr, mod)) ||
220                     (!mod && init_section_contains(site_addr, 1)))
221                         static_call_set_init(site);
222
223                 key = static_call_key(site);
224                 if (key != prev_key) {
225                         prev_key = key;
226
227                         /*
228                          * For vmlinux (!mod) avoid the allocation by storing
229                          * the sites pointer in the key itself. Also see
230                          * __static_call_update()'s @first.
231                          *
232                          * This allows architectures (eg. x86) to call
233                          * static_call_init() before memory allocation works.
234                          */
235                         if (!mod) {
236                                 key->sites = site;
237                                 key->type |= 1;
238                                 goto do_transform;
239                         }
240
241                         site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
242                         if (!site_mod)
243                                 return -ENOMEM;
244
245                         /*
246                          * When the key has a direct sites pointer, extract
247                          * that into an explicit struct static_call_mod, so we
248                          * can have a list of modules.
249                          */
250                         if (static_call_key_sites(key)) {
251                                 site_mod->mod = NULL;
252                                 site_mod->next = NULL;
253                                 site_mod->sites = static_call_key_sites(key);
254
255                                 key->mods = site_mod;
256
257                                 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
258                                 if (!site_mod)
259                                         return -ENOMEM;
260                         }
261
262                         site_mod->mod = mod;
263                         site_mod->sites = site;
264                         site_mod->next = static_call_key_next(key);
265                         key->mods = site_mod;
266                 }
267
268 do_transform:
269                 arch_static_call_transform(site_addr, NULL, key->func,
270                                 static_call_is_tail(site));
271         }
272
273         return 0;
274 }
275
276 static int addr_conflict(struct static_call_site *site, void *start, void *end)
277 {
278         unsigned long addr = (unsigned long)static_call_addr(site);
279
280         if (addr <= (unsigned long)end &&
281             addr + CALL_INSN_SIZE > (unsigned long)start)
282                 return 1;
283
284         return 0;
285 }
286
287 static int __static_call_text_reserved(struct static_call_site *iter_start,
288                                        struct static_call_site *iter_stop,
289                                        void *start, void *end)
290 {
291         struct static_call_site *iter = iter_start;
292
293         while (iter < iter_stop) {
294                 if (addr_conflict(iter, start, end))
295                         return 1;
296                 iter++;
297         }
298
299         return 0;
300 }
301
302 #ifdef CONFIG_MODULES
303
304 static int __static_call_mod_text_reserved(void *start, void *end)
305 {
306         struct module *mod;
307         int ret;
308
309         preempt_disable();
310         mod = __module_text_address((unsigned long)start);
311         WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod);
312         if (!try_module_get(mod))
313                 mod = NULL;
314         preempt_enable();
315
316         if (!mod)
317                 return 0;
318
319         ret = __static_call_text_reserved(mod->static_call_sites,
320                         mod->static_call_sites + mod->num_static_call_sites,
321                         start, end);
322
323         module_put(mod);
324
325         return ret;
326 }
327
328 static unsigned long tramp_key_lookup(unsigned long addr)
329 {
330         struct static_call_tramp_key *start = __start_static_call_tramp_key;
331         struct static_call_tramp_key *stop = __stop_static_call_tramp_key;
332         struct static_call_tramp_key *tramp_key;
333
334         for (tramp_key = start; tramp_key != stop; tramp_key++) {
335                 unsigned long tramp;
336
337                 tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp;
338                 if (tramp == addr)
339                         return (long)tramp_key->key + (long)&tramp_key->key;
340         }
341
342         return 0;
343 }
344
345 static int static_call_add_module(struct module *mod)
346 {
347         struct static_call_site *start = mod->static_call_sites;
348         struct static_call_site *stop = start + mod->num_static_call_sites;
349         struct static_call_site *site;
350
351         for (site = start; site != stop; site++) {
352                 unsigned long s_key = (long)site->key + (long)&site->key;
353                 unsigned long addr = s_key & ~STATIC_CALL_SITE_FLAGS;
354                 unsigned long key;
355
356                 /*
357                  * Is the key is exported, 'addr' points to the key, which
358                  * means modules are allowed to call static_call_update() on
359                  * it.
360                  *
361                  * Otherwise, the key isn't exported, and 'addr' points to the
362                  * trampoline so we need to lookup the key.
363                  *
364                  * We go through this dance to prevent crazy modules from
365                  * abusing sensitive static calls.
366                  */
367                 if (!kernel_text_address(addr))
368                         continue;
369
370                 key = tramp_key_lookup(addr);
371                 if (!key) {
372                         pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n",
373                                 static_call_addr(site));
374                         return -EINVAL;
375                 }
376
377                 key |= s_key & STATIC_CALL_SITE_FLAGS;
378                 site->key = key - (long)&site->key;
379         }
380
381         return __static_call_init(mod, start, stop);
382 }
383
384 static void static_call_del_module(struct module *mod)
385 {
386         struct static_call_site *start = mod->static_call_sites;
387         struct static_call_site *stop = mod->static_call_sites +
388                                         mod->num_static_call_sites;
389         struct static_call_key *key, *prev_key = NULL;
390         struct static_call_mod *site_mod, **prev;
391         struct static_call_site *site;
392
393         for (site = start; site < stop; site++) {
394                 key = static_call_key(site);
395                 if (key == prev_key)
396                         continue;
397
398                 prev_key = key;
399
400                 for (prev = &key->mods, site_mod = key->mods;
401                      site_mod && site_mod->mod != mod;
402                      prev = &site_mod->next, site_mod = site_mod->next)
403                         ;
404
405                 if (!site_mod)
406                         continue;
407
408                 *prev = site_mod->next;
409                 kfree(site_mod);
410         }
411 }
412
413 static int static_call_module_notify(struct notifier_block *nb,
414                                      unsigned long val, void *data)
415 {
416         struct module *mod = data;
417         int ret = 0;
418
419         cpus_read_lock();
420         static_call_lock();
421
422         switch (val) {
423         case MODULE_STATE_COMING:
424                 ret = static_call_add_module(mod);
425                 if (ret) {
426                         WARN(1, "Failed to allocate memory for static calls");
427                         static_call_del_module(mod);
428                 }
429                 break;
430         case MODULE_STATE_GOING:
431                 static_call_del_module(mod);
432                 break;
433         }
434
435         static_call_unlock();
436         cpus_read_unlock();
437
438         return notifier_from_errno(ret);
439 }
440
441 static struct notifier_block static_call_module_nb = {
442         .notifier_call = static_call_module_notify,
443 };
444
445 #else
446
447 static inline int __static_call_mod_text_reserved(void *start, void *end)
448 {
449         return 0;
450 }
451
452 #endif /* CONFIG_MODULES */
453
454 int static_call_text_reserved(void *start, void *end)
455 {
456         int ret = __static_call_text_reserved(__start_static_call_sites,
457                         __stop_static_call_sites, start, end);
458
459         if (ret)
460                 return ret;
461
462         return __static_call_mod_text_reserved(start, end);
463 }
464
465 int __init static_call_init(void)
466 {
467         int ret;
468
469         if (static_call_initialized)
470                 return 0;
471
472         cpus_read_lock();
473         static_call_lock();
474         ret = __static_call_init(NULL, __start_static_call_sites,
475                                  __stop_static_call_sites);
476         static_call_unlock();
477         cpus_read_unlock();
478
479         if (ret) {
480                 pr_err("Failed to allocate memory for static_call!\n");
481                 BUG();
482         }
483
484         static_call_initialized = true;
485
486 #ifdef CONFIG_MODULES
487         register_module_notifier(&static_call_module_nb);
488 #endif
489         return 0;
490 }
491 early_initcall(static_call_init);
492
493 long __static_call_return0(void)
494 {
495         return 0;
496 }
497
498 #ifdef CONFIG_STATIC_CALL_SELFTEST
499
500 static int func_a(int x)
501 {
502         return x+1;
503 }
504
505 static int func_b(int x)
506 {
507         return x+2;
508 }
509
510 DEFINE_STATIC_CALL(sc_selftest, func_a);
511
512 static struct static_call_data {
513       int (*func)(int);
514       int val;
515       int expect;
516 } static_call_data [] __initdata = {
517       { NULL,   2, 3 },
518       { func_b, 2, 4 },
519       { func_a, 2, 3 }
520 };
521
522 static int __init test_static_call_init(void)
523 {
524       int i;
525
526       for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) {
527               struct static_call_data *scd = &static_call_data[i];
528
529               if (scd->func)
530                       static_call_update(sc_selftest, scd->func);
531
532               WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect);
533       }
534
535       return 0;
536 }
537 early_initcall(test_static_call_init);
538
539 #endif /* CONFIG_STATIC_CALL_SELFTEST */