x86/bpf: Clean up non-standard comments, to make the code more readable
[linux-2.6-microblaze.git] / arch / x86 / net / bpf_jit_comp.c
index ce5b2eb..ac4df93 100644 (file)
@@ -1,4 +1,5 @@
-/* bpf_jit_comp.c : BPF JIT compiler
+/*
+ * bpf_jit_comp.c: BPF JIT compiler
  *
  * Copyright (C) 2011-2013 Eric Dumazet (eric.dumazet@gmail.com)
  * Internal BPF Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
 #include <linux/netdevice.h>
 #include <linux/filter.h>
 #include <linux/if_vlan.h>
-#include <asm/cacheflush.h>
+#include <linux/bpf.h>
+
 #include <asm/set_memory.h>
 #include <asm/nospec-branch.h>
-#include <linux/bpf.h>
 
 /*
- * assembly code in arch/x86/net/bpf_jit.S
+ * Assembly code in arch/x86/net/bpf_jit.S
  */
 extern u8 sk_load_word[], sk_load_half[], sk_load_byte[];
 extern u8 sk_load_word_positive_offset[], sk_load_half_positive_offset[];
@@ -45,14 +46,15 @@ static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
 #define EMIT2(b1, b2)          EMIT((b1) + ((b2) << 8), 2)
 #define EMIT3(b1, b2, b3)      EMIT((b1) + ((b2) << 8) + ((b3) << 16), 3)
 #define EMIT4(b1, b2, b3, b4)   EMIT((b1) + ((b2) << 8) + ((b3) << 16) + ((b4) << 24), 4)
+
 #define EMIT1_off32(b1, off) \
-       do {EMIT1(b1); EMIT(off, 4); } while (0)
+       do { EMIT1(b1); EMIT(off, 4); } while (0)
 #define EMIT2_off32(b1, b2, off) \
-       do {EMIT2(b1, b2); EMIT(off, 4); } while (0)
+       do { EMIT2(b1, b2); EMIT(off, 4); } while (0)
 #define EMIT3_off32(b1, b2, b3, off) \
-       do {EMIT3(b1, b2, b3); EMIT(off, 4); } while (0)
+       do { EMIT3(b1, b2, b3); EMIT(off, 4); } while (0)
 #define EMIT4_off32(b1, b2, b3, b4, off) \
