Merge tag 'ecryptfs-4.17-rc2-fixes' of git://git.kernel.org/pub/scm/linux/kernel...
[linux-2.6-microblaze.git] / arch / x86 / net / bpf_jit_comp.c
1 /* bpf_jit_comp.c : BPF JIT compiler
2  *
3  * Copyright (C) 2011-2013 Eric Dumazet (eric.dumazet@gmail.com)
4  * Internal BPF Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
5  *
6  * This program is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU General Public License
8  * as published by the Free Software Foundation; version 2
9  * of the License.
10  */
11 #include <linux/netdevice.h>
12 #include <linux/filter.h>
13 #include <linux/if_vlan.h>
14 #include <linux/bpf.h>
15
16 #include <asm/set_memory.h>
17 #include <asm/nospec-branch.h>
18
19 /*
20  * assembly code in arch/x86/net/bpf_jit.S
21  */
22 extern u8 sk_load_word[], sk_load_half[], sk_load_byte[];
23 extern u8 sk_load_word_positive_offset[], sk_load_half_positive_offset[];
24 extern u8 sk_load_byte_positive_offset[];
25 extern u8 sk_load_word_negative_offset[], sk_load_half_negative_offset[];
26 extern u8 sk_load_byte_negative_offset[];
27
28 static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
29 {
30         if (len == 1)
31                 *ptr = bytes;
32         else if (len == 2)
33                 *(u16 *)ptr = bytes;
34         else {
35                 *(u32 *)ptr = bytes;
36                 barrier();
37         }
38         return ptr + len;
39 }
40
41 #define EMIT(bytes, len) \
42         do { prog = emit_code(prog, bytes, len); cnt += len; } while (0)
43
44 #define EMIT1(b1)               EMIT(b1, 1)
45 #define EMIT2(b1, b2)           EMIT((b1) + ((b2) << 8), 2)
46 #define EMIT3(b1, b2, b3)       EMIT((b1) + ((b2) << 8) + ((b3) << 16), 3)
47 #define EMIT4(b1, b2, b3, b4)   EMIT((b1) + ((b2) << 8) + ((b3) << 16) + ((b4) << 24), 4)
48 #define EMIT1_off32(b1, off) \
49         do {EMIT1(b1); EMIT(off, 4); } while (0)
50 #define EMIT2_off32(b1, b2, off) \
51         do {EMIT2(b1, b2); EMIT(off, 4); } while (0)
52 #define EMIT3_off32(b1, b2, b3, off) \
53         do {EMIT3(b1, b2, b3); EMIT(off, 4); } while (0)
54 #define EMIT4_off32(b1, b2, b3, b4, off) \
55         do {EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0)
56
57 static bool is_imm8(int value)
58 {
59         return value <= 127 && value >= -128;
60 }
61
62 static bool is_simm32(s64 value)
63 {
64         return value == (s64)(s32)value;
65 }
66
67 static bool is_uimm32(u64 value)
68 {
69         return value == (u64)(u32)value;
70 }
71
72 /* mov dst, src */
73 #define EMIT_mov(DST, SRC) \
74         do {if (DST != SRC) \
75                 EMIT3(add_2mod(0x48, DST, SRC), 0x89, add_2reg(0xC0, DST, SRC)); \
76         } while (0)
77
78 static int bpf_size_to_x86_bytes(int bpf_size)
79 {
80         if (bpf_size == BPF_W)
81                 return 4;
82         else if (bpf_size == BPF_H)
83                 return 2;
84         else if (bpf_size == BPF_B)
85                 return 1;
86         else if (bpf_size == BPF_DW)
87                 return 4; /* imm32 */
88         else
89                 return 0;
90 }
91
92 /* list of x86 cond jumps opcodes (. + s8)
93  * Add 0x10 (and an extra 0x0f) to generate far jumps (. + s32)
94  */
95 #define X86_JB  0x72
96 #define X86_JAE 0x73
97 #define X86_JE  0x74
98 #define X86_JNE 0x75
99 #define X86_JBE 0x76
100 #define X86_JA  0x77
101 #define X86_JL  0x7C
102 #define X86_JGE 0x7D
103 #define X86_JLE 0x7E
104 #define X86_JG  0x7F
105
106 #define CHOOSE_LOAD_FUNC(K, func) \
107         ((int)K < 0 ? ((int)K >= SKF_LL_OFF ? func##_negative_offset : func) : func##_positive_offset)
108
109 /* pick a register outside of BPF range for JIT internal work */
110 #define AUX_REG (MAX_BPF_JIT_REG + 1)
111
112 /* The following table maps BPF registers to x64 registers.
113  *
114  * x64 register r12 is unused, since if used as base address
115  * register in load/store instructions, it always needs an
116  * extra byte of encoding and is callee saved.
117  *
118  *  r9 caches skb->len - skb->data_len
119  * r10 caches skb->data, and used for blinding (if enabled)
120  */
121 static const int reg2hex[] = {
122         [BPF_REG_0] = 0,  /* rax */
123         [BPF_REG_1] = 7,  /* rdi */
124         [BPF_REG_2] = 6,  /* rsi */
125         [BPF_REG_3] = 2,  /* rdx */
126         [BPF_REG_4] = 1,  /* rcx */
127         [BPF_REG_5] = 0,  /* r8 */
128         [BPF_REG_6] = 3,  /* rbx callee saved */
129         [BPF_REG_7] = 5,  /* r13 callee saved */
130         [BPF_REG_8] = 6,  /* r14 callee saved */
131         [BPF_REG_9] = 7,  /* r15 callee saved */
132         [BPF_REG_FP] = 5, /* rbp readonly */
133         [BPF_REG_AX] = 2, /* r10 temp register */
134         [AUX_REG] = 3,    /* r11 temp register */
135 };
136
137 /* is_ereg() == true if BPF register 'reg' maps to x64 r8..r15
138  * which need extra byte of encoding.
139  * rax,rcx,...,rbp have simpler encoding
140  */
141 static bool is_ereg(u32 reg)
142 {
143         return (1 << reg) & (BIT(BPF_REG_5) |
144                              BIT(AUX_REG) |
145                              BIT(BPF_REG_7) |
146                              BIT(BPF_REG_8) |
147                              BIT(BPF_REG_9) |
148                              BIT(BPF_REG_AX));
149 }
150
151 static bool is_axreg(u32 reg)
152 {
153         return reg == BPF_REG_0;
154 }
155
156 /* add modifiers if 'reg' maps to x64 registers r8..r15 */
157 static u8 add_1mod(u8 byte, u32 reg)
158 {
159         if (is_ereg(reg))
160                 byte |= 1;
161         return byte;
162 }
163
164 static u8 add_2mod(u8 byte, u32 r1, u32 r2)
165 {
166         if (is_ereg(r1))
167                 byte |= 1;
168         if (is_ereg(r2))
169                 byte |= 4;
170         return byte;
171 }
172
173 /* encode 'dst_reg' register into x64 opcode 'byte' */
174 static u8 add_1reg(u8 byte, u32 dst_reg)
175 {
176         return byte + reg2hex[dst_reg];
177 }
178
179 /* encode 'dst_reg' and 'src_reg' registers into x64 opcode 'byte' */
180 static u8 add_2reg(u8 byte, u32 dst_reg, u32 src_reg)
181 {
182         return byte + reg2hex[dst_reg] + (reg2hex[src_reg] << 3);
183 }
184
185 static void jit_fill_hole(void *area, unsigned int size)
186 {
187         /* fill whole space with int3 instructions */
188         memset(area, 0xcc, size);
189 }
190
191 struct jit_context {
192         int cleanup_addr; /* epilogue code offset */
193         bool seen_ld_abs;
194         bool seen_ax_reg;
195 };
196
197 /* maximum number of bytes emitted while JITing one eBPF insn */
198 #define BPF_MAX_INSN_SIZE       128
199 #define BPF_INSN_SAFETY         64
200
201 #define AUX_STACK_SPACE \
202         (32 /* space for rbx, r13, r14, r15 */ + \
203          8 /* space for skb_copy_bits() buffer */)
204
205 #define PROLOGUE_SIZE 37
206
207 /* emit x64 prologue code for BPF program and check it's size.
208  * bpf_tail_call helper will skip it while jumping into another program
209  */
210 static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
211 {
212         u8 *prog = *pprog;
213         int cnt = 0;
214
215         EMIT1(0x55); /* push rbp */
216         EMIT3(0x48, 0x89, 0xE5); /* mov rbp,rsp */
217
218         /* sub rsp, rounded_stack_depth + AUX_STACK_SPACE */
219         EMIT3_off32(0x48, 0x81, 0xEC,
220                     round_up(stack_depth, 8) + AUX_STACK_SPACE);
221
222         /* sub rbp, AUX_STACK_SPACE */
223         EMIT4(0x48, 0x83, 0xED, AUX_STACK_SPACE);
224
225         /* all classic BPF filters use R6(rbx) save it */
226
227         /* mov qword ptr [rbp+0],rbx */
228         EMIT4(0x48, 0x89, 0x5D, 0);
229
230         /* bpf_convert_filter() maps classic BPF register X to R7 and uses R8
231          * as temporary, so all tcpdump filters need to spill/fill R7(r13) and
232          * R8(r14). R9(r15) spill could be made conditional, but there is only
233          * one 'bpf_error' return path out of helper functions inside bpf_jit.S
234          * The overhead of extra spill is negligible for any filter other
235          * than synthetic ones. Therefore not worth adding complexity.
236          */
237
238         /* mov qword ptr [rbp+8],r13 */
239         EMIT4(0x4C, 0x89, 0x6D, 8);
240         /* mov qword ptr [rbp+16],r14 */
241         EMIT4(0x4C, 0x89, 0x75, 16);
242         /* mov qword ptr [rbp+24],r15 */
243         EMIT4(0x4C, 0x89, 0x7D, 24);
244
245         if (!ebpf_from_cbpf) {
246                 /* Clear the tail call counter (tail_call_cnt): for eBPF tail
247                  * calls we need to reset the counter to 0. It's done in two
248                  * instructions, resetting rax register to 0, and moving it
249                  * to the counter location.
250                  */
251
252                 /* xor eax, eax */
253                 EMIT2(0x31, 0xc0);
254                 /* mov qword ptr [rbp+32], rax */
255                 EMIT4(0x48, 0x89, 0x45, 32);
256
257                 BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
258         }
259
260         *pprog = prog;
261 }
262
263 /* generate the following code:
264  * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ...
265  *   if (index >= array->map.max_entries)
266  *     goto out;
267  *   if (++tail_call_cnt > MAX_TAIL_CALL_CNT)
268  *     goto out;
269  *   prog = array->ptrs[index];
270  *   if (prog == NULL)
271  *     goto out;
272  *   goto *(prog->bpf_func + prologue_size);
273  * out:
274  */
275 static void emit_bpf_tail_call(u8 **pprog)
276 {
277         u8 *prog = *pprog;
278         int label1, label2, label3;
279         int cnt = 0;
280
281         /* rdi - pointer to ctx
282          * rsi - pointer to bpf_array
283          * rdx - index in bpf_array
284          */
285
286         /* if (index >= array->map.max_entries)
287          *   goto out;
288          */
289         EMIT2(0x89, 0xD2);                        /* mov edx, edx */
290         EMIT3(0x39, 0x56,                         /* cmp dword ptr [rsi + 16], edx */
291               offsetof(struct bpf_array, map.max_entries));
292 #define OFFSET1 (41 + RETPOLINE_RAX_BPF_JIT_SIZE) /* number of bytes to jump */
293         EMIT2(X86_JBE, OFFSET1);                  /* jbe out */
294         label1 = cnt;
295
296         /* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
297          *   goto out;
298          */
299         EMIT2_off32(0x8B, 0x85, 36);              /* mov eax, dword ptr [rbp + 36] */
300         EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
301 #define OFFSET2 (30 + RETPOLINE_RAX_BPF_JIT_SIZE)
302         EMIT2(X86_JA, OFFSET2);                   /* ja out */
303         label2 = cnt;
304         EMIT3(0x83, 0xC0, 0x01);                  /* add eax, 1 */
305         EMIT2_off32(0x89, 0x85, 36);              /* mov dword ptr [rbp + 36], eax */
306
307         /* prog = array->ptrs[index]; */
308         EMIT4_off32(0x48, 0x8B, 0x84, 0xD6,       /* mov rax, [rsi + rdx * 8 + offsetof(...)] */
309                     offsetof(struct bpf_array, ptrs));
310
311         /* if (prog == NULL)
312          *   goto out;
313          */
314         EMIT3(0x48, 0x85, 0xC0);                  /* test rax,rax */
315 #define OFFSET3 (8 + RETPOLINE_RAX_BPF_JIT_SIZE)
316         EMIT2(X86_JE, OFFSET3);                   /* je out */
317         label3 = cnt;
318
319         /* goto *(prog->bpf_func + prologue_size); */
320         EMIT4(0x48, 0x8B, 0x40,                   /* mov rax, qword ptr [rax + 32] */
321               offsetof(struct bpf_prog, bpf_func));
322         EMIT4(0x48, 0x83, 0xC0, PROLOGUE_SIZE);   /* add rax, prologue_size */
323
324         /* now we're ready to jump into next BPF program
325          * rdi == ctx (1st arg)
326          * rax == prog->bpf_func + prologue_size
327          */
328         RETPOLINE_RAX_BPF_JIT();
329
330         /* out: */
331         BUILD_BUG_ON(cnt - label1 != OFFSET1);
332         BUILD_BUG_ON(cnt - label2 != OFFSET2);
333         BUILD_BUG_ON(cnt - label3 != OFFSET3);
334         *pprog = prog;
335 }
336
337
338 static void emit_load_skb_data_hlen(u8 **pprog)
339 {
340         u8 *prog = *pprog;
341         int cnt = 0;
342
343         /* r9d = skb->len - skb->data_len (headlen)
344          * r10 = skb->data
345          */
346         /* mov %r9d, off32(%rdi) */
347         EMIT3_off32(0x44, 0x8b, 0x8f, offsetof(struct sk_buff, len));
348
349         /* sub %r9d, off32(%rdi) */
350         EMIT3_off32(0x44, 0x2b, 0x8f, offsetof(struct sk_buff, data_len));
351
352         /* mov %r10, off32(%rdi) */
353         EMIT3_off32(0x4c, 0x8b, 0x97, offsetof(struct sk_buff, data));
354         *pprog = prog;
355 }
356
357 static void emit_mov_imm32(u8 **pprog, bool sign_propagate,
358                            u32 dst_reg, const u32 imm32)
359 {
360         u8 *prog = *pprog;
361         u8 b1, b2, b3;
362         int cnt = 0;
363
364         /* optimization: if imm32 is positive, use 'mov %eax, imm32'
365          * (which zero-extends imm32) to save 2 bytes.
366          */
367         if (sign_propagate && (s32)imm32 < 0) {
368                 /* 'mov %rax, imm32' sign extends imm32 */
369                 b1 = add_1mod(0x48, dst_reg);
370                 b2 = 0xC7;
371                 b3 = 0xC0;
372                 EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
373                 goto done;
374         }
375
376         /* optimization: if imm32 is zero, use 'xor %eax, %eax'
377          * to save 3 bytes.
378          */
379         if (imm32 == 0) {
380                 if (is_ereg(dst_reg))
381                         EMIT1(add_2mod(0x40, dst_reg, dst_reg));
382                 b2 = 0x31; /* xor */
383                 b3 = 0xC0;
384                 EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
385                 goto done;
386         }
387
388         /* mov %eax, imm32 */
389         if (is_ereg(dst_reg))
390                 EMIT1(add_1mod(0x40, dst_reg));
391         EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
392 done:
393         *pprog = prog;
394 }
395
396 static void emit_mov_imm64(u8 **pprog, u32 dst_reg,
397                            const u32 imm32_hi, const u32 imm32_lo)
398 {
399         u8 *prog = *pprog;
400         int cnt = 0;
401
402         if (is_uimm32(((u64)imm32_hi << 32) | (u32)imm32_lo)) {
403                 /* For emitting plain u32, where sign bit must not be
404                  * propagated LLVM tends to load imm64 over mov32
405                  * directly, so save couple of bytes by just doing
406                  * 'mov %eax, imm32' instead.
407                  */
408                 emit_mov_imm32(&prog, false, dst_reg, imm32_lo);
409         } else {
410                 /* movabsq %rax, imm64 */
411                 EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
412                 EMIT(imm32_lo, 4);
413                 EMIT(imm32_hi, 4);
414         }
415
416         *pprog = prog;
417 }
418
419 static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg)
420 {
421         u8 *prog = *pprog;
422         int cnt = 0;
423
424         if (is64) {
425                 /* mov dst, src */
426                 EMIT_mov(dst_reg, src_reg);
427         } else {
428                 /* mov32 dst, src */
429                 if (is_ereg(dst_reg) || is_ereg(src_reg))
430                         EMIT1(add_2mod(0x40, dst_reg, src_reg));
431                 EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
432         }
433
434         *pprog = prog;
435 }
436
437 static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
438                   int oldproglen, struct jit_context *ctx)
439 {
440         struct bpf_insn *insn = bpf_prog->insnsi;
441         int insn_cnt = bpf_prog->len;
442         bool seen_ld_abs = ctx->seen_ld_abs | (oldproglen == 0);
443         bool seen_ax_reg = ctx->seen_ax_reg | (oldproglen == 0);
444         bool seen_exit = false;
445         u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
446         int i, cnt = 0;
447         int proglen = 0;
448         u8 *prog = temp;
449
450         emit_prologue(&prog, bpf_prog->aux->stack_depth,
451                       bpf_prog_was_classic(bpf_prog));
452
453         if (seen_ld_abs)
454                 emit_load_skb_data_hlen(&prog);
455
456         for (i = 0; i < insn_cnt; i++, insn++) {
457                 const s32 imm32 = insn->imm;
458                 u32 dst_reg = insn->dst_reg;
459                 u32 src_reg = insn->src_reg;
460                 u8 b2 = 0, b3 = 0;
461                 s64 jmp_offset;
462                 u8 jmp_cond;
463                 bool reload_skb_data;
464                 int ilen;
465                 u8 *func;
466
467                 if (dst_reg == BPF_REG_AX || src_reg == BPF_REG_AX)
468                         ctx->seen_ax_reg = seen_ax_reg = true;
469
470                 switch (insn->code) {
471                         /* ALU */
472                 case BPF_ALU | BPF_ADD | BPF_X:
473                 case BPF_ALU | BPF_SUB | BPF_X:
474                 case BPF_ALU | BPF_AND | BPF_X:
475                 case BPF_ALU | BPF_OR | BPF_X:
476                 case BPF_ALU | BPF_XOR | BPF_X:
477                 case BPF_ALU64 | BPF_ADD | BPF_X:
478                 case BPF_ALU64 | BPF_SUB | BPF_X:
479                 case BPF_ALU64 | BPF_AND | BPF_X:
480                 case BPF_ALU64 | BPF_OR | BPF_X:
481                 case BPF_ALU64 | BPF_XOR | BPF_X:
482                         switch (BPF_OP(insn->code)) {
483                         case BPF_ADD: b2 = 0x01; break;
484                         case BPF_SUB: b2 = 0x29; break;
485                         case BPF_AND: b2 = 0x21; break;
486                         case BPF_OR: b2 = 0x09; break;
487                         case BPF_XOR: b2 = 0x31; break;
488                         }
489                         if (BPF_CLASS(insn->code) == BPF_ALU64)
490                                 EMIT1(add_2mod(0x48, dst_reg, src_reg));
491                         else if (is_ereg(dst_reg) || is_ereg(src_reg))
492                                 EMIT1(add_2mod(0x40, dst_reg, src_reg));
493                         EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg));
494                         break;
495
496                 case BPF_ALU64 | BPF_MOV | BPF_X:
497                 case BPF_ALU | BPF_MOV | BPF_X:
498                         emit_mov_reg(&prog,
499                                      BPF_CLASS(insn->code) == BPF_ALU64,
500                                      dst_reg, src_reg);
501                         break;
502
503                         /* neg dst */
504                 case BPF_ALU | BPF_NEG:
505                 case BPF_ALU64 | BPF_NEG:
506                         if (BPF_CLASS(insn->code) == BPF_ALU64)
507                                 EMIT1(add_1mod(0x48, dst_reg));
508                         else if (is_ereg(dst_reg))
509                                 EMIT1(add_1mod(0x40, dst_reg));
510                         EMIT2(0xF7, add_1reg(0xD8, dst_reg));
511                         break;
512
513                 case BPF_ALU | BPF_ADD | BPF_K:
514                 case BPF_ALU | BPF_SUB | BPF_K:
515                 case BPF_ALU | BPF_AND | BPF_K:
516                 case BPF_ALU | BPF_OR | BPF_K:
517                 case BPF_ALU | BPF_XOR | BPF_K:
518                 case BPF_ALU64 | BPF_ADD | BPF_K:
519                 case BPF_ALU64 | BPF_SUB | BPF_K:
520                 case BPF_ALU64 | BPF_AND | BPF_K:
521                 case BPF_ALU64 | BPF_OR | BPF_K:
522                 case BPF_ALU64 | BPF_XOR | BPF_K:
523                         if (BPF_CLASS(insn->code) == BPF_ALU64)
524                                 EMIT1(add_1mod(0x48, dst_reg));
525                         else if (is_ereg(dst_reg))
526                                 EMIT1(add_1mod(0x40, dst_reg));
527
528                         /* b3 holds 'normal' opcode, b2 short form only valid
529                          * in case dst is eax/rax.
530                          */
531                         switch (BPF_OP(insn->code)) {
532                         case BPF_ADD:
533                                 b3 = 0xC0;
534                                 b2 = 0x05;
535                                 break;
536                         case BPF_SUB:
537                                 b3 = 0xE8;
538                                 b2 = 0x2D;
539                                 break;
540                         case BPF_AND:
541                                 b3 = 0xE0;
542                                 b2 = 0x25;
543                                 break;
544                         case BPF_OR:
545                                 b3 = 0xC8;
546                                 b2 = 0x0D;
547                                 break;
548                         case BPF_XOR:
549                                 b3 = 0xF0;
550                                 b2 = 0x35;
551                                 break;
552                         }
553
554                         if (is_imm8(imm32))
555                                 EMIT3(0x83, add_1reg(b3, dst_reg), imm32);
556                         else if (is_axreg(dst_reg))
557                                 EMIT1_off32(b2, imm32);
558                         else
559                                 EMIT2_off32(0x81, add_1reg(b3, dst_reg), imm32);
560                         break;
561
562                 case BPF_ALU64 | BPF_MOV | BPF_K:
563                 case BPF_ALU | BPF_MOV | BPF_K:
564                         emit_mov_imm32(&prog, BPF_CLASS(insn->code) == BPF_ALU64,
565                                        dst_reg, imm32);
566                         break;
567
568                 case BPF_LD | BPF_IMM | BPF_DW:
569                         emit_mov_imm64(&prog, dst_reg, insn[1].imm, insn[0].imm);
570                         insn++;
571                         i++;
572                         break;
573
574                         /* dst %= src, dst /= src, dst %= imm32, dst /= imm32 */
575                 case BPF_ALU | BPF_MOD | BPF_X:
576                 case BPF_ALU | BPF_DIV | BPF_X:
577                 case BPF_ALU | BPF_MOD | BPF_K:
578                 case BPF_ALU | BPF_DIV | BPF_K:
579                 case BPF_ALU64 | BPF_MOD | BPF_X:
580                 case BPF_ALU64 | BPF_DIV | BPF_X:
581                 case BPF_ALU64 | BPF_MOD | BPF_K:
582                 case BPF_ALU64 | BPF_DIV | BPF_K:
583                         EMIT1(0x50); /* push rax */
584                         EMIT1(0x52); /* push rdx */
585
586                         if (BPF_SRC(insn->code) == BPF_X)
587                                 /* mov r11, src_reg */
588                                 EMIT_mov(AUX_REG, src_reg);
589                         else
590                                 /* mov r11, imm32 */
591                                 EMIT3_off32(0x49, 0xC7, 0xC3, imm32);
592
593                         /* mov rax, dst_reg */
594                         EMIT_mov(BPF_REG_0, dst_reg);
595
596                         /* xor edx, edx
597                          * equivalent to 'xor rdx, rdx', but one byte less
598                          */
599                         EMIT2(0x31, 0xd2);
600
601                         if (BPF_CLASS(insn->code) == BPF_ALU64)
602                                 /* div r11 */
603                                 EMIT3(0x49, 0xF7, 0xF3);
604                         else
605                                 /* div r11d */
606                                 EMIT3(0x41, 0xF7, 0xF3);
607
608                         if (BPF_OP(insn->code) == BPF_MOD)
609                                 /* mov r11, rdx */
610                                 EMIT3(0x49, 0x89, 0xD3);
611                         else
612                                 /* mov r11, rax */
613                                 EMIT3(0x49, 0x89, 0xC3);
614
615                         EMIT1(0x5A); /* pop rdx */
616                         EMIT1(0x58); /* pop rax */
617
618                         /* mov dst_reg, r11 */
619                         EMIT_mov(dst_reg, AUX_REG);
620                         break;
621
622                 case BPF_ALU | BPF_MUL | BPF_K:
623                 case BPF_ALU | BPF_MUL | BPF_X:
624                 case BPF_ALU64 | BPF_MUL | BPF_K:
625                 case BPF_ALU64 | BPF_MUL | BPF_X:
626                 {
627                         bool is64 = BPF_CLASS(insn->code) == BPF_ALU64;
628
629                         if (dst_reg != BPF_REG_0)
630                                 EMIT1(0x50); /* push rax */
631                         if (dst_reg != BPF_REG_3)
632                                 EMIT1(0x52); /* push rdx */
633
634                         /* mov r11, dst_reg */
635                         EMIT_mov(AUX_REG, dst_reg);
636
637                         if (BPF_SRC(insn->code) == BPF_X)
638                                 emit_mov_reg(&prog, is64, BPF_REG_0, src_reg);
639                         else
640                                 emit_mov_imm32(&prog, is64, BPF_REG_0, imm32);
641
642                         if (is64)
643                                 EMIT1(add_1mod(0x48, AUX_REG));
644                         else if (is_ereg(AUX_REG))
645                                 EMIT1(add_1mod(0x40, AUX_REG));
646                         /* mul(q) r11 */
647                         EMIT2(0xF7, add_1reg(0xE0, AUX_REG));
648
649                         if (dst_reg != BPF_REG_3)
650                                 EMIT1(0x5A); /* pop rdx */
651                         if (dst_reg != BPF_REG_0) {
652                                 /* mov dst_reg, rax */
653                                 EMIT_mov(dst_reg, BPF_REG_0);
654                                 EMIT1(0x58); /* pop rax */
655                         }
656                         break;
657                 }
658                         /* shifts */
659                 case BPF_ALU | BPF_LSH | BPF_K:
660                 case BPF_ALU | BPF_RSH | BPF_K:
661                 case BPF_ALU | BPF_ARSH | BPF_K:
662                 case BPF_ALU64 | BPF_LSH | BPF_K:
663                 case BPF_ALU64 | BPF_RSH | BPF_K:
664                 case BPF_ALU64 | BPF_ARSH | BPF_K:
665                         if (BPF_CLASS(insn->code) == BPF_ALU64)
666                                 EMIT1(add_1mod(0x48, dst_reg));
667                         else if (is_ereg(dst_reg))
668                                 EMIT1(add_1mod(0x40, dst_reg));
669
670                         switch (BPF_OP(insn->code)) {
671                         case BPF_LSH: b3 = 0xE0; break;
672                         case BPF_RSH: b3 = 0xE8; break;
673                         case BPF_ARSH: b3 = 0xF8; break;
674                         }
675
676                         if (imm32 == 1)
677                                 EMIT2(0xD1, add_1reg(b3, dst_reg));
678                         else
679                                 EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
680                         break;
681
682                 case BPF_ALU | BPF_LSH | BPF_X:
683                 case BPF_ALU | BPF_RSH | BPF_X:
684                 case BPF_ALU | BPF_ARSH | BPF_X:
685                 case BPF_ALU64 | BPF_LSH | BPF_X:
686                 case BPF_ALU64 | BPF_RSH | BPF_X:
687                 case BPF_ALU64 | BPF_ARSH | BPF_X:
688
689                         /* check for bad case when dst_reg == rcx */
690                         if (dst_reg == BPF_REG_4) {
691                                 /* mov r11, dst_reg */
692                                 EMIT_mov(AUX_REG, dst_reg);
693                                 dst_reg = AUX_REG;
694                         }
695
696                         if (src_reg != BPF_REG_4) { /* common case */
697                                 EMIT1(0x51); /* push rcx */
698
699                                 /* mov rcx, src_reg */
700                                 EMIT_mov(BPF_REG_4, src_reg);
701                         }
702
703                         /* shl %rax, %cl | shr %rax, %cl | sar %rax, %cl */
704                         if (BPF_CLASS(insn->code) == BPF_ALU64)
705                                 EMIT1(add_1mod(0x48, dst_reg));
706                         else if (is_ereg(dst_reg))
707                                 EMIT1(add_1mod(0x40, dst_reg));
708
709                         switch (BPF_OP(insn->code)) {
710                         case BPF_LSH: b3 = 0xE0; break;
711                         case BPF_RSH: b3 = 0xE8; break;
712                         case BPF_ARSH: b3 = 0xF8; break;
713                         }
714                         EMIT2(0xD3, add_1reg(b3, dst_reg));
715
716                         if (src_reg != BPF_REG_4)
717                                 EMIT1(0x59); /* pop rcx */
718
719                         if (insn->dst_reg == BPF_REG_4)
720                                 /* mov dst_reg, r11 */
721                                 EMIT_mov(insn->dst_reg, AUX_REG);
722                         break;
723
724                 case BPF_ALU | BPF_END | BPF_FROM_BE:
725                         switch (imm32) {
726                         case 16:
727                                 /* emit 'ror %ax, 8' to swap lower 2 bytes */
728                                 EMIT1(0x66);
729                                 if (is_ereg(dst_reg))
730                                         EMIT1(0x41);
731                                 EMIT3(0xC1, add_1reg(0xC8, dst_reg), 8);
732
733                                 /* emit 'movzwl eax, ax' */
734                                 if (is_ereg(dst_reg))
735                                         EMIT3(0x45, 0x0F, 0xB7);
736                                 else
737                                         EMIT2(0x0F, 0xB7);
738                                 EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
739                                 break;
740                         case 32:
741                                 /* emit 'bswap eax' to swap lower 4 bytes */
742                                 if (is_ereg(dst_reg))
743                                         EMIT2(0x41, 0x0F);
744                                 else
745                                         EMIT1(0x0F);
746                                 EMIT1(add_1reg(0xC8, dst_reg));
747                                 break;
748                         case 64:
749                                 /* emit 'bswap rax' to swap 8 bytes */
750                                 EMIT3(add_1mod(0x48, dst_reg), 0x0F,
751                                       add_1reg(0xC8, dst_reg));
752                                 break;
753                         }
754                         break;
755
756                 case BPF_ALU | BPF_END | BPF_FROM_LE:
757                         switch (imm32) {
758                         case 16:
759                                 /* emit 'movzwl eax, ax' to zero extend 16-bit
760                                  * into 64 bit
761                                  */
762                                 if (is_ereg(dst_reg))
763                                         EMIT3(0x45, 0x0F, 0xB7);
764                                 else
765                                         EMIT2(0x0F, 0xB7);
766                                 EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
767                                 break;
768                         case 32:
769                                 /* emit 'mov eax, eax' to clear upper 32-bits */
770                                 if (is_ereg(dst_reg))
771                                         EMIT1(0x45);
772                                 EMIT2(0x89, add_2reg(0xC0, dst_reg, dst_reg));
773                                 break;
774                         case 64:
775                                 /* nop */
776                                 break;
777                         }
778                         break;
779
780                         /* ST: *(u8*)(dst_reg + off) = imm */
781                 case BPF_ST | BPF_MEM | BPF_B:
782                         if (is_ereg(dst_reg))
783                                 EMIT2(0x41, 0xC6);
784                         else
785                                 EMIT1(0xC6);
786                         goto st;
787                 case BPF_ST | BPF_MEM | BPF_H:
788                         if (is_ereg(dst_reg))
789                                 EMIT3(0x66, 0x41, 0xC7);
790                         else
791                                 EMIT2(0x66, 0xC7);
792                         goto st;
793                 case BPF_ST | BPF_MEM | BPF_W:
794                         if (is_ereg(dst_reg))
795                                 EMIT2(0x41, 0xC7);
796                         else
797                                 EMIT1(0xC7);
798                         goto st;
799                 case BPF_ST | BPF_MEM | BPF_DW:
800                         EMIT2(add_1mod(0x48, dst_reg), 0xC7);
801
802 st:                     if (is_imm8(insn->off))
803                                 EMIT2(add_1reg(0x40, dst_reg), insn->off);
804                         else
805                                 EMIT1_off32(add_1reg(0x80, dst_reg), insn->off);
806
807                         EMIT(imm32, bpf_size_to_x86_bytes(BPF_SIZE(insn->code)));
808                         break;
809
810                         /* STX: *(u8*)(dst_reg + off) = src_reg */
811                 case BPF_STX | BPF_MEM | BPF_B:
812                         /* emit 'mov byte ptr [rax + off], al' */
813                         if (is_ereg(dst_reg) || is_ereg(src_reg) ||
814                             /* have to add extra byte for x86 SIL, DIL regs */
815                             src_reg == BPF_REG_1 || src_reg == BPF_REG_2)
816                                 EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x88);
817                         else
818                                 EMIT1(0x88);
819                         goto stx;
820                 case BPF_STX | BPF_MEM | BPF_H:
821                         if (is_ereg(dst_reg) || is_ereg(src_reg))
822                                 EMIT3(0x66, add_2mod(0x40, dst_reg, src_reg), 0x89);
823                         else
824                                 EMIT2(0x66, 0x89);
825                         goto stx;
826                 case BPF_STX | BPF_MEM | BPF_W:
827                         if (is_ereg(dst_reg) || is_ereg(src_reg))
828                                 EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x89);
829                         else
830                                 EMIT1(0x89);
831                         goto stx;
832                 case BPF_STX | BPF_MEM | BPF_DW:
833                         EMIT2(add_2mod(0x48, dst_reg, src_reg), 0x89);
834 stx:                    if (is_imm8(insn->off))
835                                 EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off);
836                         else
837                                 EMIT1_off32(add_2reg(0x80, dst_reg, src_reg),
838                                             insn->off);
839                         break;
840
841                         /* LDX: dst_reg = *(u8*)(src_reg + off) */
842                 case BPF_LDX | BPF_MEM | BPF_B:
843                         /* emit 'movzx rax, byte ptr [rax + off]' */
844                         EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB6);
845                         goto ldx;
846                 case BPF_LDX | BPF_MEM | BPF_H:
847                         /* emit 'movzx rax, word ptr [rax + off]' */
848                         EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB7);
849                         goto ldx;
850                 case BPF_LDX | BPF_MEM | BPF_W:
851                         /* emit 'mov eax, dword ptr [rax+0x14]' */
852                         if (is_ereg(dst_reg) || is_ereg(src_reg))
853                                 EMIT2(add_2mod(0x40, src_reg, dst_reg), 0x8B);
854                         else
855                                 EMIT1(0x8B);
856                         goto ldx;
857                 case BPF_LDX | BPF_MEM | BPF_DW:
858                         /* emit 'mov rax, qword ptr [rax+0x14]' */
859                         EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B);
860 ldx:                    /* if insn->off == 0 we can save one extra byte, but
861                          * special case of x86 r13 which always needs an offset
862                          * is not worth the hassle
863                          */
864                         if (is_imm8(insn->off))
865                                 EMIT2(add_2reg(0x40, src_reg, dst_reg), insn->off);
866                         else
867                                 EMIT1_off32(add_2reg(0x80, src_reg, dst_reg),
868                                             insn->off);
869                         break;
870
871                         /* STX XADD: lock *(u32*)(dst_reg + off) += src_reg */
872                 case BPF_STX | BPF_XADD | BPF_W:
873                         /* emit 'lock add dword ptr [rax + off], eax' */
874                         if (is_ereg(dst_reg) || is_ereg(src_reg))
875                                 EMIT3(0xF0, add_2mod(0x40, dst_reg, src_reg), 0x01);
876                         else
877                                 EMIT2(0xF0, 0x01);
878                         goto xadd;
879                 case BPF_STX | BPF_XADD | BPF_DW:
880                         EMIT3(0xF0, add_2mod(0x48, dst_reg, src_reg), 0x01);
881 xadd:                   if (is_imm8(insn->off))
882                                 EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off);
883                         else
884                                 EMIT1_off32(add_2reg(0x80, dst_reg, src_reg),
885                                             insn->off);
886                         break;
887
888                         /* call */
889                 case BPF_JMP | BPF_CALL:
890                         func = (u8 *) __bpf_call_base + imm32;
891                         jmp_offset = func - (image + addrs[i]);
892                         if (seen_ld_abs) {
893                                 reload_skb_data = bpf_helper_changes_pkt_data(func);
894                                 if (reload_skb_data) {
895                                         EMIT1(0x57); /* push %rdi */
896                                         jmp_offset += 22; /* pop, mov, sub, mov */
897                                 } else {
898                                         EMIT2(0x41, 0x52); /* push %r10 */
899                                         EMIT2(0x41, 0x51); /* push %r9 */
900                                         /* need to adjust jmp offset, since
901                                          * pop %r9, pop %r10 take 4 bytes after call insn
902                                          */
903                                         jmp_offset += 4;
904                                 }
905                         }
906                         if (!imm32 || !is_simm32(jmp_offset)) {
907                                 pr_err("unsupported bpf func %d addr %p image %p\n",
908                                        imm32, func, image);
909                                 return -EINVAL;
910                         }
911                         EMIT1_off32(0xE8, jmp_offset);
912                         if (seen_ld_abs) {
913                                 if (reload_skb_data) {
914                                         EMIT1(0x5F); /* pop %rdi */
915                                         emit_load_skb_data_hlen(&prog);
916                                 } else {
917                                         EMIT2(0x41, 0x59); /* pop %r9 */
918                                         EMIT2(0x41, 0x5A); /* pop %r10 */
919                                 }
920                         }
921                         break;
922
923                 case BPF_JMP | BPF_TAIL_CALL:
924                         emit_bpf_tail_call(&prog);
925                         break;
926
927                         /* cond jump */
928                 case BPF_JMP | BPF_JEQ | BPF_X:
929                 case BPF_JMP | BPF_JNE | BPF_X:
930                 case BPF_JMP | BPF_JGT | BPF_X:
931                 case BPF_JMP | BPF_JLT | BPF_X:
932                 case BPF_JMP | BPF_JGE | BPF_X:
933                 case BPF_JMP | BPF_JLE | BPF_X:
934                 case BPF_JMP | BPF_JSGT | BPF_X:
935                 case BPF_JMP | BPF_JSLT | BPF_X:
936                 case BPF_JMP | BPF_JSGE | BPF_X:
937                 case BPF_JMP | BPF_JSLE | BPF_X:
938                         /* cmp dst_reg, src_reg */
939                         EMIT3(add_2mod(0x48, dst_reg, src_reg), 0x39,
940                               add_2reg(0xC0, dst_reg, src_reg));
941                         goto emit_cond_jmp;
942
943                 case BPF_JMP | BPF_JSET | BPF_X:
944                         /* test dst_reg, src_reg */
945                         EMIT3(add_2mod(0x48, dst_reg, src_reg), 0x85,
946                               add_2reg(0xC0, dst_reg, src_reg));
947                         goto emit_cond_jmp;
948
949                 case BPF_JMP | BPF_JSET | BPF_K:
950                         /* test dst_reg, imm32 */
951                         EMIT1(add_1mod(0x48, dst_reg));
952                         EMIT2_off32(0xF7, add_1reg(0xC0, dst_reg), imm32);
953                         goto emit_cond_jmp;
954
955                 case BPF_JMP | BPF_JEQ | BPF_K:
956                 case BPF_JMP | BPF_JNE | BPF_K:
957                 case BPF_JMP | BPF_JGT | BPF_K:
958                 case BPF_JMP | BPF_JLT | BPF_K:
959                 case BPF_JMP | BPF_JGE | BPF_K:
960                 case BPF_JMP | BPF_JLE | BPF_K:
961                 case BPF_JMP | BPF_JSGT | BPF_K:
962                 case BPF_JMP | BPF_JSLT | BPF_K:
963                 case BPF_JMP | BPF_JSGE | BPF_K:
964                 case BPF_JMP | BPF_JSLE | BPF_K:
965                         /* cmp dst_reg, imm8/32 */
966                         EMIT1(add_1mod(0x48, dst_reg));
967
968                         if (is_imm8(imm32))
969                                 EMIT3(0x83, add_1reg(0xF8, dst_reg), imm32);
970                         else
971                                 EMIT2_off32(0x81, add_1reg(0xF8, dst_reg), imm32);
972
973 emit_cond_jmp:          /* convert BPF opcode to x86 */
974                         switch (BPF_OP(insn->code)) {
975                         case BPF_JEQ:
976                                 jmp_cond = X86_JE;
977                                 break;
978                         case BPF_JSET:
979                         case BPF_JNE:
980                                 jmp_cond = X86_JNE;
981                                 break;
982                         case BPF_JGT:
983                                 /* GT is unsigned '>', JA in x86 */
984                                 jmp_cond = X86_JA;
985                                 break;
986                         case BPF_JLT:
987                                 /* LT is unsigned '<', JB in x86 */
988                                 jmp_cond = X86_JB;
989                                 break;
990                         case BPF_JGE:
991                                 /* GE is unsigned '>=', JAE in x86 */
992                                 jmp_cond = X86_JAE;
993                                 break;
994                         case BPF_JLE:
995                                 /* LE is unsigned '<=', JBE in x86 */
996                                 jmp_cond = X86_JBE;
997                                 break;
998                         case BPF_JSGT:
999                                 /* signed '>', GT in x86 */
1000                                 jmp_cond = X86_JG;
1001                                 break;
1002                         case BPF_JSLT:
1003                                 /* signed '<', LT in x86 */
1004                                 jmp_cond = X86_JL;
1005                                 break;
1006                         case BPF_JSGE:
1007                                 /* signed '>=', GE in x86 */
1008                                 jmp_cond = X86_JGE;
1009                                 break;
1010                         case BPF_JSLE:
1011                                 /* signed '<=', LE in x86 */
1012                                 jmp_cond = X86_JLE;
1013                                 break;
1014                         default: /* to silence gcc warning */
1015                                 return -EFAULT;
1016                         }
1017                         jmp_offset = addrs[i + insn->off] - addrs[i];
1018                         if (is_imm8(jmp_offset)) {
1019                                 EMIT2(jmp_cond, jmp_offset);
1020                         } else if (is_simm32(jmp_offset)) {
1021                                 EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
1022                         } else {
1023                                 pr_err("cond_jmp gen bug %llx\n", jmp_offset);
1024                                 return -EFAULT;
1025                         }
1026
1027                         break;
1028
1029                 case BPF_JMP | BPF_JA:
1030                         jmp_offset = addrs[i + insn->off] - addrs[i];
1031                         if (!jmp_offset)
1032                                 /* optimize out nop jumps */
1033                                 break;
1034 emit_jmp:
1035                         if (is_imm8(jmp_offset)) {
1036                                 EMIT2(0xEB, jmp_offset);
1037                         } else if (is_simm32(jmp_offset)) {
1038                                 EMIT1_off32(0xE9, jmp_offset);
1039                         } else {
1040                                 pr_err("jmp gen bug %llx\n", jmp_offset);
1041                                 return -EFAULT;
1042                         }
1043                         break;
1044
1045                 case BPF_LD | BPF_IND | BPF_W:
1046                         func = sk_load_word;
1047                         goto common_load;
1048                 case BPF_LD | BPF_ABS | BPF_W:
1049                         func = CHOOSE_LOAD_FUNC(imm32, sk_load_word);
1050 common_load:
1051                         ctx->seen_ld_abs = seen_ld_abs = true;
1052                         jmp_offset = func - (image + addrs[i]);
1053                         if (!func || !is_simm32(jmp_offset)) {
1054                                 pr_err("unsupported bpf func %d addr %p image %p\n",
1055                                        imm32, func, image);
1056                                 return -EINVAL;
1057                         }
1058                         if (BPF_MODE(insn->code) == BPF_ABS) {
1059                                 /* mov %esi, imm32 */
1060                                 EMIT1_off32(0xBE, imm32);
1061                         } else {
1062                                 /* mov %rsi, src_reg */
1063                                 EMIT_mov(BPF_REG_2, src_reg);
1064                                 if (imm32) {
1065                                         if (is_imm8(imm32))
1066                                                 /* add %esi, imm8 */
1067                                                 EMIT3(0x83, 0xC6, imm32);
1068                                         else
1069                                                 /* add %esi, imm32 */
1070                                                 EMIT2_off32(0x81, 0xC6, imm32);
1071                                 }
1072                         }
1073                         /* skb pointer is in R6 (%rbx), it will be copied into
1074                          * %rdi if skb_copy_bits() call is necessary.
1075                          * sk_load_* helpers also use %r10 and %r9d.
1076                          * See bpf_jit.S
1077                          */
1078                         if (seen_ax_reg)
1079                                 /* r10 = skb->data, mov %r10, off32(%rbx) */
1080                                 EMIT3_off32(0x4c, 0x8b, 0x93,
1081                                             offsetof(struct sk_buff, data));
1082                         EMIT1_off32(0xE8, jmp_offset); /* call */
1083                         break;
1084
1085                 case BPF_LD | BPF_IND | BPF_H:
1086                         func = sk_load_half;
1087                         goto common_load;
1088                 case BPF_LD | BPF_ABS | BPF_H:
1089                         func = CHOOSE_LOAD_FUNC(imm32, sk_load_half);
1090                         goto common_load;
1091                 case BPF_LD | BPF_IND | BPF_B:
1092                         func = sk_load_byte;
1093                         goto common_load;
1094                 case BPF_LD | BPF_ABS | BPF_B:
1095                         func = CHOOSE_LOAD_FUNC(imm32, sk_load_byte);
1096                         goto common_load;
1097
1098                 case BPF_JMP | BPF_EXIT:
1099                         if (seen_exit) {
1100                                 jmp_offset = ctx->cleanup_addr - addrs[i];
1101                                 goto emit_jmp;
1102                         }
1103                         seen_exit = true;
1104                         /* update cleanup_addr */
1105                         ctx->cleanup_addr = proglen;
1106                         /* mov rbx, qword ptr [rbp+0] */
1107                         EMIT4(0x48, 0x8B, 0x5D, 0);
1108                         /* mov r13, qword ptr [rbp+8] */
1109                         EMIT4(0x4C, 0x8B, 0x6D, 8);
1110                         /* mov r14, qword ptr [rbp+16] */
1111                         EMIT4(0x4C, 0x8B, 0x75, 16);
1112                         /* mov r15, qword ptr [rbp+24] */
1113                         EMIT4(0x4C, 0x8B, 0x7D, 24);
1114
1115                         /* add rbp, AUX_STACK_SPACE */
1116                         EMIT4(0x48, 0x83, 0xC5, AUX_STACK_SPACE);
1117                         EMIT1(0xC9); /* leave */
1118                         EMIT1(0xC3); /* ret */
1119                         break;
1120
1121                 default:
1122                         /* By design x64 JIT should support all BPF instructions
1123                          * This error will be seen if new instruction was added
1124                          * to interpreter, but not to JIT
1125                          * or if there is junk in bpf_prog
1126                          */
1127                         pr_err("bpf_jit: unknown opcode %02x\n", insn->code);
1128                         return -EINVAL;
1129                 }
1130
1131                 ilen = prog - temp;
1132                 if (ilen > BPF_MAX_INSN_SIZE) {
1133                         pr_err("bpf_jit: fatal insn size error\n");
1134                         return -EFAULT;
1135                 }
1136
1137                 if (image) {
1138                         if (unlikely(proglen + ilen > oldproglen)) {
1139                                 pr_err("bpf_jit: fatal error\n");
1140                                 return -EFAULT;
1141                         }
1142                         memcpy(image + proglen, temp, ilen);
1143                 }
1144                 proglen += ilen;
1145                 addrs[i] = proglen;
1146                 prog = temp;
1147         }
1148         return proglen;
1149 }
1150
1151 struct x64_jit_data {
1152         struct bpf_binary_header *header;
1153         int *addrs;
1154         u8 *image;
1155         int proglen;
1156         struct jit_context ctx;
1157 };
1158
1159 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
1160 {
1161         struct bpf_binary_header *header = NULL;
1162         struct bpf_prog *tmp, *orig_prog = prog;
1163         struct x64_jit_data *jit_data;
1164         int proglen, oldproglen = 0;
1165         struct jit_context ctx = {};
1166         bool tmp_blinded = false;
1167         bool extra_pass = false;
1168         u8 *image = NULL;
1169         int *addrs;
1170         int pass;
1171         int i;
1172
1173         if (!prog->jit_requested)
1174                 return orig_prog;
1175
1176         tmp = bpf_jit_blind_constants(prog);
1177         /* If blinding was requested and we failed during blinding,
1178          * we must fall back to the interpreter.
1179          */
1180         if (IS_ERR(tmp))
1181                 return orig_prog;
1182         if (tmp != prog) {
1183                 tmp_blinded = true;
1184                 prog = tmp;
1185         }
1186
1187         jit_data = prog->aux->jit_data;
1188         if (!jit_data) {
1189                 jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
1190                 if (!jit_data) {
1191                         prog = orig_prog;
1192                         goto out;
1193                 }
1194                 prog->aux->jit_data = jit_data;
1195         }
1196         addrs = jit_data->addrs;
1197         if (addrs) {
1198                 ctx = jit_data->ctx;
1199                 oldproglen = jit_data->proglen;
1200                 image = jit_data->image;
1201                 header = jit_data->header;
1202                 extra_pass = true;
1203                 goto skip_init_addrs;
1204         }
1205         addrs = kmalloc(prog->len * sizeof(*addrs), GFP_KERNEL);
1206         if (!addrs) {
1207                 prog = orig_prog;
1208                 goto out_addrs;
1209         }
1210
1211         /* Before first pass, make a rough estimation of addrs[]
1212          * each bpf instruction is translated to less than 64 bytes
1213          */
1214         for (proglen = 0, i = 0; i < prog->len; i++) {
1215                 proglen += 64;
1216                 addrs[i] = proglen;
1217         }
1218         ctx.cleanup_addr = proglen;
1219 skip_init_addrs:
1220
1221         /* JITed image shrinks with every pass and the loop iterates
1222          * until the image stops shrinking. Very large bpf programs
1223          * may converge on the last pass. In such case do one more
1224          * pass to emit the final image
1225          */
1226         for (pass = 0; pass < 20 || image; pass++) {
1227                 proglen = do_jit(prog, addrs, image, oldproglen, &ctx);
1228                 if (proglen <= 0) {
1229                         image = NULL;
1230                         if (header)
1231                                 bpf_jit_binary_free(header);
1232                         prog = orig_prog;
1233                         goto out_addrs;
1234                 }
1235                 if (image) {
1236                         if (proglen != oldproglen) {
1237                                 pr_err("bpf_jit: proglen=%d != oldproglen=%d\n",
1238                                        proglen, oldproglen);
1239                                 prog = orig_prog;
1240                                 goto out_addrs;
1241                         }
1242                         break;
1243                 }
1244                 if (proglen == oldproglen) {
1245                         header = bpf_jit_binary_alloc(proglen, &image,
1246                                                       1, jit_fill_hole);
1247                         if (!header) {
1248                                 prog = orig_prog;
1249                                 goto out_addrs;
1250                         }
1251                 }
1252                 oldproglen = proglen;
1253                 cond_resched();
1254         }
1255
1256         if (bpf_jit_enable > 1)
1257                 bpf_jit_dump(prog->len, proglen, pass + 1, image);
1258
1259         if (image) {
1260                 if (!prog->is_func || extra_pass) {
1261                         bpf_jit_binary_lock_ro(header);
1262                 } else {
1263                         jit_data->addrs = addrs;
1264                         jit_data->ctx = ctx;
1265                         jit_data->proglen = proglen;
1266                         jit_data->image = image;
1267                         jit_data->header = header;
1268                 }
1269                 prog->bpf_func = (void *)image;
1270                 prog->jited = 1;
1271                 prog->jited_len = proglen;
1272         } else {
1273                 prog = orig_prog;
1274         }
1275
1276         if (!prog->is_func || extra_pass) {
1277 out_addrs:
1278                 kfree(addrs);
1279                 kfree(jit_data);
1280                 prog->aux->jit_data = NULL;
1281         }
1282 out:
1283         if (tmp_blinded)
1284                 bpf_jit_prog_release_other(prog, prog == orig_prog ?
1285                                            tmp : orig_prog);
1286         return prog;
1287 }