s390/bpf: Remove unused SEEN_RET0, SEEN_REG_AX and ret0_ip
[linux-2.6-microblaze.git] / arch / s390 / net / bpf_jit_comp.c
index 955eb35..1115071 100644 (file)
@@ -23,6 +23,7 @@
 #include <linux/filter.h>
 #include <linux/init.h>
 #include <linux/bpf.h>
+#include <linux/mm.h>
 #include <asm/cacheflush.h>
 #include <asm/dis.h>
 #include <asm/facility.h>
@@ -41,7 +42,6 @@ struct bpf_jit {
        int lit_start;          /* Start of literal pool */
        int lit;                /* Current position in literal pool */
        int base_ip;            /* Base address for literal pool */
-       int ret0_ip;            /* Address of return 0 */
        int exit_ip;            /* Address of exit */
        int r1_thunk_ip;        /* Address of expoline thunk for 'br %r1' */
        int r14_thunk_ip;       /* Address of expoline thunk for 'br %r14' */
@@ -51,12 +51,10 @@ struct bpf_jit {
 
 #define BPF_SIZE_MAX   0xffff  /* Max size for program (16 bit branches) */
 
-#define SEEN_MEM       (1 << 0)        /* use mem[] for temporary storage */
-#define SEEN_RET0      (1 << 1)        /* ret0_ip points to a valid return 0 */
-#define SEEN_LITERAL   (1 << 2)        /* code uses literals */
-#define SEEN_FUNC      (1 << 3)        /* calls C functions */
-#define SEEN_TAIL_CALL (1 << 4)        /* code uses tail calls */
-#define SEEN_REG_AX    (1 << 5)        /* code uses constant blinding */
+#define SEEN_MEM       BIT(0)          /* use mem[] for temporary storage */
+#define SEEN_LITERAL   BIT(1)          /* code uses literals */
+#define SEEN_FUNC      BIT(2)          /* calls C functions */
+#define SEEN_TAIL_CALL BIT(3)          /* code uses tail calls */
 #define SEEN_STACK     (SEEN_FUNC | SEEN_MEM)
 
 /*
@@ -131,13 +129,13 @@ static inline void reg_set_seen(struct bpf_jit *jit, u32 b1)
 #define _EMIT2(op)                                             \
 ({                                                             \
        if (jit->prg_buf)                                       \
-               *(u16 *) (jit->prg_buf + jit->prg) = op;        \
+               *(u16 *) (jit->prg_buf + jit->prg) = (op);      \
        jit->prg += 2;                                          \
 })
 
 #define EMIT2(op, b1, b2)                                      \
 ({                                                             \
-       _EMIT2(op | reg(b1, b2));                               \
+       _EMIT2((op) | reg(b1, b2));                             \
        REG_SET_SEEN(b1);                                       \
        REG_SET_SEEN(b2);                                       \
 })
@@ -145,20 +143,20 @@ static inline void reg_set_seen(struct bpf_jit *jit, u32 b1)
 #define _EMIT4(op)                                             \
 ({                                                             \
        if (jit->prg_buf)                                       \
-               *(u32 *) (jit->prg_buf + jit->prg) = op;        \
+               *(u32 *) (jit->prg_buf + jit->prg) = (op);      \
        jit->prg += 4;                                          \
 })
 
 #define EMIT4(op, b1, b2)                                      \
 ({                                                             \
-       _EMIT4(op | reg(b1, b2));                               \
+       _EMIT4((op) | reg(b1, b2));                             \
        REG_SET_SEEN(b1);                                       \
        REG_SET_SEEN(b2);                                       \
 })
 
 #define EMIT4_RRF(op, b1, b2, b3)                              \
 ({                                                             \
-       _EMIT4(op | reg_high(b3) << 8 | reg(b1, b2));           \
+       _EMIT4((op) | reg_high(b3) << 8 | reg(b1, b2));         \
        REG_SET_SEEN(b1);                                       \
        REG_SET_SEEN(b2);                                       \
        REG_SET_SEEN(b3);                                       \
@@ -167,13 +165,13 @@ static inline void reg_set_seen(struct bpf_jit *jit, u32 b1)
 #define _EMIT4_DISP(op, disp)                                  \
 ({                                                             \
        unsigned int __disp = (disp) & 0xfff;                   \
-       _EMIT4(op | __disp);                                    \
+       _EMIT4((op) | __disp);                                  \
 })
 
 #define EMIT4_DISP(op, b1, b2, disp)                           \
 ({                                                             \
-       _EMIT4_DISP(op | reg_high(b1) << 16 |                   \
-                   reg_high(b2) << 8, disp);                   \
+       _EMIT4_DISP((op) | reg_high(b1) << 16 |                 \
+                   reg_high(b2) << 8, (disp));                 \
        REG_SET_SEEN(b1);                                       \
        REG_SET_SEEN(b2);                                       \
 })
@@ -181,21 +179,21 @@ static inline void reg_set_seen(struct bpf_jit *jit, u32 b1)
 #define EMIT4_IMM(op, b1, imm)                                 \
 ({                                                             \
        unsigned int __imm = (imm) & 0xffff;                    \
-       _EMIT4(op | reg_high(b1) << 16 | __imm);                \
+       _EMIT4((op) | reg_high(b1) << 16 | __imm);              \
        REG_SET_SEEN(b1);                                       \
 })
 
 #define EMIT4_PCREL(op, pcrel)                                 \
 ({                                                             \
        long __pcrel = ((pcrel) >> 1) & 0xffff;                 \
-       _EMIT4(op | __pcrel);                                   \
+       _EMIT4((op) | __pcrel);                                 \
 })
 
 #define _EMIT6(op1, op2)                                       \
 ({                                                             \
        if (jit->prg_buf) {                                     \
-               *(u32 *) (jit->prg_buf + jit->prg) = op1;       \
-               *(u16 *) (jit->prg_buf + jit->prg + 4) = op2;   \
+               *(u32 *) (jit->prg_buf + jit->prg) = (op1);     \
+               *(u16 *) (jit->prg_buf + jit->prg + 4) = (op2); \
        }                                                       \
        jit->prg += 6;                                          \
 })
@@ -203,20 +201,20 @@ static inline void reg_set_seen(struct bpf_jit *jit, u32 b1)
 #define _EMIT6_DISP(op1, op2, disp)                            \
 ({                                                             \
        unsigned int __disp = (disp) & 0xfff;                   \
-       _EMIT6(op1 | __disp, op2);                              \
+       _EMIT6((op1) | __disp, op2);                            \
 })
 
 #define _EMIT6_DISP_LH(op1, op2, disp)                         \
 ({                                                             \
-       u32 _disp = (u32) disp;                                 \
+       u32 _disp = (u32) (disp);                               \
        unsigned int __disp_h = _disp & 0xff000;                \
        unsigned int __disp_l = _disp & 0x00fff;                \
-       _EMIT6(op1 | __disp_l, op2 | __disp_h >> 4);            \
+       _EMIT6((op1) | __disp_l, (op2) | __disp_h >> 4);        \
 })
 
 #define EMIT6_DISP_LH(op1, op2, b1, b2, b3, disp)              \
 ({                                                             \
-       _EMIT6_DISP_LH(op1 | reg(b1, b2) << 16 |                \
+       _EMIT6_DISP_LH((op1) | reg(b1, b2) << 16 |              \
                       reg_high(b3) << 8, op2, disp);           \
        REG_SET_SEEN(b1);                                       \
        REG_SET_SEEN(b2);                                       \
@@ -226,8 +224,8 @@ static inline void reg_set_seen(struct bpf_jit *jit, u32 b1)
 #define EMIT6_PCREL_LABEL(op1, op2, b1, b2, label, mask)       \
 ({                                                             \
        int rel = (jit->labels[label] - jit->prg) >> 1;         \
-       _EMIT6(op1 | reg(b1, b2) << 16 | (rel & 0xffff),        \
-              op2 | mask << 12);                               \
+       _EMIT6((op1) | reg(b1, b2) << 16 | (rel & 0xffff),      \
+              (op2) | (mask) << 12);                           \
        REG_SET_SEEN(b1);                                       \
        REG_SET_SEEN(b2);                                       \
 })
@@ -235,43 +233,43 @@ static inline void reg_set_seen(struct bpf_jit *jit, u32 b1)
 #define EMIT6_PCREL_IMM_LABEL(op1, op2, b1, imm, label, mask)  \
 ({                                                             \
        int rel = (jit->labels[label] - jit->prg) >> 1;         \
-       _EMIT6(op1 | (reg_high(b1) | mask) << 16 |              \
-               (rel & 0xffff), op2 | (imm & 0xff) << 8);       \
+       _EMIT6((op1) | (reg_high(b1) | (mask)) << 16 |          \
+               (rel & 0xffff), (op2) | ((imm) & 0xff) << 8);   \
        REG_SET_SEEN(b1);                                       \
-       BUILD_BUG_ON(((unsigned long) imm) > 0xff);             \
+       BUILD_BUG_ON(((unsigned long) (imm)) > 0xff);           \
 })
 
 #define EMIT6_PCREL(op1, op2, b1, b2, i, off, mask)            \
 ({                                                             \
        /* Branch instruction needs 6 bytes */                  \
-       int rel = (addrs[i + off + 1] - (addrs[i + 1] - 6)) / 2;\
-       _EMIT6(op1 | reg(b1, b2) << 16 | (rel & 0xffff), op2 | mask);   \
+       int rel = (addrs[(i) + (off) + 1] - (addrs[(i) + 1] - 6)) / 2;\
+       _EMIT6((op1) | reg(b1, b2) << 16 | (rel & 0xffff), (op2) | (mask));\
        REG_SET_SEEN(b1);                                       \
        REG_SET_SEEN(b2);                                       \
 })
 
 #define EMIT6_PCREL_RILB(op, b, target)                                \
 ({                                                             \
-       int rel = (target - jit->prg) / 2;                      \
-       _EMIT6(op | reg_high(b) << 16 | rel >> 16, rel & 0xffff);       \
+       int rel = ((target) - jit->prg) / 2;                    \
+       _EMIT6((op) | reg_high(b) << 16 | rel >> 16, rel & 0xffff);\
        REG_SET_SEEN(b);                                        \
 })
 
 #define EMIT6_PCREL_RIL(op, target)                            \
 ({                                                             \
-       int rel = (target - jit->prg) / 2;                      \
-       _EMIT6(op | rel >> 16, rel & 0xffff);                   \
+       int rel = ((target) - jit->prg) / 2;                    \
+       _EMIT6((op) | rel >> 16, rel & 0xffff);                 \
 })
 
 #define _EMIT6_IMM(op, imm)                                    \
 ({                                                             \
        unsigned int __imm = (imm);                             \
-       _EMIT6(op | (__imm >> 16), __imm & 0xffff);             \
+       _EMIT6((op) | (__imm >> 16), __imm & 0xffff);           \
 })
 
 #define EMIT6_IMM(op, b1, imm)                                 \
 ({                                                             \
-       _EMIT6_IMM(op | reg_high(b1) << 16, imm);               \
+       _EMIT6_IMM((op) | reg_high(b1) << 16, imm);             \
        REG_SET_SEEN(b1);                                       \
 })
 