-       do {EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0)
+       do { EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0)
 
 static bool is_imm8(int value)
 {
@@ -61,13 +63,19 @@ static bool is_imm8(int value)
 
 static bool is_simm32(s64 value)
 {
-       return value == (s64) (s32) value;
+       return value == (s64)(s32)value;
+}
+
+static bool is_uimm32(u64 value)
+{
+       return value == (u64)(u32)value;
 }
 
 /* mov dst, src */
-#define EMIT_mov(DST, SRC) \
-       do {if (DST != SRC) \
-               EMIT3(add_2mod(0x48, DST, SRC), 0x89, add_2reg(0xC0, DST, SRC)); \
+#define EMIT_mov(DST, SRC)                                                              \
+       do {                                                                             \
+               if (DST != SRC)                                                          \
+                       EMIT3(add_2mod(0x48, DST, SRC), 0x89, add_2reg(0xC0, DST, SRC)); \
        } while (0)
 
 static int bpf_size_to_x86_bytes(int bpf_size)
@@ -84,7 +92,8 @@ static int bpf_size_to_x86_bytes(int bpf_size)
                return 0;
 }
 
-/* list of x86 cond jumps opcodes (. + s8)
+/*
+ * List of x86 cond jumps opcodes (. + s8)
  * Add 0x10 (and an extra 0x0f) to generate far jumps (. + s32)
  */
 #define X86_JB  0x72
@@ -98,48 +107,40 @@ static int bpf_size_to_x86_bytes(int bpf_size)
 #define X86_JLE 0x7E
 #define X86_JG  0x7F
 
-static void bpf_flush_icache(void *start, void *end)
-{
-       mm_segment_t old_fs = get_fs();
-
-       set_fs(KERNEL_DS);
-       smp_wmb();
-       flush_icache_range((unsigned long)start, (unsigned long)end);
-       set_fs(old_fs);
-}
-
 #define CHOOSE_LOAD_FUNC(K, func) \
        ((int)K < 0 ? ((int)K >= SKF_LL_OFF ? func##_negative_offset : func) : func##_positive_offset)
 
-/* pick a register outside of BPF range for JIT internal work */
+/* Pick a register outside of BPF range for JIT internal work */
 #define AUX_REG (MAX_BPF_JIT_REG + 1)
 
-/* The following table maps BPF registers to x64 registers.
+/*
+ * The following table maps BPF registers to x86-64 registers.
  *
- * x64 register r12 is unused, since if used as base address
+ * x86-64 register R12 is unused, since if used as base address
  * register in load/store instructions, it always needs an
  * extra byte of encoding and is callee saved.
  *
- *  r9 caches skb->len - skb->data_len
- * r10 caches skb->data, and used for blinding (if enabled)
+ * R9  caches skb->len - skb->data_len
+ * R10 caches skb->data, and used for blinding (if enabled)
  */
 static const int reg2hex[] = {
-       [BPF_REG_0] = 0,  /* rax */
-       [BPF_REG_1] = 7,  /* rdi */
-       [BPF_REG_2] = 6,  /* rsi */
-       [BPF_REG_3] = 2,  /* rdx */
-       [BPF_REG_4] = 1,  /* rcx */
-       [BPF_REG_5] = 0,  /* r8 */
-       [BPF_REG_6] = 3,  /* rbx callee saved */
-       [BPF_REG_7] = 5,  /* r13 callee saved */
-       [BPF_REG_8] = 6,  /* r14 callee saved */
-       [BPF_REG_9] = 7,  /* r15 callee saved */
-       [BPF_REG_FP] = 5, /* rbp readonly */
-       [BPF_REG_AX] = 2, /* r10 temp register */
-       [AUX_REG] = 3,    /* r11 temp register */
+       [BPF_REG_0] = 0,  /* RAX */
+       [BPF_REG_1] = 7,  /* RDI */
+       [BPF_REG_2] = 6,  /* RSI */
+       [BPF_REG_3] = 2,  /* RDX */
+       [BPF_REG_4] = 1,  /* RCX */
+       [BPF_REG_5] = 0,  /* R8  */
+       [BPF_REG_6] = 3,  /* RBX callee saved */
+       [BPF_REG_7] = 5,  /* R13 callee saved */
+       [BPF_REG_8] = 6,  /* R14 callee saved */
+       [BPF_REG_9] = 7,  /* R15 callee saved */
+       [BPF_REG_FP] = 5, /* RBP readonly */
+       [BPF_REG_AX] = 2, /* R10 temp register */
+       [AUX_REG] = 3,    /* R11 temp register */
 };
 
-/* is_ereg() == true if BPF register 'reg' maps to x64 r8..r15
+/*
+ * is_ereg() == true if BPF register 'reg' maps to x86-64 r8..r15
  * which need extra byte of encoding.
  * rax,rcx,...,rbp have simpler encoding
  */
@@ -158,7 +159,7 @@ static bool is_axreg(u32 reg)
        return reg == BPF_REG_0;
 }
 
-/* add modifiers if 'reg' maps to x64 registers r8..r15 */
+/* Add modifiers if 'reg' maps to x86-64 registers R8..R15 */
 static u8 add_1mod(u8 byte, u32 reg)
 {
        if (is_ereg(reg))
@@ -175,13 +176,13 @@ static u8 add_2mod(u8 byte, u32 r1, u32 r2)
        return byte;
 }
 
-/* encode 'dst_reg' register into x64 opcode 'byte' */
+/* Encode 'dst_reg' register into x86-64 opcode 'byte' */
 static u8 add_1reg(u8 byte, u32 dst_reg)
 {
        return byte + reg2hex[dst_reg];
 }
 
-/* encode 'dst_reg' and 'src_reg' registers into x64 opcode 'byte' */
+/* Encode 'dst_reg' and 'src_reg' registers into x86-64 opcode 'byte' */
 static u8 add_2reg(u8 byte, u32 dst_reg, u32 src_reg)
 {
        return byte + reg2hex[dst_reg] + (reg2hex[src_reg] << 3);
@@ -189,36 +190,40 @@ static u8 add_2reg(u8 byte, u32 dst_reg, u32 src_reg)
 
 static void jit_fill_hole(void *area, unsigned int size)
 {
-       /* fill whole space with int3 instructions */
+       /* Fill whole space with INT3 instructions */
        memset(area, 0xcc, size);
 }
 
 struct jit_context {
-       int cleanup_addr; /* epilogue code offset */
+       int cleanup_addr; /* Epilogue code offset */
        bool seen_ld_abs;
        bool seen_ax_reg;
 };
 
-/* maximum number of bytes emitted while JITing one eBPF insn */
+/* Maximum number of bytes emitted while JITing one eBPF insn */
 #define BPF_MAX_INSN_SIZE      128
 #define BPF_INSN_SAFETY                64
 
 #define AUX_STACK_SPACE \
-       (32 /* space for rbx, r13, r14, r15 */ + \
-        8 /* space for skb_copy_bits() buffer */)
+       (32 /* Space for RBX, R13, R14, R15 */ + \
+         8 /* Space for skb_copy_bits() buffer */)
 
 #define PROLOGUE_SIZE 37
 
-/* emit x64 prologue code for BPF program and check it's size.
+/*
+ * Emit x86-64 prologue code for BPF program and check its size.
  * bpf_tail_call helper will skip it while jumping into another program
  */
-static void emit_prologue(u8 **pprog, u32 stack_depth)
+static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
 {
        u8 *prog = *pprog;
        int cnt = 0;
 
-       EMIT1(0x55); /* push rbp */
-       EMIT3(0x48, 0x89, 0xE5); /* mov rbp,rsp */
+       /* push rbp */
+       EMIT1(0x55);
+
+       /* mov rbp,rsp */
+       EMIT3(0x48, 0x89, 0xE5);
 
        /* sub rsp, rounded_stack_depth + AUX_STACK_SPACE */
        EMIT3_off32(0x48, 0x81, 0xEC,
@@ -227,14 +232,15 @@ static void emit_prologue(u8 **pprog, u32 stack_depth)
        /* sub rbp, AUX_STACK_SPACE */
        EMIT4(0x48, 0x83, 0xED, AUX_STACK_SPACE);
 
-       /* all classic BPF filters use R6(rbx) save it */
+       /* All classic BPF filters use R6(rbx) save it */
 
        /* mov qword ptr [rbp+0],rbx */
        EMIT4(0x48, 0x89, 0x5D, 0);
 
-       /* bpf_convert_filter() maps classic BPF register X to R7 and uses R8
-        * as temporary, so all tcpdump filters need to spill/fill R7(r13) and
-        * R8(r14). R9(r15) spill could be made conditional, but there is only
+       /*
+        * bpf_convert_filter() maps classic BPF register X to R7 and uses R8
+        * as temporary, so all tcpdump filters need to spill/fill R7(R13) and
+        * R8(R14). R9(R15) spill could be made conditional, but there is only
         * one 'bpf_error' return path out of helper functions inside bpf_jit.S
         * The overhead of extra spill is negligible for any filter other
         * than synthetic ones. Therefore not worth adding complexity.
@@ -247,22 +253,28 @@ static void emit_prologue(u8 **pprog, u32 stack_depth)
        /* mov qword ptr [rbp+24],r15 */
        EMIT4(0x4C, 0x89, 0x7D, 24);
 
-       /* Clear the tail call counter (tail_call_cnt): for eBPF tail calls
-        * we need to reset the counter to 0. It's done in two instructions,
-        * resetting rax register to 0 (xor on eax gets 0 extended), and
-        * moving it to the counter location.
-        */
+       if (!ebpf_from_cbpf) {
+               /*
+                * Clear the tail call counter (tail_call_cnt): for eBPF tail
+                * calls we need to reset the counter to 0. It's done in two
+                * instructions, resetting RAX register to 0, and moving it
+                * to the counter location.
+                */
 
-       /* xor eax, eax */
-       EMIT2(0x31, 0xc0);
-       /* mov qword ptr [rbp+32], rax */
-       EMIT4(0x48, 0x89, 0x45, 32);
+               /* xor eax, eax */
+               EMIT2(0x31, 0xc0);
+               /* mov qword ptr [rbp+32], rax */
+               EMIT4(0x48, 0x89, 0x45, 32);
+
+               BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
+       }
 
-       BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
        *pprog = prog;
 }
 
-/* generate the following code:
+/*
+ * Generate the following code:
+ *
  * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ...
  *   if (index >= array->map.max_entries)
  *     goto out;
@@ -280,23 +292,26 @@ static void emit_bpf_tail_call(u8 **pprog)
        int label1, label2, label3;
        int cnt = 0;
 
-       /* rdi - pointer to ctx
+       /*
+        * rdi - pointer to ctx
         * rsi - pointer to bpf_array
         * rdx - index in bpf_array
         */
 
-       /* if (index >= array->map.max_entries)
-        *   goto out;
+       /*
+        * if (index >= array->map.max_entries)
+        *      goto out;
         */
        EMIT2(0x89, 0xD2);                        /* mov edx, edx */
        EMIT3(0x39, 0x56,                         /* cmp dword ptr [rsi + 16], edx */
              offsetof(struct bpf_array, map.max_entries));
-#define OFFSET1 (41 + RETPOLINE_RAX_BPF_JIT_SIZE) /* number of bytes to jump */
+#define OFFSET1 (41 + RETPOLINE_RAX_BPF_JIT_SIZE) /* Number of bytes to jump */
        EMIT2(X86_JBE, OFFSET1);                  /* jbe out */
        label1 = cnt;
 
-       /* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
-        *   goto out;
+       /*
+        * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
+        *      goto out;
         */
        EMIT2_off32(0x8B, 0x85, 36);              /* mov eax, dword ptr [rbp + 36] */
        EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
@@ -310,8 +325,9 @@ static void emit_bpf_tail_call(u8 **pprog)
        EMIT4_off32(0x48, 0x8B, 0x84, 0xD6,       /* mov rax, [rsi + rdx * 8 + offsetof(...)] */
                    offsetof(struct bpf_array, ptrs));
 
-       /* if (prog == NULL)
-        *   goto out;
+       /*
+        * if (prog == NULL)
+        *      goto out;
         */
        EMIT3(0x48, 0x85, 0xC0);                  /* test rax,rax */
 #define OFFSET3 (8 + RETPOLINE_RAX_BPF_JIT_SIZE)
@@ -323,7 +339,8 @@ static void emit_bpf_tail_call(u8 **pprog)
              offsetof(struct bpf_prog, bpf_func));
        EMIT4(0x48, 0x83, 0xC0, PROLOGUE_SIZE);   /* add rax, prologue_size */
 
-       /* now we're ready to jump into next BPF program
+       /*
+        * Wow we're ready to jump into next BPF program
         * rdi == ctx (1st arg)
         * rax == prog->bpf_func + prologue_size
         */
@@ -342,7 +359,8 @@ static void emit_load_skb_data_hlen(u8 **pprog)
        u8 *prog = *pprog;
        int cnt = 0;
 
-       /* r9d = skb->len - skb->data_len (headlen)
+       /*
+        * r9d = skb->len - skb->data_len (headlen)
         * r10 = skb->data
         */
        /* mov %r9d, off32(%rdi) */
@@ -356,6 +374,89 @@ static void emit_load_skb_data_hlen(u8 **pprog)
        *pprog = prog;
 }
 
+static void emit_mov_imm32(u8 **pprog, bool sign_propagate,
+                          u32 dst_reg, const u32 imm32)
+{
+       u8 *prog = *pprog;
+       u8 b1, b2, b3;
+       int cnt = 0;
+
+       /*
+        * Optimization: if imm32 is positive, use 'mov %eax, imm32'
+        * (which zero-extends imm32) to save 2 bytes.
+        */
+       if (sign_propagate && (s32)imm32 < 0) {
+               /* 'mov %rax, imm32' sign extends imm32 */
+               b1 = add_1mod(0x48, dst_reg);
+               b2 = 0xC7;
+               b3 = 0xC0;
+               EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
+               goto done;
+       }
+
+       /*
+        * Optimization: if imm32 is zero, use 'xor %eax, %eax'
+        * to save 3 bytes.
+        */
+       if (imm32 == 0) {
+               if (is_ereg(dst_reg))
+                       EMIT1(add_2mod(0x40, dst_reg, dst_reg));
+               b2 = 0x31; /* xor */
+               b3 = 0xC0;
+               EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
+               goto done;
+       }
+
+       /* mov %eax, imm32 */
+       if (is_ereg(dst_reg))
+               EMIT1(add_1mod(0x40, dst_reg));
+       EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
+done:
+       *pprog = prog;
+}
+
+static void emit_mov_imm64(u8 **pprog, u32 dst_reg,
+                          const u32 imm32_hi, const u32 imm32_lo)
+{
+       u8 *prog = *pprog;
+       int cnt = 0;
+
+       if (is_uimm32(((u64)imm32_hi << 32) | (u32)imm32_lo)) {
+               /*
+                * For emitting plain u32, where sign bit must not be
+                * propagated LLVM tends to load imm64 over mov32
+                * directly, so save couple of bytes by just doing
+                * 'mov %eax, imm32' instead.
+                */
+               emit_mov_imm32(&prog, false, dst_reg, imm32_lo);
+       } else {
+               /* movabsq %rax, imm64 */
+               EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
+               EMIT(imm32_lo, 4);
+               EMIT(imm32_hi, 4);
+       }
+
+       *pprog = prog;
+}
+
+static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg)
+{
+       u8 *prog = *pprog;
+       int cnt = 0;
+
+       if (is64) {
+               /* mov dst, src */
+               EMIT_mov(dst_reg, src_reg);
+       } else {
+               /* mov32 dst, src */
+               if (is_ereg(dst_reg) || is_ereg(src_reg))
+                       EMIT1(add_2mod(0x40, dst_reg, src_reg));
+               EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
+       }
+
+       *pprog = prog;
+}
+
 static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                  int oldproglen, struct jit_context *ctx)
 {
@@ -369,7 +470,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
        int proglen = 0;
        u8 *prog = temp;
 
-       emit_prologue(&prog, bpf_prog->aux->stack_depth);
+       emit_prologue(&prog, bpf_prog->aux->stack_depth,
+                     bpf_prog_was_classic(bpf_prog));
 
        if (seen_ld_abs)
                emit_load_skb_data_hlen(&prog);
@@ -378,7 +480,7 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                const s32 imm32 = insn->imm;
                u32 dst_reg = insn->dst_reg;
                u32 src_reg = insn->src_reg;
-               u8 b1 = 0, b2 = 0, b3 = 0;
+               u8 b2 = 0, b3 = 0;
                s64 jmp_offset;
                u8 jmp_cond;
                bool reload_skb_data;
@@ -414,16 +516,11 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                        EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg));
                        break;
 
