bpf, arm64: use bpf_prog_pack for memory management
[linux-2.6-microblaze.git] / arch / arm64 / net / bpf_jit_comp.c
index 20720ec..5afc7a5 100644 (file)
@@ -76,6 +76,7 @@ struct jit_ctx {
        int *offset;
        int exentry_idx;
        __le32 *image;
+       __le32 *ro_image;
        u32 stack_size;
        int fpb_offset;
 };
@@ -205,6 +206,14 @@ static void jit_fill_hole(void *area, unsigned int size)
                *ptr++ = cpu_to_le32(AARCH64_BREAK_FAULT);
 }
 
+int bpf_arch_text_invalidate(void *dst, size_t len)
+{
+       if (!aarch64_insn_set(dst, AARCH64_BREAK_FAULT, len))
+               return -EINVAL;
+
+       return 0;
+}
+
 static inline int epilogue_offset(const struct jit_ctx *ctx)
 {
        int to = ctx->epilogue_offset;
@@ -746,7 +755,8 @@ static int add_exception_handler(const struct bpf_insn *insn,
                                 struct jit_ctx *ctx,
                                 int dst_reg)
 {
-       off_t offset;
+       off_t ins_offset;
+       off_t fixup_offset;
        unsigned long pc;
        struct exception_table_entry *ex;
 
@@ -763,12 +773,17 @@ static int add_exception_handler(const struct bpf_insn *insn,
                return -EINVAL;
 
        ex = &ctx->prog->aux->extable[ctx->exentry_idx];
-       pc = (unsigned long)&ctx->image[ctx->idx - 1];
+       pc = (unsigned long)&ctx->ro_image[ctx->idx - 1];
 
-       offset = pc - (long)&ex->insn;
-       if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN))
+       /*
+        * This is the relative offset of the instruction that may fault from
+        * the exception table itself. This will be written to the exception
+        * table and if this instruction faults, the destination register will
+        * be set to '0' and the execution will jump to the next instruction.
+        */
+       ins_offset = pc - (long)&ex->insn;
+       if (WARN_ON_ONCE(ins_offset >= 0 || ins_offset < INT_MIN))
                return -ERANGE;
-       ex->insn = offset;
 
        /*
         * Since the extable follows the program, the fixup offset is always
@@ -777,12 +792,25 @@ static int add_exception_handler(const struct bpf_insn *insn,
         * bits. We don't need to worry about buildtime or runtime sort
         * modifying the upper bits because the table is already sorted, and
         * isn't part of the main exception table.
+        *
+        * The fixup_offset is set to the next instruction from the instruction
+        * that may fault. The execution will jump to this after handling the
+        * fault.
         */
-       offset = (long)&ex->fixup - (pc + AARCH64_INSN_SIZE);
-       if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset))
+       fixup_offset = (long)&ex->fixup - (pc + AARCH64_INSN_SIZE);
+       if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, fixup_offset))
                return -ERANGE;
 
-       ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) |
+       /*
+        * The offsets above have been calculated using the RO buffer but we
+        * need to use the R/W buffer for writes.
+        * switch ex to rw buffer for writing.
+        */
+       ex = (void *)ctx->image + ((void *)ex - (void *)ctx->ro_image);
+
+       ex->insn = ins_offset;
+
+       ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, fixup_offset) |
                    FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
 
        ex->type = EX_TYPE_BPF;
@@ -1550,7 +1578,8 @@ static inline void bpf_flush_icache(void *start, void *end)
 
 struct arm64_jit_data {
        struct bpf_binary_header *header;
-       u8 *image;
+       u8 *ro_image;
+       struct bpf_binary_header *ro_header;
        struct jit_ctx ctx;
 };
 
@@ -1559,12 +1588,14 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
        int image_size, prog_size, extable_size, extable_align, extable_offset;
        struct bpf_prog *tmp, *orig_prog = prog;
        struct bpf_binary_header *header;
+       struct bpf_binary_header *ro_header;
        struct arm64_jit_data *jit_data;
        bool was_classic = bpf_prog_was_classic(prog);
        bool tmp_blinded = false;
        bool extra_pass = false;
        struct jit_ctx ctx;
        u8 *image_ptr;
+       u8 *ro_image_ptr;
 
        if (!prog->jit_requested)
                return orig_prog;
