Merge tag 'drm-next-2021-11-12' of git://anongit.freedesktop.org/drm/drm
[linux-2.6-microblaze.git] / arch / riscv / net / bpf_jit_comp64.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* BPF JIT compiler for RV64G
3  *
4  * Copyright(c) 2019 Björn Töpel <bjorn.topel@gmail.com>
5  *
6  */
7
8 #include <linux/bitfield.h>
9 #include <linux/bpf.h>
10 #include <linux/filter.h>
11 #include "bpf_jit.h"
12
13 #define RV_REG_TCC RV_REG_A6
14 #define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
15
16 static const int regmap[] = {
17         [BPF_REG_0] =   RV_REG_A5,
18         [BPF_REG_1] =   RV_REG_A0,
19         [BPF_REG_2] =   RV_REG_A1,
20         [BPF_REG_3] =   RV_REG_A2,
21         [BPF_REG_4] =   RV_REG_A3,
22         [BPF_REG_5] =   RV_REG_A4,
23         [BPF_REG_6] =   RV_REG_S1,
24         [BPF_REG_7] =   RV_REG_S2,
25         [BPF_REG_8] =   RV_REG_S3,
26         [BPF_REG_9] =   RV_REG_S4,
27         [BPF_REG_FP] =  RV_REG_S5,
28         [BPF_REG_AX] =  RV_REG_T0,
29 };
30
31 static const int pt_regmap[] = {
32         [RV_REG_A0] = offsetof(struct pt_regs, a0),
33         [RV_REG_A1] = offsetof(struct pt_regs, a1),
34         [RV_REG_A2] = offsetof(struct pt_regs, a2),
35         [RV_REG_A3] = offsetof(struct pt_regs, a3),
36         [RV_REG_A4] = offsetof(struct pt_regs, a4),
37         [RV_REG_A5] = offsetof(struct pt_regs, a5),
38         [RV_REG_S1] = offsetof(struct pt_regs, s1),
39         [RV_REG_S2] = offsetof(struct pt_regs, s2),
40         [RV_REG_S3] = offsetof(struct pt_regs, s3),
41         [RV_REG_S4] = offsetof(struct pt_regs, s4),
42         [RV_REG_S5] = offsetof(struct pt_regs, s5),
43         [RV_REG_T0] = offsetof(struct pt_regs, t0),
44 };
45
46 enum {
47         RV_CTX_F_SEEN_TAIL_CALL =       0,
48         RV_CTX_F_SEEN_CALL =            RV_REG_RA,
49         RV_CTX_F_SEEN_S1 =              RV_REG_S1,
50         RV_CTX_F_SEEN_S2 =              RV_REG_S2,
51         RV_CTX_F_SEEN_S3 =              RV_REG_S3,
52         RV_CTX_F_SEEN_S4 =              RV_REG_S4,
53         RV_CTX_F_SEEN_S5 =              RV_REG_S5,
54         RV_CTX_F_SEEN_S6 =              RV_REG_S6,
55 };
56
57 static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
58 {
59         u8 reg = regmap[bpf_reg];
60
61         switch (reg) {
62         case RV_CTX_F_SEEN_S1:
63         case RV_CTX_F_SEEN_S2:
64         case RV_CTX_F_SEEN_S3:
65         case RV_CTX_F_SEEN_S4:
66         case RV_CTX_F_SEEN_S5:
67         case RV_CTX_F_SEEN_S6:
68                 __set_bit(reg, &ctx->flags);
69         }
70         return reg;
71 };
72
73 static bool seen_reg(int reg, struct rv_jit_context *ctx)
74 {
75         switch (reg) {
76         case RV_CTX_F_SEEN_CALL:
77         case RV_CTX_F_SEEN_S1:
78         case RV_CTX_F_SEEN_S2:
79         case RV_CTX_F_SEEN_S3:
80         case RV_CTX_F_SEEN_S4:
81         case RV_CTX_F_SEEN_S5:
82         case RV_CTX_F_SEEN_S6:
83                 return test_bit(reg, &ctx->flags);
84         }
85         return false;
86 }
87
88 static void mark_fp(struct rv_jit_context *ctx)
89 {
90         __set_bit(RV_CTX_F_SEEN_S5, &ctx->flags);
91 }
92
93 static void mark_call(struct rv_jit_context *ctx)
94 {
95         __set_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
96 }
97
98 static bool seen_call(struct rv_jit_context *ctx)
99 {
100         return test_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
101 }
102
103 static void mark_tail_call(struct rv_jit_context *ctx)
104 {
105         __set_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
106 }
107
108 static bool seen_tail_call(struct rv_jit_context *ctx)
109 {
110         return test_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
111 }
112
113 static u8 rv_tail_call_reg(struct rv_jit_context *ctx)
114 {
115         mark_tail_call(ctx);
116
117         if (seen_call(ctx)) {
118                 __set_bit(RV_CTX_F_SEEN_S6, &ctx->flags);
119                 return RV_REG_S6;
120         }
121         return RV_REG_A6;
122 }
123
124 static bool is_32b_int(s64 val)
125 {
126         return -(1L << 31) <= val && val < (1L << 31);
127 }
128
129 static bool in_auipc_jalr_range(s64 val)
130 {
131         /*
132          * auipc+jalr can reach any signed PC-relative offset in the range
133          * [-2^31 - 2^11, 2^31 - 2^11).
134          */
135         return (-(1L << 31) - (1L << 11)) <= val &&
136                 val < ((1L << 31) - (1L << 11));
137 }
138
139 static void emit_imm(u8 rd, s64 val, struct rv_jit_context *ctx)
140 {
141         /* Note that the immediate from the add is sign-extended,
142          * which means that we need to compensate this by adding 2^12,
143          * when the 12th bit is set. A simpler way of doing this, and
144          * getting rid of the check, is to just add 2**11 before the
145          * shift. The "Loading a 32-Bit constant" example from the
146          * "Computer Organization and Design, RISC-V edition" book by
147          * Patterson/Hennessy highlights this fact.
148          *
149          * This also means that we need to process LSB to MSB.
150          */
151         s64 upper = (val + (1 << 11)) >> 12;
152         /* Sign-extend lower 12 bits to 64 bits since immediates for li, addiw,
153          * and addi are signed and RVC checks will perform signed comparisons.
154          */
155         s64 lower = ((val & 0xfff) << 52) >> 52;
156         int shift;
157
158         if (is_32b_int(val)) {
159                 if (upper)
160                         emit_lui(rd, upper, ctx);
161
162                 if (!upper) {
163                         emit_li(rd, lower, ctx);
164                         return;
165                 }
166
167                 emit_addiw(rd, rd, lower, ctx);
168                 return;
169         }
170
171         shift = __ffs(upper);
172         upper >>= shift;
173         shift += 12;
174
175         emit_imm(rd, upper, ctx);
176
177         emit_slli(rd, rd, shift, ctx);
178         if (lower)
179                 emit_addi(rd, rd, lower, ctx);
180 }
181
182 static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
183 {
184         int stack_adjust = ctx->stack_size, store_offset = stack_adjust - 8;
185
186         if (seen_reg(RV_REG_RA, ctx)) {
187                 emit_ld(RV_REG_RA, store_offset, RV_REG_SP, ctx);
188                 store_offset -= 8;
189         }
190         emit_ld(RV_REG_FP, store_offset, RV_REG_SP, ctx);
191         store_offset -= 8;
192         if (seen_reg(RV_REG_S1, ctx)) {
193                 emit_ld(RV_REG_S1, store_offset, RV_REG_SP, ctx);
194                 store_offset -= 8;
195         }
196         if (seen_reg(RV_REG_S2, ctx)) {
197                 emit_ld(RV_REG_S2, store_offset, RV_REG_SP, ctx);
198                 store_offset -= 8;
199         }
200         if (seen_reg(RV_REG_S3, ctx)) {
201                 emit_ld(RV_REG_S3, store_offset, RV_REG_SP, ctx);
202                 store_offset -= 8;
203         }
204         if (seen_reg(RV_REG_S4, ctx)) {
205                 emit_ld(RV_REG_S4, store_offset, RV_REG_SP, ctx);
206                 store_offset -= 8;
207         }
208         if (seen_reg(RV_REG_S5, ctx)) {
209                 emit_ld(RV_REG_S5, store_offset, RV_REG_SP, ctx);
210                 store_offset -= 8;
211         }
212         if (seen_reg(RV_REG_S6, ctx)) {
213                 emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
214                 store_offset -= 8;
215         }
216
217         emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
218         /* Set return value. */
219         if (!is_tail_call)
220                 emit_mv(RV_REG_A0, RV_REG_A5, ctx);
221         emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
222                   is_tail_call ? 4 : 0, /* skip TCC init */
223                   ctx);
224 }
225
226 static void emit_bcc(u8 cond, u8 rd, u8 rs, int rvoff,
227                      struct rv_jit_context *ctx)
228 {
229         switch (cond) {
230         case BPF_JEQ:
231                 emit(rv_beq(rd, rs, rvoff >> 1), ctx);
232                 return;
233         case BPF_JGT:
234                 emit(rv_bltu(rs, rd, rvoff >> 1), ctx);
235                 return;
236         case BPF_JLT:
237                 emit(rv_bltu(rd, rs, rvoff >> 1), ctx);
238                 return;
239         case BPF_JGE:
240                 emit(rv_bgeu(rd, rs, rvoff >> 1), ctx);
241                 return;
242         case BPF_JLE:
243                 emit(rv_bgeu(rs, rd, rvoff >> 1), ctx);
244                 return;
245         case BPF_JNE:
246                 emit(rv_bne(rd, rs, rvoff >> 1), ctx);
247                 return;
248         case BPF_JSGT:
249                 emit(rv_blt(rs, rd, rvoff >> 1), ctx);
250                 return;
251         case BPF_JSLT:
252                 emit(rv_blt(rd, rs, rvoff >> 1), ctx);
253                 return;
254         case BPF_JSGE:
255                 emit(rv_bge(rd, rs, rvoff >> 1), ctx);
256                 return;
257         case BPF_JSLE:
258                 emit(rv_bge(rs, rd, rvoff >> 1), ctx);
259         }
260 }
261
262 static void emit_branch(u8 cond, u8 rd, u8 rs, int rvoff,
263                         struct rv_jit_context *ctx)
264 {
265         s64 upper, lower;
266
267         if (is_13b_int(rvoff)) {
268                 emit_bcc(cond, rd, rs, rvoff, ctx);
269                 return;
270         }
271
272         /* Adjust for jal */
273         rvoff -= 4;
274
275         /* Transform, e.g.:
276          *   bne rd,rs,foo
277          * to
278          *   beq rd,rs,<.L1>
279          *   (auipc foo)
280          *   jal(r) foo
281          * .L1
282          */
283         cond = invert_bpf_cond(cond);
284         if (is_21b_int(rvoff)) {
285                 emit_bcc(cond, rd, rs, 8, ctx);
286                 emit(rv_jal(RV_REG_ZERO, rvoff >> 1), ctx);
287                 return;
288         }
289
290         /* 32b No need for an additional rvoff adjustment, since we
291          * get that from the auipc at PC', where PC = PC' + 4.
292          */
293         upper = (rvoff + (1 << 11)) >> 12;
294         lower = rvoff & 0xfff;
295
296         emit_bcc(cond, rd, rs, 12, ctx);
297         emit(rv_auipc(RV_REG_T1, upper), ctx);
298         emit(rv_jalr(RV_REG_ZERO, RV_REG_T1, lower), ctx);
299 }
300
301 static void emit_zext_32(u8 reg, struct rv_jit_context *ctx)
302 {
303         emit_slli(reg, reg, 32, ctx);
304         emit_srli(reg, reg, 32, ctx);
305 }
306
307 static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
308 {
309         int tc_ninsn, off, start_insn = ctx->ninsns;
310         u8 tcc = rv_tail_call_reg(ctx);
311
312         /* a0: &ctx
313          * a1: &array
314          * a2: index
315          *
316          * if (index >= array->map.max_entries)
317          *      goto out;
318          */
319         tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
320                    ctx->offset[0];
321         emit_zext_32(RV_REG_A2, ctx);
322
323         off = offsetof(struct bpf_array, map.max_entries);
324         if (is_12b_check(off, insn))
325                 return -1;
326         emit(rv_lwu(RV_REG_T1, off, RV_REG_A1), ctx);
327         off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
328         emit_branch(BPF_JGE, RV_REG_A2, RV_REG_T1, off, ctx);
329
330         /* if (TCC-- < 0)
331          *     goto out;
332          */
333         emit_addi(RV_REG_T1, tcc, -1, ctx);
334         off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
335         emit_branch(BPF_JSLT, tcc, RV_REG_ZERO, off, ctx);
336
337         /* prog = array->ptrs[index];
338          * if (!prog)
339          *     goto out;
340          */
341         emit_slli(RV_REG_T2, RV_REG_A2, 3, ctx);
342         emit_add(RV_REG_T2, RV_REG_T2, RV_REG_A1, ctx);
343         off = offsetof(struct bpf_array, ptrs);
344         if (is_12b_check(off, insn))
345                 return -1;
346         emit_ld(RV_REG_T2, off, RV_REG_T2, ctx);
347         off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
348         emit_branch(BPF_JEQ, RV_REG_T2, RV_REG_ZERO, off, ctx);
349
350         /* goto *(prog->bpf_func + 4); */
351         off = offsetof(struct bpf_prog, bpf_func);
352         if (is_12b_check(off, insn))
353                 return -1;
354         emit_ld(RV_REG_T3, off, RV_REG_T2, ctx);
355         emit_mv(RV_REG_TCC, RV_REG_T1, ctx);
356         __build_epilogue(true, ctx);
357         return 0;
358 }
359
360 static void init_regs(u8 *rd, u8 *rs, const struct bpf_insn *insn,
361                       struct rv_jit_context *ctx)
362 {
363         u8 code = insn->code;
364
365         switch (code) {
366         case BPF_JMP | BPF_JA:
367         case BPF_JMP | BPF_CALL:
368         case BPF_JMP | BPF_EXIT:
369         case BPF_JMP | BPF_TAIL_CALL:
370                 break;
371         default:
372                 *rd = bpf_to_rv_reg(insn->dst_reg, ctx);
373         }
374
375         if (code & (BPF_ALU | BPF_X) || code & (BPF_ALU64 | BPF_X) ||
376             code & (BPF_JMP | BPF_X) || code & (BPF_JMP32 | BPF_X) ||
377             code & BPF_LDX || code & BPF_STX)
378                 *rs = bpf_to_rv_reg(insn->src_reg, ctx);
379 }
380
381 static void emit_zext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
382 {
383         emit_mv(RV_REG_T2, *rd, ctx);
384         emit_zext_32(RV_REG_T2, ctx);
385         emit_mv(RV_REG_T1, *rs, ctx);
386         emit_zext_32(RV_REG_T1, ctx);
387         *rd = RV_REG_T2;
388         *rs = RV_REG_T1;
389 }
390
391 static void emit_sext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
392 {
393         emit_addiw(RV_REG_T2, *rd, 0, ctx);
394         emit_addiw(RV_REG_T1, *rs, 0, ctx);
395         *rd = RV_REG_T2;
396         *rs = RV_REG_T1;
397 }
398
399 static void emit_zext_32_rd_t1(u8 *rd, struct rv_jit_context *ctx)
400 {
401         emit_mv(RV_REG_T2, *rd, ctx);
402         emit_zext_32(RV_REG_T2, ctx);
403         emit_zext_32(RV_REG_T1, ctx);
404         *rd = RV_REG_T2;
405 }
406
407 static void emit_sext_32_rd(u8 *rd, struct rv_jit_context *ctx)
408 {
409         emit_addiw(RV_REG_T2, *rd, 0, ctx);
410         *rd = RV_REG_T2;
411 }
412
413 static int emit_jump_and_link(u8 rd, s64 rvoff, bool force_jalr,
414                               struct rv_jit_context *ctx)
415 {
416         s64 upper, lower;
417
418         if (rvoff && is_21b_int(rvoff) && !force_jalr) {
419                 emit(rv_jal(rd, rvoff >> 1), ctx);
420                 return 0;
421         } else if (in_auipc_jalr_range(rvoff)) {
422                 upper = (rvoff + (1 << 11)) >> 12;
423                 lower = rvoff & 0xfff;
424                 emit(rv_auipc(RV_REG_T1, upper), ctx);
425                 emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
426                 return 0;
427         }
428
429         pr_err("bpf-jit: target offset 0x%llx is out of range\n", rvoff);
430         return -ERANGE;
431 }
432
433 static bool is_signed_bpf_cond(u8 cond)
434 {
435         return cond == BPF_JSGT || cond == BPF_JSLT ||
436                 cond == BPF_JSGE || cond == BPF_JSLE;
437 }
438
439 static int emit_call(bool fixed, u64 addr, struct rv_jit_context *ctx)
440 {
441         s64 off = 0;
442         u64 ip;
443         u8 rd;
444         int ret;
445
446         if (addr && ctx->insns) {
447                 ip = (u64)(long)(ctx->insns + ctx->ninsns);
448                 off = addr - ip;
449         }
450
451         ret = emit_jump_and_link(RV_REG_RA, off, !fixed, ctx);
452         if (ret)
453                 return ret;
454         rd = bpf_to_rv_reg(BPF_REG_0, ctx);
455         emit_mv(rd, RV_REG_A0, ctx);
456         return 0;
457 }
458
459 #define BPF_FIXUP_OFFSET_MASK   GENMASK(26, 0)
460 #define BPF_FIXUP_REG_MASK      GENMASK(31, 27)
461
462 int rv_bpf_fixup_exception(const struct exception_table_entry *ex,
463                                 struct pt_regs *regs);
464 int rv_bpf_fixup_exception(const struct exception_table_entry *ex,
465                                 struct pt_regs *regs)
466 {
467         off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
468         int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
469
470         *(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
471         regs->epc = (unsigned long)&ex->fixup - offset;
472
473         return 1;
474 }
475
476 /* For accesses to BTF pointers, add an entry to the exception table */
477 static int add_exception_handler(const struct bpf_insn *insn,
478                                  struct rv_jit_context *ctx,
479                                  int dst_reg, int insn_len)
480 {
481         struct exception_table_entry *ex;
482         unsigned long pc;
483         off_t offset;
484
485         if (!ctx->insns || !ctx->prog->aux->extable || BPF_MODE(insn->code) != BPF_PROBE_MEM)
486                 return 0;
487
488         if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
489                 return -EINVAL;
490
491         if (WARN_ON_ONCE(insn_len > ctx->ninsns))
492                 return -EINVAL;
493
494         if (WARN_ON_ONCE(!rvc_enabled() && insn_len == 1))
495                 return -EINVAL;
496
497         ex = &ctx->prog->aux->extable[ctx->nexentries];
498         pc = (unsigned long)&ctx->insns[ctx->ninsns - insn_len];
499
500         offset = pc - (long)&ex->insn;
501         if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN))
502                 return -ERANGE;
503         ex->insn = pc;
504
505         /*
506          * Since the extable follows the program, the fixup offset is always
507          * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
508          * to keep things simple, and put the destination register in the upper
509          * bits. We don't need to worry about buildtime or runtime sort
510          * modifying the upper bits because the table is already sorted, and
511          * isn't part of the main exception table.
512          */
513         offset = (long)&ex->fixup - (pc + insn_len * sizeof(u16));
514         if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset))
515                 return -ERANGE;
516
517         ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) |
518                 FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
519
520         ctx->nexentries++;
521         return 0;
522 }
523
524 int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
525                       bool extra_pass)
526 {
527         bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
528                     BPF_CLASS(insn->code) == BPF_JMP;
529         int s, e, rvoff, ret, i = insn - ctx->prog->insnsi;
530         struct bpf_prog_aux *aux = ctx->prog->aux;
531         u8 rd = -1, rs = -1, code = insn->code;
532         s16 off = insn->off;
533         s32 imm = insn->imm;
534
535         init_regs(&rd, &rs, insn, ctx);
536
537         switch (code) {
538         /* dst = src */
539         case BPF_ALU | BPF_MOV | BPF_X:
540         case BPF_ALU64 | BPF_MOV | BPF_X:
541                 if (imm == 1) {
542                         /* Special mov32 for zext */
543                         emit_zext_32(rd, ctx);
544                         break;
545                 }
546                 emit_mv(rd, rs, ctx);
547                 if (!is64 && !aux->verifier_zext)
548                         emit_zext_32(rd, ctx);
549                 break;
550
551         /* dst = dst OP src */
552         case BPF_ALU | BPF_ADD | BPF_X:
553         case BPF_ALU64 | BPF_ADD | BPF_X:
554                 emit_add(rd, rd, rs, ctx);
555                 if (!is64 && !aux->verifier_zext)
556                         emit_zext_32(rd, ctx);
557                 break;
558         case BPF_ALU | BPF_SUB | BPF_X:
559         case BPF_ALU64 | BPF_SUB | BPF_X:
560                 if (is64)
561                         emit_sub(rd, rd, rs, ctx);
562                 else
563                         emit_subw(rd, rd, rs, ctx);
564
565                 if (!is64 && !aux->verifier_zext)
566                         emit_zext_32(rd, ctx);
567                 break;
568         case BPF_ALU | BPF_AND | BPF_X:
569         case BPF_ALU64 | BPF_AND | BPF_X:
570                 emit_and(rd, rd, rs, ctx);
571                 if (!is64 && !aux->verifier_zext)
572                         emit_zext_32(rd, ctx);
573                 break;
574         case BPF_ALU | BPF_OR | BPF_X:
575         case BPF_ALU64 | BPF_OR | BPF_X:
576                 emit_or(rd, rd, rs, ctx);
577                 if (!is64 && !aux->verifier_zext)
578                         emit_zext_32(rd, ctx);
579                 break;
580         case BPF_ALU | BPF_XOR | BPF_X:
581         case BPF_ALU64 | BPF_XOR | BPF_X:
582                 emit_xor(rd, rd, rs, ctx);
583                 if (!is64 && !aux->verifier_zext)
584                         emit_zext_32(rd, ctx);
585                 break;
586         case BPF_ALU | BPF_MUL | BPF_X:
587         case BPF_ALU64 | BPF_MUL | BPF_X:
588                 emit(is64 ? rv_mul(rd, rd, rs) : rv_mulw(rd, rd, rs), ctx);
589                 if (!is64 && !aux->verifier_zext)
590                         emit_zext_32(rd, ctx);
591                 break;
592         case BPF_ALU | BPF_DIV | BPF_X:
593         case BPF_ALU64 | BPF_DIV | BPF_X:
594                 emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx);
595                 if (!is64 && !aux->verifier_zext)
596                         emit_zext_32(rd, ctx);
597                 break;
598         case BPF_ALU | BPF_MOD | BPF_X:
599         case BPF_ALU64 | BPF_MOD | BPF_X:
600                 emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx);
601                 if (!is64 && !aux->verifier_zext)
602                         emit_zext_32(rd, ctx);
603                 break;
604         case BPF_ALU | BPF_LSH | BPF_X:
605         case BPF_ALU64 | BPF_LSH | BPF_X:
606                 emit(is64 ? rv_sll(rd, rd, rs) : rv_sllw(rd, rd, rs), ctx);
607                 if (!is64 && !aux->verifier_zext)
608                         emit_zext_32(rd, ctx);
609                 break;
610         case BPF_ALU | BPF_RSH | BPF_X:
611         case BPF_ALU64 | BPF_RSH | BPF_X:
612                 emit(is64 ? rv_srl(rd, rd, rs) : rv_srlw(rd, rd, rs), ctx);
613                 if (!is64 && !aux->verifier_zext)
614                         emit_zext_32(rd, ctx);
615                 break;
616         case BPF_ALU | BPF_ARSH | BPF_X:
617         case BPF_ALU64 | BPF_ARSH | BPF_X:
618                 emit(is64 ? rv_sra(rd, rd, rs) : rv_sraw(rd, rd, rs), ctx);
619                 if (!is64 && !aux->verifier_zext)
620                         emit_zext_32(rd, ctx);
621                 break;
622
623         /* dst = -dst */
624         case BPF_ALU | BPF_NEG:
625         case BPF_ALU64 | BPF_NEG:
626                 emit_sub(rd, RV_REG_ZERO, rd, ctx);
627                 if (!is64 && !aux->verifier_zext)
628                         emit_zext_32(rd, ctx);
629                 break;
630
631         /* dst = BSWAP##imm(dst) */
632         case BPF_ALU | BPF_END | BPF_FROM_LE:
633                 switch (imm) {
634                 case 16:
635                         emit_slli(rd, rd, 48, ctx);
636                         emit_srli(rd, rd, 48, ctx);
637                         break;
638                 case 32:
639                         if (!aux->verifier_zext)
640                                 emit_zext_32(rd, ctx);
641                         break;
642                 case 64:
643                         /* Do nothing */
644                         break;
645                 }
646                 break;
647
648         case BPF_ALU | BPF_END | BPF_FROM_BE:
649                 emit_li(RV_REG_T2, 0, ctx);
650
651                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
652                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
653                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
654                 emit_srli(rd, rd, 8, ctx);
655                 if (imm == 16)
656                         goto out_be;
657
658                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
659                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
660                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
661                 emit_srli(rd, rd, 8, ctx);
662
663                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
664                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
665                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
666                 emit_srli(rd, rd, 8, ctx);
667                 if (imm == 32)
668                         goto out_be;
669
670                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
671                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
672                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
673                 emit_srli(rd, rd, 8, ctx);
674
675                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
676                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
677                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
678                 emit_srli(rd, rd, 8, ctx);
679
680                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
681                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
682                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
683                 emit_srli(rd, rd, 8, ctx);
684
685                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
686                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
687                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
688                 emit_srli(rd, rd, 8, ctx);
689 out_be:
690                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
691                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
692
693                 emit_mv(rd, RV_REG_T2, ctx);
694                 break;
695
696         /* dst = imm */
697         case BPF_ALU | BPF_MOV | BPF_K:
698         case BPF_ALU64 | BPF_MOV | BPF_K:
699                 emit_imm(rd, imm, ctx);
700                 if (!is64 && !aux->verifier_zext)
701                         emit_zext_32(rd, ctx);
702                 break;
703
704         /* dst = dst OP imm */
705         case BPF_ALU | BPF_ADD | BPF_K:
706         case BPF_ALU64 | BPF_ADD | BPF_K:
707                 if (is_12b_int(imm)) {
708                         emit_addi(rd, rd, imm, ctx);
709                 } else {
710                         emit_imm(RV_REG_T1, imm, ctx);
711                         emit_add(rd, rd, RV_REG_T1, ctx);
712                 }
713                 if (!is64 && !aux->verifier_zext)
714                         emit_zext_32(rd, ctx);
715                 break;
716         case BPF_ALU | BPF_SUB | BPF_K:
717         case BPF_ALU64 | BPF_SUB | BPF_K:
718                 if (is_12b_int(-imm)) {
719                         emit_addi(rd, rd, -imm, ctx);
720                 } else {
721                         emit_imm(RV_REG_T1, imm, ctx);
722                         emit_sub(rd, rd, RV_REG_T1, ctx);
723                 }
724                 if (!is64 && !aux->verifier_zext)
725                         emit_zext_32(rd, ctx);
726                 break;
727         case BPF_ALU | BPF_AND | BPF_K:
728         case BPF_ALU64 | BPF_AND | BPF_K:
729                 if (is_12b_int(imm)) {
730                         emit_andi(rd, rd, imm, ctx);
731                 } else {
732                         emit_imm(RV_REG_T1, imm, ctx);
733                         emit_and(rd, rd, RV_REG_T1, ctx);
734                 }
735                 if (!is64 && !aux->verifier_zext)
736                         emit_zext_32(rd, ctx);
737                 break;
738         case BPF_ALU | BPF_OR | BPF_K:
739         case BPF_ALU64 | BPF_OR | BPF_K:
740                 if (is_12b_int(imm)) {
741                         emit(rv_ori(rd, rd, imm), ctx);
742                 } else {
743                         emit_imm(RV_REG_T1, imm, ctx);
744                         emit_or(rd, rd, RV_REG_T1, ctx);
745                 }
746                 if (!is64 && !aux->verifier_zext)
747                         emit_zext_32(rd, ctx);
748                 break;
749         case BPF_ALU | BPF_XOR | BPF_K:
750         case BPF_ALU64 | BPF_XOR | BPF_K:
751                 if (is_12b_int(imm)) {
752                         emit(rv_xori(rd, rd, imm), ctx);
753                 } else {
754                         emit_imm(RV_REG_T1, imm, ctx);
755                         emit_xor(rd, rd, RV_REG_T1, ctx);
756                 }
757                 if (!is64 && !aux->verifier_zext)
758                         emit_zext_32(rd, ctx);
759                 break;
760         case BPF_ALU | BPF_MUL | BPF_K:
761         case BPF_ALU64 | BPF_MUL | BPF_K:
762                 emit_imm(RV_REG_T1, imm, ctx);
763                 emit(is64 ? rv_mul(rd, rd, RV_REG_T1) :
764                      rv_mulw(rd, rd, RV_REG_T1), ctx);
765                 if (!is64 && !aux->verifier_zext)
766                         emit_zext_32(rd, ctx);
767                 break;
768         case BPF_ALU | BPF_DIV | BPF_K:
769         case BPF_ALU64 | BPF_DIV | BPF_K:
770                 emit_imm(RV_REG_T1, imm, ctx);
771                 emit(is64 ? rv_divu(rd, rd, RV_REG_T1) :
772                      rv_divuw(rd, rd, RV_REG_T1), ctx);
773                 if (!is64 && !aux->verifier_zext)
774                         emit_zext_32(rd, ctx);
775                 break;
776         case BPF_ALU | BPF_MOD | BPF_K:
777         case BPF_ALU64 | BPF_MOD | BPF_K:
778                 emit_imm(RV_REG_T1, imm, ctx);
779                 emit(is64 ? rv_remu(rd, rd, RV_REG_T1) :
780                      rv_remuw(rd, rd, RV_REG_T1), ctx);
781                 if (!is64 && !aux->verifier_zext)
782                         emit_zext_32(rd, ctx);
783                 break;
784         case BPF_ALU | BPF_LSH | BPF_K:
785         case BPF_ALU64 | BPF_LSH | BPF_K:
786                 emit_slli(rd, rd, imm, ctx);
787
788                 if (!is64 && !aux->verifier_zext)
789                         emit_zext_32(rd, ctx);
790                 break;
791         case BPF_ALU | BPF_RSH | BPF_K:
792         case BPF_ALU64 | BPF_RSH | BPF_K:
793                 if (is64)
794                         emit_srli(rd, rd, imm, ctx);
795                 else
796                         emit(rv_srliw(rd, rd, imm), ctx);
797
798                 if (!is64 && !aux->verifier_zext)
799                         emit_zext_32(rd, ctx);
800                 break;
801         case BPF_ALU | BPF_ARSH | BPF_K:
802         case BPF_ALU64 | BPF_ARSH | BPF_K:
803                 if (is64)
804                         emit_srai(rd, rd, imm, ctx);
805                 else
806                         emit(rv_sraiw(rd, rd, imm), ctx);
807
808                 if (!is64 && !aux->verifier_zext)
809                         emit_zext_32(rd, ctx);
810                 break;
811
812         /* JUMP off */
813         case BPF_JMP | BPF_JA:
814                 rvoff = rv_offset(i, off, ctx);
815                 ret = emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
816                 if (ret)
817                         return ret;
818                 break;
819
820         /* IF (dst COND src) JUMP off */
821         case BPF_JMP | BPF_JEQ | BPF_X:
822         case BPF_JMP32 | BPF_JEQ | BPF_X:
823         case BPF_JMP | BPF_JGT | BPF_X:
824         case BPF_JMP32 | BPF_JGT | BPF_X:
825         case BPF_JMP | BPF_JLT | BPF_X:
826         case BPF_JMP32 | BPF_JLT | BPF_X:
827         case BPF_JMP | BPF_JGE | BPF_X:
828         case BPF_JMP32 | BPF_JGE | BPF_X:
829         case BPF_JMP | BPF_JLE | BPF_X:
830         case BPF_JMP32 | BPF_JLE | BPF_X:
831         case BPF_JMP | BPF_JNE | BPF_X:
832         case BPF_JMP32 | BPF_JNE | BPF_X:
833         case BPF_JMP | BPF_JSGT | BPF_X:
834         case BPF_JMP32 | BPF_JSGT | BPF_X:
835         case BPF_JMP | BPF_JSLT | BPF_X:
836         case BPF_JMP32 | BPF_JSLT | BPF_X:
837         case BPF_JMP | BPF_JSGE | BPF_X:
838         case BPF_JMP32 | BPF_JSGE | BPF_X:
839         case BPF_JMP | BPF_JSLE | BPF_X:
840         case BPF_JMP32 | BPF_JSLE | BPF_X:
841         case BPF_JMP | BPF_JSET | BPF_X:
842         case BPF_JMP32 | BPF_JSET | BPF_X:
843                 rvoff = rv_offset(i, off, ctx);
844                 if (!is64) {
845                         s = ctx->ninsns;
846                         if (is_signed_bpf_cond(BPF_OP(code)))
847                                 emit_sext_32_rd_rs(&rd, &rs, ctx);
848                         else
849                                 emit_zext_32_rd_rs(&rd, &rs, ctx);
850                         e = ctx->ninsns;
851
852                         /* Adjust for extra insns */
853                         rvoff -= ninsns_rvoff(e - s);
854                 }
855
856                 if (BPF_OP(code) == BPF_JSET) {
857                         /* Adjust for and */
858                         rvoff -= 4;
859                         emit_and(RV_REG_T1, rd, rs, ctx);
860                         emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff,
861                                     ctx);
862                 } else {
863                         emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
864                 }
865                 break;
866
867         /* IF (dst COND imm) JUMP off */
868         case BPF_JMP | BPF_JEQ | BPF_K:
869         case BPF_JMP32 | BPF_JEQ | BPF_K:
870         case BPF_JMP | BPF_JGT | BPF_K:
871         case BPF_JMP32 | BPF_JGT | BPF_K:
872         case BPF_JMP | BPF_JLT | BPF_K:
873         case BPF_JMP32 | BPF_JLT | BPF_K:
874         case BPF_JMP | BPF_JGE | BPF_K:
875         case BPF_JMP32 | BPF_JGE | BPF_K:
876         case BPF_JMP | BPF_JLE | BPF_K:
877         case BPF_JMP32 | BPF_JLE | BPF_K:
878         case BPF_JMP | BPF_JNE | BPF_K:
879         case BPF_JMP32 | BPF_JNE | BPF_K:
880         case BPF_JMP | BPF_JSGT | BPF_K:
881         case BPF_JMP32 | BPF_JSGT | BPF_K:
882         case BPF_JMP | BPF_JSLT | BPF_K:
883         case BPF_JMP32 | BPF_JSLT | BPF_K:
884         case BPF_JMP | BPF_JSGE | BPF_K:
885         case BPF_JMP32 | BPF_JSGE | BPF_K:
886         case BPF_JMP | BPF_JSLE | BPF_K:
887         case BPF_JMP32 | BPF_JSLE | BPF_K:
888                 rvoff = rv_offset(i, off, ctx);
889                 s = ctx->ninsns;
890                 if (imm) {
891                         emit_imm(RV_REG_T1, imm, ctx);
892                         rs = RV_REG_T1;
893                 } else {
894                         /* If imm is 0, simply use zero register. */
895                         rs = RV_REG_ZERO;
896                 }
897                 if (!is64) {
898                         if (is_signed_bpf_cond(BPF_OP(code)))
899                                 emit_sext_32_rd(&rd, ctx);
900                         else
901                                 emit_zext_32_rd_t1(&rd, ctx);
902                 }
903                 e = ctx->ninsns;
904
905                 /* Adjust for extra insns */
906                 rvoff -= ninsns_rvoff(e - s);
907                 emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
908                 break;
909
910         case BPF_JMP | BPF_JSET | BPF_K:
911         case BPF_JMP32 | BPF_JSET | BPF_K:
912                 rvoff = rv_offset(i, off, ctx);
913                 s = ctx->ninsns;
914                 if (is_12b_int(imm)) {
915                         emit_andi(RV_REG_T1, rd, imm, ctx);
916                 } else {
917                         emit_imm(RV_REG_T1, imm, ctx);
918                         emit_and(RV_REG_T1, rd, RV_REG_T1, ctx);
919                 }
920                 /* For jset32, we should clear the upper 32 bits of t1, but
921                  * sign-extension is sufficient here and saves one instruction,
922                  * as t1 is used only in comparison against zero.
923                  */
924                 if (!is64 && imm < 0)
925                         emit_addiw(RV_REG_T1, RV_REG_T1, 0, ctx);
926                 e = ctx->ninsns;
927                 rvoff -= ninsns_rvoff(e - s);
928                 emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx);
929                 break;
930
931         /* function call */
932         case BPF_JMP | BPF_CALL:
933         {
934                 bool fixed;
935                 u64 addr;
936
937                 mark_call(ctx);
938                 ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass, &addr,
939                                             &fixed);
940                 if (ret < 0)
941                         return ret;
942                 ret = emit_call(fixed, addr, ctx);
943                 if (ret)
944                         return ret;
945                 break;
946         }
947         /* tail call */
948         case BPF_JMP | BPF_TAIL_CALL:
949                 if (emit_bpf_tail_call(i, ctx))
950                         return -1;
951                 break;
952
953         /* function return */
954         case BPF_JMP | BPF_EXIT:
955                 if (i == ctx->prog->len - 1)
956                         break;
957
958                 rvoff = epilogue_offset(ctx);
959                 ret = emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
960                 if (ret)
961                         return ret;
962                 break;
963
964         /* dst = imm64 */
965         case BPF_LD | BPF_IMM | BPF_DW:
966         {
967                 struct bpf_insn insn1 = insn[1];
968                 u64 imm64;
969
970                 imm64 = (u64)insn1.imm << 32 | (u32)imm;
971                 emit_imm(rd, imm64, ctx);
972                 return 1;
973         }
974
975         /* LDX: dst = *(size *)(src + off) */
976         case BPF_LDX | BPF_MEM | BPF_B:
977         case BPF_LDX | BPF_MEM | BPF_H:
978         case BPF_LDX | BPF_MEM | BPF_W:
979         case BPF_LDX | BPF_MEM | BPF_DW:
980         case BPF_LDX | BPF_PROBE_MEM | BPF_B:
981         case BPF_LDX | BPF_PROBE_MEM | BPF_H:
982         case BPF_LDX | BPF_PROBE_MEM | BPF_W:
983         case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
984         {
985                 int insn_len, insns_start;
986
987                 switch (BPF_SIZE(code)) {
988                 case BPF_B:
989                         if (is_12b_int(off)) {
990                                 insns_start = ctx->ninsns;
991                                 emit(rv_lbu(rd, off, rs), ctx);
992                                 insn_len = ctx->ninsns - insns_start;
993                                 break;
994                         }
995
996                         emit_imm(RV_REG_T1, off, ctx);
997                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
998                         insns_start = ctx->ninsns;
999                         emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
1000                         insn_len = ctx->ninsns - insns_start;
1001                         if (insn_is_zext(&insn[1]))
1002                                 return 1;
1003                         break;
1004                 case BPF_H:
1005                         if (is_12b_int(off)) {
1006                                 insns_start = ctx->ninsns;
1007                                 emit(rv_lhu(rd, off, rs), ctx);
1008                                 insn_len = ctx->ninsns - insns_start;
1009                                 break;
1010                         }
1011
1012                         emit_imm(RV_REG_T1, off, ctx);
1013                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1014                         insns_start = ctx->ninsns;
1015                         emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
1016                         insn_len = ctx->ninsns - insns_start;
1017                         if (insn_is_zext(&insn[1]))
1018                                 return 1;
1019                         break;
1020                 case BPF_W:
1021                         if (is_12b_int(off)) {
1022                                 insns_start = ctx->ninsns;
1023                                 emit(rv_lwu(rd, off, rs), ctx);
1024                                 insn_len = ctx->ninsns - insns_start;
1025                                 break;
1026                         }
1027
1028                         emit_imm(RV_REG_T1, off, ctx);
1029                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1030                         insns_start = ctx->ninsns;
1031                         emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
1032                         insn_len = ctx->ninsns - insns_start;
1033                         if (insn_is_zext(&insn[1]))
1034                                 return 1;
1035                         break;
1036                 case BPF_DW:
1037                         if (is_12b_int(off)) {
1038                                 insns_start = ctx->ninsns;
1039                                 emit_ld(rd, off, rs, ctx);
1040                                 insn_len = ctx->ninsns - insns_start;
1041                                 break;
1042                         }
1043
1044                         emit_imm(RV_REG_T1, off, ctx);
1045                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1046                         insns_start = ctx->ninsns;
1047                         emit_ld(rd, 0, RV_REG_T1, ctx);
1048                         insn_len = ctx->ninsns - insns_start;
1049                         break;
1050                 }
1051
1052                 ret = add_exception_handler(insn, ctx, rd, insn_len);
1053                 if (ret)
1054                         return ret;
1055                 break;
1056         }
1057         /* speculation barrier */
1058         case BPF_ST | BPF_NOSPEC:
1059                 break;
1060
1061         /* ST: *(size *)(dst + off) = imm */
1062         case BPF_ST | BPF_MEM | BPF_B:
1063                 emit_imm(RV_REG_T1, imm, ctx);
1064                 if (is_12b_int(off)) {
1065                         emit(rv_sb(rd, off, RV_REG_T1), ctx);
1066                         break;
1067                 }
1068
1069                 emit_imm(RV_REG_T2, off, ctx);
1070                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1071                 emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
1072                 break;
1073
1074         case BPF_ST | BPF_MEM | BPF_H:
1075                 emit_imm(RV_REG_T1, imm, ctx);
1076                 if (is_12b_int(off)) {
1077                         emit(rv_sh(rd, off, RV_REG_T1), ctx);
1078                         break;
1079                 }
1080
1081                 emit_imm(RV_REG_T2, off, ctx);
1082                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1083                 emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
1084                 break;
1085         case BPF_ST | BPF_MEM | BPF_W:
1086                 emit_imm(RV_REG_T1, imm, ctx);
1087                 if (is_12b_int(off)) {
1088                         emit_sw(rd, off, RV_REG_T1, ctx);
1089                         break;
1090                 }
1091
1092                 emit_imm(RV_REG_T2, off, ctx);
1093                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1094                 emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
1095                 break;
1096         case BPF_ST | BPF_MEM | BPF_DW:
1097                 emit_imm(RV_REG_T1, imm, ctx);
1098                 if (is_12b_int(off)) {
1099                         emit_sd(rd, off, RV_REG_T1, ctx);
1100                         break;
1101                 }
1102
1103                 emit_imm(RV_REG_T2, off, ctx);
1104                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1105                 emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
1106                 break;
1107
1108         /* STX: *(size *)(dst + off) = src */
1109         case BPF_STX | BPF_MEM | BPF_B:
1110                 if (is_12b_int(off)) {
1111                         emit(rv_sb(rd, off, rs), ctx);
1112                         break;
1113                 }
1114
1115                 emit_imm(RV_REG_T1, off, ctx);
1116                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1117                 emit(rv_sb(RV_REG_T1, 0, rs), ctx);
1118                 break;
1119         case BPF_STX | BPF_MEM | BPF_H:
1120                 if (is_12b_int(off)) {
1121                         emit(rv_sh(rd, off, rs), ctx);
1122                         break;
1123                 }
1124
1125                 emit_imm(RV_REG_T1, off, ctx);
1126                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1127                 emit(rv_sh(RV_REG_T1, 0, rs), ctx);
1128                 break;
1129         case BPF_STX | BPF_MEM | BPF_W:
1130                 if (is_12b_int(off)) {
1131                         emit_sw(rd, off, rs, ctx);
1132                         break;
1133                 }
1134
1135                 emit_imm(RV_REG_T1, off, ctx);
1136                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1137                 emit_sw(RV_REG_T1, 0, rs, ctx);
1138                 break;
1139         case BPF_STX | BPF_MEM | BPF_DW:
1140                 if (is_12b_int(off)) {
1141                         emit_sd(rd, off, rs, ctx);
1142                         break;
1143                 }
1144
1145                 emit_imm(RV_REG_T1, off, ctx);
1146                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1147                 emit_sd(RV_REG_T1, 0, rs, ctx);
1148                 break;
1149         case BPF_STX | BPF_ATOMIC | BPF_W:
1150         case BPF_STX | BPF_ATOMIC | BPF_DW:
1151                 if (insn->imm != BPF_ADD) {
1152                         pr_err("bpf-jit: not supported: atomic operation %02x ***\n",
1153                                insn->imm);
1154                         return -EINVAL;
1155                 }
1156
1157                 /* atomic_add: lock *(u32 *)(dst + off) += src
1158                  * atomic_add: lock *(u64 *)(dst + off) += src
1159                  */
1160
1161                 if (off) {
1162                         if (is_12b_int(off)) {
1163                                 emit_addi(RV_REG_T1, rd, off, ctx);
1164                         } else {
1165                                 emit_imm(RV_REG_T1, off, ctx);
1166                                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1167                         }
1168
1169                         rd = RV_REG_T1;
1170                 }
1171
1172                 emit(BPF_SIZE(code) == BPF_W ?
1173                      rv_amoadd_w(RV_REG_ZERO, rs, rd, 0, 0) :
1174                      rv_amoadd_d(RV_REG_ZERO, rs, rd, 0, 0), ctx);
1175                 break;
1176         default:
1177                 pr_err("bpf-jit: unknown opcode %02x\n", code);
1178                 return -EINVAL;
1179         }
1180
1181         return 0;
1182 }
1183
1184 void bpf_jit_build_prologue(struct rv_jit_context *ctx)
1185 {
1186         int stack_adjust = 0, store_offset, bpf_stack_adjust;
1187
1188         bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
1189         if (bpf_stack_adjust)
1190                 mark_fp(ctx);
1191
1192         if (seen_reg(RV_REG_RA, ctx))
1193                 stack_adjust += 8;
1194         stack_adjust += 8; /* RV_REG_FP */
1195         if (seen_reg(RV_REG_S1, ctx))
1196                 stack_adjust += 8;
1197         if (seen_reg(RV_REG_S2, ctx))
1198                 stack_adjust += 8;
1199         if (seen_reg(RV_REG_S3, ctx))
1200                 stack_adjust += 8;
1201         if (seen_reg(RV_REG_S4, ctx))
1202                 stack_adjust += 8;
1203         if (seen_reg(RV_REG_S5, ctx))
1204                 stack_adjust += 8;
1205         if (seen_reg(RV_REG_S6, ctx))
1206                 stack_adjust += 8;
1207
1208         stack_adjust = round_up(stack_adjust, 16);
1209         stack_adjust += bpf_stack_adjust;
1210
1211         store_offset = stack_adjust - 8;
1212
1213         /* First instruction is always setting the tail-call-counter
1214          * (TCC) register. This instruction is skipped for tail calls.
1215          * Force using a 4-byte (non-compressed) instruction.
1216          */
1217         emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
1218
1219         emit_addi(RV_REG_SP, RV_REG_SP, -stack_adjust, ctx);
1220
1221         if (seen_reg(RV_REG_RA, ctx)) {
1222                 emit_sd(RV_REG_SP, store_offset, RV_REG_RA, ctx);
1223                 store_offset -= 8;
1224         }
1225         emit_sd(RV_REG_SP, store_offset, RV_REG_FP, ctx);
1226         store_offset -= 8;
1227         if (seen_reg(RV_REG_S1, ctx)) {
1228                 emit_sd(RV_REG_SP, store_offset, RV_REG_S1, ctx);
1229                 store_offset -= 8;
1230         }
1231         if (seen_reg(RV_REG_S2, ctx)) {
1232                 emit_sd(RV_REG_SP, store_offset, RV_REG_S2, ctx);
1233                 store_offset -= 8;
1234         }
1235         if (seen_reg(RV_REG_S3, ctx)) {
1236                 emit_sd(RV_REG_SP, store_offset, RV_REG_S3, ctx);
1237                 store_offset -= 8;
1238         }
1239         if (seen_reg(RV_REG_S4, ctx)) {
1240                 emit_sd(RV_REG_SP, store_offset, RV_REG_S4, ctx);
1241                 store_offset -= 8;
1242         }
1243         if (seen_reg(RV_REG_S5, ctx)) {
1244                 emit_sd(RV_REG_SP, store_offset, RV_REG_S5, ctx);
1245                 store_offset -= 8;
1246         }
1247         if (seen_reg(RV_REG_S6, ctx)) {
1248                 emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
1249                 store_offset -= 8;
1250         }
1251
1252         emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);
1253
1254         if (bpf_stack_adjust)
1255                 emit_addi(RV_REG_S5, RV_REG_SP, bpf_stack_adjust, ctx);
1256
1257         /* Program contains calls and tail calls, so RV_REG_TCC need
1258          * to be saved across calls.
1259          */
1260         if (seen_tail_call(ctx) && seen_call(ctx))
1261                 emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);
1262
1263         ctx->stack_size = stack_adjust;
1264 }
1265
1266 void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
1267 {
1268         __build_epilogue(false, ctx);
1269 }