-                       /* mov dst, src */
                case BPF_ALU64 | BPF_MOV | BPF_X:
-                       EMIT_mov(dst_reg, src_reg);
-                       break;
-
-                       /* mov32 dst, src */
                case BPF_ALU | BPF_MOV | BPF_X:
-                       if (is_ereg(dst_reg) || is_ereg(src_reg))
-                               EMIT1(add_2mod(0x40, dst_reg, src_reg));
-                       EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
+                       emit_mov_reg(&prog,
+                                    BPF_CLASS(insn->code) == BPF_ALU64,
+                                    dst_reg, src_reg);
                        break;
 
                        /* neg dst */
@@ -451,7 +548,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                        else if (is_ereg(dst_reg))
                                EMIT1(add_1mod(0x40, dst_reg));
 
-                       /* b3 holds 'normal' opcode, b2 short form only valid
+                       /*
+                        * b3 holds 'normal' opcode, b2 short form only valid
                         * in case dst is eax/rax.
                         */
                        switch (BPF_OP(insn->code)) {
@@ -486,58 +584,13 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                        break;
 
                case BPF_ALU64 | BPF_MOV | BPF_K:
-                       /* optimization: if imm32 is positive,
-                        * use 'mov eax, imm32' (which zero-extends imm32)
-                        * to save 2 bytes
-                        */
-                       if (imm32 < 0) {
-                               /* 'mov rax, imm32' sign extends imm32 */
-                               b1 = add_1mod(0x48, dst_reg);
-                               b2 = 0xC7;
-                               b3 = 0xC0;
-                               EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
-                               break;
-                       }
-
                case BPF_ALU | BPF_MOV | BPF_K:
-                       /* optimization: if imm32 is zero, use 'xor <dst>,<dst>'
-                        * to save 3 bytes.
-                        */
-                       if (imm32 == 0) {
-                               if (is_ereg(dst_reg))
-                                       EMIT1(add_2mod(0x40, dst_reg, dst_reg));
-                               b2 = 0x31; /* xor */
-                               b3 = 0xC0;
-                               EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
-                               break;
-                       }
-
-                       /* mov %eax, imm32 */
-                       if (is_ereg(dst_reg))
-                               EMIT1(add_1mod(0x40, dst_reg));
-                       EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
+                       emit_mov_imm32(&prog, BPF_CLASS(insn->code) == BPF_ALU64,
+                                      dst_reg, imm32);
                        break;
 
                case BPF_LD | BPF_IMM | BPF_DW:
-                       /* optimization: if imm64 is zero, use 'xor <dst>,<dst>'
-                        * to save 7 bytes.
-                        */
-                       if (insn[0].imm == 0 && insn[1].imm == 0) {
-                               b1 = add_2mod(0x48, dst_reg, dst_reg);
-                               b2 = 0x31; /* xor */
-                               b3 = 0xC0;
-                               EMIT3(b1, b2, add_2reg(b3, dst_reg, dst_reg));
-
-                               insn++;
-                               i++;
-                               break;
-                       }
-
-                       /* movabsq %rax, imm64 */
-                       EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
-                       EMIT(insn[0].imm, 4);
-                       EMIT(insn[1].imm, 4);
-
+                       emit_mov_imm64(&prog, dst_reg, insn[1].imm, insn[0].imm);
                        insn++;
                        i++;
                        break;
@@ -564,7 +617,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                        /* mov rax, dst_reg */
                        EMIT_mov(BPF_REG_0, dst_reg);
 
-                       /* xor edx, edx
+                       /*
+                        * xor edx, edx
                         * equivalent to 'xor rdx, rdx', but one byte less
                         */
                        EMIT2(0x31, 0xd2);
@@ -594,37 +648,39 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                case BPF_ALU | BPF_MUL | BPF_X:
                case BPF_ALU64 | BPF_MUL | BPF_K:
                case BPF_ALU64 | BPF_MUL | BPF_X:
-                       EMIT1(0x50); /* push rax */
-                       EMIT1(0x52); /* push rdx */
+               {
+                       bool is64 = BPF_CLASS(insn->code) == BPF_ALU64;
+
+                       if (dst_reg != BPF_REG_0)
+                               EMIT1(0x50); /* push rax */
+                       if (dst_reg != BPF_REG_3)
+                               EMIT1(0x52); /* push rdx */
 
                        /* mov r11, dst_reg */
                        EMIT_mov(AUX_REG, dst_reg);
 
                        if (BPF_SRC(insn->code) == BPF_X)
-                               /* mov rax, src_reg */
-                               EMIT_mov(BPF_REG_0, src_reg);
+                               emit_mov_reg(&prog, is64, BPF_REG_0, src_reg);
                        else
-                               /* mov rax, imm32 */
-                               EMIT3_off32(0x48, 0xC7, 0xC0, imm32);
+                               emit_mov_imm32(&prog, is64, BPF_REG_0, imm32);
 
-                       if (BPF_CLASS(insn->code) == BPF_ALU64)
+                       if (is64)
                                EMIT1(add_1mod(0x48, AUX_REG));
                        else if (is_ereg(AUX_REG))
                                EMIT1(add_1mod(0x40, AUX_REG));
                        /* mul(q) r11 */
                        EMIT2(0xF7, add_1reg(0xE0, AUX_REG));
 
-                       /* mov r11, rax */
-                       EMIT_mov(AUX_REG, BPF_REG_0);
-
-                       EMIT1(0x5A); /* pop rdx */
-                       EMIT1(0x58); /* pop rax */
-
-                       /* mov dst_reg, r11 */
-                       EMIT_mov(dst_reg, AUX_REG);
+                       if (dst_reg != BPF_REG_3)
+                               EMIT1(0x5A); /* pop rdx */
+                       if (dst_reg != BPF_REG_0) {
+                               /* mov dst_reg, rax */
+                               EMIT_mov(dst_reg, BPF_REG_0);
+                               EMIT1(0x58); /* pop rax */
+                       }
                        break;
-
-                       /* shifts */
+               }
+                       /* Shifts */
                case BPF_ALU | BPF_LSH | BPF_K:
                case BPF_ALU | BPF_RSH | BPF_K:
                case BPF_ALU | BPF_ARSH | BPF_K:
@@ -641,7 +697,11 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                        case BPF_RSH: b3 = 0xE8; break;
                        case BPF_ARSH: b3 = 0xF8; break;
                        }
