selftests: vm: pkeys: add helpers for pkey bits
[linux-2.6-microblaze.git] / tools / testing / selftests / vm / protection_keys.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Tests Memory Protection Keys (see Documentation/vm/protection-keys.txt)
4  *
5  * There are examples in here of:
6  *  * how to set protection keys on memory
7  *  * how to set/clear bits in pkey registers (the rights register)
8  *  * how to handle SEGV_PKUERR signals and extract pkey-relevant
9  *    information from the siginfo
10  *
11  * Things to add:
12  *      make sure KSM and KSM COW breaking works
13  *      prefault pages in at malloc, or not
14  *      protect MPX bounds tables with protection keys?
15  *      make sure VMA splitting/merging is working correctly
16  *      OOMs can destroy mm->mmap (see exit_mmap()), so make sure it is immune to pkeys
17  *      look for pkey "leaks" where it is still set on a VMA but "freed" back to the kernel
18  *      do a plain mprotect() to a mprotect_pkey() area and make sure the pkey sticks
19  *
20  * Compile like this:
21  *      gcc      -o protection_keys    -O2 -g -std=gnu99 -pthread -Wall protection_keys.c -lrt -ldl -lm
22  *      gcc -m32 -o protection_keys_32 -O2 -g -std=gnu99 -pthread -Wall protection_keys.c -lrt -ldl -lm
23  */
24 #define _GNU_SOURCE
25 #define __SANE_USERSPACE_TYPES__
26 #include <errno.h>
27 #include <linux/futex.h>
28 #include <sys/time.h>
29 #include <sys/syscall.h>
30 #include <string.h>
31 #include <stdio.h>
32 #include <stdint.h>
33 #include <stdbool.h>
34 #include <signal.h>
35 #include <assert.h>
36 #include <stdlib.h>
37 #include <ucontext.h>
38 #include <sys/mman.h>
39 #include <sys/types.h>
40 #include <sys/wait.h>
41 #include <sys/stat.h>
42 #include <fcntl.h>
43 #include <unistd.h>
44 #include <sys/ptrace.h>
45 #include <setjmp.h>
46
47 #include "pkey-helpers.h"
48
49 int iteration_nr = 1;
50 int test_nr;
51
52 u64 shadow_pkey_reg;
53 int dprint_in_signal;
54 char dprint_in_signal_buffer[DPRINT_IN_SIGNAL_BUF_SIZE];
55
56 void cat_into_file(char *str, char *file)
57 {
58         int fd = open(file, O_RDWR);
59         int ret;
60
61         dprintf2("%s(): writing '%s' to '%s'\n", __func__, str, file);
62         /*
63          * these need to be raw because they are called under
64          * pkey_assert()
65          */
66         if (fd < 0) {
67                 fprintf(stderr, "error opening '%s'\n", str);
68                 perror("error: ");
69                 exit(__LINE__);
70         }
71
72         ret = write(fd, str, strlen(str));
73         if (ret != strlen(str)) {
74                 perror("write to file failed");
75                 fprintf(stderr, "filename: '%s' str: '%s'\n", file, str);
76                 exit(__LINE__);
77         }
78         close(fd);
79 }
80
81 #if CONTROL_TRACING > 0
82 static int warned_tracing;
83 int tracing_root_ok(void)
84 {
85         if (geteuid() != 0) {
86                 if (!warned_tracing)
87                         fprintf(stderr, "WARNING: not run as root, "
88                                         "can not do tracing control\n");
89                 warned_tracing = 1;
90                 return 0;
91         }
92         return 1;
93 }
94 #endif
95
96 void tracing_on(void)
97 {
98 #if CONTROL_TRACING > 0
99 #define TRACEDIR "/sys/kernel/debug/tracing"
100         char pidstr[32];
101
102         if (!tracing_root_ok())
103                 return;
104
105         sprintf(pidstr, "%d", getpid());
106         cat_into_file("0", TRACEDIR "/tracing_on");
107         cat_into_file("\n", TRACEDIR "/trace");
108         if (1) {
109                 cat_into_file("function_graph", TRACEDIR "/current_tracer");
110                 cat_into_file("1", TRACEDIR "/options/funcgraph-proc");
111         } else {
112                 cat_into_file("nop", TRACEDIR "/current_tracer");
113         }
114         cat_into_file(pidstr, TRACEDIR "/set_ftrace_pid");
115         cat_into_file("1", TRACEDIR "/tracing_on");
116         dprintf1("enabled tracing\n");
117 #endif
118 }
119
120 void tracing_off(void)
121 {
122 #if CONTROL_TRACING > 0
123         if (!tracing_root_ok())
124                 return;
125         cat_into_file("0", "/sys/kernel/debug/tracing/tracing_on");
126 #endif
127 }
128
129 void abort_hooks(void)
130 {
131         fprintf(stderr, "running %s()...\n", __func__);
132         tracing_off();
133 #ifdef SLEEP_ON_ABORT
134         sleep(SLEEP_ON_ABORT);
135 #endif
136 }
137
138 /*
139  * This attempts to have roughly a page of instructions followed by a few
140  * instructions that do a write, and another page of instructions.  That
141  * way, we are pretty sure that the write is in the second page of
142  * instructions and has at least a page of padding behind it.
143  *
144  * *That* lets us be sure to madvise() away the write instruction, which
145  * will then fault, which makes sure that the fault code handles
146  * execute-only memory properly.
147  */
148 __attribute__((__aligned__(PAGE_SIZE)))
149 void lots_o_noops_around_write(int *write_to_me)
150 {
151         dprintf3("running %s()\n", __func__);
152         __page_o_noops();
153         /* Assume this happens in the second page of instructions: */
154         *write_to_me = __LINE__;
155         /* pad out by another page: */
156         __page_o_noops();
157         dprintf3("%s() done\n", __func__);
158 }
159
160 void dump_mem(void *dumpme, int len_bytes)
161 {
162         char *c = (void *)dumpme;
163         int i;
164
165         for (i = 0; i < len_bytes; i += sizeof(u64)) {
166                 u64 *ptr = (u64 *)(c + i);
167                 dprintf1("dump[%03d][@%p]: %016llx\n", i, ptr, *ptr);
168         }
169 }
170
171 /* Failed address bound checks: */
172 #ifndef SEGV_BNDERR
173 # define SEGV_BNDERR            3
174 #endif
175
176 #ifndef SEGV_PKUERR
177 # define SEGV_PKUERR            4
178 #endif
179
180 static char *si_code_str(int si_code)
181 {
182         if (si_code == SEGV_MAPERR)
183                 return "SEGV_MAPERR";
184         if (si_code == SEGV_ACCERR)
185                 return "SEGV_ACCERR";
186         if (si_code == SEGV_BNDERR)
187                 return "SEGV_BNDERR";
188         if (si_code == SEGV_PKUERR)
189                 return "SEGV_PKUERR";
190         return "UNKNOWN";
191 }
192
193 int pkey_faults;
194 int last_si_pkey = -1;
195 void signal_handler(int signum, siginfo_t *si, void *vucontext)
196 {
197         ucontext_t *uctxt = vucontext;
198         int trapno;
199         unsigned long ip;
200         char *fpregs;
201         u32 *pkey_reg_ptr;
202         u64 siginfo_pkey;
203         u32 *si_pkey_ptr;
204         int pkey_reg_offset;
205         fpregset_t fpregset;
206
207         dprint_in_signal = 1;
208         dprintf1(">>>>===============SIGSEGV============================\n");
209         dprintf1("%s()::%d, pkey_reg: 0x%016llx shadow: %016llx\n",
210                         __func__, __LINE__,
211                         __read_pkey_reg(), shadow_pkey_reg);
212
213         trapno = uctxt->uc_mcontext.gregs[REG_TRAPNO];
214         ip = uctxt->uc_mcontext.gregs[REG_IP_IDX];
215         fpregset = uctxt->uc_mcontext.fpregs;
216         fpregs = (void *)fpregset;
217
218         dprintf2("%s() trapno: %d ip: 0x%016lx info->si_code: %s/%d\n",
219                         __func__, trapno, ip, si_code_str(si->si_code),
220                         si->si_code);
221 #ifdef __i386__
222         /*
223          * 32-bit has some extra padding so that userspace can tell whether
224          * the XSTATE header is present in addition to the "legacy" FPU
225          * state.  We just assume that it is here.
226          */
227         fpregs += 0x70;
228 #endif
229         pkey_reg_offset = pkey_reg_xstate_offset();
230         pkey_reg_ptr = (void *)(&fpregs[pkey_reg_offset]);
231
232         dprintf1("siginfo: %p\n", si);
233         dprintf1(" fpregs: %p\n", fpregs);
234         /*
235          * If we got a PKEY fault, we *HAVE* to have at least one bit set in
236          * here.
237          */
238         dprintf1("pkey_reg_xstate_offset: %d\n", pkey_reg_xstate_offset());
239         if (DEBUG_LEVEL > 4)
240                 dump_mem(pkey_reg_ptr - 128, 256);
241         pkey_assert(*pkey_reg_ptr);
242
243         if ((si->si_code == SEGV_MAPERR) ||
244             (si->si_code == SEGV_ACCERR) ||
245             (si->si_code == SEGV_BNDERR)) {
246                 printf("non-PK si_code, exiting...\n");
247                 exit(4);
248         }
249
250         si_pkey_ptr = (u32 *)(((u8 *)si) + si_pkey_offset);
251         dprintf1("si_pkey_ptr: %p\n", si_pkey_ptr);
252         dump_mem((u8 *)si_pkey_ptr - 8, 24);
253         siginfo_pkey = *si_pkey_ptr;
254         pkey_assert(siginfo_pkey < NR_PKEYS);
255         last_si_pkey = siginfo_pkey;
256
257         dprintf1("signal pkey_reg from xsave: %08x\n", *pkey_reg_ptr);
258         /*
259          * need __read_pkey_reg() version so we do not do shadow_pkey_reg
260          * checking
261          */
262         dprintf1("signal pkey_reg from  pkey_reg: %016llx\n",
263                         __read_pkey_reg());
264         dprintf1("pkey from siginfo: %016llx\n", siginfo_pkey);
265         *(u64 *)pkey_reg_ptr = 0x00000000;
266         dprintf1("WARNING: set PKEY_REG=0 to allow faulting instruction to continue\n");
267         pkey_faults++;
268         dprintf1("<<<<==================================================\n");
269         dprint_in_signal = 0;
270 }
271
272 int wait_all_children(void)
273 {
274         int status;
275         return waitpid(-1, &status, 0);
276 }
277
278 void sig_chld(int x)
279 {
280         dprint_in_signal = 1;
281         dprintf2("[%d] SIGCHLD: %d\n", getpid(), x);
282         dprint_in_signal = 0;
283 }
284
285 void setup_sigsegv_handler(void)
286 {
287         int r, rs;
288         struct sigaction newact;
289         struct sigaction oldact;
290
291         /* #PF is mapped to sigsegv */
292         int signum  = SIGSEGV;
293
294         newact.sa_handler = 0;
295         newact.sa_sigaction = signal_handler;
296
297         /*sigset_t - signals to block while in the handler */
298         /* get the old signal mask. */
299         rs = sigprocmask(SIG_SETMASK, 0, &newact.sa_mask);
300         pkey_assert(rs == 0);
301
302         /* call sa_sigaction, not sa_handler*/
303         newact.sa_flags = SA_SIGINFO;
304
305         newact.sa_restorer = 0;  /* void(*)(), obsolete */
306         r = sigaction(signum, &newact, &oldact);
307         r = sigaction(SIGALRM, &newact, &oldact);
308         pkey_assert(r == 0);
309 }
310
311 void setup_handlers(void)
312 {
313         signal(SIGCHLD, &sig_chld);
314         setup_sigsegv_handler();
315 }
316
317 pid_t fork_lazy_child(void)
318 {
319         pid_t forkret;
320
321         forkret = fork();
322         pkey_assert(forkret >= 0);
323         dprintf3("[%d] fork() ret: %d\n", getpid(), forkret);
324
325         if (!forkret) {
326                 /* in the child */
327                 while (1) {
328                         dprintf1("child sleeping...\n");
329                         sleep(30);
330                 }
331         }
332         return forkret;
333 }
334
335 static u32 hw_pkey_get(int pkey, unsigned long flags)
336 {
337         u64 pkey_reg = __read_pkey_reg();
338
339         dprintf1("%s(pkey=%d, flags=%lx) = %x / %d\n",
340                         __func__, pkey, flags, 0, 0);
341         dprintf2("%s() raw pkey_reg: %016llx\n", __func__, pkey_reg);
342
343         return (u32) get_pkey_bits(pkey_reg, pkey);
344 }
345
346 static int hw_pkey_set(int pkey, unsigned long rights, unsigned long flags)
347 {
348         u32 mask = (PKEY_DISABLE_ACCESS|PKEY_DISABLE_WRITE);
349         u64 old_pkey_reg = __read_pkey_reg();
350         u64 new_pkey_reg;
351
352         /* make sure that 'rights' only contains the bits we expect: */
353         assert(!(rights & ~mask));
354
355         /* modify bits accordingly in old pkey_reg and assign it */
356         new_pkey_reg = set_pkey_bits(old_pkey_reg, pkey, rights);
357
358         __write_pkey_reg(new_pkey_reg);
359
360         dprintf3("%s(pkey=%d, rights=%lx, flags=%lx) = %x"
361                 " pkey_reg now: %016llx old_pkey_reg: %016llx\n",
362                 __func__, pkey, rights, flags, 0, __read_pkey_reg(),
363                 old_pkey_reg);
364         return 0;
365 }
366
367 void pkey_disable_set(int pkey, int flags)
368 {
369         unsigned long syscall_flags = 0;
370         int ret;
371         int pkey_rights;
372         u64 orig_pkey_reg = read_pkey_reg();
373
374         dprintf1("START->%s(%d, 0x%x)\n", __func__,
375                 pkey, flags);
376         pkey_assert(flags & (PKEY_DISABLE_ACCESS | PKEY_DISABLE_WRITE));
377
378         pkey_rights = hw_pkey_get(pkey, syscall_flags);
379
380         dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
381                         pkey, pkey, pkey_rights);
382
383         pkey_assert(pkey_rights >= 0);
384
385         pkey_rights |= flags;
386
387         ret = hw_pkey_set(pkey, pkey_rights, syscall_flags);
388         assert(!ret);
389         /* pkey_reg and flags have the same format */
390         shadow_pkey_reg = set_pkey_bits(shadow_pkey_reg, pkey, pkey_rights);
391         dprintf1("%s(%d) shadow: 0x%016llx\n",
392                 __func__, pkey, shadow_pkey_reg);
393
394         pkey_assert(ret >= 0);
395
396         pkey_rights = hw_pkey_get(pkey, syscall_flags);
397         dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
398                         pkey, pkey, pkey_rights);
399
400         dprintf1("%s(%d) pkey_reg: 0x%016llx\n",
401                 __func__, pkey, read_pkey_reg());
402         if (flags)
403                 pkey_assert(read_pkey_reg() > orig_pkey_reg);
404         dprintf1("END<---%s(%d, 0x%x)\n", __func__,
405                 pkey, flags);
406 }
407
408 void pkey_disable_clear(int pkey, int flags)
409 {
410         unsigned long syscall_flags = 0;
411         int ret;
412         int pkey_rights = hw_pkey_get(pkey, syscall_flags);
413         u64 orig_pkey_reg = read_pkey_reg();
414
415         pkey_assert(flags & (PKEY_DISABLE_ACCESS | PKEY_DISABLE_WRITE));
416
417         dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
418                         pkey, pkey, pkey_rights);
419         pkey_assert(pkey_rights >= 0);
420
421         pkey_rights |= flags;
422
423         ret = hw_pkey_set(pkey, pkey_rights, 0);
424         shadow_pkey_reg = set_pkey_bits(shadow_pkey_reg, pkey, pkey_rights);
425         pkey_assert(ret >= 0);
426
427         pkey_rights = hw_pkey_get(pkey, syscall_flags);
428         dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
429                         pkey, pkey, pkey_rights);
430
431         dprintf1("%s(%d) pkey_reg: 0x%016llx\n", __func__,
432                         pkey, read_pkey_reg());
433         if (flags)
434                 assert(read_pkey_reg() > orig_pkey_reg);
435 }
436
437 void pkey_write_allow(int pkey)
438 {
439         pkey_disable_clear(pkey, PKEY_DISABLE_WRITE);
440 }
441 void pkey_write_deny(int pkey)
442 {
443         pkey_disable_set(pkey, PKEY_DISABLE_WRITE);
444 }
445 void pkey_access_allow(int pkey)
446 {
447         pkey_disable_clear(pkey, PKEY_DISABLE_ACCESS);
448 }
449 void pkey_access_deny(int pkey)
450 {
451         pkey_disable_set(pkey, PKEY_DISABLE_ACCESS);
452 }
453
454 int sys_mprotect_pkey(void *ptr, size_t size, unsigned long orig_prot,
455                 unsigned long pkey)
456 {
457         int sret;
458
459         dprintf2("%s(0x%p, %zx, prot=%lx, pkey=%lx)\n", __func__,
460                         ptr, size, orig_prot, pkey);
461
462         errno = 0;
463         sret = syscall(SYS_mprotect_key, ptr, size, orig_prot, pkey);
464         if (errno) {
465                 dprintf2("SYS_mprotect_key sret: %d\n", sret);
466                 dprintf2("SYS_mprotect_key prot: 0x%lx\n", orig_prot);
467                 dprintf2("SYS_mprotect_key failed, errno: %d\n", errno);
468                 if (DEBUG_LEVEL >= 2)
469                         perror("SYS_mprotect_pkey");
470         }
471         return sret;
472 }
473
474 int sys_pkey_alloc(unsigned long flags, unsigned long init_val)
475 {
476         int ret = syscall(SYS_pkey_alloc, flags, init_val);
477         dprintf1("%s(flags=%lx, init_val=%lx) syscall ret: %d errno: %d\n",
478                         __func__, flags, init_val, ret, errno);
479         return ret;
480 }
481
482 int alloc_pkey(void)
483 {
484         int ret;
485         unsigned long init_val = 0x0;
486
487         dprintf1("%s()::%d, pkey_reg: 0x%016llx shadow: %016llx\n",
488                         __func__, __LINE__, __read_pkey_reg(), shadow_pkey_reg);
489         ret = sys_pkey_alloc(0, init_val);
490         /*
491          * pkey_alloc() sets PKEY register, so we need to reflect it in
492          * shadow_pkey_reg:
493          */
494         dprintf4("%s()::%d, ret: %d pkey_reg: 0x%016llx"
495                         " shadow: 0x%016llx\n",
496                         __func__, __LINE__, ret, __read_pkey_reg(),
497                         shadow_pkey_reg);
498         if (ret) {
499                 /* clear both the bits: */
500                 shadow_pkey_reg = set_pkey_bits(shadow_pkey_reg, ret,
501                                                 ~PKEY_MASK);
502                 dprintf4("%s()::%d, ret: %d pkey_reg: 0x%016llx"
503                                 " shadow: 0x%016llx\n",
504                                 __func__,
505                                 __LINE__, ret, __read_pkey_reg(),
506                                 shadow_pkey_reg);
507                 /*
508                  * move the new state in from init_val
509                  * (remember, we cheated and init_val == pkey_reg format)
510                  */
511                 shadow_pkey_reg = set_pkey_bits(shadow_pkey_reg, ret,
512                                                 init_val);
513         }
514         dprintf4("%s()::%d, ret: %d pkey_reg: 0x%016llx"
515                         " shadow: 0x%016llx\n",
516                         __func__, __LINE__, ret, __read_pkey_reg(),
517                         shadow_pkey_reg);
518         dprintf1("%s()::%d errno: %d\n", __func__, __LINE__, errno);
519         /* for shadow checking: */
520         read_pkey_reg();
521         dprintf4("%s()::%d, ret: %d pkey_reg: 0x%016llx"
522                  " shadow: 0x%016llx\n",
523                 __func__, __LINE__, ret, __read_pkey_reg(),
524                 shadow_pkey_reg);
525         return ret;
526 }
527
528 int sys_pkey_free(unsigned long pkey)
529 {
530         int ret = syscall(SYS_pkey_free, pkey);
531         dprintf1("%s(pkey=%ld) syscall ret: %d\n", __func__, pkey, ret);
532         return ret;
533 }
534
535 /*
536  * I had a bug where pkey bits could be set by mprotect() but
537  * not cleared.  This ensures we get lots of random bit sets
538  * and clears on the vma and pte pkey bits.
539  */
540 int alloc_random_pkey(void)
541 {
542         int max_nr_pkey_allocs;
543         int ret;
544         int i;
545         int alloced_pkeys[NR_PKEYS];
546         int nr_alloced = 0;
547         int random_index;
548         memset(alloced_pkeys, 0, sizeof(alloced_pkeys));
549
550         /* allocate every possible key and make a note of which ones we got */
551         max_nr_pkey_allocs = NR_PKEYS;
552         max_nr_pkey_allocs = 1;
553         for (i = 0; i < max_nr_pkey_allocs; i++) {
554                 int new_pkey = alloc_pkey();
555                 if (new_pkey < 0)
556                         break;
557                 alloced_pkeys[nr_alloced++] = new_pkey;
558         }
559
560         pkey_assert(nr_alloced > 0);
561         /* select a random one out of the allocated ones */
562         random_index = rand() % nr_alloced;
563         ret = alloced_pkeys[random_index];
564         /* now zero it out so we don't free it next */
565         alloced_pkeys[random_index] = 0;
566
567         /* go through the allocated ones that we did not want and free them */
568         for (i = 0; i < nr_alloced; i++) {
569                 int free_ret;
570                 if (!alloced_pkeys[i])
571                         continue;
572                 free_ret = sys_pkey_free(alloced_pkeys[i]);
573                 pkey_assert(!free_ret);
574         }
575         dprintf1("%s()::%d, ret: %d pkey_reg: 0x%016llx"
576                          " shadow: 0x%016llx\n", __func__,
577                         __LINE__, ret, __read_pkey_reg(), shadow_pkey_reg);
578         return ret;
579 }
580
581 int mprotect_pkey(void *ptr, size_t size, unsigned long orig_prot,
582                 unsigned long pkey)
583 {
584         int nr_iterations = random() % 100;
585         int ret;
586
587         while (0) {
588                 int rpkey = alloc_random_pkey();
589                 ret = sys_mprotect_pkey(ptr, size, orig_prot, pkey);
590                 dprintf1("sys_mprotect_pkey(%p, %zx, prot=0x%lx, pkey=%ld) ret: %d\n",
591                                 ptr, size, orig_prot, pkey, ret);
592                 if (nr_iterations-- < 0)
593                         break;
594
595                 dprintf1("%s()::%d, ret: %d pkey_reg: 0x%016llx"
596                         " shadow: 0x%016llx\n",
597                         __func__, __LINE__, ret, __read_pkey_reg(),
598                         shadow_pkey_reg);
599                 sys_pkey_free(rpkey);
600                 dprintf1("%s()::%d, ret: %d pkey_reg: 0x%016llx"
601                         " shadow: 0x%016llx\n",
602                         __func__, __LINE__, ret, __read_pkey_reg(),
603                         shadow_pkey_reg);
604         }
605         pkey_assert(pkey < NR_PKEYS);
606
607         ret = sys_mprotect_pkey(ptr, size, orig_prot, pkey);
608         dprintf1("mprotect_pkey(%p, %zx, prot=0x%lx, pkey=%ld) ret: %d\n",
609                         ptr, size, orig_prot, pkey, ret);
610         pkey_assert(!ret);
611         dprintf1("%s()::%d, ret: %d pkey_reg: 0x%016llx"
612                         " shadow: 0x%016llx\n", __func__,
613                         __LINE__, ret, __read_pkey_reg(), shadow_pkey_reg);
614         return ret;
615 }
616
617 struct pkey_malloc_record {
618         void *ptr;
619         long size;
620         int prot;
621 };
622 struct pkey_malloc_record *pkey_malloc_records;
623 struct pkey_malloc_record *pkey_last_malloc_record;
624 long nr_pkey_malloc_records;
625 void record_pkey_malloc(void *ptr, long size, int prot)
626 {
627         long i;
628         struct pkey_malloc_record *rec = NULL;
629
630         for (i = 0; i < nr_pkey_malloc_records; i++) {
631                 rec = &pkey_malloc_records[i];
632                 /* find a free record */
633                 if (rec)
634                         break;
635         }
636         if (!rec) {
637                 /* every record is full */
638                 size_t old_nr_records = nr_pkey_malloc_records;
639                 size_t new_nr_records = (nr_pkey_malloc_records * 2 + 1);
640                 size_t new_size = new_nr_records * sizeof(struct pkey_malloc_record);
641                 dprintf2("new_nr_records: %zd\n", new_nr_records);
642                 dprintf2("new_size: %zd\n", new_size);
643                 pkey_malloc_records = realloc(pkey_malloc_records, new_size);
644                 pkey_assert(pkey_malloc_records != NULL);
645                 rec = &pkey_malloc_records[nr_pkey_malloc_records];
646                 /*
647                  * realloc() does not initialize memory, so zero it from
648                  * the first new record all the way to the end.
649                  */
650                 for (i = 0; i < new_nr_records - old_nr_records; i++)
651                         memset(rec + i, 0, sizeof(*rec));
652         }
653         dprintf3("filling malloc record[%d/%p]: {%p, %ld}\n",
654                 (int)(rec - pkey_malloc_records), rec, ptr, size);
655         rec->ptr = ptr;
656         rec->size = size;
657         rec->prot = prot;
658         pkey_last_malloc_record = rec;
659         nr_pkey_malloc_records++;
660 }
661
662 void free_pkey_malloc(void *ptr)
663 {
664         long i;
665         int ret;
666         dprintf3("%s(%p)\n", __func__, ptr);
667         for (i = 0; i < nr_pkey_malloc_records; i++) {
668                 struct pkey_malloc_record *rec = &pkey_malloc_records[i];
669                 dprintf4("looking for ptr %p at record[%ld/%p]: {%p, %ld}\n",
670                                 ptr, i, rec, rec->ptr, rec->size);
671                 if ((ptr <  rec->ptr) ||
672                     (ptr >= rec->ptr + rec->size))
673                         continue;
674
675                 dprintf3("found ptr %p at record[%ld/%p]: {%p, %ld}\n",
676                                 ptr, i, rec, rec->ptr, rec->size);
677                 nr_pkey_malloc_records--;
678                 ret = munmap(rec->ptr, rec->size);
679                 dprintf3("munmap ret: %d\n", ret);
680                 pkey_assert(!ret);
681                 dprintf3("clearing rec->ptr, rec: %p\n", rec);
682                 rec->ptr = NULL;
683                 dprintf3("done clearing rec->ptr, rec: %p\n", rec);
684                 return;
685         }
686         pkey_assert(false);
687 }
688
689
690 void *malloc_pkey_with_mprotect(long size, int prot, u16 pkey)
691 {
692         void *ptr;
693         int ret;
694
695         read_pkey_reg();
696         dprintf1("doing %s(size=%ld, prot=0x%x, pkey=%d)\n", __func__,
697                         size, prot, pkey);
698         pkey_assert(pkey < NR_PKEYS);
699         ptr = mmap(NULL, size, prot, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
700         pkey_assert(ptr != (void *)-1);
701         ret = mprotect_pkey((void *)ptr, PAGE_SIZE, prot, pkey);
702         pkey_assert(!ret);
703         record_pkey_malloc(ptr, size, prot);
704         read_pkey_reg();
705
706         dprintf1("%s() for pkey %d @ %p\n", __func__, pkey, ptr);
707         return ptr;
708 }
709
710 void *malloc_pkey_anon_huge(long size, int prot, u16 pkey)
711 {
712         int ret;
713         void *ptr;
714
715         dprintf1("doing %s(size=%ld, prot=0x%x, pkey=%d)\n", __func__,
716                         size, prot, pkey);
717         /*
718          * Guarantee we can fit at least one huge page in the resulting
719          * allocation by allocating space for 2:
720          */
721         size = ALIGN_UP(size, HPAGE_SIZE * 2);
722         ptr = mmap(NULL, size, PROT_NONE, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
723         pkey_assert(ptr != (void *)-1);
724         record_pkey_malloc(ptr, size, prot);
725         mprotect_pkey(ptr, size, prot, pkey);
726
727         dprintf1("unaligned ptr: %p\n", ptr);
728         ptr = ALIGN_PTR_UP(ptr, HPAGE_SIZE);
729         dprintf1("  aligned ptr: %p\n", ptr);
730         ret = madvise(ptr, HPAGE_SIZE, MADV_HUGEPAGE);
731         dprintf1("MADV_HUGEPAGE ret: %d\n", ret);
732         ret = madvise(ptr, HPAGE_SIZE, MADV_WILLNEED);
733         dprintf1("MADV_WILLNEED ret: %d\n", ret);
734         memset(ptr, 0, HPAGE_SIZE);
735
736         dprintf1("mmap()'d thp for pkey %d @ %p\n", pkey, ptr);
737         return ptr;
738 }
739
740 int hugetlb_setup_ok;
741 #define GET_NR_HUGE_PAGES 10
742 void setup_hugetlbfs(void)
743 {
744         int err;
745         int fd;
746         char buf[] = "123";
747
748         if (geteuid() != 0) {
749                 fprintf(stderr, "WARNING: not run as root, can not do hugetlb test\n");
750                 return;
751         }
752
753         cat_into_file(__stringify(GET_NR_HUGE_PAGES), "/proc/sys/vm/nr_hugepages");
754
755         /*
756          * Now go make sure that we got the pages and that they
757          * are 2M pages.  Someone might have made 1G the default.
758          */
759         fd = open("/sys/kernel/mm/hugepages/hugepages-2048kB/nr_hugepages", O_RDONLY);
760         if (fd < 0) {
761                 perror("opening sysfs 2M hugetlb config");
762                 return;
763         }
764
765         /* -1 to guarantee leaving the trailing \0 */
766         err = read(fd, buf, sizeof(buf)-1);
767         close(fd);
768         if (err <= 0) {
769                 perror("reading sysfs 2M hugetlb config");
770                 return;
771         }
772
773         if (atoi(buf) != GET_NR_HUGE_PAGES) {
774                 fprintf(stderr, "could not confirm 2M pages, got: '%s' expected %d\n",
775                         buf, GET_NR_HUGE_PAGES);
776                 return;
777         }
778
779         hugetlb_setup_ok = 1;
780 }
781
782 void *malloc_pkey_hugetlb(long size, int prot, u16 pkey)
783 {
784         void *ptr;
785         int flags = MAP_ANONYMOUS|MAP_PRIVATE|MAP_HUGETLB;
786
787         if (!hugetlb_setup_ok)
788                 return PTR_ERR_ENOTSUP;
789
790         dprintf1("doing %s(%ld, %x, %x)\n", __func__, size, prot, pkey);
791         size = ALIGN_UP(size, HPAGE_SIZE * 2);
792         pkey_assert(pkey < NR_PKEYS);
793         ptr = mmap(NULL, size, PROT_NONE, flags, -1, 0);
794         pkey_assert(ptr != (void *)-1);
795         mprotect_pkey(ptr, size, prot, pkey);
796
797         record_pkey_malloc(ptr, size, prot);
798
799         dprintf1("mmap()'d hugetlbfs for pkey %d @ %p\n", pkey, ptr);
800         return ptr;
801 }
802
803 void *malloc_pkey_mmap_dax(long size, int prot, u16 pkey)
804 {
805         void *ptr;
806         int fd;
807
808         dprintf1("doing %s(size=%ld, prot=0x%x, pkey=%d)\n", __func__,
809                         size, prot, pkey);
810         pkey_assert(pkey < NR_PKEYS);
811         fd = open("/dax/foo", O_RDWR);
812         pkey_assert(fd >= 0);
813
814         ptr = mmap(0, size, prot, MAP_SHARED, fd, 0);
815         pkey_assert(ptr != (void *)-1);
816
817         mprotect_pkey(ptr, size, prot, pkey);
818
819         record_pkey_malloc(ptr, size, prot);
820
821         dprintf1("mmap()'d for pkey %d @ %p\n", pkey, ptr);
822         close(fd);
823         return ptr;
824 }
825
826 void *(*pkey_malloc[])(long size, int prot, u16 pkey) = {
827
828         malloc_pkey_with_mprotect,
829         malloc_pkey_anon_huge,
830         malloc_pkey_hugetlb
831 /* can not do direct with the pkey_mprotect() API:
832         malloc_pkey_mmap_direct,
833         malloc_pkey_mmap_dax,
834 */
835 };
836
837 void *malloc_pkey(long size, int prot, u16 pkey)
838 {
839         void *ret;
840         static int malloc_type;
841         int nr_malloc_types = ARRAY_SIZE(pkey_malloc);
842
843         pkey_assert(pkey < NR_PKEYS);
844
845         while (1) {
846                 pkey_assert(malloc_type < nr_malloc_types);
847
848                 ret = pkey_malloc[malloc_type](size, prot, pkey);
849                 pkey_assert(ret != (void *)-1);
850
851                 malloc_type++;
852                 if (malloc_type >= nr_malloc_types)
853                         malloc_type = (random()%nr_malloc_types);
854
855                 /* try again if the malloc_type we tried is unsupported */
856                 if (ret == PTR_ERR_ENOTSUP)
857                         continue;
858
859                 break;
860         }
861
862         dprintf3("%s(%ld, prot=%x, pkey=%x) returning: %p\n", __func__,
863                         size, prot, pkey, ret);
864         return ret;
865 }
866
867 int last_pkey_faults;
868 #define UNKNOWN_PKEY -2
869 void expected_pkey_fault(int pkey)
870 {
871         dprintf2("%s(): last_pkey_faults: %d pkey_faults: %d\n",
872                         __func__, last_pkey_faults, pkey_faults);
873         dprintf2("%s(%d): last_si_pkey: %d\n", __func__, pkey, last_si_pkey);
874         pkey_assert(last_pkey_faults + 1 == pkey_faults);
875
876        /*
877         * For exec-only memory, we do not know the pkey in
878         * advance, so skip this check.
879         */
880         if (pkey != UNKNOWN_PKEY)
881                 pkey_assert(last_si_pkey == pkey);
882
883         /*
884          * The signal handler shold have cleared out PKEY register to let the
885          * test program continue.  We now have to restore it.
886          */
887         if (__read_pkey_reg() != 0)
888                 pkey_assert(0);
889
890         __write_pkey_reg(shadow_pkey_reg);
891         dprintf1("%s() set pkey_reg=%016llx to restore state after signal "
892                        "nuked it\n", __func__, shadow_pkey_reg);
893         last_pkey_faults = pkey_faults;
894         last_si_pkey = -1;
895 }
896
897 #define do_not_expect_pkey_fault(msg)   do {                    \
898         if (last_pkey_faults != pkey_faults)                    \
899                 dprintf0("unexpected PKey fault: %s\n", msg);   \
900         pkey_assert(last_pkey_faults == pkey_faults);           \
901 } while (0)
902
903 int test_fds[10] = { -1 };
904 int nr_test_fds;
905 void __save_test_fd(int fd)
906 {
907         pkey_assert(fd >= 0);
908         pkey_assert(nr_test_fds < ARRAY_SIZE(test_fds));
909         test_fds[nr_test_fds] = fd;
910         nr_test_fds++;
911 }
912
913 int get_test_read_fd(void)
914 {
915         int test_fd = open("/etc/passwd", O_RDONLY);
916         __save_test_fd(test_fd);
917         return test_fd;
918 }
919
920 void close_test_fds(void)
921 {
922         int i;
923
924         for (i = 0; i < nr_test_fds; i++) {
925                 if (test_fds[i] < 0)
926                         continue;
927                 close(test_fds[i]);
928                 test_fds[i] = -1;
929         }
930         nr_test_fds = 0;
931 }
932
933 #define barrier() __asm__ __volatile__("": : :"memory")
934 __attribute__((noinline)) int read_ptr(int *ptr)
935 {
936         /*
937          * Keep GCC from optimizing this away somehow
938          */
939         barrier();
940         return *ptr;
941 }
942
943 void test_read_of_write_disabled_region(int *ptr, u16 pkey)
944 {
945         int ptr_contents;
946
947         dprintf1("disabling write access to PKEY[1], doing read\n");
948         pkey_write_deny(pkey);
949         ptr_contents = read_ptr(ptr);
950         dprintf1("*ptr: %d\n", ptr_contents);
951         dprintf1("\n");
952 }
953 void test_read_of_access_disabled_region(int *ptr, u16 pkey)
954 {
955         int ptr_contents;
956
957         dprintf1("disabling access to PKEY[%02d], doing read @ %p\n", pkey, ptr);
958         read_pkey_reg();
959         pkey_access_deny(pkey);
960         ptr_contents = read_ptr(ptr);
961         dprintf1("*ptr: %d\n", ptr_contents);
962         expected_pkey_fault(pkey);
963 }
964 void test_write_of_write_disabled_region(int *ptr, u16 pkey)
965 {
966         dprintf1("disabling write access to PKEY[%02d], doing write\n", pkey);
967         pkey_write_deny(pkey);
968         *ptr = __LINE__;
969         expected_pkey_fault(pkey);
970 }
971 void test_write_of_access_disabled_region(int *ptr, u16 pkey)
972 {
973         dprintf1("disabling access to PKEY[%02d], doing write\n", pkey);
974         pkey_access_deny(pkey);
975         *ptr = __LINE__;
976         expected_pkey_fault(pkey);
977 }
978 void test_kernel_write_of_access_disabled_region(int *ptr, u16 pkey)
979 {
980         int ret;
981         int test_fd = get_test_read_fd();
982
983         dprintf1("disabling access to PKEY[%02d], "
984                  "having kernel read() to buffer\n", pkey);
985         pkey_access_deny(pkey);
986         ret = read(test_fd, ptr, 1);
987         dprintf1("read ret: %d\n", ret);
988         pkey_assert(ret);
989 }
990 void test_kernel_write_of_write_disabled_region(int *ptr, u16 pkey)
991 {
992         int ret;
993         int test_fd = get_test_read_fd();
994
995         pkey_write_deny(pkey);
996         ret = read(test_fd, ptr, 100);
997         dprintf1("read ret: %d\n", ret);
998         if (ret < 0 && (DEBUG_LEVEL > 0))
999                 perror("verbose read result (OK for this to be bad)");
1000         pkey_assert(ret);
1001 }
1002
1003 void test_kernel_gup_of_access_disabled_region(int *ptr, u16 pkey)
1004 {
1005         int pipe_ret, vmsplice_ret;
1006         struct iovec iov;
1007         int pipe_fds[2];
1008
1009         pipe_ret = pipe(pipe_fds);
1010
1011         pkey_assert(pipe_ret == 0);
1012         dprintf1("disabling access to PKEY[%02d], "
1013                  "having kernel vmsplice from buffer\n", pkey);
1014         pkey_access_deny(pkey);
1015         iov.iov_base = ptr;
1016         iov.iov_len = PAGE_SIZE;
1017         vmsplice_ret = vmsplice(pipe_fds[1], &iov, 1, SPLICE_F_GIFT);
1018         dprintf1("vmsplice() ret: %d\n", vmsplice_ret);
1019         pkey_assert(vmsplice_ret == -1);
1020
1021         close(pipe_fds[0]);
1022         close(pipe_fds[1]);
1023 }
1024
1025 void test_kernel_gup_write_to_write_disabled_region(int *ptr, u16 pkey)
1026 {
1027         int ignored = 0xdada;
1028         int futex_ret;
1029         int some_int = __LINE__;
1030
1031         dprintf1("disabling write to PKEY[%02d], "
1032                  "doing futex gunk in buffer\n", pkey);
1033         *ptr = some_int;
1034         pkey_write_deny(pkey);
1035         futex_ret = syscall(SYS_futex, ptr, FUTEX_WAIT, some_int-1, NULL,
1036                         &ignored, ignored);
1037         if (DEBUG_LEVEL > 0)
1038                 perror("futex");
1039         dprintf1("futex() ret: %d\n", futex_ret);
1040 }
1041
1042 /* Assumes that all pkeys other than 'pkey' are unallocated */
1043 void test_pkey_syscalls_on_non_allocated_pkey(int *ptr, u16 pkey)
1044 {
1045         int err;
1046         int i;
1047
1048         /* Note: 0 is the default pkey, so don't mess with it */
1049         for (i = 1; i < NR_PKEYS; i++) {
1050                 if (pkey == i)
1051                         continue;
1052
1053                 dprintf1("trying get/set/free to non-allocated pkey: %2d\n", i);
1054                 err = sys_pkey_free(i);
1055                 pkey_assert(err);
1056
1057                 err = sys_pkey_free(i);
1058                 pkey_assert(err);
1059
1060                 err = sys_mprotect_pkey(ptr, PAGE_SIZE, PROT_READ, i);
1061                 pkey_assert(err);
1062         }
1063 }
1064
1065 /* Assumes that all pkeys other than 'pkey' are unallocated */
1066 void test_pkey_syscalls_bad_args(int *ptr, u16 pkey)
1067 {
1068         int err;
1069         int bad_pkey = NR_PKEYS+99;
1070
1071         /* pass a known-invalid pkey in: */
1072         err = sys_mprotect_pkey(ptr, PAGE_SIZE, PROT_READ, bad_pkey);
1073         pkey_assert(err);
1074 }
1075
1076 void become_child(void)
1077 {
1078         pid_t forkret;
1079
1080         forkret = fork();
1081         pkey_assert(forkret >= 0);
1082         dprintf3("[%d] fork() ret: %d\n", getpid(), forkret);
1083
1084         if (!forkret) {
1085                 /* in the child */
1086                 return;
1087         }
1088         exit(0);
1089 }
1090
1091 /* Assumes that all pkeys other than 'pkey' are unallocated */
1092 void test_pkey_alloc_exhaust(int *ptr, u16 pkey)
1093 {
1094         int err;
1095         int allocated_pkeys[NR_PKEYS] = {0};
1096         int nr_allocated_pkeys = 0;
1097         int i;
1098
1099         for (i = 0; i < NR_PKEYS*3; i++) {
1100                 int new_pkey;
1101                 dprintf1("%s() alloc loop: %d\n", __func__, i);
1102                 new_pkey = alloc_pkey();
1103                 dprintf4("%s()::%d, err: %d pkey_reg: 0x%016llx"
1104                                 " shadow: 0x%016llx\n",
1105                                 __func__, __LINE__, err, __read_pkey_reg(),
1106                                 shadow_pkey_reg);
1107                 read_pkey_reg(); /* for shadow checking */
1108                 dprintf2("%s() errno: %d ENOSPC: %d\n", __func__, errno, ENOSPC);
1109                 if ((new_pkey == -1) && (errno == ENOSPC)) {
1110                         dprintf2("%s() failed to allocate pkey after %d tries\n",
1111                                 __func__, nr_allocated_pkeys);
1112                 } else {
1113                         /*
1114                          * Ensure the number of successes never
1115                          * exceeds the number of keys supported
1116                          * in the hardware.
1117                          */
1118                         pkey_assert(nr_allocated_pkeys < NR_PKEYS);
1119                         allocated_pkeys[nr_allocated_pkeys++] = new_pkey;
1120                 }
1121
1122                 /*
1123                  * Make sure that allocation state is properly
1124                  * preserved across fork().
1125                  */
1126                 if (i == NR_PKEYS*2)
1127                         become_child();
1128         }
1129
1130         dprintf3("%s()::%d\n", __func__, __LINE__);
1131
1132         /*
1133          * There are 16 pkeys supported in hardware.  Three are
1134          * allocated by the time we get here:
1135          *   1. The default key (0)
1136          *   2. One possibly consumed by an execute-only mapping.
1137          *   3. One allocated by the test code and passed in via
1138          *      'pkey' to this function.
1139          * Ensure that we can allocate at least another 13 (16-3).
1140          */
1141         pkey_assert(i >= NR_PKEYS-3);
1142
1143         for (i = 0; i < nr_allocated_pkeys; i++) {
1144                 err = sys_pkey_free(allocated_pkeys[i]);
1145                 pkey_assert(!err);
1146                 read_pkey_reg(); /* for shadow checking */
1147         }
1148 }
1149
1150 /*
1151  * pkey 0 is special.  It is allocated by default, so you do not
1152  * have to call pkey_alloc() to use it first.  Make sure that it
1153  * is usable.
1154  */
1155 void test_mprotect_with_pkey_0(int *ptr, u16 pkey)
1156 {
1157         long size;
1158         int prot;
1159
1160         assert(pkey_last_malloc_record);
1161         size = pkey_last_malloc_record->size;
1162         /*
1163          * This is a bit of a hack.  But mprotect() requires
1164          * huge-page-aligned sizes when operating on hugetlbfs.
1165          * So, make sure that we use something that's a multiple
1166          * of a huge page when we can.
1167          */
1168         if (size >= HPAGE_SIZE)
1169                 size = HPAGE_SIZE;
1170         prot = pkey_last_malloc_record->prot;
1171
1172         /* Use pkey 0 */
1173         mprotect_pkey(ptr, size, prot, 0);
1174
1175         /* Make sure that we can set it back to the original pkey. */
1176         mprotect_pkey(ptr, size, prot, pkey);
1177 }
1178
1179 void test_ptrace_of_child(int *ptr, u16 pkey)
1180 {
1181         __attribute__((__unused__)) int peek_result;
1182         pid_t child_pid;
1183         void *ignored = 0;
1184         long ret;
1185         int status;
1186         /*
1187          * This is the "control" for our little expermient.  Make sure
1188          * we can always access it when ptracing.
1189          */
1190         int *plain_ptr_unaligned = malloc(HPAGE_SIZE);
1191         int *plain_ptr = ALIGN_PTR_UP(plain_ptr_unaligned, PAGE_SIZE);
1192
1193         /*
1194          * Fork a child which is an exact copy of this process, of course.
1195          * That means we can do all of our tests via ptrace() and then plain
1196          * memory access and ensure they work differently.
1197          */
1198         child_pid = fork_lazy_child();
1199         dprintf1("[%d] child pid: %d\n", getpid(), child_pid);
1200
1201         ret = ptrace(PTRACE_ATTACH, child_pid, ignored, ignored);
1202         if (ret)
1203                 perror("attach");
1204         dprintf1("[%d] attach ret: %ld %d\n", getpid(), ret, __LINE__);
1205         pkey_assert(ret != -1);
1206         ret = waitpid(child_pid, &status, WUNTRACED);
1207         if ((ret != child_pid) || !(WIFSTOPPED(status))) {
1208                 fprintf(stderr, "weird waitpid result %ld stat %x\n",
1209                                 ret, status);
1210                 pkey_assert(0);
1211         }
1212         dprintf2("waitpid ret: %ld\n", ret);
1213         dprintf2("waitpid status: %d\n", status);
1214
1215         pkey_access_deny(pkey);
1216         pkey_write_deny(pkey);
1217
1218         /* Write access, untested for now:
1219         ret = ptrace(PTRACE_POKEDATA, child_pid, peek_at, data);
1220         pkey_assert(ret != -1);
1221         dprintf1("poke at %p: %ld\n", peek_at, ret);
1222         */
1223
1224         /*
1225          * Try to access the pkey-protected "ptr" via ptrace:
1226          */
1227         ret = ptrace(PTRACE_PEEKDATA, child_pid, ptr, ignored);
1228         /* expect it to work, without an error: */
1229         pkey_assert(ret != -1);
1230         /* Now access from the current task, and expect an exception: */
1231         peek_result = read_ptr(ptr);
1232         expected_pkey_fault(pkey);
1233
1234         /*
1235          * Try to access the NON-pkey-protected "plain_ptr" via ptrace:
1236          */
1237         ret = ptrace(PTRACE_PEEKDATA, child_pid, plain_ptr, ignored);
1238         /* expect it to work, without an error: */
1239         pkey_assert(ret != -1);
1240         /* Now access from the current task, and expect NO exception: */
1241         peek_result = read_ptr(plain_ptr);
1242         do_not_expect_pkey_fault("read plain pointer after ptrace");
1243
1244         ret = ptrace(PTRACE_DETACH, child_pid, ignored, 0);
1245         pkey_assert(ret != -1);
1246
1247         ret = kill(child_pid, SIGKILL);
1248         pkey_assert(ret != -1);
1249
1250         wait(&status);
1251
1252         free(plain_ptr_unaligned);
1253 }
1254
1255 void *get_pointer_to_instructions(void)
1256 {
1257         void *p1;
1258
1259         p1 = ALIGN_PTR_UP(&lots_o_noops_around_write, PAGE_SIZE);
1260         dprintf3("&lots_o_noops: %p\n", &lots_o_noops_around_write);
1261         /* lots_o_noops_around_write should be page-aligned already */
1262         assert(p1 == &lots_o_noops_around_write);
1263
1264         /* Point 'p1' at the *second* page of the function: */
1265         p1 += PAGE_SIZE;
1266
1267         /*
1268          * Try to ensure we fault this in on next touch to ensure
1269          * we get an instruction fault as opposed to a data one
1270          */
1271         madvise(p1, PAGE_SIZE, MADV_DONTNEED);
1272
1273         return p1;
1274 }
1275
1276 void test_executing_on_unreadable_memory(int *ptr, u16 pkey)
1277 {
1278         void *p1;
1279         int scratch;
1280         int ptr_contents;
1281         int ret;
1282
1283         p1 = get_pointer_to_instructions();
1284         lots_o_noops_around_write(&scratch);
1285         ptr_contents = read_ptr(p1);
1286         dprintf2("ptr (%p) contents@%d: %x\n", p1, __LINE__, ptr_contents);
1287
1288         ret = mprotect_pkey(p1, PAGE_SIZE, PROT_EXEC, (u64)pkey);
1289         pkey_assert(!ret);
1290         pkey_access_deny(pkey);
1291
1292         dprintf2("pkey_reg: %016llx\n", read_pkey_reg());
1293
1294         /*
1295          * Make sure this is an *instruction* fault
1296          */
1297         madvise(p1, PAGE_SIZE, MADV_DONTNEED);
1298         lots_o_noops_around_write(&scratch);
1299         do_not_expect_pkey_fault("executing on PROT_EXEC memory");
1300         ptr_contents = read_ptr(p1);
1301         dprintf2("ptr (%p) contents@%d: %x\n", p1, __LINE__, ptr_contents);
1302         expected_pkey_fault(pkey);
1303 }
1304
1305 void test_implicit_mprotect_exec_only_memory(int *ptr, u16 pkey)
1306 {
1307         void *p1;
1308         int scratch;
1309         int ptr_contents;
1310         int ret;
1311
1312         dprintf1("%s() start\n", __func__);
1313
1314         p1 = get_pointer_to_instructions();
1315         lots_o_noops_around_write(&scratch);
1316         ptr_contents = read_ptr(p1);
1317         dprintf2("ptr (%p) contents@%d: %x\n", p1, __LINE__, ptr_contents);
1318
1319         /* Use a *normal* mprotect(), not mprotect_pkey(): */
1320         ret = mprotect(p1, PAGE_SIZE, PROT_EXEC);
1321         pkey_assert(!ret);
1322
1323         dprintf2("pkey_reg: %016llx\n", read_pkey_reg());
1324
1325         /* Make sure this is an *instruction* fault */
1326         madvise(p1, PAGE_SIZE, MADV_DONTNEED);
1327         lots_o_noops_around_write(&scratch);
1328         do_not_expect_pkey_fault("executing on PROT_EXEC memory");
1329         ptr_contents = read_ptr(p1);
1330         dprintf2("ptr (%p) contents@%d: %x\n", p1, __LINE__, ptr_contents);
1331         expected_pkey_fault(UNKNOWN_PKEY);
1332
1333         /*
1334          * Put the memory back to non-PROT_EXEC.  Should clear the
1335          * exec-only pkey off the VMA and allow it to be readable
1336          * again.  Go to PROT_NONE first to check for a kernel bug
1337          * that did not clear the pkey when doing PROT_NONE.
1338          */
1339         ret = mprotect(p1, PAGE_SIZE, PROT_NONE);
1340         pkey_assert(!ret);
1341
1342         ret = mprotect(p1, PAGE_SIZE, PROT_READ|PROT_EXEC);
1343         pkey_assert(!ret);
1344         ptr_contents = read_ptr(p1);
1345         do_not_expect_pkey_fault("plain read on recently PROT_EXEC area");
1346 }
1347
1348 void test_mprotect_pkey_on_unsupported_cpu(int *ptr, u16 pkey)
1349 {
1350         int size = PAGE_SIZE;
1351         int sret;
1352
1353         if (cpu_has_pku()) {
1354                 dprintf1("SKIP: %s: no CPU support\n", __func__);
1355                 return;
1356         }
1357
1358         sret = syscall(SYS_mprotect_key, ptr, size, PROT_READ, pkey);
1359         pkey_assert(sret < 0);
1360 }
1361
1362 void (*pkey_tests[])(int *ptr, u16 pkey) = {
1363         test_read_of_write_disabled_region,
1364         test_read_of_access_disabled_region,
1365         test_write_of_write_disabled_region,
1366         test_write_of_access_disabled_region,
1367         test_kernel_write_of_access_disabled_region,
1368         test_kernel_write_of_write_disabled_region,
1369         test_kernel_gup_of_access_disabled_region,
1370         test_kernel_gup_write_to_write_disabled_region,
1371         test_executing_on_unreadable_memory,
1372         test_implicit_mprotect_exec_only_memory,
1373         test_mprotect_with_pkey_0,
1374         test_ptrace_of_child,
1375         test_pkey_syscalls_on_non_allocated_pkey,
1376         test_pkey_syscalls_bad_args,
1377         test_pkey_alloc_exhaust,
1378 };
1379
1380 void run_tests_once(void)
1381 {
1382         int *ptr;
1383         int prot = PROT_READ|PROT_WRITE;
1384
1385         for (test_nr = 0; test_nr < ARRAY_SIZE(pkey_tests); test_nr++) {
1386                 int pkey;
1387                 int orig_pkey_faults = pkey_faults;
1388
1389                 dprintf1("======================\n");
1390                 dprintf1("test %d preparing...\n", test_nr);
1391
1392                 tracing_on();
1393                 pkey = alloc_random_pkey();
1394                 dprintf1("test %d starting with pkey: %d\n", test_nr, pkey);
1395                 ptr = malloc_pkey(PAGE_SIZE, prot, pkey);
1396                 dprintf1("test %d starting...\n", test_nr);
1397                 pkey_tests[test_nr](ptr, pkey);
1398                 dprintf1("freeing test memory: %p\n", ptr);
1399                 free_pkey_malloc(ptr);
1400                 sys_pkey_free(pkey);
1401
1402                 dprintf1("pkey_faults: %d\n", pkey_faults);
1403                 dprintf1("orig_pkey_faults: %d\n", orig_pkey_faults);
1404
1405                 tracing_off();
1406                 close_test_fds();
1407
1408                 printf("test %2d PASSED (iteration %d)\n", test_nr, iteration_nr);
1409                 dprintf1("======================\n\n");
1410         }
1411         iteration_nr++;
1412 }
1413
1414 void pkey_setup_shadow(void)
1415 {
1416         shadow_pkey_reg = __read_pkey_reg();
1417 }
1418
1419 int main(void)
1420 {
1421         int nr_iterations = 22;
1422
1423         setup_handlers();
1424
1425         printf("has pku: %d\n", cpu_has_pku());
1426
1427         if (!cpu_has_pku()) {
1428                 int size = PAGE_SIZE;
1429                 int *ptr;
1430
1431                 printf("running PKEY tests for unsupported CPU/OS\n");
1432
1433                 ptr  = mmap(NULL, size, PROT_NONE, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
1434                 assert(ptr != (void *)-1);
1435                 test_mprotect_pkey_on_unsupported_cpu(ptr, 1);
1436                 exit(0);
1437         }
1438
1439         pkey_setup_shadow();
1440         printf("startup pkey_reg: %016llx\n", read_pkey_reg());
1441         setup_hugetlbfs();
1442
1443         while (nr_iterations-- > 0)
1444                 run_tests_once();
1445
1446         printf("done (all tests OK)\n");
1447         return 0;
1448 }