Merge tag 'net-6.8-rc2' of git://git.kernel.org/pub/scm/linux/kernel/git/netdev/net
[linux-2.6-microblaze.git] / arch / riscv / net / bpf_jit_comp64.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* BPF JIT compiler for RV64G
3  *
4  * Copyright(c) 2019 Björn Töpel <bjorn.topel@gmail.com>
5  *
6  */
7
8 #include <linux/bitfield.h>
9 #include <linux/bpf.h>
10 #include <linux/filter.h>
11 #include <linux/memory.h>
12 #include <linux/stop_machine.h>
13 #include <asm/patch.h>
14 #include "bpf_jit.h"
15
16 #define RV_FENTRY_NINSNS 2
17
18 #define RV_REG_TCC RV_REG_A6
19 #define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
20
21 static const int regmap[] = {
22         [BPF_REG_0] =   RV_REG_A5,
23         [BPF_REG_1] =   RV_REG_A0,
24         [BPF_REG_2] =   RV_REG_A1,
25         [BPF_REG_3] =   RV_REG_A2,
26         [BPF_REG_4] =   RV_REG_A3,
27         [BPF_REG_5] =   RV_REG_A4,
28         [BPF_REG_6] =   RV_REG_S1,
29         [BPF_REG_7] =   RV_REG_S2,
30         [BPF_REG_8] =   RV_REG_S3,
31         [BPF_REG_9] =   RV_REG_S4,
32         [BPF_REG_FP] =  RV_REG_S5,
33         [BPF_REG_AX] =  RV_REG_T0,
34 };
35
36 static const int pt_regmap[] = {
37         [RV_REG_A0] = offsetof(struct pt_regs, a0),
38         [RV_REG_A1] = offsetof(struct pt_regs, a1),
39         [RV_REG_A2] = offsetof(struct pt_regs, a2),
40         [RV_REG_A3] = offsetof(struct pt_regs, a3),
41         [RV_REG_A4] = offsetof(struct pt_regs, a4),
42         [RV_REG_A5] = offsetof(struct pt_regs, a5),
43         [RV_REG_S1] = offsetof(struct pt_regs, s1),
44         [RV_REG_S2] = offsetof(struct pt_regs, s2),
45         [RV_REG_S3] = offsetof(struct pt_regs, s3),
46         [RV_REG_S4] = offsetof(struct pt_regs, s4),
47         [RV_REG_S5] = offsetof(struct pt_regs, s5),
48         [RV_REG_T0] = offsetof(struct pt_regs, t0),
49 };
50
51 enum {
52         RV_CTX_F_SEEN_TAIL_CALL =       0,
53         RV_CTX_F_SEEN_CALL =            RV_REG_RA,
54         RV_CTX_F_SEEN_S1 =              RV_REG_S1,
55         RV_CTX_F_SEEN_S2 =              RV_REG_S2,
56         RV_CTX_F_SEEN_S3 =              RV_REG_S3,
57         RV_CTX_F_SEEN_S4 =              RV_REG_S4,
58         RV_CTX_F_SEEN_S5 =              RV_REG_S5,
59         RV_CTX_F_SEEN_S6 =              RV_REG_S6,
60 };
61
62 static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
63 {
64         u8 reg = regmap[bpf_reg];
65
66         switch (reg) {
67         case RV_CTX_F_SEEN_S1:
68         case RV_CTX_F_SEEN_S2:
69         case RV_CTX_F_SEEN_S3:
70         case RV_CTX_F_SEEN_S4:
71         case RV_CTX_F_SEEN_S5:
72         case RV_CTX_F_SEEN_S6:
73                 __set_bit(reg, &ctx->flags);
74         }
75         return reg;
76 };
77
78 static bool seen_reg(int reg, struct rv_jit_context *ctx)
79 {
80         switch (reg) {
81         case RV_CTX_F_SEEN_CALL:
82         case RV_CTX_F_SEEN_S1:
83         case RV_CTX_F_SEEN_S2:
84         case RV_CTX_F_SEEN_S3:
85         case RV_CTX_F_SEEN_S4:
86         case RV_CTX_F_SEEN_S5:
87         case RV_CTX_F_SEEN_S6:
88                 return test_bit(reg, &ctx->flags);
89         }
90         return false;
91 }
92
93 static void mark_fp(struct rv_jit_context *ctx)
94 {
95         __set_bit(RV_CTX_F_SEEN_S5, &ctx->flags);
96 }
97
98 static void mark_call(struct rv_jit_context *ctx)
99 {
100         __set_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
101 }
102
103 static bool seen_call(struct rv_jit_context *ctx)
104 {
105         return test_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
106 }
107
108 static void mark_tail_call(struct rv_jit_context *ctx)
109 {
110         __set_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
111 }
112
113 static bool seen_tail_call(struct rv_jit_context *ctx)
114 {
115         return test_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
116 }
117
118 static u8 rv_tail_call_reg(struct rv_jit_context *ctx)
119 {
120         mark_tail_call(ctx);
121
122         if (seen_call(ctx)) {
123                 __set_bit(RV_CTX_F_SEEN_S6, &ctx->flags);
124                 return RV_REG_S6;
125         }
126         return RV_REG_A6;
127 }
128
129 static bool is_32b_int(s64 val)
130 {
131         return -(1L << 31) <= val && val < (1L << 31);
132 }
133
134 static bool in_auipc_jalr_range(s64 val)
135 {
136         /*
137          * auipc+jalr can reach any signed PC-relative offset in the range
138          * [-2^31 - 2^11, 2^31 - 2^11).
139          */
140         return (-(1L << 31) - (1L << 11)) <= val &&
141                 val < ((1L << 31) - (1L << 11));
142 }
143
144 /* Emit fixed-length instructions for address */
145 static int emit_addr(u8 rd, u64 addr, bool extra_pass, struct rv_jit_context *ctx)
146 {
147         /*
148          * Use the ro_insns(RX) to calculate the offset as the BPF program will
149          * finally run from this memory region.
150          */
151         u64 ip = (u64)(ctx->ro_insns + ctx->ninsns);
152         s64 off = addr - ip;
153         s64 upper = (off + (1 << 11)) >> 12;
154         s64 lower = off & 0xfff;
155
156         if (extra_pass && !in_auipc_jalr_range(off)) {
157                 pr_err("bpf-jit: target offset 0x%llx is out of range\n", off);
158                 return -ERANGE;
159         }
160
161         emit(rv_auipc(rd, upper), ctx);
162         emit(rv_addi(rd, rd, lower), ctx);
163         return 0;
164 }
165
166 /* Emit variable-length instructions for 32-bit and 64-bit imm */
167 static void emit_imm(u8 rd, s64 val, struct rv_jit_context *ctx)
168 {
169         /* Note that the immediate from the add is sign-extended,
170          * which means that we need to compensate this by adding 2^12,
171          * when the 12th bit is set. A simpler way of doing this, and
172          * getting rid of the check, is to just add 2**11 before the
173          * shift. The "Loading a 32-Bit constant" example from the
174          * "Computer Organization and Design, RISC-V edition" book by
175          * Patterson/Hennessy highlights this fact.
176          *
177          * This also means that we need to process LSB to MSB.
178          */
179         s64 upper = (val + (1 << 11)) >> 12;
180         /* Sign-extend lower 12 bits to 64 bits since immediates for li, addiw,
181          * and addi are signed and RVC checks will perform signed comparisons.
182          */
183         s64 lower = ((val & 0xfff) << 52) >> 52;
184         int shift;
185
186         if (is_32b_int(val)) {
187                 if (upper)
188                         emit_lui(rd, upper, ctx);
189
190                 if (!upper) {
191                         emit_li(rd, lower, ctx);
192                         return;
193                 }
194
195                 emit_addiw(rd, rd, lower, ctx);
196                 return;
197         }
198
199         shift = __ffs(upper);
200         upper >>= shift;
201         shift += 12;
202
203         emit_imm(rd, upper, ctx);
204
205         emit_slli(rd, rd, shift, ctx);
206         if (lower)
207                 emit_addi(rd, rd, lower, ctx);
208 }
209
210 static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
211 {
212         int stack_adjust = ctx->stack_size, store_offset = stack_adjust - 8;
213
214         if (seen_reg(RV_REG_RA, ctx)) {
215                 emit_ld(RV_REG_RA, store_offset, RV_REG_SP, ctx);
216                 store_offset -= 8;
217         }
218         emit_ld(RV_REG_FP, store_offset, RV_REG_SP, ctx);
219         store_offset -= 8;
220         if (seen_reg(RV_REG_S1, ctx)) {
221                 emit_ld(RV_REG_S1, store_offset, RV_REG_SP, ctx);
222                 store_offset -= 8;
223         }
224         if (seen_reg(RV_REG_S2, ctx)) {
225                 emit_ld(RV_REG_S2, store_offset, RV_REG_SP, ctx);
226                 store_offset -= 8;
227         }
228         if (seen_reg(RV_REG_S3, ctx)) {
229                 emit_ld(RV_REG_S3, store_offset, RV_REG_SP, ctx);
230                 store_offset -= 8;
231         }
232         if (seen_reg(RV_REG_S4, ctx)) {
233                 emit_ld(RV_REG_S4, store_offset, RV_REG_SP, ctx);
234                 store_offset -= 8;
235         }
236         if (seen_reg(RV_REG_S5, ctx)) {
237                 emit_ld(RV_REG_S5, store_offset, RV_REG_SP, ctx);
238                 store_offset -= 8;
239         }
240         if (seen_reg(RV_REG_S6, ctx)) {
241                 emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
242                 store_offset -= 8;
243         }
244
245         emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
246         /* Set return value. */
247         if (!is_tail_call)
248                 emit_addiw(RV_REG_A0, RV_REG_A5, 0, ctx);
249         emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
250                   is_tail_call ? (RV_FENTRY_NINSNS + 1) * 4 : 0, /* skip reserved nops and TCC init */
251                   ctx);
252 }
253
254 static void emit_bcc(u8 cond, u8 rd, u8 rs, int rvoff,
255                      struct rv_jit_context *ctx)
256 {
257         switch (cond) {
258         case BPF_JEQ:
259                 emit(rv_beq(rd, rs, rvoff >> 1), ctx);
260                 return;
261         case BPF_JGT:
262                 emit(rv_bltu(rs, rd, rvoff >> 1), ctx);
263                 return;
264         case BPF_JLT:
265                 emit(rv_bltu(rd, rs, rvoff >> 1), ctx);
266                 return;
267         case BPF_JGE:
268                 emit(rv_bgeu(rd, rs, rvoff >> 1), ctx);
269                 return;
270         case BPF_JLE:
271                 emit(rv_bgeu(rs, rd, rvoff >> 1), ctx);
272                 return;
273         case BPF_JNE:
274                 emit(rv_bne(rd, rs, rvoff >> 1), ctx);
275                 return;
276         case BPF_JSGT:
277                 emit(rv_blt(rs, rd, rvoff >> 1), ctx);
278                 return;
279         case BPF_JSLT:
280                 emit(rv_blt(rd, rs, rvoff >> 1), ctx);
281                 return;
282         case BPF_JSGE:
283                 emit(rv_bge(rd, rs, rvoff >> 1), ctx);
284                 return;
285         case BPF_JSLE:
286                 emit(rv_bge(rs, rd, rvoff >> 1), ctx);
287         }
288 }
289
290 static void emit_branch(u8 cond, u8 rd, u8 rs, int rvoff,
291                         struct rv_jit_context *ctx)
292 {
293         s64 upper, lower;
294
295         if (is_13b_int(rvoff)) {
296                 emit_bcc(cond, rd, rs, rvoff, ctx);
297                 return;
298         }
299
300         /* Adjust for jal */
301         rvoff -= 4;
302
303         /* Transform, e.g.:
304          *   bne rd,rs,foo
305          * to
306          *   beq rd,rs,<.L1>
307          *   (auipc foo)
308          *   jal(r) foo
309          * .L1
310          */
311         cond = invert_bpf_cond(cond);
312         if (is_21b_int(rvoff)) {
313                 emit_bcc(cond, rd, rs, 8, ctx);
314                 emit(rv_jal(RV_REG_ZERO, rvoff >> 1), ctx);
315                 return;
316         }
317
318         /* 32b No need for an additional rvoff adjustment, since we
319          * get that from the auipc at PC', where PC = PC' + 4.
320          */
321         upper = (rvoff + (1 << 11)) >> 12;
322         lower = rvoff & 0xfff;
323
324         emit_bcc(cond, rd, rs, 12, ctx);
325         emit(rv_auipc(RV_REG_T1, upper), ctx);
326         emit(rv_jalr(RV_REG_ZERO, RV_REG_T1, lower), ctx);
327 }
328
329 static void emit_zext_32(u8 reg, struct rv_jit_context *ctx)
330 {
331         emit_slli(reg, reg, 32, ctx);
332         emit_srli(reg, reg, 32, ctx);
333 }
334
335 static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
336 {
337         int tc_ninsn, off, start_insn = ctx->ninsns;
338         u8 tcc = rv_tail_call_reg(ctx);
339
340         /* a0: &ctx
341          * a1: &array
342          * a2: index
343          *
344          * if (index >= array->map.max_entries)
345          *      goto out;
346          */
347         tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
348                    ctx->offset[0];
349         emit_zext_32(RV_REG_A2, ctx);
350
351         off = offsetof(struct bpf_array, map.max_entries);
352         if (is_12b_check(off, insn))
353                 return -1;
354         emit(rv_lwu(RV_REG_T1, off, RV_REG_A1), ctx);
355         off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
356         emit_branch(BPF_JGE, RV_REG_A2, RV_REG_T1, off, ctx);
357
358         /* if (--TCC < 0)
359          *     goto out;
360          */
361         emit_addi(RV_REG_TCC, tcc, -1, ctx);
362         off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
363         emit_branch(BPF_JSLT, RV_REG_TCC, RV_REG_ZERO, off, ctx);
364
365         /* prog = array->ptrs[index];
366          * if (!prog)
367          *     goto out;
368          */
369         emit_slli(RV_REG_T2, RV_REG_A2, 3, ctx);
370         emit_add(RV_REG_T2, RV_REG_T2, RV_REG_A1, ctx);
371         off = offsetof(struct bpf_array, ptrs);
372         if (is_12b_check(off, insn))
373                 return -1;
374         emit_ld(RV_REG_T2, off, RV_REG_T2, ctx);
375         off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
376         emit_branch(BPF_JEQ, RV_REG_T2, RV_REG_ZERO, off, ctx);
377
378         /* goto *(prog->bpf_func + 4); */
379         off = offsetof(struct bpf_prog, bpf_func);
380         if (is_12b_check(off, insn))
381                 return -1;
382         emit_ld(RV_REG_T3, off, RV_REG_T2, ctx);
383         __build_epilogue(true, ctx);
384         return 0;
385 }
386
387 static void init_regs(u8 *rd, u8 *rs, const struct bpf_insn *insn,
388                       struct rv_jit_context *ctx)
389 {
390         u8 code = insn->code;
391
392         switch (code) {
393         case BPF_JMP | BPF_JA:
394         case BPF_JMP | BPF_CALL:
395         case BPF_JMP | BPF_EXIT:
396         case BPF_JMP | BPF_TAIL_CALL:
397                 break;
398         default:
399                 *rd = bpf_to_rv_reg(insn->dst_reg, ctx);
400         }
401
402         if (code & (BPF_ALU | BPF_X) || code & (BPF_ALU64 | BPF_X) ||
403             code & (BPF_JMP | BPF_X) || code & (BPF_JMP32 | BPF_X) ||
404             code & BPF_LDX || code & BPF_STX)
405                 *rs = bpf_to_rv_reg(insn->src_reg, ctx);
406 }
407
408 static void emit_zext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
409 {
410         emit_mv(RV_REG_T2, *rd, ctx);
411         emit_zext_32(RV_REG_T2, ctx);
412         emit_mv(RV_REG_T1, *rs, ctx);
413         emit_zext_32(RV_REG_T1, ctx);
414         *rd = RV_REG_T2;
415         *rs = RV_REG_T1;
416 }
417
418 static void emit_sext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
419 {
420         emit_addiw(RV_REG_T2, *rd, 0, ctx);
421         emit_addiw(RV_REG_T1, *rs, 0, ctx);
422         *rd = RV_REG_T2;
423         *rs = RV_REG_T1;
424 }
425
426 static void emit_zext_32_rd_t1(u8 *rd, struct rv_jit_context *ctx)
427 {
428         emit_mv(RV_REG_T2, *rd, ctx);
429         emit_zext_32(RV_REG_T2, ctx);
430         emit_zext_32(RV_REG_T1, ctx);
431         *rd = RV_REG_T2;
432 }
433
434 static void emit_sext_32_rd(u8 *rd, struct rv_jit_context *ctx)
435 {
436         emit_addiw(RV_REG_T2, *rd, 0, ctx);
437         *rd = RV_REG_T2;
438 }
439
440 static int emit_jump_and_link(u8 rd, s64 rvoff, bool fixed_addr,
441                               struct rv_jit_context *ctx)
442 {
443         s64 upper, lower;
444
445         if (rvoff && fixed_addr && is_21b_int(rvoff)) {
446                 emit(rv_jal(rd, rvoff >> 1), ctx);
447                 return 0;
448         } else if (in_auipc_jalr_range(rvoff)) {
449                 upper = (rvoff + (1 << 11)) >> 12;
450                 lower = rvoff & 0xfff;
451                 emit(rv_auipc(RV_REG_T1, upper), ctx);
452                 emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
453                 return 0;
454         }
455
456         pr_err("bpf-jit: target offset 0x%llx is out of range\n", rvoff);
457         return -ERANGE;
458 }
459
460 static bool is_signed_bpf_cond(u8 cond)
461 {
462         return cond == BPF_JSGT || cond == BPF_JSLT ||
463                 cond == BPF_JSGE || cond == BPF_JSLE;
464 }
465
466 static int emit_call(u64 addr, bool fixed_addr, struct rv_jit_context *ctx)
467 {
468         s64 off = 0;
469         u64 ip;
470
471         if (addr && ctx->insns && ctx->ro_insns) {
472                 /*
473                  * Use the ro_insns(RX) to calculate the offset as the BPF
474                  * program will finally run from this memory region.
475                  */
476                 ip = (u64)(long)(ctx->ro_insns + ctx->ninsns);
477                 off = addr - ip;
478         }
479
480         return emit_jump_and_link(RV_REG_RA, off, fixed_addr, ctx);
481 }
482
483 static void emit_atomic(u8 rd, u8 rs, s16 off, s32 imm, bool is64,
484                         struct rv_jit_context *ctx)
485 {
486         u8 r0;
487         int jmp_offset;
488
489         if (off) {
490                 if (is_12b_int(off)) {
491                         emit_addi(RV_REG_T1, rd, off, ctx);
492                 } else {
493                         emit_imm(RV_REG_T1, off, ctx);
494                         emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
495                 }
496                 rd = RV_REG_T1;
497         }
498
499         switch (imm) {
500         /* lock *(u32/u64 *)(dst_reg + off16) <op>= src_reg */
501         case BPF_ADD:
502                 emit(is64 ? rv_amoadd_d(RV_REG_ZERO, rs, rd, 0, 0) :
503                      rv_amoadd_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
504                 break;
505         case BPF_AND:
506                 emit(is64 ? rv_amoand_d(RV_REG_ZERO, rs, rd, 0, 0) :
507                      rv_amoand_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
508                 break;
509         case BPF_OR:
510                 emit(is64 ? rv_amoor_d(RV_REG_ZERO, rs, rd, 0, 0) :
511                      rv_amoor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
512                 break;
513         case BPF_XOR:
514                 emit(is64 ? rv_amoxor_d(RV_REG_ZERO, rs, rd, 0, 0) :
515                      rv_amoxor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
516                 break;
517         /* src_reg = atomic_fetch_<op>(dst_reg + off16, src_reg) */
518         case BPF_ADD | BPF_FETCH:
519                 emit(is64 ? rv_amoadd_d(rs, rs, rd, 0, 0) :
520                      rv_amoadd_w(rs, rs, rd, 0, 0), ctx);
521                 if (!is64)
522                         emit_zext_32(rs, ctx);
523                 break;
524         case BPF_AND | BPF_FETCH:
525                 emit(is64 ? rv_amoand_d(rs, rs, rd, 0, 0) :
526                      rv_amoand_w(rs, rs, rd, 0, 0), ctx);
527                 if (!is64)
528                         emit_zext_32(rs, ctx);
529                 break;
530         case BPF_OR | BPF_FETCH:
531                 emit(is64 ? rv_amoor_d(rs, rs, rd, 0, 0) :
532                      rv_amoor_w(rs, rs, rd, 0, 0), ctx);
533                 if (!is64)
534                         emit_zext_32(rs, ctx);
535                 break;
536         case BPF_XOR | BPF_FETCH:
537                 emit(is64 ? rv_amoxor_d(rs, rs, rd, 0, 0) :
538                      rv_amoxor_w(rs, rs, rd, 0, 0), ctx);
539                 if (!is64)
540                         emit_zext_32(rs, ctx);
541                 break;
542         /* src_reg = atomic_xchg(dst_reg + off16, src_reg); */
543         case BPF_XCHG:
544                 emit(is64 ? rv_amoswap_d(rs, rs, rd, 0, 0) :
545                      rv_amoswap_w(rs, rs, rd, 0, 0), ctx);
546                 if (!is64)
547                         emit_zext_32(rs, ctx);
548                 break;
549         /* r0 = atomic_cmpxchg(dst_reg + off16, r0, src_reg); */
550         case BPF_CMPXCHG:
551                 r0 = bpf_to_rv_reg(BPF_REG_0, ctx);
552                 emit(is64 ? rv_addi(RV_REG_T2, r0, 0) :
553                      rv_addiw(RV_REG_T2, r0, 0), ctx);
554                 emit(is64 ? rv_lr_d(r0, 0, rd, 0, 0) :
555                      rv_lr_w(r0, 0, rd, 0, 0), ctx);
556                 jmp_offset = ninsns_rvoff(8);
557                 emit(rv_bne(RV_REG_T2, r0, jmp_offset >> 1), ctx);
558                 emit(is64 ? rv_sc_d(RV_REG_T3, rs, rd, 0, 0) :
559                      rv_sc_w(RV_REG_T3, rs, rd, 0, 0), ctx);
560                 jmp_offset = ninsns_rvoff(-6);
561                 emit(rv_bne(RV_REG_T3, 0, jmp_offset >> 1), ctx);
562                 emit(rv_fence(0x3, 0x3), ctx);
563                 break;
564         }
565 }
566
567 #define BPF_FIXUP_OFFSET_MASK   GENMASK(26, 0)
568 #define BPF_FIXUP_REG_MASK      GENMASK(31, 27)
569
570 bool ex_handler_bpf(const struct exception_table_entry *ex,
571                     struct pt_regs *regs)
572 {
573         off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
574         int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
575
576         *(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
577         regs->epc = (unsigned long)&ex->fixup - offset;
578
579         return true;
580 }
581
582 /* For accesses to BTF pointers, add an entry to the exception table */
583 static int add_exception_handler(const struct bpf_insn *insn,
584                                  struct rv_jit_context *ctx,
585                                  int dst_reg, int insn_len)
586 {
587         struct exception_table_entry *ex;
588         unsigned long pc;
589         off_t ins_offset;
590         off_t fixup_offset;
591
592         if (!ctx->insns || !ctx->ro_insns || !ctx->prog->aux->extable ||
593             (BPF_MODE(insn->code) != BPF_PROBE_MEM && BPF_MODE(insn->code) != BPF_PROBE_MEMSX))
594                 return 0;
595
596         if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
597                 return -EINVAL;
598
599         if (WARN_ON_ONCE(insn_len > ctx->ninsns))
600                 return -EINVAL;
601
602         if (WARN_ON_ONCE(!rvc_enabled() && insn_len == 1))
603                 return -EINVAL;
604
605         ex = &ctx->prog->aux->extable[ctx->nexentries];
606         pc = (unsigned long)&ctx->ro_insns[ctx->ninsns - insn_len];
607
608         /*
609          * This is the relative offset of the instruction that may fault from
610          * the exception table itself. This will be written to the exception
611          * table and if this instruction faults, the destination register will
612          * be set to '0' and the execution will jump to the next instruction.
613          */
614         ins_offset = pc - (long)&ex->insn;
615         if (WARN_ON_ONCE(ins_offset >= 0 || ins_offset < INT_MIN))
616                 return -ERANGE;
617
618         /*
619          * Since the extable follows the program, the fixup offset is always
620          * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
621          * to keep things simple, and put the destination register in the upper
622          * bits. We don't need to worry about buildtime or runtime sort
623          * modifying the upper bits because the table is already sorted, and
624          * isn't part of the main exception table.
625          *
626          * The fixup_offset is set to the next instruction from the instruction
627          * that may fault. The execution will jump to this after handling the
628          * fault.
629          */
630         fixup_offset = (long)&ex->fixup - (pc + insn_len * sizeof(u16));
631         if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, fixup_offset))
632                 return -ERANGE;
633
634         /*
635          * The offsets above have been calculated using the RO buffer but we
636          * need to use the R/W buffer for writes.
637          * switch ex to rw buffer for writing.
638          */
639         ex = (void *)ctx->insns + ((void *)ex - (void *)ctx->ro_insns);
640
641         ex->insn = ins_offset;
642
643         ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, fixup_offset) |
644                 FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
645         ex->type = EX_TYPE_BPF;
646
647         ctx->nexentries++;
648         return 0;
649 }
650
651 static int gen_jump_or_nops(void *target, void *ip, u32 *insns, bool is_call)
652 {
653         s64 rvoff;
654         struct rv_jit_context ctx;
655
656         ctx.ninsns = 0;
657         ctx.insns = (u16 *)insns;
658
659         if (!target) {
660                 emit(rv_nop(), &ctx);
661                 emit(rv_nop(), &ctx);
662                 return 0;
663         }
664
665         rvoff = (s64)(target - ip);
666         return emit_jump_and_link(is_call ? RV_REG_T0 : RV_REG_ZERO, rvoff, false, &ctx);
667 }
668
669 int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type,
670                        void *old_addr, void *new_addr)
671 {
672         u32 old_insns[RV_FENTRY_NINSNS], new_insns[RV_FENTRY_NINSNS];
673         bool is_call = poke_type == BPF_MOD_CALL;
674         int ret;
675
676         if (!is_kernel_text((unsigned long)ip) &&
677             !is_bpf_text_address((unsigned long)ip))
678                 return -ENOTSUPP;
679
680         ret = gen_jump_or_nops(old_addr, ip, old_insns, is_call);
681         if (ret)
682                 return ret;
683
684         if (memcmp(ip, old_insns, RV_FENTRY_NINSNS * 4))
685                 return -EFAULT;
686
687         ret = gen_jump_or_nops(new_addr, ip, new_insns, is_call);
688         if (ret)
689                 return ret;
690
691         cpus_read_lock();
692         mutex_lock(&text_mutex);
693         if (memcmp(ip, new_insns, RV_FENTRY_NINSNS * 4))
694                 ret = patch_text(ip, new_insns, RV_FENTRY_NINSNS);
695         mutex_unlock(&text_mutex);
696         cpus_read_unlock();
697
698         return ret;
699 }
700
701 static void store_args(int nregs, int args_off, struct rv_jit_context *ctx)
702 {
703         int i;
704
705         for (i = 0; i < nregs; i++) {
706                 emit_sd(RV_REG_FP, -args_off, RV_REG_A0 + i, ctx);
707                 args_off -= 8;
708         }
709 }
710
711 static void restore_args(int nregs, int args_off, struct rv_jit_context *ctx)
712 {
713         int i;
714
715         for (i = 0; i < nregs; i++) {
716                 emit_ld(RV_REG_A0 + i, -args_off, RV_REG_FP, ctx);
717                 args_off -= 8;
718         }
719 }
720
721 static int invoke_bpf_prog(struct bpf_tramp_link *l, int args_off, int retval_off,
722                            int run_ctx_off, bool save_ret, struct rv_jit_context *ctx)
723 {
724         int ret, branch_off;
725         struct bpf_prog *p = l->link.prog;
726         int cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
727
728         if (l->cookie) {
729                 emit_imm(RV_REG_T1, l->cookie, ctx);
730                 emit_sd(RV_REG_FP, -run_ctx_off + cookie_off, RV_REG_T1, ctx);
731         } else {
732                 emit_sd(RV_REG_FP, -run_ctx_off + cookie_off, RV_REG_ZERO, ctx);
733         }
734
735         /* arg1: prog */
736         emit_imm(RV_REG_A0, (const s64)p, ctx);
737         /* arg2: &run_ctx */
738         emit_addi(RV_REG_A1, RV_REG_FP, -run_ctx_off, ctx);
739         ret = emit_call((const u64)bpf_trampoline_enter(p), true, ctx);
740         if (ret)
741                 return ret;
742
743         /* if (__bpf_prog_enter(prog) == 0)
744          *      goto skip_exec_of_prog;
745          */
746         branch_off = ctx->ninsns;
747         /* nop reserved for conditional jump */
748         emit(rv_nop(), ctx);
749
750         /* store prog start time */
751         emit_mv(RV_REG_S1, RV_REG_A0, ctx);
752
753         /* arg1: &args_off */
754         emit_addi(RV_REG_A0, RV_REG_FP, -args_off, ctx);
755         if (!p->jited)
756                 /* arg2: progs[i]->insnsi for interpreter */
757                 emit_imm(RV_REG_A1, (const s64)p->insnsi, ctx);
758         ret = emit_call((const u64)p->bpf_func, true, ctx);
759         if (ret)
760                 return ret;
761
762         if (save_ret) {
763                 emit_sd(RV_REG_FP, -retval_off, RV_REG_A0, ctx);
764                 emit_sd(RV_REG_FP, -(retval_off - 8), regmap[BPF_REG_0], ctx);
765         }
766
767         /* update branch with beqz */
768         if (ctx->insns) {
769                 int offset = ninsns_rvoff(ctx->ninsns - branch_off);
770                 u32 insn = rv_beq(RV_REG_A0, RV_REG_ZERO, offset >> 1);
771                 *(u32 *)(ctx->insns + branch_off) = insn;
772         }
773
774         /* arg1: prog */
775         emit_imm(RV_REG_A0, (const s64)p, ctx);
776         /* arg2: prog start time */
777         emit_mv(RV_REG_A1, RV_REG_S1, ctx);
778         /* arg3: &run_ctx */
779         emit_addi(RV_REG_A2, RV_REG_FP, -run_ctx_off, ctx);
780         ret = emit_call((const u64)bpf_trampoline_exit(p), true, ctx);
781
782         return ret;
783 }
784
785 static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
786                                          const struct btf_func_model *m,
787                                          struct bpf_tramp_links *tlinks,
788                                          void *func_addr, u32 flags,
789                                          struct rv_jit_context *ctx)
790 {
791         int i, ret, offset;
792         int *branches_off = NULL;
793         int stack_size = 0, nregs = m->nr_args;
794         int retval_off, args_off, nregs_off, ip_off, run_ctx_off, sreg_off;
795         struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
796         struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
797         struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
798         bool is_struct_ops = flags & BPF_TRAMP_F_INDIRECT;
799         void *orig_call = func_addr;
800         bool save_ret;
801         u32 insn;
802
803         /* Two types of generated trampoline stack layout:
804          *
805          * 1. trampoline called from function entry
806          * --------------------------------------
807          * FP + 8           [ RA to parent func ] return address to parent
808          *                                        function
809          * FP + 0           [ FP of parent func ] frame pointer of parent
810          *                                        function
811          * FP - 8           [ T0 to traced func ] return address of traced
812          *                                        function
813          * FP - 16          [ FP of traced func ] frame pointer of traced
814          *                                        function
815          * --------------------------------------
816          *
817          * 2. trampoline called directly
818          * --------------------------------------
819          * FP - 8           [ RA to caller func ] return address to caller
820          *                                        function
821          * FP - 16          [ FP of caller func ] frame pointer of caller
822          *                                        function
823          * --------------------------------------
824          *
825          * FP - retval_off  [ return value      ] BPF_TRAMP_F_CALL_ORIG or
826          *                                        BPF_TRAMP_F_RET_FENTRY_RET
827          *                  [ argN              ]
828          *                  [ ...               ]
829          * FP - args_off    [ arg1              ]
830          *
831          * FP - nregs_off   [ regs count        ]
832          *
833          * FP - ip_off      [ traced func       ] BPF_TRAMP_F_IP_ARG
834          *
835          * FP - run_ctx_off [ bpf_tramp_run_ctx ]
836          *
837          * FP - sreg_off    [ callee saved reg  ]
838          *
839          *                  [ pads              ] pads for 16 bytes alignment
840          */
841
842         if (flags & (BPF_TRAMP_F_ORIG_STACK | BPF_TRAMP_F_SHARE_IPMODIFY))
843                 return -ENOTSUPP;
844
845         /* extra regiters for struct arguments */
846         for (i = 0; i < m->nr_args; i++)
847                 if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG)
848                         nregs += round_up(m->arg_size[i], 8) / 8 - 1;
849
850         /* 8 arguments passed by registers */
851         if (nregs > 8)
852                 return -ENOTSUPP;
853
854         /* room of trampoline frame to store return address and frame pointer */
855         stack_size += 16;
856
857         save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
858         if (save_ret) {
859                 stack_size += 16; /* Save both A5 (BPF R0) and A0 */
860                 retval_off = stack_size;
861         }
862
863         stack_size += nregs * 8;
864         args_off = stack_size;
865
866         stack_size += 8;
867         nregs_off = stack_size;
868
869         if (flags & BPF_TRAMP_F_IP_ARG) {
870                 stack_size += 8;
871                 ip_off = stack_size;
872         }
873
874         stack_size += round_up(sizeof(struct bpf_tramp_run_ctx), 8);
875         run_ctx_off = stack_size;
876
877         stack_size += 8;
878         sreg_off = stack_size;
879
880         stack_size = round_up(stack_size, 16);
881
882         if (!is_struct_ops) {
883                 /* For the trampoline called from function entry,
884                  * the frame of traced function and the frame of
885                  * trampoline need to be considered.
886                  */
887                 emit_addi(RV_REG_SP, RV_REG_SP, -16, ctx);
888                 emit_sd(RV_REG_SP, 8, RV_REG_RA, ctx);
889                 emit_sd(RV_REG_SP, 0, RV_REG_FP, ctx);
890                 emit_addi(RV_REG_FP, RV_REG_SP, 16, ctx);
891
892                 emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
893                 emit_sd(RV_REG_SP, stack_size - 8, RV_REG_T0, ctx);
894                 emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
895                 emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
896         } else {
897                 /* For the trampoline called directly, just handle
898                  * the frame of trampoline.
899                  */
900                 emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
901                 emit_sd(RV_REG_SP, stack_size - 8, RV_REG_RA, ctx);
902                 emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
903                 emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
904         }
905
906         /* callee saved register S1 to pass start time */
907         emit_sd(RV_REG_FP, -sreg_off, RV_REG_S1, ctx);
908
909         /* store ip address of the traced function */
910         if (flags & BPF_TRAMP_F_IP_ARG) {
911                 emit_imm(RV_REG_T1, (const s64)func_addr, ctx);
912                 emit_sd(RV_REG_FP, -ip_off, RV_REG_T1, ctx);
913         }
914
915         emit_li(RV_REG_T1, nregs, ctx);
916         emit_sd(RV_REG_FP, -nregs_off, RV_REG_T1, ctx);
917
918         store_args(nregs, args_off, ctx);
919
920         /* skip to actual body of traced function */
921         if (flags & BPF_TRAMP_F_SKIP_FRAME)
922                 orig_call += RV_FENTRY_NINSNS * 4;
923
924         if (flags & BPF_TRAMP_F_CALL_ORIG) {
925                 emit_imm(RV_REG_A0, (const s64)im, ctx);
926                 ret = emit_call((const u64)__bpf_tramp_enter, true, ctx);
927                 if (ret)
928                         return ret;
929         }
930
931         for (i = 0; i < fentry->nr_links; i++) {
932                 ret = invoke_bpf_prog(fentry->links[i], args_off, retval_off, run_ctx_off,
933                                       flags & BPF_TRAMP_F_RET_FENTRY_RET, ctx);
934                 if (ret)
935                         return ret;
936         }
937
938         if (fmod_ret->nr_links) {
939                 branches_off = kcalloc(fmod_ret->nr_links, sizeof(int), GFP_KERNEL);
940                 if (!branches_off)
941                         return -ENOMEM;
942
943                 /* cleanup to avoid garbage return value confusion */
944                 emit_sd(RV_REG_FP, -retval_off, RV_REG_ZERO, ctx);
945                 for (i = 0; i < fmod_ret->nr_links; i++) {
946                         ret = invoke_bpf_prog(fmod_ret->links[i], args_off, retval_off,
947                                               run_ctx_off, true, ctx);
948                         if (ret)
949                                 goto out;
950                         emit_ld(RV_REG_T1, -retval_off, RV_REG_FP, ctx);
951                         branches_off[i] = ctx->ninsns;
952                         /* nop reserved for conditional jump */
953                         emit(rv_nop(), ctx);
954                 }
955         }
956
957         if (flags & BPF_TRAMP_F_CALL_ORIG) {
958                 restore_args(nregs, args_off, ctx);
959                 ret = emit_call((const u64)orig_call, true, ctx);
960                 if (ret)
961                         goto out;
962                 emit_sd(RV_REG_FP, -retval_off, RV_REG_A0, ctx);
963                 emit_sd(RV_REG_FP, -(retval_off - 8), regmap[BPF_REG_0], ctx);
964                 im->ip_after_call = ctx->insns + ctx->ninsns;
965                 /* 2 nops reserved for auipc+jalr pair */
966                 emit(rv_nop(), ctx);
967                 emit(rv_nop(), ctx);
968         }
969
970         /* update branches saved in invoke_bpf_mod_ret with bnez */
971         for (i = 0; ctx->insns && i < fmod_ret->nr_links; i++) {
972                 offset = ninsns_rvoff(ctx->ninsns - branches_off[i]);
973                 insn = rv_bne(RV_REG_T1, RV_REG_ZERO, offset >> 1);
974                 *(u32 *)(ctx->insns + branches_off[i]) = insn;
975         }
976
977         for (i = 0; i < fexit->nr_links; i++) {
978                 ret = invoke_bpf_prog(fexit->links[i], args_off, retval_off,
979                                       run_ctx_off, false, ctx);
980                 if (ret)
981                         goto out;
982         }
983
984         if (flags & BPF_TRAMP_F_CALL_ORIG) {
985                 im->ip_epilogue = ctx->insns + ctx->ninsns;
986                 emit_imm(RV_REG_A0, (const s64)im, ctx);
987                 ret = emit_call((const u64)__bpf_tramp_exit, true, ctx);
988                 if (ret)
989                         goto out;
990         }
991
992         if (flags & BPF_TRAMP_F_RESTORE_REGS)
993                 restore_args(nregs, args_off, ctx);
994
995         if (save_ret) {
996                 emit_ld(RV_REG_A0, -retval_off, RV_REG_FP, ctx);
997                 emit_ld(regmap[BPF_REG_0], -(retval_off - 8), RV_REG_FP, ctx);
998         }
999
1000         emit_ld(RV_REG_S1, -sreg_off, RV_REG_FP, ctx);
1001
1002         if (!is_struct_ops) {
1003                 /* trampoline called from function entry */
1004                 emit_ld(RV_REG_T0, stack_size - 8, RV_REG_SP, ctx);
1005                 emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
1006                 emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
1007
1008                 emit_ld(RV_REG_RA, 8, RV_REG_SP, ctx);
1009                 emit_ld(RV_REG_FP, 0, RV_REG_SP, ctx);
1010                 emit_addi(RV_REG_SP, RV_REG_SP, 16, ctx);
1011
1012                 if (flags & BPF_TRAMP_F_SKIP_FRAME)
1013                         /* return to parent function */
1014                         emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
1015                 else
1016                         /* return to traced function */
1017                         emit_jalr(RV_REG_ZERO, RV_REG_T0, 0, ctx);
1018         } else {
1019                 /* trampoline called directly */
1020                 emit_ld(RV_REG_RA, stack_size - 8, RV_REG_SP, ctx);
1021                 emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
1022                 emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
1023
1024                 emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
1025         }
1026
1027         ret = ctx->ninsns;
1028 out:
1029         kfree(branches_off);
1030         return ret;
1031 }
1032
1033 int arch_bpf_trampoline_size(const struct btf_func_model *m, u32 flags,
1034                              struct bpf_tramp_links *tlinks, void *func_addr)
1035 {
1036         struct bpf_tramp_image im;
1037         struct rv_jit_context ctx;
1038         int ret;
1039
1040         ctx.ninsns = 0;
1041         ctx.insns = NULL;
1042         ctx.ro_insns = NULL;
1043         ret = __arch_prepare_bpf_trampoline(&im, m, tlinks, func_addr, flags, &ctx);
1044
1045         return ret < 0 ? ret : ninsns_rvoff(ctx.ninsns);
1046 }
1047
1048 int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image,
1049                                 void *image_end, const struct btf_func_model *m,
1050                                 u32 flags, struct bpf_tramp_links *tlinks,
1051                                 void *func_addr)
1052 {
1053         int ret;
1054         struct rv_jit_context ctx;
1055
1056         ctx.ninsns = 0;
1057         /*
1058          * The bpf_int_jit_compile() uses a RW buffer (ctx.insns) to write the
1059          * JITed instructions and later copies it to a RX region (ctx.ro_insns).
1060          * It also uses ctx.ro_insns to calculate offsets for jumps etc. As the
1061          * trampoline image uses the same memory area for writing and execution,
1062          * both ctx.insns and ctx.ro_insns can be set to image.
1063          */
1064         ctx.insns = image;
1065         ctx.ro_insns = image;
1066         ret = __arch_prepare_bpf_trampoline(im, m, tlinks, func_addr, flags, &ctx);
1067         if (ret < 0)
1068                 return ret;
1069
1070         bpf_flush_icache(ctx.insns, ctx.insns + ctx.ninsns);
1071
1072         return ninsns_rvoff(ret);
1073 }
1074
1075 int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
1076                       bool extra_pass)
1077 {
1078         bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
1079                     BPF_CLASS(insn->code) == BPF_JMP;
1080         int s, e, rvoff, ret, i = insn - ctx->prog->insnsi;
1081         struct bpf_prog_aux *aux = ctx->prog->aux;
1082         u8 rd = -1, rs = -1, code = insn->code;
1083         s16 off = insn->off;
1084         s32 imm = insn->imm;
1085
1086         init_regs(&rd, &rs, insn, ctx);
1087
1088         switch (code) {
1089         /* dst = src */
1090         case BPF_ALU | BPF_MOV | BPF_X:
1091         case BPF_ALU64 | BPF_MOV | BPF_X:
1092                 if (imm == 1) {
1093                         /* Special mov32 for zext */
1094                         emit_zext_32(rd, ctx);
1095                         break;
1096                 }
1097                 switch (insn->off) {
1098                 case 0:
1099                         emit_mv(rd, rs, ctx);
1100                         break;
1101                 case 8:
1102                 case 16:
1103                         emit_slli(RV_REG_T1, rs, 64 - insn->off, ctx);
1104                         emit_srai(rd, RV_REG_T1, 64 - insn->off, ctx);
1105                         break;
1106                 case 32:
1107                         emit_addiw(rd, rs, 0, ctx);
1108                         break;
1109                 }
1110                 if (!is64 && !aux->verifier_zext)
1111                         emit_zext_32(rd, ctx);
1112                 break;
1113
1114         /* dst = dst OP src */
1115         case BPF_ALU | BPF_ADD | BPF_X:
1116         case BPF_ALU64 | BPF_ADD | BPF_X:
1117                 emit_add(rd, rd, rs, ctx);
1118                 if (!is64 && !aux->verifier_zext)
1119                         emit_zext_32(rd, ctx);
1120                 break;
1121         case BPF_ALU | BPF_SUB | BPF_X:
1122         case BPF_ALU64 | BPF_SUB | BPF_X:
1123                 if (is64)
1124                         emit_sub(rd, rd, rs, ctx);
1125                 else
1126                         emit_subw(rd, rd, rs, ctx);
1127
1128                 if (!is64 && !aux->verifier_zext)
1129                         emit_zext_32(rd, ctx);
1130                 break;
1131         case BPF_ALU | BPF_AND | BPF_X:
1132         case BPF_ALU64 | BPF_AND | BPF_X:
1133                 emit_and(rd, rd, rs, ctx);
1134                 if (!is64 && !aux->verifier_zext)
1135                         emit_zext_32(rd, ctx);
1136                 break;
1137         case BPF_ALU | BPF_OR | BPF_X:
1138         case BPF_ALU64 | BPF_OR | BPF_X:
1139                 emit_or(rd, rd, rs, ctx);
1140                 if (!is64 && !aux->verifier_zext)
1141                         emit_zext_32(rd, ctx);
1142                 break;
1143         case BPF_ALU | BPF_XOR | BPF_X:
1144         case BPF_ALU64 | BPF_XOR | BPF_X:
1145                 emit_xor(rd, rd, rs, ctx);
1146                 if (!is64 && !aux->verifier_zext)
1147                         emit_zext_32(rd, ctx);
1148                 break;
1149         case BPF_ALU | BPF_MUL | BPF_X:
1150         case BPF_ALU64 | BPF_MUL | BPF_X:
1151                 emit(is64 ? rv_mul(rd, rd, rs) : rv_mulw(rd, rd, rs), ctx);
1152                 if (!is64 && !aux->verifier_zext)
1153                         emit_zext_32(rd, ctx);
1154                 break;
1155         case BPF_ALU | BPF_DIV | BPF_X:
1156         case BPF_ALU64 | BPF_DIV | BPF_X:
1157                 if (off)
1158                         emit(is64 ? rv_div(rd, rd, rs) : rv_divw(rd, rd, rs), ctx);
1159                 else
1160                         emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx);
1161                 if (!is64 && !aux->verifier_zext)
1162                         emit_zext_32(rd, ctx);
1163                 break;
1164         case BPF_ALU | BPF_MOD | BPF_X:
1165         case BPF_ALU64 | BPF_MOD | BPF_X:
1166                 if (off)
1167                         emit(is64 ? rv_rem(rd, rd, rs) : rv_remw(rd, rd, rs), ctx);
1168                 else
1169                         emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx);
1170                 if (!is64 && !aux->verifier_zext)
1171                         emit_zext_32(rd, ctx);
1172                 break;
1173         case BPF_ALU | BPF_LSH | BPF_X:
1174         case BPF_ALU64 | BPF_LSH | BPF_X:
1175                 emit(is64 ? rv_sll(rd, rd, rs) : rv_sllw(rd, rd, rs), ctx);
1176                 if (!is64 && !aux->verifier_zext)
1177                         emit_zext_32(rd, ctx);
1178                 break;
1179         case BPF_ALU | BPF_RSH | BPF_X:
1180         case BPF_ALU64 | BPF_RSH | BPF_X:
1181                 emit(is64 ? rv_srl(rd, rd, rs) : rv_srlw(rd, rd, rs), ctx);
1182                 if (!is64 && !aux->verifier_zext)
1183                         emit_zext_32(rd, ctx);
1184                 break;
1185         case BPF_ALU | BPF_ARSH | BPF_X:
1186         case BPF_ALU64 | BPF_ARSH | BPF_X:
1187                 emit(is64 ? rv_sra(rd, rd, rs) : rv_sraw(rd, rd, rs), ctx);
1188                 if (!is64 && !aux->verifier_zext)
1189                         emit_zext_32(rd, ctx);
1190                 break;
1191
1192         /* dst = -dst */
1193         case BPF_ALU | BPF_NEG:
1194         case BPF_ALU64 | BPF_NEG:
1195                 emit_sub(rd, RV_REG_ZERO, rd, ctx);
1196                 if (!is64 && !aux->verifier_zext)
1197                         emit_zext_32(rd, ctx);
1198                 break;
1199
1200         /* dst = BSWAP##imm(dst) */
1201         case BPF_ALU | BPF_END | BPF_FROM_LE:
1202                 switch (imm) {
1203                 case 16:
1204                         emit_slli(rd, rd, 48, ctx);
1205                         emit_srli(rd, rd, 48, ctx);
1206                         break;
1207                 case 32:
1208                         if (!aux->verifier_zext)
1209                                 emit_zext_32(rd, ctx);
1210                         break;
1211                 case 64:
1212                         /* Do nothing */
1213                         break;
1214                 }
1215                 break;
1216
1217         case BPF_ALU | BPF_END | BPF_FROM_BE:
1218         case BPF_ALU64 | BPF_END | BPF_FROM_LE:
1219                 emit_li(RV_REG_T2, 0, ctx);
1220
1221                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1222                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1223                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1224                 emit_srli(rd, rd, 8, ctx);
1225                 if (imm == 16)
1226                         goto out_be;
1227
1228                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1229                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1230                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1231                 emit_srli(rd, rd, 8, ctx);
1232
1233                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1234                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1235                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1236                 emit_srli(rd, rd, 8, ctx);
1237                 if (imm == 32)
1238                         goto out_be;
1239
1240                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1241                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1242                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1243                 emit_srli(rd, rd, 8, ctx);
1244
1245                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1246                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1247                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1248                 emit_srli(rd, rd, 8, ctx);
1249
1250                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1251                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1252                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1253                 emit_srli(rd, rd, 8, ctx);
1254
1255                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1256                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1257                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1258                 emit_srli(rd, rd, 8, ctx);
1259 out_be:
1260                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1261                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1262
1263                 emit_mv(rd, RV_REG_T2, ctx);
1264                 break;
1265
1266         /* dst = imm */
1267         case BPF_ALU | BPF_MOV | BPF_K:
1268         case BPF_ALU64 | BPF_MOV | BPF_K:
1269                 emit_imm(rd, imm, ctx);
1270                 if (!is64 && !aux->verifier_zext)
1271                         emit_zext_32(rd, ctx);
1272                 break;
1273
1274         /* dst = dst OP imm */
1275         case BPF_ALU | BPF_ADD | BPF_K:
1276         case BPF_ALU64 | BPF_ADD | BPF_K:
1277                 if (is_12b_int(imm)) {
1278                         emit_addi(rd, rd, imm, ctx);
1279                 } else {
1280                         emit_imm(RV_REG_T1, imm, ctx);
1281                         emit_add(rd, rd, RV_REG_T1, ctx);
1282                 }
1283                 if (!is64 && !aux->verifier_zext)
1284                         emit_zext_32(rd, ctx);
1285                 break;
1286         case BPF_ALU | BPF_SUB | BPF_K:
1287         case BPF_ALU64 | BPF_SUB | BPF_K:
1288                 if (is_12b_int(-imm)) {
1289                         emit_addi(rd, rd, -imm, ctx);
1290                 } else {
1291                         emit_imm(RV_REG_T1, imm, ctx);
1292                         emit_sub(rd, rd, RV_REG_T1, ctx);
1293                 }
1294                 if (!is64 && !aux->verifier_zext)
1295                         emit_zext_32(rd, ctx);
1296                 break;
1297         case BPF_ALU | BPF_AND | BPF_K:
1298         case BPF_ALU64 | BPF_AND | BPF_K:
1299                 if (is_12b_int(imm)) {
1300                         emit_andi(rd, rd, imm, ctx);
1301                 } else {
1302                         emit_imm(RV_REG_T1, imm, ctx);
1303                         emit_and(rd, rd, RV_REG_T1, ctx);
1304                 }
1305                 if (!is64 && !aux->verifier_zext)
1306                         emit_zext_32(rd, ctx);
1307                 break;
1308         case BPF_ALU | BPF_OR | BPF_K:
1309         case BPF_ALU64 | BPF_OR | BPF_K:
1310                 if (is_12b_int(imm)) {
1311                         emit(rv_ori(rd, rd, imm), ctx);
1312                 } else {
1313                         emit_imm(RV_REG_T1, imm, ctx);
1314                         emit_or(rd, rd, RV_REG_T1, ctx);
1315                 }
1316                 if (!is64 && !aux->verifier_zext)
1317                         emit_zext_32(rd, ctx);
1318                 break;
1319         case BPF_ALU | BPF_XOR | BPF_K:
1320         case BPF_ALU64 | BPF_XOR | BPF_K:
1321                 if (is_12b_int(imm)) {
1322                         emit(rv_xori(rd, rd, imm), ctx);
1323                 } else {
1324                         emit_imm(RV_REG_T1, imm, ctx);
1325                         emit_xor(rd, rd, RV_REG_T1, ctx);
1326                 }
1327                 if (!is64 && !aux->verifier_zext)
1328                         emit_zext_32(rd, ctx);
1329                 break;
1330         case BPF_ALU | BPF_MUL | BPF_K:
1331         case BPF_ALU64 | BPF_MUL | BPF_K:
1332                 emit_imm(RV_REG_T1, imm, ctx);
1333                 emit(is64 ? rv_mul(rd, rd, RV_REG_T1) :
1334                      rv_mulw(rd, rd, RV_REG_T1), ctx);
1335                 if (!is64 && !aux->verifier_zext)
1336                         emit_zext_32(rd, ctx);
1337                 break;
1338         case BPF_ALU | BPF_DIV | BPF_K:
1339         case BPF_ALU64 | BPF_DIV | BPF_K:
1340                 emit_imm(RV_REG_T1, imm, ctx);
1341                 if (off)
1342                         emit(is64 ? rv_div(rd, rd, RV_REG_T1) :
1343                              rv_divw(rd, rd, RV_REG_T1), ctx);
1344                 else
1345                         emit(is64 ? rv_divu(rd, rd, RV_REG_T1) :
1346                              rv_divuw(rd, rd, RV_REG_T1), ctx);
1347                 if (!is64 && !aux->verifier_zext)
1348                         emit_zext_32(rd, ctx);
1349                 break;
1350         case BPF_ALU | BPF_MOD | BPF_K:
1351         case BPF_ALU64 | BPF_MOD | BPF_K:
1352                 emit_imm(RV_REG_T1, imm, ctx);
1353                 if (off)
1354                         emit(is64 ? rv_rem(rd, rd, RV_REG_T1) :
1355                              rv_remw(rd, rd, RV_REG_T1), ctx);
1356                 else
1357                         emit(is64 ? rv_remu(rd, rd, RV_REG_T1) :
1358                              rv_remuw(rd, rd, RV_REG_T1), ctx);
1359                 if (!is64 && !aux->verifier_zext)
1360                         emit_zext_32(rd, ctx);
1361                 break;
1362         case BPF_ALU | BPF_LSH | BPF_K:
1363         case BPF_ALU64 | BPF_LSH | BPF_K:
1364                 emit_slli(rd, rd, imm, ctx);
1365
1366                 if (!is64 && !aux->verifier_zext)
1367                         emit_zext_32(rd, ctx);
1368                 break;
1369         case BPF_ALU | BPF_RSH | BPF_K:
1370         case BPF_ALU64 | BPF_RSH | BPF_K:
1371                 if (is64)
1372                         emit_srli(rd, rd, imm, ctx);
1373                 else
1374                         emit(rv_srliw(rd, rd, imm), ctx);
1375
1376                 if (!is64 && !aux->verifier_zext)
1377                         emit_zext_32(rd, ctx);
1378                 break;
1379         case BPF_ALU | BPF_ARSH | BPF_K:
1380         case BPF_ALU64 | BPF_ARSH | BPF_K:
1381                 if (is64)
1382                         emit_srai(rd, rd, imm, ctx);
1383                 else
1384                         emit(rv_sraiw(rd, rd, imm), ctx);
1385
1386                 if (!is64 && !aux->verifier_zext)
1387                         emit_zext_32(rd, ctx);
1388                 break;
1389
1390         /* JUMP off */
1391         case BPF_JMP | BPF_JA:
1392         case BPF_JMP32 | BPF_JA:
1393                 if (BPF_CLASS(code) == BPF_JMP)
1394                         rvoff = rv_offset(i, off, ctx);
1395                 else
1396                         rvoff = rv_offset(i, imm, ctx);
1397                 ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
1398                 if (ret)
1399                         return ret;
1400                 break;
1401
1402         /* IF (dst COND src) JUMP off */
1403         case BPF_JMP | BPF_JEQ | BPF_X:
1404         case BPF_JMP32 | BPF_JEQ | BPF_X:
1405         case BPF_JMP | BPF_JGT | BPF_X:
1406         case BPF_JMP32 | BPF_JGT | BPF_X:
1407         case BPF_JMP | BPF_JLT | BPF_X:
1408         case BPF_JMP32 | BPF_JLT | BPF_X:
1409         case BPF_JMP | BPF_JGE | BPF_X:
1410         case BPF_JMP32 | BPF_JGE | BPF_X:
1411         case BPF_JMP | BPF_JLE | BPF_X:
1412         case BPF_JMP32 | BPF_JLE | BPF_X:
1413         case BPF_JMP | BPF_JNE | BPF_X:
1414         case BPF_JMP32 | BPF_JNE | BPF_X:
1415         case BPF_JMP | BPF_JSGT | BPF_X:
1416         case BPF_JMP32 | BPF_JSGT | BPF_X:
1417         case BPF_JMP | BPF_JSLT | BPF_X:
1418         case BPF_JMP32 | BPF_JSLT | BPF_X:
1419         case BPF_JMP | BPF_JSGE | BPF_X:
1420         case BPF_JMP32 | BPF_JSGE | BPF_X:
1421         case BPF_JMP | BPF_JSLE | BPF_X:
1422         case BPF_JMP32 | BPF_JSLE | BPF_X:
1423         case BPF_JMP | BPF_JSET | BPF_X:
1424         case BPF_JMP32 | BPF_JSET | BPF_X:
1425                 rvoff = rv_offset(i, off, ctx);
1426                 if (!is64) {
1427                         s = ctx->ninsns;
1428                         if (is_signed_bpf_cond(BPF_OP(code)))
1429                                 emit_sext_32_rd_rs(&rd, &rs, ctx);
1430                         else
1431                                 emit_zext_32_rd_rs(&rd, &rs, ctx);
1432                         e = ctx->ninsns;
1433
1434                         /* Adjust for extra insns */
1435                         rvoff -= ninsns_rvoff(e - s);
1436                 }
1437
1438                 if (BPF_OP(code) == BPF_JSET) {
1439                         /* Adjust for and */
1440                         rvoff -= 4;
1441                         emit_and(RV_REG_T1, rd, rs, ctx);
1442                         emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff,
1443                                     ctx);
1444                 } else {
1445                         emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
1446                 }
1447                 break;
1448
1449         /* IF (dst COND imm) JUMP off */
1450         case BPF_JMP | BPF_JEQ | BPF_K:
1451         case BPF_JMP32 | BPF_JEQ | BPF_K:
1452         case BPF_JMP | BPF_JGT | BPF_K:
1453         case BPF_JMP32 | BPF_JGT | BPF_K:
1454         case BPF_JMP | BPF_JLT | BPF_K:
1455         case BPF_JMP32 | BPF_JLT | BPF_K:
1456         case BPF_JMP | BPF_JGE | BPF_K:
1457         case BPF_JMP32 | BPF_JGE | BPF_K:
1458         case BPF_JMP | BPF_JLE | BPF_K:
1459         case BPF_JMP32 | BPF_JLE | BPF_K:
1460         case BPF_JMP | BPF_JNE | BPF_K:
1461         case BPF_JMP32 | BPF_JNE | BPF_K:
1462         case BPF_JMP | BPF_JSGT | BPF_K:
1463         case BPF_JMP32 | BPF_JSGT | BPF_K:
1464         case BPF_JMP | BPF_JSLT | BPF_K:
1465         case BPF_JMP32 | BPF_JSLT | BPF_K:
1466         case BPF_JMP | BPF_JSGE | BPF_K:
1467         case BPF_JMP32 | BPF_JSGE | BPF_K:
1468         case BPF_JMP | BPF_JSLE | BPF_K:
1469         case BPF_JMP32 | BPF_JSLE | BPF_K:
1470                 rvoff = rv_offset(i, off, ctx);
1471                 s = ctx->ninsns;
1472                 if (imm) {
1473                         emit_imm(RV_REG_T1, imm, ctx);
1474                         rs = RV_REG_T1;
1475                 } else {
1476                         /* If imm is 0, simply use zero register. */
1477                         rs = RV_REG_ZERO;
1478                 }
1479                 if (!is64) {
1480                         if (is_signed_bpf_cond(BPF_OP(code)))
1481                                 emit_sext_32_rd(&rd, ctx);
1482                         else
1483                                 emit_zext_32_rd_t1(&rd, ctx);
1484                 }
1485                 e = ctx->ninsns;
1486
1487                 /* Adjust for extra insns */
1488                 rvoff -= ninsns_rvoff(e - s);
1489                 emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
1490                 break;
1491
1492         case BPF_JMP | BPF_JSET | BPF_K:
1493         case BPF_JMP32 | BPF_JSET | BPF_K:
1494                 rvoff = rv_offset(i, off, ctx);
1495                 s = ctx->ninsns;
1496                 if (is_12b_int(imm)) {
1497                         emit_andi(RV_REG_T1, rd, imm, ctx);
1498                 } else {
1499                         emit_imm(RV_REG_T1, imm, ctx);
1500                         emit_and(RV_REG_T1, rd, RV_REG_T1, ctx);
1501                 }
1502                 /* For jset32, we should clear the upper 32 bits of t1, but
1503                  * sign-extension is sufficient here and saves one instruction,
1504                  * as t1 is used only in comparison against zero.
1505                  */
1506                 if (!is64 && imm < 0)
1507                         emit_addiw(RV_REG_T1, RV_REG_T1, 0, ctx);
1508                 e = ctx->ninsns;
1509                 rvoff -= ninsns_rvoff(e - s);
1510                 emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx);
1511                 break;
1512
1513         /* function call */
1514         case BPF_JMP | BPF_CALL:
1515         {
1516                 bool fixed_addr;
1517                 u64 addr;
1518
1519                 mark_call(ctx);
1520                 ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
1521                                             &addr, &fixed_addr);
1522                 if (ret < 0)
1523                         return ret;
1524
1525                 ret = emit_call(addr, fixed_addr, ctx);
1526                 if (ret)
1527                         return ret;
1528
1529                 if (insn->src_reg != BPF_PSEUDO_CALL)
1530                         emit_mv(bpf_to_rv_reg(BPF_REG_0, ctx), RV_REG_A0, ctx);
1531                 break;
1532         }
1533         /* tail call */
1534         case BPF_JMP | BPF_TAIL_CALL:
1535                 if (emit_bpf_tail_call(i, ctx))
1536                         return -1;
1537                 break;
1538
1539         /* function return */
1540         case BPF_JMP | BPF_EXIT:
1541                 if (i == ctx->prog->len - 1)
1542                         break;
1543
1544                 rvoff = epilogue_offset(ctx);
1545                 ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
1546                 if (ret)
1547                         return ret;
1548                 break;
1549
1550         /* dst = imm64 */
1551         case BPF_LD | BPF_IMM | BPF_DW:
1552         {
1553                 struct bpf_insn insn1 = insn[1];
1554                 u64 imm64;
1555
1556                 imm64 = (u64)insn1.imm << 32 | (u32)imm;
1557                 if (bpf_pseudo_func(insn)) {
1558                         /* fixed-length insns for extra jit pass */
1559                         ret = emit_addr(rd, imm64, extra_pass, ctx);
1560                         if (ret)
1561                                 return ret;
1562                 } else {
1563                         emit_imm(rd, imm64, ctx);
1564                 }
1565
1566                 return 1;
1567         }
1568
1569         /* LDX: dst = *(unsigned size *)(src + off) */
1570         case BPF_LDX | BPF_MEM | BPF_B:
1571         case BPF_LDX | BPF_MEM | BPF_H:
1572         case BPF_LDX | BPF_MEM | BPF_W:
1573         case BPF_LDX | BPF_MEM | BPF_DW:
1574         case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1575         case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1576         case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1577         case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1578         /* LDSX: dst = *(signed size *)(src + off) */
1579         case BPF_LDX | BPF_MEMSX | BPF_B:
1580         case BPF_LDX | BPF_MEMSX | BPF_H:
1581         case BPF_LDX | BPF_MEMSX | BPF_W:
1582         case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
1583         case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
1584         case BPF_LDX | BPF_PROBE_MEMSX | BPF_W:
1585         {
1586                 int insn_len, insns_start;
1587                 bool sign_ext;
1588
1589                 sign_ext = BPF_MODE(insn->code) == BPF_MEMSX ||
1590                            BPF_MODE(insn->code) == BPF_PROBE_MEMSX;
1591
1592                 switch (BPF_SIZE(code)) {
1593                 case BPF_B:
1594                         if (is_12b_int(off)) {
1595                                 insns_start = ctx->ninsns;
1596                                 if (sign_ext)
1597                                         emit(rv_lb(rd, off, rs), ctx);
1598                                 else
1599                                         emit(rv_lbu(rd, off, rs), ctx);
1600                                 insn_len = ctx->ninsns - insns_start;
1601                                 break;
1602                         }
1603
1604                         emit_imm(RV_REG_T1, off, ctx);
1605                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1606                         insns_start = ctx->ninsns;
1607                         if (sign_ext)
1608                                 emit(rv_lb(rd, 0, RV_REG_T1), ctx);
1609                         else
1610                                 emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
1611                         insn_len = ctx->ninsns - insns_start;
1612                         break;
1613                 case BPF_H:
1614                         if (is_12b_int(off)) {
1615                                 insns_start = ctx->ninsns;
1616                                 if (sign_ext)
1617                                         emit(rv_lh(rd, off, rs), ctx);
1618                                 else
1619                                         emit(rv_lhu(rd, off, rs), ctx);
1620                                 insn_len = ctx->ninsns - insns_start;
1621                                 break;
1622                         }
1623
1624                         emit_imm(RV_REG_T1, off, ctx);
1625                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1626                         insns_start = ctx->ninsns;
1627                         if (sign_ext)
1628                                 emit(rv_lh(rd, 0, RV_REG_T1), ctx);
1629                         else
1630                                 emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
1631                         insn_len = ctx->ninsns - insns_start;
1632                         break;
1633                 case BPF_W:
1634                         if (is_12b_int(off)) {
1635                                 insns_start = ctx->ninsns;
1636                                 if (sign_ext)
1637                                         emit(rv_lw(rd, off, rs), ctx);
1638                                 else
1639                                         emit(rv_lwu(rd, off, rs), ctx);
1640                                 insn_len = ctx->ninsns - insns_start;
1641                                 break;
1642                         }
1643
1644                         emit_imm(RV_REG_T1, off, ctx);
1645                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1646                         insns_start = ctx->ninsns;
1647                         if (sign_ext)
1648                                 emit(rv_lw(rd, 0, RV_REG_T1), ctx);
1649                         else
1650                                 emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
1651                         insn_len = ctx->ninsns - insns_start;
1652                         break;
1653                 case BPF_DW:
1654                         if (is_12b_int(off)) {
1655                                 insns_start = ctx->ninsns;
1656                                 emit_ld(rd, off, rs, ctx);
1657                                 insn_len = ctx->ninsns - insns_start;
1658                                 break;
1659                         }
1660
1661                         emit_imm(RV_REG_T1, off, ctx);
1662                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1663                         insns_start = ctx->ninsns;
1664                         emit_ld(rd, 0, RV_REG_T1, ctx);
1665                         insn_len = ctx->ninsns - insns_start;
1666                         break;
1667                 }
1668
1669                 ret = add_exception_handler(insn, ctx, rd, insn_len);
1670                 if (ret)
1671                         return ret;
1672
1673                 if (BPF_SIZE(code) != BPF_DW && insn_is_zext(&insn[1]))
1674                         return 1;
1675                 break;
1676         }
1677         /* speculation barrier */
1678         case BPF_ST | BPF_NOSPEC:
1679                 break;
1680
1681         /* ST: *(size *)(dst + off) = imm */
1682         case BPF_ST | BPF_MEM | BPF_B:
1683                 emit_imm(RV_REG_T1, imm, ctx);
1684                 if (is_12b_int(off)) {
1685                         emit(rv_sb(rd, off, RV_REG_T1), ctx);
1686                         break;
1687                 }
1688
1689                 emit_imm(RV_REG_T2, off, ctx);
1690                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1691                 emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
1692                 break;
1693
1694         case BPF_ST | BPF_MEM | BPF_H:
1695                 emit_imm(RV_REG_T1, imm, ctx);
1696                 if (is_12b_int(off)) {
1697                         emit(rv_sh(rd, off, RV_REG_T1), ctx);
1698                         break;
1699                 }
1700
1701                 emit_imm(RV_REG_T2, off, ctx);
1702                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1703                 emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
1704                 break;
1705         case BPF_ST | BPF_MEM | BPF_W:
1706                 emit_imm(RV_REG_T1, imm, ctx);
1707                 if (is_12b_int(off)) {
1708                         emit_sw(rd, off, RV_REG_T1, ctx);
1709                         break;
1710                 }
1711
1712                 emit_imm(RV_REG_T2, off, ctx);
1713                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1714                 emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
1715                 break;
1716         case BPF_ST | BPF_MEM | BPF_DW:
1717                 emit_imm(RV_REG_T1, imm, ctx);
1718                 if (is_12b_int(off)) {
1719                         emit_sd(rd, off, RV_REG_T1, ctx);
1720                         break;
1721                 }
1722
1723                 emit_imm(RV_REG_T2, off, ctx);
1724                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1725                 emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
1726                 break;
1727
1728         /* STX: *(size *)(dst + off) = src */
1729         case BPF_STX | BPF_MEM | BPF_B:
1730                 if (is_12b_int(off)) {
1731                         emit(rv_sb(rd, off, rs), ctx);
1732                         break;
1733                 }
1734
1735                 emit_imm(RV_REG_T1, off, ctx);
1736                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1737                 emit(rv_sb(RV_REG_T1, 0, rs), ctx);
1738                 break;
1739         case BPF_STX | BPF_MEM | BPF_H:
1740                 if (is_12b_int(off)) {
1741                         emit(rv_sh(rd, off, rs), ctx);
1742                         break;
1743                 }
1744
1745                 emit_imm(RV_REG_T1, off, ctx);
1746                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1747                 emit(rv_sh(RV_REG_T1, 0, rs), ctx);
1748                 break;
1749         case BPF_STX | BPF_MEM | BPF_W:
1750                 if (is_12b_int(off)) {
1751                         emit_sw(rd, off, rs, ctx);
1752                         break;
1753                 }
1754
1755                 emit_imm(RV_REG_T1, off, ctx);
1756                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1757                 emit_sw(RV_REG_T1, 0, rs, ctx);
1758                 break;
1759         case BPF_STX | BPF_MEM | BPF_DW:
1760                 if (is_12b_int(off)) {
1761                         emit_sd(rd, off, rs, ctx);
1762                         break;
1763                 }
1764
1765                 emit_imm(RV_REG_T1, off, ctx);
1766                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1767                 emit_sd(RV_REG_T1, 0, rs, ctx);
1768                 break;
1769         case BPF_STX | BPF_ATOMIC | BPF_W:
1770         case BPF_STX | BPF_ATOMIC | BPF_DW:
1771                 emit_atomic(rd, rs, off, imm,
1772                             BPF_SIZE(code) == BPF_DW, ctx);
1773                 break;
1774         default:
1775                 pr_err("bpf-jit: unknown opcode %02x\n", code);
1776                 return -EINVAL;
1777         }
1778
1779         return 0;
1780 }
1781
1782 void bpf_jit_build_prologue(struct rv_jit_context *ctx)
1783 {
1784         int i, stack_adjust = 0, store_offset, bpf_stack_adjust;
1785
1786         bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
1787         if (bpf_stack_adjust)
1788                 mark_fp(ctx);
1789
1790         if (seen_reg(RV_REG_RA, ctx))
1791                 stack_adjust += 8;
1792         stack_adjust += 8; /* RV_REG_FP */
1793         if (seen_reg(RV_REG_S1, ctx))
1794                 stack_adjust += 8;
1795         if (seen_reg(RV_REG_S2, ctx))
1796                 stack_adjust += 8;
1797         if (seen_reg(RV_REG_S3, ctx))
1798                 stack_adjust += 8;
1799         if (seen_reg(RV_REG_S4, ctx))
1800                 stack_adjust += 8;
1801         if (seen_reg(RV_REG_S5, ctx))
1802                 stack_adjust += 8;
1803         if (seen_reg(RV_REG_S6, ctx))
1804                 stack_adjust += 8;
1805
1806         stack_adjust = round_up(stack_adjust, 16);
1807         stack_adjust += bpf_stack_adjust;
1808
1809         store_offset = stack_adjust - 8;
1810
1811         /* nops reserved for auipc+jalr pair */
1812         for (i = 0; i < RV_FENTRY_NINSNS; i++)
1813                 emit(rv_nop(), ctx);
1814
1815         /* First instruction is always setting the tail-call-counter
1816          * (TCC) register. This instruction is skipped for tail calls.
1817          * Force using a 4-byte (non-compressed) instruction.
1818          */
1819         emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
1820
1821         emit_addi(RV_REG_SP, RV_REG_SP, -stack_adjust, ctx);
1822
1823         if (seen_reg(RV_REG_RA, ctx)) {
1824                 emit_sd(RV_REG_SP, store_offset, RV_REG_RA, ctx);
1825                 store_offset -= 8;
1826         }
1827         emit_sd(RV_REG_SP, store_offset, RV_REG_FP, ctx);
1828         store_offset -= 8;
1829         if (seen_reg(RV_REG_S1, ctx)) {
1830                 emit_sd(RV_REG_SP, store_offset, RV_REG_S1, ctx);
1831                 store_offset -= 8;
1832         }
1833         if (seen_reg(RV_REG_S2, ctx)) {
1834                 emit_sd(RV_REG_SP, store_offset, RV_REG_S2, ctx);
1835                 store_offset -= 8;
1836         }
1837         if (seen_reg(RV_REG_S3, ctx)) {
1838                 emit_sd(RV_REG_SP, store_offset, RV_REG_S3, ctx);
1839                 store_offset -= 8;
1840         }
1841         if (seen_reg(RV_REG_S4, ctx)) {
1842                 emit_sd(RV_REG_SP, store_offset, RV_REG_S4, ctx);
1843                 store_offset -= 8;
1844         }
1845         if (seen_reg(RV_REG_S5, ctx)) {
1846                 emit_sd(RV_REG_SP, store_offset, RV_REG_S5, ctx);
1847                 store_offset -= 8;
1848         }
1849         if (seen_reg(RV_REG_S6, ctx)) {
1850                 emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
1851                 store_offset -= 8;
1852         }
1853
1854         emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);
1855
1856         if (bpf_stack_adjust)
1857                 emit_addi(RV_REG_S5, RV_REG_SP, bpf_stack_adjust, ctx);
1858
1859         /* Program contains calls and tail calls, so RV_REG_TCC need
1860          * to be saved across calls.
1861          */
1862         if (seen_tail_call(ctx) && seen_call(ctx))
1863                 emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);
1864
1865         ctx->stack_size = stack_adjust;
1866 }
1867
1868 void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
1869 {
1870         __build_epilogue(false, ctx);
1871 }
1872
1873 bool bpf_jit_supports_kfunc_call(void)
1874 {
1875         return true;
1876 }