-                       EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
+
+                       if (imm32 == 1)
+                               EMIT2(0xD1, add_1reg(b3, dst_reg));
+                       else
+                               EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
                        break;
 
                case BPF_ALU | BPF_LSH | BPF_X:
@@ -651,7 +711,7 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                case BPF_ALU64 | BPF_RSH | BPF_X:
                case BPF_ALU64 | BPF_ARSH | BPF_X:
 
-                       /* check for bad case when dst_reg == rcx */
+                       /* Check for bad case when dst_reg == rcx */
                        if (dst_reg == BPF_REG_4) {
                                /* mov r11, dst_reg */
                                EMIT_mov(AUX_REG, dst_reg);
@@ -689,13 +749,13 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                case BPF_ALU | BPF_END | BPF_FROM_BE:
                        switch (imm32) {
                        case 16:
-                               /* emit 'ror %ax, 8' to swap lower 2 bytes */
+                               /* Emit 'ror %ax, 8' to swap lower 2 bytes */
                                EMIT1(0x66);
                                if (is_ereg(dst_reg))
                                        EMIT1(0x41);
                                EMIT3(0xC1, add_1reg(0xC8, dst_reg), 8);
 
-                               /* emit 'movzwl eax, ax' */
+                               /* Emit 'movzwl eax, ax' */
                                if (is_ereg(dst_reg))
                                        EMIT3(0x45, 0x0F, 0xB7);
                                else
@@ -703,7 +763,7 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                                EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
                                break;
                        case 32:
-                               /* emit 'bswap eax' to swap lower 4 bytes */
+                               /* Emit 'bswap eax' to swap lower 4 bytes */
                                if (is_ereg(dst_reg))
                                        EMIT2(0x41, 0x0F);
                                else
@@ -711,7 +771,7 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                                EMIT1(add_1reg(0xC8, dst_reg));
                                break;
                        case 64:
-                               /* emit 'bswap rax' to swap 8 bytes */
+                               /* Emit 'bswap rax' to swap 8 bytes */
                                EMIT3(add_1mod(0x48, dst_reg), 0x0F,
                                      add_1reg(0xC8, dst_reg));
                                break;
@@ -721,7 +781,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                case BPF_ALU | BPF_END | BPF_FROM_LE:
                        switch (imm32) {
                        case 16:
-                               /* emit 'movzwl eax, ax' to zero extend 16-bit
+                               /*
+                                * Emit 'movzwl eax, ax' to zero extend 16-bit
                                 * into 64 bit
                                 */
                                if (is_ereg(dst_reg))
@@ -731,7 +792,7 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                                EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
                                break;
                        case 32:
-                               /* emit 'mov eax, eax' to clear upper 32-bits */
+                               /* Emit 'mov eax, eax' to clear upper 32-bits */
                                if (is_ereg(dst_reg))
                                        EMIT1(0x45);
                                EMIT2(0x89, add_2reg(0xC0, dst_reg, dst_reg));
@@ -774,9 +835,9 @@ st:                 if (is_imm8(insn->off))
 
                        /* STX: *(u8*)(dst_reg + off) = src_reg */
                case BPF_STX | BPF_MEM | BPF_B:
-                       /* emit 'mov byte ptr [rax + off], al' */
+                       /* Emit 'mov byte ptr [rax + off], al' */
                        if (is_ereg(dst_reg) || is_ereg(src_reg) ||
-                           /* have to add extra byte for x86 SIL, DIL regs */
+                           /* We have to add extra byte for x86 SIL, DIL regs */
                            src_reg == BPF_REG_1 || src_reg == BPF_REG_2)
                                EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x88);
                        else
@@ -805,25 +866,26 @@ stx:                      if (is_imm8(insn->off))
 
                        /* LDX: dst_reg = *(u8*)(src_reg + off) */
                case BPF_LDX | BPF_MEM | BPF_B:
-                       /* emit 'movzx rax, byte ptr [rax + off]' */
+                       /* Emit 'movzx rax, byte ptr [rax + off]' */
                        EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB6);
                        goto ldx;
                case BPF_LDX | BPF_MEM | BPF_H:
-                       /* emit 'movzx rax, word ptr [rax + off]' */
+                       /* Emit 'movzx rax, word ptr [rax + off]' */
                        EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB7);
                        goto ldx;
                case BPF_LDX | BPF_MEM | BPF_W:
-                       /* emit 'mov eax, dword ptr [rax+0x14]' */
+                       /* Emit 'mov eax, dword ptr [rax+0x14]' */
                        if (is_ereg(dst_reg) || is_ereg(src_reg))
                                EMIT2(add_2mod(0x40, src_reg, dst_reg), 0x8B);
                        else
                                EMIT1(0x8B);
                        goto ldx;
                case BPF_LDX | BPF_MEM | BPF_DW:
-                       /* emit 'mov rax, qword ptr [rax+0x14]' */
+                       /* Emit 'mov rax, qword ptr [rax+0x14]' */
                        EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B);
-ldx:                   /* if insn->off == 0 we can save one extra byte, but
-                        * special case of x86 r13 which always needs an offset
+ldx:                   /*
+                        * If insn->off == 0 we can save one extra byte, but
+                        * special case of x86 R13 which always needs an offset
                         * is not worth the hassle
                         */
                        if (is_imm8(insn->off))
@@ -835,7 +897,7 @@ ldx:                        /* if insn->off == 0 we can save one extra byte, but
 
                        /* STX XADD: lock *(u32*)(dst_reg + off) += src_reg */
                case BPF_STX | BPF_XADD | BPF_W:
-                       /* emit 'lock add dword ptr [rax + off], eax' */
+                       /* Emit 'lock add dword ptr [rax + off], eax' */
                        if (is_ereg(dst_reg) || is_ereg(src_reg))
                                EMIT3(0xF0, add_2mod(0x40, dst_reg, src_reg), 0x01);
                        else
