Merge branch 'akpm' (patches from Andrew)
[linux-2.6-microblaze.git] / arch / arm64 / net / bpf_jit_comp.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * BPF JIT compiler for ARM64
4  *
5  * Copyright (C) 2014-2016 Zi Shen Lim <zlim.lnx@gmail.com>
6  */
7
8 #define pr_fmt(fmt) "bpf_jit: " fmt
9
10 #include <linux/bitfield.h>
11 #include <linux/bpf.h>
12 #include <linux/filter.h>
13 #include <linux/printk.h>
14 #include <linux/slab.h>
15
16 #include <asm/asm-extable.h>
17 #include <asm/byteorder.h>
18 #include <asm/cacheflush.h>
19 #include <asm/debug-monitors.h>
20 #include <asm/insn.h>
21 #include <asm/set_memory.h>
22
23 #include "bpf_jit.h"
24
25 #define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
26 #define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
27 #define TCALL_CNT (MAX_BPF_JIT_REG + 2)
28 #define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
29
30 #define check_imm(bits, imm) do {                               \
31         if ((((imm) > 0) && ((imm) >> (bits))) ||               \
32             (((imm) < 0) && (~(imm) >> (bits)))) {              \
33                 pr_info("[%2d] imm=%d(0x%x) out of range\n",    \
34                         i, imm, imm);                           \
35                 return -EINVAL;                                 \
36         }                                                       \
37 } while (0)
38 #define check_imm19(imm) check_imm(19, imm)
39 #define check_imm26(imm) check_imm(26, imm)
40
41 /* Map BPF registers to A64 registers */
42 static const int bpf2a64[] = {
43         /* return value from in-kernel function, and exit value from eBPF */
44         [BPF_REG_0] = A64_R(7),
45         /* arguments from eBPF program to in-kernel function */
46         [BPF_REG_1] = A64_R(0),
47         [BPF_REG_2] = A64_R(1),
48         [BPF_REG_3] = A64_R(2),
49         [BPF_REG_4] = A64_R(3),
50         [BPF_REG_5] = A64_R(4),
51         /* callee saved registers that in-kernel function will preserve */
52         [BPF_REG_6] = A64_R(19),
53         [BPF_REG_7] = A64_R(20),
54         [BPF_REG_8] = A64_R(21),
55         [BPF_REG_9] = A64_R(22),
56         /* read-only frame pointer to access stack */
57         [BPF_REG_FP] = A64_R(25),
58         /* temporary registers for BPF JIT */
59         [TMP_REG_1] = A64_R(10),
60         [TMP_REG_2] = A64_R(11),
61         [TMP_REG_3] = A64_R(12),
62         /* tail_call_cnt */
63         [TCALL_CNT] = A64_R(26),
64         /* temporary register for blinding constants */
65         [BPF_REG_AX] = A64_R(9),
66 };
67
68 struct jit_ctx {
69         const struct bpf_prog *prog;
70         int idx;
71         int epilogue_offset;
72         int *offset;
73         int exentry_idx;
74         __le32 *image;
75         u32 stack_size;
76 };
77
78 static inline void emit(const u32 insn, struct jit_ctx *ctx)
79 {
80         if (ctx->image != NULL)
81                 ctx->image[ctx->idx] = cpu_to_le32(insn);
82
83         ctx->idx++;
84 }
85
86 static inline void emit_a64_mov_i(const int is64, const int reg,
87                                   const s32 val, struct jit_ctx *ctx)
88 {
89         u16 hi = val >> 16;
90         u16 lo = val & 0xffff;
91
92         if (hi & 0x8000) {
93                 if (hi == 0xffff) {
94                         emit(A64_MOVN(is64, reg, (u16)~lo, 0), ctx);
95                 } else {
96                         emit(A64_MOVN(is64, reg, (u16)~hi, 16), ctx);
97                         if (lo != 0xffff)
98                                 emit(A64_MOVK(is64, reg, lo, 0), ctx);
99                 }
100         } else {
101                 emit(A64_MOVZ(is64, reg, lo, 0), ctx);
102                 if (hi)
103                         emit(A64_MOVK(is64, reg, hi, 16), ctx);
104         }
105 }
106
107 static int i64_i16_blocks(const u64 val, bool inverse)
108 {
109         return (((val >>  0) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
110                (((val >> 16) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
111                (((val >> 32) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
112                (((val >> 48) & 0xffff) != (inverse ? 0xffff : 0x0000));
113 }
114
115 static inline void emit_a64_mov_i64(const int reg, const u64 val,
116                                     struct jit_ctx *ctx)
117 {
118         u64 nrm_tmp = val, rev_tmp = ~val;
119         bool inverse;
120         int shift;
121
122         if (!(nrm_tmp >> 32))
123                 return emit_a64_mov_i(0, reg, (u32)val, ctx);
124
125         inverse = i64_i16_blocks(nrm_tmp, true) < i64_i16_blocks(nrm_tmp, false);
126         shift = max(round_down((inverse ? (fls64(rev_tmp) - 1) :
127                                           (fls64(nrm_tmp) - 1)), 16), 0);
128         if (inverse)
129                 emit(A64_MOVN(1, reg, (rev_tmp >> shift) & 0xffff, shift), ctx);
130         else
131                 emit(A64_MOVZ(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
132         shift -= 16;
133         while (shift >= 0) {
134                 if (((nrm_tmp >> shift) & 0xffff) != (inverse ? 0xffff : 0x0000))
135                         emit(A64_MOVK(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
136                 shift -= 16;
137         }
138 }
139
140 /*
141  * Kernel addresses in the vmalloc space use at most 48 bits, and the
142  * remaining bits are guaranteed to be 0x1. So we can compose the address
143  * with a fixed length movn/movk/movk sequence.
144  */
145 static inline void emit_addr_mov_i64(const int reg, const u64 val,
146                                      struct jit_ctx *ctx)
147 {
148         u64 tmp = val;
149         int shift = 0;
150
151         emit(A64_MOVN(1, reg, ~tmp & 0xffff, shift), ctx);
152         while (shift < 32) {
153                 tmp >>= 16;
154                 shift += 16;
155                 emit(A64_MOVK(1, reg, tmp & 0xffff, shift), ctx);
156         }
157 }
158
159 static inline int bpf2a64_offset(int bpf_insn, int off,
160                                  const struct jit_ctx *ctx)
161 {
162         /* BPF JMP offset is relative to the next instruction */
163         bpf_insn++;
164         /*
165          * Whereas arm64 branch instructions encode the offset
166          * from the branch itself, so we must subtract 1 from the
167          * instruction offset.
168          */
169         return ctx->offset[bpf_insn + off] - (ctx->offset[bpf_insn] - 1);
170 }
171
172 static void jit_fill_hole(void *area, unsigned int size)
173 {
174         __le32 *ptr;
175         /* We are guaranteed to have aligned memory. */
176         for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
177                 *ptr++ = cpu_to_le32(AARCH64_BREAK_FAULT);
178 }
179
180 static inline int epilogue_offset(const struct jit_ctx *ctx)
181 {
182         int to = ctx->epilogue_offset;
183         int from = ctx->idx;
184
185         return to - from;
186 }
187
188 static bool is_addsub_imm(u32 imm)
189 {
190         /* Either imm12 or shifted imm12. */
191         return !(imm & ~0xfff) || !(imm & ~0xfff000);
192 }
193
194 /* Tail call offset to jump into */
195 #if IS_ENABLED(CONFIG_ARM64_BTI_KERNEL)
196 #define PROLOGUE_OFFSET 8
197 #else
198 #define PROLOGUE_OFFSET 7
199 #endif
200
201 static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
202 {
203         const struct bpf_prog *prog = ctx->prog;
204         const u8 r6 = bpf2a64[BPF_REG_6];
205         const u8 r7 = bpf2a64[BPF_REG_7];
206         const u8 r8 = bpf2a64[BPF_REG_8];
207         const u8 r9 = bpf2a64[BPF_REG_9];
208         const u8 fp = bpf2a64[BPF_REG_FP];
209         const u8 tcc = bpf2a64[TCALL_CNT];
210         const int idx0 = ctx->idx;
211         int cur_offset;
212
213         /*
214          * BPF prog stack layout
215          *
216          *                         high
217          * original A64_SP =>   0:+-----+ BPF prologue
218          *                        |FP/LR|
219          * current A64_FP =>  -16:+-----+
220          *                        | ... | callee saved registers
221          * BPF fp register => -64:+-----+ <= (BPF_FP)
222          *                        |     |
223          *                        | ... | BPF prog stack
224          *                        |     |
225          *                        +-----+ <= (BPF_FP - prog->aux->stack_depth)
226          *                        |RSVD | padding
227          * current A64_SP =>      +-----+ <= (BPF_FP - ctx->stack_size)
228          *                        |     |
229          *                        | ... | Function call stack
230          *                        |     |
231          *                        +-----+
232          *                          low
233          *
234          */
235
236         /* BTI landing pad */
237         if (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL))
238                 emit(A64_BTI_C, ctx);
239
240         /* Save FP and LR registers to stay align with ARM64 AAPCS */
241         emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
242         emit(A64_MOV(1, A64_FP, A64_SP), ctx);
243
244         /* Save callee-saved registers */
245         emit(A64_PUSH(r6, r7, A64_SP), ctx);
246         emit(A64_PUSH(r8, r9, A64_SP), ctx);
247         emit(A64_PUSH(fp, tcc, A64_SP), ctx);
248
249         /* Set up BPF prog stack base register */
250         emit(A64_MOV(1, fp, A64_SP), ctx);
251
252         if (!ebpf_from_cbpf) {
253                 /* Initialize tail_call_cnt */
254                 emit(A64_MOVZ(1, tcc, 0, 0), ctx);
255
256                 cur_offset = ctx->idx - idx0;
257                 if (cur_offset != PROLOGUE_OFFSET) {
258                         pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
259                                     cur_offset, PROLOGUE_OFFSET);
260                         return -1;
261                 }
262
263                 /* BTI landing pad for the tail call, done with a BR */
264                 if (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL))
265                         emit(A64_BTI_J, ctx);
266         }
267
268         /* Stack must be multiples of 16B */
269         ctx->stack_size = round_up(prog->aux->stack_depth, 16);
270
271         /* Set up function call stack */
272         emit(A64_SUB_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
273         return 0;
274 }
275
276 static int out_offset = -1; /* initialized on the first pass of build_body() */
277 static int emit_bpf_tail_call(struct jit_ctx *ctx)
278 {
279         /* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */
280         const u8 r2 = bpf2a64[BPF_REG_2];
281         const u8 r3 = bpf2a64[BPF_REG_3];
282
283         const u8 tmp = bpf2a64[TMP_REG_1];
284         const u8 prg = bpf2a64[TMP_REG_2];
285         const u8 tcc = bpf2a64[TCALL_CNT];
286         const int idx0 = ctx->idx;
287 #define cur_offset (ctx->idx - idx0)
288 #define jmp_offset (out_offset - (cur_offset))
289         size_t off;
290
291         /* if (index >= array->map.max_entries)
292          *     goto out;
293          */
294         off = offsetof(struct bpf_array, map.max_entries);
295         emit_a64_mov_i64(tmp, off, ctx);
296         emit(A64_LDR32(tmp, r2, tmp), ctx);
297         emit(A64_MOV(0, r3, r3), ctx);
298         emit(A64_CMP(0, r3, tmp), ctx);
299         emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
300
301         /*
302          * if (tail_call_cnt >= MAX_TAIL_CALL_CNT)
303          *     goto out;
304          * tail_call_cnt++;
305          */
306         emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
307         emit(A64_CMP(1, tcc, tmp), ctx);
308         emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
309         emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
310
311         /* prog = array->ptrs[index];
312          * if (prog == NULL)
313          *     goto out;
314          */
315         off = offsetof(struct bpf_array, ptrs);
316         emit_a64_mov_i64(tmp, off, ctx);
317         emit(A64_ADD(1, tmp, r2, tmp), ctx);
318         emit(A64_LSL(1, prg, r3, 3), ctx);
319         emit(A64_LDR64(prg, tmp, prg), ctx);
320         emit(A64_CBZ(1, prg, jmp_offset), ctx);
321
322         /* goto *(prog->bpf_func + prologue_offset); */
323         off = offsetof(struct bpf_prog, bpf_func);
324         emit_a64_mov_i64(tmp, off, ctx);
325         emit(A64_LDR64(tmp, prg, tmp), ctx);
326         emit(A64_ADD_I(1, tmp, tmp, sizeof(u32) * PROLOGUE_OFFSET), ctx);
327         emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
328         emit(A64_BR(tmp), ctx);
329
330         /* out: */
331         if (out_offset == -1)
332                 out_offset = cur_offset;
333         if (cur_offset != out_offset) {
334                 pr_err_once("tail_call out_offset = %d, expected %d!\n",
335                             cur_offset, out_offset);
336                 return -1;
337         }
338         return 0;
339 #undef cur_offset
340 #undef jmp_offset
341 }
342
343 #ifdef CONFIG_ARM64_LSE_ATOMICS
344 static int emit_lse_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
345 {
346         const u8 code = insn->code;
347         const u8 dst = bpf2a64[insn->dst_reg];
348         const u8 src = bpf2a64[insn->src_reg];
349         const u8 tmp = bpf2a64[TMP_REG_1];
350         const u8 tmp2 = bpf2a64[TMP_REG_2];
351         const bool isdw = BPF_SIZE(code) == BPF_DW;
352         const s16 off = insn->off;
353         u8 reg;
354
355         if (!off) {
356                 reg = dst;
357         } else {
358                 emit_a64_mov_i(1, tmp, off, ctx);
359                 emit(A64_ADD(1, tmp, tmp, dst), ctx);
360                 reg = tmp;
361         }
362
363         switch (insn->imm) {
364         /* lock *(u32/u64 *)(dst_reg + off) <op>= src_reg */
365         case BPF_ADD:
366                 emit(A64_STADD(isdw, reg, src), ctx);
367                 break;
368         case BPF_AND:
369                 emit(A64_MVN(isdw, tmp2, src), ctx);
370                 emit(A64_STCLR(isdw, reg, tmp2), ctx);
371                 break;
372         case BPF_OR:
373                 emit(A64_STSET(isdw, reg, src), ctx);
374                 break;
375         case BPF_XOR:
376                 emit(A64_STEOR(isdw, reg, src), ctx);
377                 break;
378         /* src_reg = atomic_fetch_<op>(dst_reg + off, src_reg) */
379         case BPF_ADD | BPF_FETCH:
380                 emit(A64_LDADDAL(isdw, src, reg, src), ctx);
381                 break;
382         case BPF_AND | BPF_FETCH:
383                 emit(A64_MVN(isdw, tmp2, src), ctx);
384                 emit(A64_LDCLRAL(isdw, src, reg, tmp2), ctx);
385                 break;
386         case BPF_OR | BPF_FETCH:
387                 emit(A64_LDSETAL(isdw, src, reg, src), ctx);
388                 break;
389         case BPF_XOR | BPF_FETCH:
390                 emit(A64_LDEORAL(isdw, src, reg, src), ctx);
391                 break;
392         /* src_reg = atomic_xchg(dst_reg + off, src_reg); */
393         case BPF_XCHG:
394                 emit(A64_SWPAL(isdw, src, reg, src), ctx);
395                 break;
396         /* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */
397         case BPF_CMPXCHG:
398                 emit(A64_CASAL(isdw, src, reg, bpf2a64[BPF_REG_0]), ctx);
399                 break;
400         default:
401                 pr_err_once("unknown atomic op code %02x\n", insn->imm);
402                 return -EINVAL;
403         }
404
405         return 0;
406 }
407 #else
408 static inline int emit_lse_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
409 {
410         return -EINVAL;
411 }
412 #endif
413
414 static int emit_ll_sc_atomic(const struct bpf_insn *insn, struct jit_ctx *ctx)
415 {
416         const u8 code = insn->code;
417         const u8 dst = bpf2a64[insn->dst_reg];
418         const u8 src = bpf2a64[insn->src_reg];
419         const u8 tmp = bpf2a64[TMP_REG_1];
420         const u8 tmp2 = bpf2a64[TMP_REG_2];
421         const u8 tmp3 = bpf2a64[TMP_REG_3];
422         const int i = insn - ctx->prog->insnsi;
423         const s32 imm = insn->imm;
424         const s16 off = insn->off;
425         const bool isdw = BPF_SIZE(code) == BPF_DW;
426         u8 reg;
427         s32 jmp_offset;
428
429         if (!off) {
430                 reg = dst;
431         } else {
432                 emit_a64_mov_i(1, tmp, off, ctx);
433                 emit(A64_ADD(1, tmp, tmp, dst), ctx);
434                 reg = tmp;
435         }
436
437         if (imm == BPF_ADD || imm == BPF_AND ||
438             imm == BPF_OR || imm == BPF_XOR) {
439                 /* lock *(u32/u64 *)(dst_reg + off) <op>= src_reg */
440                 emit(A64_LDXR(isdw, tmp2, reg), ctx);
441                 if (imm == BPF_ADD)
442                         emit(A64_ADD(isdw, tmp2, tmp2, src), ctx);
443                 else if (imm == BPF_AND)
444                         emit(A64_AND(isdw, tmp2, tmp2, src), ctx);
445                 else if (imm == BPF_OR)
446                         emit(A64_ORR(isdw, tmp2, tmp2, src), ctx);
447                 else
448                         emit(A64_EOR(isdw, tmp2, tmp2, src), ctx);
449                 emit(A64_STXR(isdw, tmp2, reg, tmp3), ctx);
450                 jmp_offset = -3;
451                 check_imm19(jmp_offset);
452                 emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
453         } else if (imm == (BPF_ADD | BPF_FETCH) ||
454                    imm == (BPF_AND | BPF_FETCH) ||
455                    imm == (BPF_OR | BPF_FETCH) ||
456                    imm == (BPF_XOR | BPF_FETCH)) {
457                 /* src_reg = atomic_fetch_<op>(dst_reg + off, src_reg) */
458                 const u8 ax = bpf2a64[BPF_REG_AX];
459
460                 emit(A64_MOV(isdw, ax, src), ctx);
461                 emit(A64_LDXR(isdw, src, reg), ctx);
462                 if (imm == (BPF_ADD | BPF_FETCH))
463                         emit(A64_ADD(isdw, tmp2, src, ax), ctx);
464                 else if (imm == (BPF_AND | BPF_FETCH))
465                         emit(A64_AND(isdw, tmp2, src, ax), ctx);
466                 else if (imm == (BPF_OR | BPF_FETCH))
467                         emit(A64_ORR(isdw, tmp2, src, ax), ctx);
468                 else
469                         emit(A64_EOR(isdw, tmp2, src, ax), ctx);
470                 emit(A64_STLXR(isdw, tmp2, reg, tmp3), ctx);
471                 jmp_offset = -3;
472                 check_imm19(jmp_offset);
473                 emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
474                 emit(A64_DMB_ISH, ctx);
475         } else if (imm == BPF_XCHG) {
476                 /* src_reg = atomic_xchg(dst_reg + off, src_reg); */
477                 emit(A64_MOV(isdw, tmp2, src), ctx);
478                 emit(A64_LDXR(isdw, src, reg), ctx);
479                 emit(A64_STLXR(isdw, tmp2, reg, tmp3), ctx);
480                 jmp_offset = -2;
481                 check_imm19(jmp_offset);
482                 emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
483                 emit(A64_DMB_ISH, ctx);
484         } else if (imm == BPF_CMPXCHG) {
485                 /* r0 = atomic_cmpxchg(dst_reg + off, r0, src_reg); */
486                 const u8 r0 = bpf2a64[BPF_REG_0];
487
488                 emit(A64_MOV(isdw, tmp2, r0), ctx);
489                 emit(A64_LDXR(isdw, r0, reg), ctx);
490                 emit(A64_EOR(isdw, tmp3, r0, tmp2), ctx);
491                 jmp_offset = 4;
492                 check_imm19(jmp_offset);
493                 emit(A64_CBNZ(isdw, tmp3, jmp_offset), ctx);
494                 emit(A64_STLXR(isdw, src, reg, tmp3), ctx);
495                 jmp_offset = -4;
496                 check_imm19(jmp_offset);
497                 emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
498                 emit(A64_DMB_ISH, ctx);
499         } else {
500                 pr_err_once("unknown atomic op code %02x\n", imm);
501                 return -EINVAL;
502         }
503
504         return 0;
505 }
506
507 static void build_epilogue(struct jit_ctx *ctx)
508 {
509         const u8 r0 = bpf2a64[BPF_REG_0];
510         const u8 r6 = bpf2a64[BPF_REG_6];
511         const u8 r7 = bpf2a64[BPF_REG_7];
512         const u8 r8 = bpf2a64[BPF_REG_8];
513         const u8 r9 = bpf2a64[BPF_REG_9];
514         const u8 fp = bpf2a64[BPF_REG_FP];
515
516         /* We're done with BPF stack */
517         emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
518
519         /* Restore fs (x25) and x26 */
520         emit(A64_POP(fp, A64_R(26), A64_SP), ctx);
521
522         /* Restore callee-saved register */
523         emit(A64_POP(r8, r9, A64_SP), ctx);
524         emit(A64_POP(r6, r7, A64_SP), ctx);
525
526         /* Restore FP/LR registers */
527         emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
528
529         /* Set return value */
530         emit(A64_MOV(1, A64_R(0), r0), ctx);
531
532         emit(A64_RET(A64_LR), ctx);
533 }
534
535 #define BPF_FIXUP_OFFSET_MASK   GENMASK(26, 0)
536 #define BPF_FIXUP_REG_MASK      GENMASK(31, 27)
537
538 bool ex_handler_bpf(const struct exception_table_entry *ex,
539                     struct pt_regs *regs)
540 {
541         off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
542         int dst_reg = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
543
544         regs->regs[dst_reg] = 0;
545         regs->pc = (unsigned long)&ex->fixup - offset;
546         return true;
547 }
548
549 /* For accesses to BTF pointers, add an entry to the exception table */
550 static int add_exception_handler(const struct bpf_insn *insn,
551                                  struct jit_ctx *ctx,
552                                  int dst_reg)
553 {
554         off_t offset;
555         unsigned long pc;
556         struct exception_table_entry *ex;
557
558         if (!ctx->image)
559                 /* First pass */
560                 return 0;
561
562         if (BPF_MODE(insn->code) != BPF_PROBE_MEM)
563                 return 0;
564
565         if (!ctx->prog->aux->extable ||
566             WARN_ON_ONCE(ctx->exentry_idx >= ctx->prog->aux->num_exentries))
567                 return -EINVAL;
568
569         ex = &ctx->prog->aux->extable[ctx->exentry_idx];
570         pc = (unsigned long)&ctx->image[ctx->idx - 1];
571
572         offset = pc - (long)&ex->insn;
573         if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN))
574                 return -ERANGE;
575         ex->insn = offset;
576
577         /*
578          * Since the extable follows the program, the fixup offset is always
579          * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
580          * to keep things simple, and put the destination register in the upper
581          * bits. We don't need to worry about buildtime or runtime sort
582          * modifying the upper bits because the table is already sorted, and
583          * isn't part of the main exception table.
584          */
585         offset = (long)&ex->fixup - (pc + AARCH64_INSN_SIZE);
586         if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset))
587                 return -ERANGE;
588
589         ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) |
590                     FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
591
592         ex->type = EX_TYPE_BPF;
593
594         ctx->exentry_idx++;
595         return 0;
596 }
597
598 /* JITs an eBPF instruction.
599  * Returns:
600  * 0  - successfully JITed an 8-byte eBPF instruction.
601  * >0 - successfully JITed a 16-byte eBPF instruction.
602  * <0 - failed to JIT.
603  */
604 static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
605                       bool extra_pass)
606 {
607         const u8 code = insn->code;
608         const u8 dst = bpf2a64[insn->dst_reg];
609         const u8 src = bpf2a64[insn->src_reg];
610         const u8 tmp = bpf2a64[TMP_REG_1];
611         const u8 tmp2 = bpf2a64[TMP_REG_2];
612         const s16 off = insn->off;
613         const s32 imm = insn->imm;
614         const int i = insn - ctx->prog->insnsi;
615         const bool is64 = BPF_CLASS(code) == BPF_ALU64 ||
616                           BPF_CLASS(code) == BPF_JMP;
617         u8 jmp_cond;
618         s32 jmp_offset;
619         u32 a64_insn;
620         int ret;
621
622         switch (code) {
623         /* dst = src */
624         case BPF_ALU | BPF_MOV | BPF_X:
625         case BPF_ALU64 | BPF_MOV | BPF_X:
626                 emit(A64_MOV(is64, dst, src), ctx);
627                 break;
628         /* dst = dst OP src */
629         case BPF_ALU | BPF_ADD | BPF_X:
630         case BPF_ALU64 | BPF_ADD | BPF_X:
631                 emit(A64_ADD(is64, dst, dst, src), ctx);
632                 break;
633         case BPF_ALU | BPF_SUB | BPF_X:
634         case BPF_ALU64 | BPF_SUB | BPF_X:
635                 emit(A64_SUB(is64, dst, dst, src), ctx);
636                 break;
637         case BPF_ALU | BPF_AND | BPF_X:
638         case BPF_ALU64 | BPF_AND | BPF_X:
639                 emit(A64_AND(is64, dst, dst, src), ctx);
640                 break;
641         case BPF_ALU | BPF_OR | BPF_X:
642         case BPF_ALU64 | BPF_OR | BPF_X:
643                 emit(A64_ORR(is64, dst, dst, src), ctx);
644                 break;
645         case BPF_ALU | BPF_XOR | BPF_X:
646         case BPF_ALU64 | BPF_XOR | BPF_X:
647                 emit(A64_EOR(is64, dst, dst, src), ctx);
648                 break;
649         case BPF_ALU | BPF_MUL | BPF_X:
650         case BPF_ALU64 | BPF_MUL | BPF_X:
651                 emit(A64_MUL(is64, dst, dst, src), ctx);
652                 break;
653         case BPF_ALU | BPF_DIV | BPF_X:
654         case BPF_ALU64 | BPF_DIV | BPF_X:
655                 emit(A64_UDIV(is64, dst, dst, src), ctx);
656                 break;
657         case BPF_ALU | BPF_MOD | BPF_X:
658         case BPF_ALU64 | BPF_MOD | BPF_X:
659                 emit(A64_UDIV(is64, tmp, dst, src), ctx);
660                 emit(A64_MSUB(is64, dst, dst, tmp, src), ctx);
661                 break;
662         case BPF_ALU | BPF_LSH | BPF_X:
663         case BPF_ALU64 | BPF_LSH | BPF_X:
664                 emit(A64_LSLV(is64, dst, dst, src), ctx);
665                 break;
666         case BPF_ALU | BPF_RSH | BPF_X:
667         case BPF_ALU64 | BPF_RSH | BPF_X:
668                 emit(A64_LSRV(is64, dst, dst, src), ctx);
669                 break;
670         case BPF_ALU | BPF_ARSH | BPF_X:
671         case BPF_ALU64 | BPF_ARSH | BPF_X:
672                 emit(A64_ASRV(is64, dst, dst, src), ctx);
673                 break;
674         /* dst = -dst */
675         case BPF_ALU | BPF_NEG:
676         case BPF_ALU64 | BPF_NEG:
677                 emit(A64_NEG(is64, dst, dst), ctx);
678                 break;
679         /* dst = BSWAP##imm(dst) */
680         case BPF_ALU | BPF_END | BPF_FROM_LE:
681         case BPF_ALU | BPF_END | BPF_FROM_BE:
682 #ifdef CONFIG_CPU_BIG_ENDIAN
683                 if (BPF_SRC(code) == BPF_FROM_BE)
684                         goto emit_bswap_uxt;
685 #else /* !CONFIG_CPU_BIG_ENDIAN */
686                 if (BPF_SRC(code) == BPF_FROM_LE)
687                         goto emit_bswap_uxt;
688 #endif
689                 switch (imm) {
690                 case 16:
691                         emit(A64_REV16(is64, dst, dst), ctx);
692                         /* zero-extend 16 bits into 64 bits */
693                         emit(A64_UXTH(is64, dst, dst), ctx);
694                         break;
695                 case 32:
696                         emit(A64_REV32(is64, dst, dst), ctx);
697                         /* upper 32 bits already cleared */
698                         break;
699                 case 64:
700                         emit(A64_REV64(dst, dst), ctx);
701                         break;
702                 }
703                 break;
704 emit_bswap_uxt:
705                 switch (imm) {
706                 case 16:
707                         /* zero-extend 16 bits into 64 bits */
708                         emit(A64_UXTH(is64, dst, dst), ctx);
709                         break;
710                 case 32:
711                         /* zero-extend 32 bits into 64 bits */
712                         emit(A64_UXTW(is64, dst, dst), ctx);
713                         break;
714                 case 64:
715                         /* nop */
716                         break;
717                 }
718                 break;
719         /* dst = imm */
720         case BPF_ALU | BPF_MOV | BPF_K:
721         case BPF_ALU64 | BPF_MOV | BPF_K:
722                 emit_a64_mov_i(is64, dst, imm, ctx);
723                 break;
724         /* dst = dst OP imm */
725         case BPF_ALU | BPF_ADD | BPF_K:
726         case BPF_ALU64 | BPF_ADD | BPF_K:
727                 if (is_addsub_imm(imm)) {
728                         emit(A64_ADD_I(is64, dst, dst, imm), ctx);
729                 } else if (is_addsub_imm(-imm)) {
730                         emit(A64_SUB_I(is64, dst, dst, -imm), ctx);
731                 } else {
732                         emit_a64_mov_i(is64, tmp, imm, ctx);
733                         emit(A64_ADD(is64, dst, dst, tmp), ctx);
734                 }
735                 break;
736         case BPF_ALU | BPF_SUB | BPF_K:
737         case BPF_ALU64 | BPF_SUB | BPF_K:
738                 if (is_addsub_imm(imm)) {
739                         emit(A64_SUB_I(is64, dst, dst, imm), ctx);
740                 } else if (is_addsub_imm(-imm)) {
741                         emit(A64_ADD_I(is64, dst, dst, -imm), ctx);
742                 } else {
743                         emit_a64_mov_i(is64, tmp, imm, ctx);
744                         emit(A64_SUB(is64, dst, dst, tmp), ctx);
745                 }
746                 break;
747         case BPF_ALU | BPF_AND | BPF_K:
748         case BPF_ALU64 | BPF_AND | BPF_K:
749                 a64_insn = A64_AND_I(is64, dst, dst, imm);
750                 if (a64_insn != AARCH64_BREAK_FAULT) {
751                         emit(a64_insn, ctx);
752                 } else {
753                         emit_a64_mov_i(is64, tmp, imm, ctx);
754                         emit(A64_AND(is64, dst, dst, tmp), ctx);
755                 }
756                 break;
757         case BPF_ALU | BPF_OR | BPF_K:
758         case BPF_ALU64 | BPF_OR | BPF_K:
759                 a64_insn = A64_ORR_I(is64, dst, dst, imm);
760                 if (a64_insn != AARCH64_BREAK_FAULT) {
761                         emit(a64_insn, ctx);
762                 } else {
763                         emit_a64_mov_i(is64, tmp, imm, ctx);
764                         emit(A64_ORR(is64, dst, dst, tmp), ctx);
765                 }
766                 break;
767         case BPF_ALU | BPF_XOR | BPF_K:
768         case BPF_ALU64 | BPF_XOR | BPF_K:
769                 a64_insn = A64_EOR_I(is64, dst, dst, imm);
770                 if (a64_insn != AARCH64_BREAK_FAULT) {
771                         emit(a64_insn, ctx);
772                 } else {
773                         emit_a64_mov_i(is64, tmp, imm, ctx);
774                         emit(A64_EOR(is64, dst, dst, tmp), ctx);
775                 }
776                 break;
777         case BPF_ALU | BPF_MUL | BPF_K:
778         case BPF_ALU64 | BPF_MUL | BPF_K:
779                 emit_a64_mov_i(is64, tmp, imm, ctx);
780                 emit(A64_MUL(is64, dst, dst, tmp), ctx);
781                 break;
782         case BPF_ALU | BPF_DIV | BPF_K:
783         case BPF_ALU64 | BPF_DIV | BPF_K:
784                 emit_a64_mov_i(is64, tmp, imm, ctx);
785                 emit(A64_UDIV(is64, dst, dst, tmp), ctx);
786                 break;
787         case BPF_ALU | BPF_MOD | BPF_K:
788         case BPF_ALU64 | BPF_MOD | BPF_K:
789                 emit_a64_mov_i(is64, tmp2, imm, ctx);
790                 emit(A64_UDIV(is64, tmp, dst, tmp2), ctx);
791                 emit(A64_MSUB(is64, dst, dst, tmp, tmp2), ctx);
792                 break;
793         case BPF_ALU | BPF_LSH | BPF_K:
794         case BPF_ALU64 | BPF_LSH | BPF_K:
795                 emit(A64_LSL(is64, dst, dst, imm), ctx);
796                 break;
797         case BPF_ALU | BPF_RSH | BPF_K:
798         case BPF_ALU64 | BPF_RSH | BPF_K:
799                 emit(A64_LSR(is64, dst, dst, imm), ctx);
800                 break;
801         case BPF_ALU | BPF_ARSH | BPF_K:
802         case BPF_ALU64 | BPF_ARSH | BPF_K:
803                 emit(A64_ASR(is64, dst, dst, imm), ctx);
804                 break;
805
806         /* JUMP off */
807         case BPF_JMP | BPF_JA:
808                 jmp_offset = bpf2a64_offset(i, off, ctx);
809                 check_imm26(jmp_offset);
810                 emit(A64_B(jmp_offset), ctx);
811                 break;
812         /* IF (dst COND src) JUMP off */
813         case BPF_JMP | BPF_JEQ | BPF_X:
814         case BPF_JMP | BPF_JGT | BPF_X:
815         case BPF_JMP | BPF_JLT | BPF_X:
816         case BPF_JMP | BPF_JGE | BPF_X:
817         case BPF_JMP | BPF_JLE | BPF_X:
818         case BPF_JMP | BPF_JNE | BPF_X:
819         case BPF_JMP | BPF_JSGT | BPF_X:
820         case BPF_JMP | BPF_JSLT | BPF_X:
821         case BPF_JMP | BPF_JSGE | BPF_X:
822         case BPF_JMP | BPF_JSLE | BPF_X:
823         case BPF_JMP32 | BPF_JEQ | BPF_X:
824         case BPF_JMP32 | BPF_JGT | BPF_X:
825         case BPF_JMP32 | BPF_JLT | BPF_X:
826         case BPF_JMP32 | BPF_JGE | BPF_X:
827         case BPF_JMP32 | BPF_JLE | BPF_X:
828         case BPF_JMP32 | BPF_JNE | BPF_X:
829         case BPF_JMP32 | BPF_JSGT | BPF_X:
830         case BPF_JMP32 | BPF_JSLT | BPF_X:
831         case BPF_JMP32 | BPF_JSGE | BPF_X:
832         case BPF_JMP32 | BPF_JSLE | BPF_X:
833                 emit(A64_CMP(is64, dst, src), ctx);
834 emit_cond_jmp:
835                 jmp_offset = bpf2a64_offset(i, off, ctx);
836                 check_imm19(jmp_offset);
837                 switch (BPF_OP(code)) {
838                 case BPF_JEQ:
839                         jmp_cond = A64_COND_EQ;
840                         break;
841                 case BPF_JGT:
842                         jmp_cond = A64_COND_HI;
843                         break;
844                 case BPF_JLT:
845                         jmp_cond = A64_COND_CC;
846                         break;
847                 case BPF_JGE:
848                         jmp_cond = A64_COND_CS;
849                         break;
850                 case BPF_JLE:
851                         jmp_cond = A64_COND_LS;
852                         break;
853                 case BPF_JSET:
854                 case BPF_JNE:
855                         jmp_cond = A64_COND_NE;
856                         break;
857                 case BPF_JSGT:
858                         jmp_cond = A64_COND_GT;
859                         break;
860                 case BPF_JSLT:
861                         jmp_cond = A64_COND_LT;
862                         break;
863                 case BPF_JSGE:
864                         jmp_cond = A64_COND_GE;
865                         break;
866                 case BPF_JSLE:
867                         jmp_cond = A64_COND_LE;
868                         break;
869                 default:
870                         return -EFAULT;
871                 }
872                 emit(A64_B_(jmp_cond, jmp_offset), ctx);
873                 break;
874         case BPF_JMP | BPF_JSET | BPF_X:
875         case BPF_JMP32 | BPF_JSET | BPF_X:
876                 emit(A64_TST(is64, dst, src), ctx);
877                 goto emit_cond_jmp;
878         /* IF (dst COND imm) JUMP off */
879         case BPF_JMP | BPF_JEQ | BPF_K:
880         case BPF_JMP | BPF_JGT | BPF_K:
881         case BPF_JMP | BPF_JLT | BPF_K:
882         case BPF_JMP | BPF_JGE | BPF_K:
883         case BPF_JMP | BPF_JLE | BPF_K:
884         case BPF_JMP | BPF_JNE | BPF_K:
885         case BPF_JMP | BPF_JSGT | BPF_K:
886         case BPF_JMP | BPF_JSLT | BPF_K:
887         case BPF_JMP | BPF_JSGE | BPF_K:
888         case BPF_JMP | BPF_JSLE | BPF_K:
889         case BPF_JMP32 | BPF_JEQ | BPF_K:
890         case BPF_JMP32 | BPF_JGT | BPF_K:
891         case BPF_JMP32 | BPF_JLT | BPF_K:
892         case BPF_JMP32 | BPF_JGE | BPF_K:
893         case BPF_JMP32 | BPF_JLE | BPF_K:
894         case BPF_JMP32 | BPF_JNE | BPF_K:
895         case BPF_JMP32 | BPF_JSGT | BPF_K:
896         case BPF_JMP32 | BPF_JSLT | BPF_K:
897         case BPF_JMP32 | BPF_JSGE | BPF_K:
898         case BPF_JMP32 | BPF_JSLE | BPF_K:
899                 if (is_addsub_imm(imm)) {
900                         emit(A64_CMP_I(is64, dst, imm), ctx);
901                 } else if (is_addsub_imm(-imm)) {
902                         emit(A64_CMN_I(is64, dst, -imm), ctx);
903                 } else {
904                         emit_a64_mov_i(is64, tmp, imm, ctx);
905                         emit(A64_CMP(is64, dst, tmp), ctx);
906                 }
907                 goto emit_cond_jmp;
908         case BPF_JMP | BPF_JSET | BPF_K:
909         case BPF_JMP32 | BPF_JSET | BPF_K:
910                 a64_insn = A64_TST_I(is64, dst, imm);
911                 if (a64_insn != AARCH64_BREAK_FAULT) {
912                         emit(a64_insn, ctx);
913                 } else {
914                         emit_a64_mov_i(is64, tmp, imm, ctx);
915                         emit(A64_TST(is64, dst, tmp), ctx);
916                 }
917                 goto emit_cond_jmp;
918         /* function call */
919         case BPF_JMP | BPF_CALL:
920         {
921                 const u8 r0 = bpf2a64[BPF_REG_0];
922                 bool func_addr_fixed;
923                 u64 func_addr;
924
925                 ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
926                                             &func_addr, &func_addr_fixed);
927                 if (ret < 0)
928                         return ret;
929                 emit_addr_mov_i64(tmp, func_addr, ctx);
930                 emit(A64_BLR(tmp), ctx);
931                 emit(A64_MOV(1, r0, A64_R(0)), ctx);
932                 break;
933         }
934         /* tail call */
935         case BPF_JMP | BPF_TAIL_CALL:
936                 if (emit_bpf_tail_call(ctx))
937                         return -EFAULT;
938                 break;
939         /* function return */
940         case BPF_JMP | BPF_EXIT:
941                 /* Optimization: when last instruction is EXIT,
942                    simply fallthrough to epilogue. */
943                 if (i == ctx->prog->len - 1)
944                         break;
945                 jmp_offset = epilogue_offset(ctx);
946                 check_imm26(jmp_offset);
947                 emit(A64_B(jmp_offset), ctx);
948                 break;
949
950         /* dst = imm64 */
951         case BPF_LD | BPF_IMM | BPF_DW:
952         {
953                 const struct bpf_insn insn1 = insn[1];
954                 u64 imm64;
955
956                 imm64 = (u64)insn1.imm << 32 | (u32)imm;
957                 if (bpf_pseudo_func(insn))
958                         emit_addr_mov_i64(dst, imm64, ctx);
959                 else
960                         emit_a64_mov_i64(dst, imm64, ctx);
961
962                 return 1;
963         }
964
965         /* LDX: dst = *(size *)(src + off) */
966         case BPF_LDX | BPF_MEM | BPF_W:
967         case BPF_LDX | BPF_MEM | BPF_H:
968         case BPF_LDX | BPF_MEM | BPF_B:
969         case BPF_LDX | BPF_MEM | BPF_DW:
970         case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
971         case BPF_LDX | BPF_PROBE_MEM | BPF_W:
972         case BPF_LDX | BPF_PROBE_MEM | BPF_H:
973         case BPF_LDX | BPF_PROBE_MEM | BPF_B:
974                 emit_a64_mov_i(1, tmp, off, ctx);
975                 switch (BPF_SIZE(code)) {
976                 case BPF_W:
977                         emit(A64_LDR32(dst, src, tmp), ctx);
978                         break;
979                 case BPF_H:
980                         emit(A64_LDRH(dst, src, tmp), ctx);
981                         break;
982                 case BPF_B:
983                         emit(A64_LDRB(dst, src, tmp), ctx);
984                         break;
985                 case BPF_DW:
986                         emit(A64_LDR64(dst, src, tmp), ctx);
987                         break;
988                 }
989
990                 ret = add_exception_handler(insn, ctx, dst);
991                 if (ret)
992                         return ret;
993                 break;
994
995         /* speculation barrier */
996         case BPF_ST | BPF_NOSPEC:
997                 /*
998                  * Nothing required here.
999                  *
1000                  * In case of arm64, we rely on the firmware mitigation of
1001                  * Speculative Store Bypass as controlled via the ssbd kernel
1002                  * parameter. Whenever the mitigation is enabled, it works
1003                  * for all of the kernel code with no need to provide any
1004                  * additional instructions.
1005                  */
1006                 break;
1007
1008         /* ST: *(size *)(dst + off) = imm */
1009         case BPF_ST | BPF_MEM | BPF_W:
1010         case BPF_ST | BPF_MEM | BPF_H:
1011         case BPF_ST | BPF_MEM | BPF_B:
1012         case BPF_ST | BPF_MEM | BPF_DW:
1013                 /* Load imm to a register then store it */
1014                 emit_a64_mov_i(1, tmp2, off, ctx);
1015                 emit_a64_mov_i(1, tmp, imm, ctx);
1016                 switch (BPF_SIZE(code)) {
1017                 case BPF_W:
1018                         emit(A64_STR32(tmp, dst, tmp2), ctx);
1019                         break;
1020                 case BPF_H:
1021                         emit(A64_STRH(tmp, dst, tmp2), ctx);
1022                         break;
1023                 case BPF_B:
1024                         emit(A64_STRB(tmp, dst, tmp2), ctx);
1025                         break;
1026                 case BPF_DW:
1027                         emit(A64_STR64(tmp, dst, tmp2), ctx);
1028                         break;
1029                 }
1030                 break;
1031
1032         /* STX: *(size *)(dst + off) = src */
1033         case BPF_STX | BPF_MEM | BPF_W:
1034         case BPF_STX | BPF_MEM | BPF_H:
1035         case BPF_STX | BPF_MEM | BPF_B:
1036         case BPF_STX | BPF_MEM | BPF_DW:
1037                 emit_a64_mov_i(1, tmp, off, ctx);
1038                 switch (BPF_SIZE(code)) {
1039                 case BPF_W:
1040                         emit(A64_STR32(src, dst, tmp), ctx);
1041                         break;
1042                 case BPF_H:
1043                         emit(A64_STRH(src, dst, tmp), ctx);
1044                         break;
1045                 case BPF_B:
1046                         emit(A64_STRB(src, dst, tmp), ctx);
1047                         break;
1048                 case BPF_DW:
1049                         emit(A64_STR64(src, dst, tmp), ctx);
1050                         break;
1051                 }
1052                 break;
1053
1054         case BPF_STX | BPF_ATOMIC | BPF_W:
1055         case BPF_STX | BPF_ATOMIC | BPF_DW:
1056                 if (cpus_have_cap(ARM64_HAS_LSE_ATOMICS))
1057                         ret = emit_lse_atomic(insn, ctx);
1058                 else
1059                         ret = emit_ll_sc_atomic(insn, ctx);
1060                 if (ret)
1061                         return ret;
1062                 break;
1063
1064         default:
1065                 pr_err_once("unknown opcode %02x\n", code);
1066                 return -EINVAL;
1067         }
1068
1069         return 0;
1070 }
1071
1072 static int build_body(struct jit_ctx *ctx, bool extra_pass)
1073 {
1074         const struct bpf_prog *prog = ctx->prog;
1075         int i;
1076
1077         /*
1078          * - offset[0] offset of the end of prologue,
1079          *   start of the 1st instruction.
1080          * - offset[1] - offset of the end of 1st instruction,
1081          *   start of the 2nd instruction
1082          * [....]
1083          * - offset[3] - offset of the end of 3rd instruction,
1084          *   start of 4th instruction
1085          */
1086         for (i = 0; i < prog->len; i++) {
1087                 const struct bpf_insn *insn = &prog->insnsi[i];
1088                 int ret;
1089
1090                 if (ctx->image == NULL)
1091                         ctx->offset[i] = ctx->idx;
1092                 ret = build_insn(insn, ctx, extra_pass);
1093                 if (ret > 0) {
1094                         i++;
1095                         if (ctx->image == NULL)
1096                                 ctx->offset[i] = ctx->idx;
1097                         continue;
1098                 }
1099                 if (ret)
1100                         return ret;
1101         }
1102         /*
1103          * offset is allocated with prog->len + 1 so fill in
1104          * the last element with the offset after the last
1105          * instruction (end of program)
1106          */
1107         if (ctx->image == NULL)
1108                 ctx->offset[i] = ctx->idx;
1109
1110         return 0;
1111 }
1112
1113 static int validate_code(struct jit_ctx *ctx)
1114 {
1115         int i;
1116
1117         for (i = 0; i < ctx->idx; i++) {
1118                 u32 a64_insn = le32_to_cpu(ctx->image[i]);
1119
1120                 if (a64_insn == AARCH64_BREAK_FAULT)
1121                         return -1;
1122         }
1123
1124         if (WARN_ON_ONCE(ctx->exentry_idx != ctx->prog->aux->num_exentries))
1125                 return -1;
1126
1127         return 0;
1128 }
1129
1130 static inline void bpf_flush_icache(void *start, void *end)
1131 {
1132         flush_icache_range((unsigned long)start, (unsigned long)end);
1133 }
1134
1135 struct arm64_jit_data {
1136         struct bpf_binary_header *header;
1137         u8 *image;
1138         struct jit_ctx ctx;
1139 };
1140
1141 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
1142 {
1143         int image_size, prog_size, extable_size;
1144         struct bpf_prog *tmp, *orig_prog = prog;
1145         struct bpf_binary_header *header;
1146         struct arm64_jit_data *jit_data;
1147         bool was_classic = bpf_prog_was_classic(prog);
1148         bool tmp_blinded = false;
1149         bool extra_pass = false;
1150         struct jit_ctx ctx;
1151         u8 *image_ptr;
1152
1153         if (!prog->jit_requested)
1154                 return orig_prog;
1155
1156         tmp = bpf_jit_blind_constants(prog);
1157         /* If blinding was requested and we failed during blinding,
1158          * we must fall back to the interpreter.
1159          */
1160         if (IS_ERR(tmp))
1161                 return orig_prog;
1162         if (tmp != prog) {
1163                 tmp_blinded = true;
1164                 prog = tmp;
1165         }
1166
1167         jit_data = prog->aux->jit_data;
1168         if (!jit_data) {
1169                 jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
1170                 if (!jit_data) {
1171                         prog = orig_prog;
1172                         goto out;
1173                 }
1174                 prog->aux->jit_data = jit_data;
1175         }
1176         if (jit_data->ctx.offset) {
1177                 ctx = jit_data->ctx;
1178                 image_ptr = jit_data->image;
1179                 header = jit_data->header;
1180                 extra_pass = true;
1181                 prog_size = sizeof(u32) * ctx.idx;
1182                 goto skip_init_ctx;
1183         }
1184         memset(&ctx, 0, sizeof(ctx));
1185         ctx.prog = prog;
1186
1187         ctx.offset = kcalloc(prog->len + 1, sizeof(int), GFP_KERNEL);
1188         if (ctx.offset == NULL) {
1189                 prog = orig_prog;
1190                 goto out_off;
1191         }
1192
1193         /*
1194          * 1. Initial fake pass to compute ctx->idx and ctx->offset.
1195          *
1196          * BPF line info needs ctx->offset[i] to be the offset of
1197          * instruction[i] in jited image, so build prologue first.
1198          */
1199         if (build_prologue(&ctx, was_classic)) {
1200                 prog = orig_prog;
1201                 goto out_off;
1202         }
1203
1204         if (build_body(&ctx, extra_pass)) {
1205                 prog = orig_prog;
1206                 goto out_off;
1207         }
1208
1209         ctx.epilogue_offset = ctx.idx;
1210         build_epilogue(&ctx);
1211
1212         extable_size = prog->aux->num_exentries *
1213                 sizeof(struct exception_table_entry);
1214
1215         /* Now we know the actual image size. */
1216         prog_size = sizeof(u32) * ctx.idx;
1217         image_size = prog_size + extable_size;
1218         header = bpf_jit_binary_alloc(image_size, &image_ptr,
1219                                       sizeof(u32), jit_fill_hole);
1220         if (header == NULL) {
1221                 prog = orig_prog;
1222                 goto out_off;
1223         }
1224
1225         /* 2. Now, the actual pass. */
1226
1227         ctx.image = (__le32 *)image_ptr;
1228         if (extable_size)
1229                 prog->aux->extable = (void *)image_ptr + prog_size;
1230 skip_init_ctx:
1231         ctx.idx = 0;
1232         ctx.exentry_idx = 0;
1233
1234         build_prologue(&ctx, was_classic);
1235
1236         if (build_body(&ctx, extra_pass)) {
1237                 bpf_jit_binary_free(header);
1238                 prog = orig_prog;
1239                 goto out_off;
1240         }
1241
1242         build_epilogue(&ctx);
1243
1244         /* 3. Extra pass to validate JITed code. */
1245         if (validate_code(&ctx)) {
1246                 bpf_jit_binary_free(header);
1247                 prog = orig_prog;
1248                 goto out_off;
1249         }
1250
1251         /* And we're done. */
1252         if (bpf_jit_enable > 1)
1253                 bpf_jit_dump(prog->len, prog_size, 2, ctx.image);
1254
1255         bpf_flush_icache(header, ctx.image + ctx.idx);
1256
1257         if (!prog->is_func || extra_pass) {
1258                 if (extra_pass && ctx.idx != jit_data->ctx.idx) {
1259                         pr_err_once("multi-func JIT bug %d != %d\n",
1260                                     ctx.idx, jit_data->ctx.idx);
1261                         bpf_jit_binary_free(header);
1262                         prog->bpf_func = NULL;
1263                         prog->jited = 0;
1264                         goto out_off;
1265                 }
1266                 bpf_jit_binary_lock_ro(header);
1267         } else {
1268                 jit_data->ctx = ctx;
1269                 jit_data->image = image_ptr;
1270                 jit_data->header = header;
1271         }
1272         prog->bpf_func = (void *)ctx.image;
1273         prog->jited = 1;
1274         prog->jited_len = prog_size;
1275
1276         if (!prog->is_func || extra_pass) {
1277                 int i;
1278
1279                 /* offset[prog->len] is the size of program */
1280                 for (i = 0; i <= prog->len; i++)
1281                         ctx.offset[i] *= AARCH64_INSN_SIZE;
1282                 bpf_prog_fill_jited_linfo(prog, ctx.offset + 1);
1283 out_off:
1284                 kfree(ctx.offset);
1285                 kfree(jit_data);
1286                 prog->aux->jit_data = NULL;
1287         }
1288 out:
1289         if (tmp_blinded)
1290                 bpf_jit_prog_release_other(prog, prog == orig_prog ?
1291                                            tmp : orig_prog);
1292         return prog;
1293 }
1294
1295 bool bpf_jit_supports_kfunc_call(void)
1296 {
1297         return true;
1298 }
1299
1300 u64 bpf_jit_alloc_exec_limit(void)
1301 {
1302         return VMALLOC_END - VMALLOC_START;
1303 }
1304
1305 void *bpf_jit_alloc_exec(unsigned long size)
1306 {
1307         /* Memory is intended to be executable, reset the pointer tag. */
1308         return kasan_reset_tag(vmalloc(size));
1309 }
1310
1311 void bpf_jit_free_exec(void *addr)
1312 {
1313         return vfree(addr);
1314 }