@@ -281,7 +279,7 @@ static inline void reg_set_seen(struct bpf_jit *jit, u32 b1)
        ret = jit->lit - jit->base_ip;                          \
        jit->seen |= SEEN_LITERAL;                              \
        if (jit->prg_buf)                                       \
-               *(u32 *) (jit->prg_buf + jit->lit) = (u32) val; \
+               *(u32 *) (jit->prg_buf + jit->lit) = (u32) (val);\
        jit->lit += 4;                                          \
        ret;                                                    \
 })
@@ -292,7 +290,7 @@ static inline void reg_set_seen(struct bpf_jit *jit, u32 b1)
        ret = jit->lit - jit->base_ip;                          \
        jit->seen |= SEEN_LITERAL;                              \
        if (jit->prg_buf)                                       \
-               *(u64 *) (jit->prg_buf + jit->lit) = (u64) val; \
+               *(u64 *) (jit->prg_buf + jit->lit) = (u64) (val);\
        jit->lit += 8;                                          \
        ret;                                                    \
 })
@@ -446,12 +444,6 @@ static void bpf_jit_prologue(struct bpf_jit *jit, u32 stack_depth)
  */
 static void bpf_jit_epilogue(struct bpf_jit *jit, u32 stack_depth)
 {
-       /* Return 0 */
-       if (jit->seen & SEEN_RET0) {
-               jit->ret0_ip = jit->prg;
-               /* lghi %b0,0 */
-               EMIT4_IMM(0xa7090000, BPF_REG_0, 0);
-       }
        jit->exit_ip = jit->prg;
        /* Load exit code: lgr %r2,%b0 */
        EMIT4(0xb9040000, REG_2, BPF_REG_0);
@@ -502,7 +494,8 @@ static void bpf_jit_epilogue(struct bpf_jit *jit, u32 stack_depth)
  * NOTE: Use noinline because for gcov (-fprofile-arcs) gcc allocates a lot of
  * stack space for the large switch statement.
  */
-static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp, int i)
+static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp,
+                                int i, bool extra_pass)
 {
        struct bpf_insn *insn = &fp->insnsi[i];
        int jmp_off, last, insn_count = 1;
@@ -513,8 +506,6 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp, int i
        s16 off = insn->off;
        unsigned int mask;
 
-       if (dst_reg == BPF_REG_AX || src_reg == BPF_REG_AX)
-               jit->seen |= SEEN_REG_AX;
        switch (insn->code) {
        /*
         * BPF_MOV
@@ -1011,10 +1002,14 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp, int i
         */
        case BPF_JMP | BPF_CALL:
        {
-               /*
-                * b0 = (__bpf_call_base + imm)(b1, b2, b3, b4, b5)
-                */
-               const u64 func = (u64)__bpf_call_base + imm;
+               u64 func;
+               bool func_addr_fixed;
+               int ret;
+
+               ret = bpf_jit_get_func_addr(fp, insn, extra_pass,
+                                           &func, &func_addr_fixed);
+               if (ret < 0)
+                       return -1;
 
                REG_SET_SEEN(BPF_REG_5);
                jit->seen |= SEEN_FUNC;
@@ -1105,7 +1100,7 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp, int i
                break;
        case BPF_JMP | BPF_EXIT: /* return b0 */
                last = (i == fp->len - 1) ? 1 : 0;
-               if (last && !(jit->seen & SEEN_RET0))
+               if (last)
                        break;
                /* j <exit> */
                EMIT4_PCREL(0xa7f40000, jit->exit_ip - jit->prg);
@@ -1283,7 +1278,8 @@ branch_oc:
 /*
  * Compile eBPF program into s390x code
  */
-static int bpf_jit_prog(struct bpf_jit *jit, struct bpf_prog *fp)
+static int bpf_jit_prog(struct bpf_jit *jit, struct bpf_prog *fp,
+                       bool extra_pass)
 {
        int i, insn_count;
 
@@ -1292,7 +1288,7 @@ static int bpf_jit_prog(struct bpf_jit *jit, struct bpf_prog *fp)
 
        bpf_jit_prologue(jit, fp->aux->stack_depth);
        for (i = 0; i < fp->len; i += insn_count) {
-               insn_count = bpf_jit_insn(jit, fp, i);
+               insn_count = bpf_jit_insn(jit, fp, i, extra_pass);
                if (insn_count < 0)
                        return -1;
                /* Next instruction address */
@@ -1311,6 +1307,12 @@ bool bpf_jit_needs_zext(void)
        return true;
 }
 
+struct s390_jit_data {
+       struct bpf_binary_header *header;
+       struct bpf_jit ctx;
+       int pass;
+};
+
 /*
  * Compile eBPF program "fp"
  */
@@ -1318,7 +1320,9 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *fp)
 {
        struct bpf_prog *tmp, *orig_fp = fp;
        struct bpf_binary_header *header;
+       struct s390_jit_data *jit_data;
        bool tmp_blinded = false;
+       bool extra_pass = false;
        struct bpf_jit jit;
        int pass;
 
@@ -1337,8 +1341,25 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *fp)
                fp = tmp;
        }
 
+       jit_data = fp->aux->jit_data;
+       if (!jit_data) {
+               jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
+               if (!jit_data) {
+                       fp = orig_fp;
+                       goto out;
+               }
+               fp->aux->jit_data = jit_data;
+       }
+       if (jit_data->ctx.addrs) {
+               jit = jit_data->ctx;
+               header = jit_data->header;
+               extra_pass = true;
+               pass = jit_data->pass + 1;
+               goto skip_init_ctx;
+       }
+
        memset(&jit, 0, sizeof(jit));
-       jit.addrs = kcalloc(fp->len + 1, sizeof(*jit.addrs), GFP_KERNEL);
+       jit.addrs = kvcalloc(fp->len + 1, sizeof(*jit.addrs), GFP_KERNEL);
        if (jit.addrs == NULL) {
                fp = orig_fp;
                goto out;
@@ -1349,7 +1370,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *fp)
         *   - 3:   Calculate program size and addrs arrray
         */
        for (pass = 1; pass <= 3; pass++) {
-               if (bpf_jit_prog(&jit, fp)) {
+               if (bpf_jit_prog(&jit, fp, extra_pass)) {
                        fp = orig_fp;
                        goto free_addrs;
                }
@@ -1361,12 +1382,14 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *fp)
                fp = orig_fp;
                goto free_addrs;
        }
+
        header = bpf_jit_binary_alloc(jit.size, &jit.prg_buf, 2, jit_fill_hole);
        if (!header) {
                fp = orig_fp;
                goto free_addrs;
        }
-       if (bpf_jit_prog(&jit, fp)) {
+skip_init_ctx:
+       if (bpf_jit_prog(&jit, fp, extra_pass)) {
                bpf_jit_binary_free(header);
                fp = orig_fp;
                goto free_addrs;
@@ -1375,12 +1398,24 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *fp)
                bpf_jit_dump(fp->len, jit.size, pass, jit.prg_buf);
                print_fn_code(jit.prg_buf, jit.size_prg);
        }
-       bpf_jit_binary_lock_ro(header);
+       if (!fp->is_func || extra_pass) {
+               bpf_jit_binary_lock_ro(header);
+       } else {
+               jit_data->header = header;
+               jit_data->ctx = jit;
+               jit_data->pass = pass;
+       }
        fp->bpf_func = (void *) jit.prg_buf;
        fp->jited = 1;
        fp->jited_len = jit.size;
+
+       if (!fp->is_func || extra_pass) {
+               bpf_prog_fill_jited_linfo(fp, jit.addrs + 1);
 free_addrs:
-       kfree(jit.addrs);
+               kvfree(jit.addrs);
+               kfree(jit_data);
+               fp->aux->jit_data = NULL;
+       }
 out:
        if (tmp_blinded)
                bpf_jit_prog_release_other(fp, fp == orig_fp ?