@@ -1591,8 +1622,11 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
        }
        if (jit_data->ctx.offset) {
                ctx = jit_data->ctx;
-               image_ptr = jit_data->image;
+               ro_image_ptr = jit_data->ro_image;
+               ro_header = jit_data->ro_header;
                header = jit_data->header;
+               image_ptr = (void *)header + ((void *)ro_image_ptr
+                                                - (void *)ro_header);
                extra_pass = true;
                prog_size = sizeof(u32) * ctx.idx;
                goto skip_init_ctx;
@@ -1637,18 +1671,27 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
        /* also allocate space for plt target */
        extable_offset = round_up(prog_size + PLT_TARGET_SIZE, extable_align);
        image_size = extable_offset + extable_size;
-       header = bpf_jit_binary_alloc(image_size, &image_ptr,
-                                     sizeof(u32), jit_fill_hole);
-       if (header == NULL) {
+       ro_header = bpf_jit_binary_pack_alloc(image_size, &ro_image_ptr,
+                                             sizeof(u32), &header, &image_ptr,
+                                             jit_fill_hole);
+       if (!ro_header) {
                prog = orig_prog;
                goto out_off;
        }
 
        /* 2. Now, the actual pass. */
 
+       /*
+        * Use the image(RW) for writing the JITed instructions. But also save
+        * the ro_image(RX) for calculating the offsets in the image. The RW
+        * image will be later copied to the RX image from where the program
+        * will run. The bpf_jit_binary_pack_finalize() will do this copy in the
+        * final step.
+        */
        ctx.image = (__le32 *)image_ptr;
+       ctx.ro_image = (__le32 *)ro_image_ptr;
        if (extable_size)
-               prog->aux->extable = (void *)image_ptr + extable_offset;
+               prog->aux->extable = (void *)ro_image_ptr + extable_offset;
 skip_init_ctx:
        ctx.idx = 0;
        ctx.exentry_idx = 0;
@@ -1656,9 +1699,8 @@ skip_init_ctx:
        build_prologue(&ctx, was_classic, prog->aux->exception_cb);
 
        if (build_body(&ctx, extra_pass)) {
-               bpf_jit_binary_free(header);
                prog = orig_prog;
-               goto out_off;
+               goto out_free_hdr;
        }
 
        build_epilogue(&ctx, prog->aux->exception_cb);
@@ -1666,34 +1708,44 @@ skip_init_ctx:
 
        /* 3. Extra pass to validate JITed code. */
        if (validate_ctx(&ctx)) {
-               bpf_jit_binary_free(header);
                prog = orig_prog;
-               goto out_off;
+               goto out_free_hdr;
        }
 
        /* And we're done. */
        if (bpf_jit_enable > 1)
                bpf_jit_dump(prog->len, prog_size, 2, ctx.image);
 
-       bpf_flush_icache(header, ctx.image + ctx.idx);
-
        if (!prog->is_func || extra_pass) {
                if (extra_pass && ctx.idx != jit_data->ctx.idx) {
                        pr_err_once("multi-func JIT bug %d != %d\n",
                                    ctx.idx, jit_data->ctx.idx);
-                       bpf_jit_binary_free(header);
                        prog->bpf_func = NULL;
                        prog->jited = 0;
                        prog->jited_len = 0;
+                       goto out_free_hdr;
+               }
+               if (WARN_ON(bpf_jit_binary_pack_finalize(prog, ro_header,
+                                                        header))) {
+                       /* ro_header has been freed */
+                       ro_header = NULL;
+                       prog = orig_prog;
                        goto out_off;
                }
-               bpf_jit_binary_lock_ro(header);
+               /*
+                * The instructions have now been copied to the ROX region from
+                * where they will execute. Now the data cache has to be cleaned to
+                * the PoU and the I-cache has to be invalidated for the VAs.
+                */
+               bpf_flush_icache(ro_header, ctx.ro_image + ctx.idx);
        } else {
                jit_data->ctx = ctx;
-               jit_data->image = image_ptr;
+               jit_data->ro_image = ro_image_ptr;
                jit_data->header = header;
+               jit_data->ro_header = ro_header;
        }
-       prog->bpf_func = (void *)ctx.image;
+
+       prog->bpf_func = (void *)ctx.ro_image;
        prog->jited = 1;
        prog->jited_len = prog_size;
 
@@ -1714,6 +1766,14 @@ out:
                bpf_jit_prog_release_other(prog, prog == orig_prog ?
                                           tmp : orig_prog);
        return prog;
+
+out_free_hdr:
+       if (header) {
+               bpf_arch_text_copy(&ro_header->size, &header->size,
+                                  sizeof(header->size));
+               bpf_jit_binary_pack_free(ro_header, header);
+       }
+       goto out_off;
 }
 
 bool bpf_jit_supports_kfunc_call(void)
@@ -1721,6 +1781,13 @@ bool bpf_jit_supports_kfunc_call(void)
        return true;
 }
 
+void *bpf_arch_text_copy(void *dst, void *src, size_t len)
+{
+       if (!aarch64_insn_copy(dst, src, len))
+               return ERR_PTR(-EINVAL);
+       return dst;
+}
+
 u64 bpf_jit_alloc_exec_limit(void)
 {
        return VMALLOC_END - VMALLOC_START;
@@ -2359,3 +2426,27 @@ bool bpf_jit_supports_exceptions(void)
         */
        return true;
 }
+
+void bpf_jit_free(struct bpf_prog *prog)
+{
+       if (prog->jited) {
+               struct arm64_jit_data *jit_data = prog->aux->jit_data;
+               struct bpf_binary_header *hdr;
+
+               /*
+                * If we fail the final pass of JIT (from jit_subprogs),
+                * the program may not be finalized yet. Call finalize here
+                * before freeing it.
+                */
+               if (jit_data) {
+                       bpf_arch_text_copy(&jit_data->ro_header->size, &jit_data->header->size,
+                                          sizeof(jit_data->header->size));
+                       kfree(jit_data);
+               }
+               hdr = bpf_jit_binary_pack_hdr(prog);
+               bpf_jit_binary_pack_free(hdr, NULL);
+               WARN_ON_ONCE(!bpf_prog_kallsyms_verify_off(prog));
+       }
+
+       bpf_prog_unlock_free(prog);
+}