@@ -862,14 +924,15 @@ xadd:                     if (is_imm8(insn->off))
                                } else {
                                        EMIT2(0x41, 0x52); /* push %r10 */
                                        EMIT2(0x41, 0x51); /* push %r9 */
-                                       /* need to adjust jmp offset, since
+                                       /*
+                                        * We need to adjust jmp offset, since
                                         * pop %r9, pop %r10 take 4 bytes after call insn
                                         */
                                        jmp_offset += 4;
                                }
                        }
                        if (!imm32 || !is_simm32(jmp_offset)) {
-                               pr_err("unsupported bpf func %d addr %p image %p\n",
+                               pr_err("unsupported BPF func %d addr %p image %p\n",
                                       imm32, func, image);
                                return -EINVAL;
                        }
@@ -935,7 +998,7 @@ xadd:                       if (is_imm8(insn->off))
                        else
                                EMIT2_off32(0x81, add_1reg(0xF8, dst_reg), imm32);
 
-emit_cond_jmp:         /* convert BPF opcode to x86 */
+emit_cond_jmp:         /* Convert BPF opcode to x86 */
                        switch (BPF_OP(insn->code)) {
                        case BPF_JEQ:
                                jmp_cond = X86_JE;
@@ -961,22 +1024,22 @@ emit_cond_jmp:           /* convert BPF opcode to x86 */
                                jmp_cond = X86_JBE;
                                break;
                        case BPF_JSGT:
-                               /* signed '>', GT in x86 */
+                               /* Signed '>', GT in x86 */
                                jmp_cond = X86_JG;
                                break;
                        case BPF_JSLT:
-                               /* signed '<', LT in x86 */
+                               /* Signed '<', LT in x86 */
                                jmp_cond = X86_JL;
                                break;
                        case BPF_JSGE:
-                               /* signed '>=', GE in x86 */
+                               /* Signed '>=', GE in x86 */
                                jmp_cond = X86_JGE;
                                break;
                        case BPF_JSLE:
-                               /* signed '<=', LE in x86 */
+                               /* Signed '<=', LE in x86 */
                                jmp_cond = X86_JLE;
                                break;
-                       default: /* to silence gcc warning */
+                       default: /* to silence GCC warning */
                                return -EFAULT;
                        }
                        jmp_offset = addrs[i + insn->off] - addrs[i];
@@ -994,7 +1057,7 @@ emit_cond_jmp:             /* convert BPF opcode to x86 */
                case BPF_JMP | BPF_JA:
                        jmp_offset = addrs[i + insn->off] - addrs[i];
                        if (!jmp_offset)
-                               /* optimize out nop jumps */
+                               /* Optimize out nop jumps */
                                break;
 emit_jmp:
                        if (is_imm8(jmp_offset)) {
@@ -1016,7 +1079,7 @@ common_load:
                        ctx->seen_ld_abs = seen_ld_abs = true;
                        jmp_offset = func - (image + addrs[i]);
                        if (!func || !is_simm32(jmp_offset)) {
-                               pr_err("unsupported bpf func %d addr %p image %p\n",
+                               pr_err("unsupported BPF func %d addr %p image %p\n",
                                       imm32, func, image);
                                return -EINVAL;
                        }
@@ -1035,7 +1098,8 @@ common_load:
                                                EMIT2_off32(0x81, 0xC6, imm32);
                                }
                        }
-                       /* skb pointer is in R6 (%rbx), it will be copied into
+                       /*
+                        * skb pointer is in R6 (%rbx), it will be copied into
                         * %rdi if skb_copy_bits() call is necessary.
                         * sk_load_* helpers also use %r10 and %r9d.
                         * See bpf_jit.S
@@ -1066,7 +1130,7 @@ common_load:
                                goto emit_jmp;
                        }
                        seen_exit = true;
-                       /* update cleanup_addr */
+                       /* Update cleanup_addr */
                        ctx->cleanup_addr = proglen;
                        /* mov rbx, qword ptr [rbp+0] */
                        EMIT4(0x48, 0x8B, 0x5D, 0);
@@ -1084,10 +1148,11 @@ common_load:
                        break;
 
                default:
-                       /* By design x64 JIT should support all BPF instructions
+                       /*
+                        * By design x86-64 JIT should support all BPF instructions.
                         * This error will be seen if new instruction was added
-                        * to interpreter, but not to JIT
-                        * or if there is junk in bpf_prog
+                        * to the interpreter, but not to the JIT, or if there is
+                        * junk in bpf_prog.
                         */
                        pr_err("bpf_jit: unknown opcode %02x\n", insn->code);
                        return -EINVAL;
@@ -1139,7 +1204,8 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
                return orig_prog;
 
        tmp = bpf_jit_blind_constants(prog);
-       /* If blinding was requested and we failed during blinding,
+       /*
+        * If blinding was requested and we failed during blinding,
         * we must fall back to the interpreter.
         */
        if (IS_ERR(tmp))
@@ -1173,8 +1239,9 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
                goto out_addrs;
        }
 
-       /* Before first pass, make a rough estimation of addrs[]
-        * each bpf instruction is translated to less than 64 bytes
+       /*
+        * Before first pass, make a rough estimation of addrs[]
+        * each BPF instruction is translated to less than 64 bytes
         */
        for (proglen = 0, i = 0; i < prog->len; i++) {
                proglen += 64;
@@ -1183,10 +1250,11 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
        ctx.cleanup_addr = proglen;
 skip_init_addrs:
 
-       /* JITed image shrinks with every pass and the loop iterates
-        * until the image stops shrinking. Very large bpf programs
+       /*
+        * JITed image shrinks with every pass and the loop iterates
+        * until the image stops shrinking. Very large BPF programs
         * may converge on the last pass. In such case do one more
-        * pass to emit the final image
+        * pass to emit the final image.
         */
        for (pass = 0; pass < 20 || image; pass++) {
                proglen = do_jit(prog, addrs, image, oldproglen, &ctx);
@@ -1222,7 +1290,6 @@ skip_init_addrs:
                bpf_jit_dump(prog->len, proglen, pass + 1, image);
 
        if (image) {
-               bpf_flush_icache(header, image + proglen);
                if (!prog->is_func || extra_pass) {
                        bpf_jit_binary_lock_ro(header);
                } else {