bpf: Handle sign-extenstin ctx member accesses
[linux-2.6-microblaze.git] / kernel / bpf / verifier.c
index 584eb34..7a6945b 100644 (file)
@@ -3421,7 +3421,7 @@ static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
                        return 0;
                if (opcode == BPF_MOV) {
                        if (BPF_SRC(insn->code) == BPF_X) {
-                               /* dreg = sreg
+                               /* dreg = sreg or dreg = (s8, s16, s32)sreg
                                 * dreg needs precision after this insn
                                 * sreg needs precision before this insn
                                 */
@@ -5413,12 +5413,25 @@ static bool is_flow_key_reg(struct bpf_verifier_env *env, int regno)
        return reg->type == PTR_TO_FLOW_KEYS;
 }
 
+static u32 *reg2btf_ids[__BPF_REG_TYPE_MAX] = {
+#ifdef CONFIG_NET
+       [PTR_TO_SOCKET] = &btf_sock_ids[BTF_SOCK_TYPE_SOCK],
+       [PTR_TO_SOCK_COMMON] = &btf_sock_ids[BTF_SOCK_TYPE_SOCK_COMMON],
+       [PTR_TO_TCP_SOCK] = &btf_sock_ids[BTF_SOCK_TYPE_TCP],
+#endif
+       [CONST_PTR_TO_MAP] = btf_bpf_map_id,
+};
+
 static bool is_trusted_reg(const struct bpf_reg_state *reg)
 {
        /* A referenced register is always trusted. */
        if (reg->ref_obj_id)
                return true;
 
+       /* Types listed in the reg2btf_ids are always trusted */
+       if (reg2btf_ids[base_type(reg->type)])
+               return true;
+
        /* If a register is not referenced, it is trusted if it has the
         * MEM_ALLOC or PTR_TRUSTED type modifiers, and no others. Some of the
         * other type modifiers may be safe, but we elect to take an opt-in
@@ -5574,16 +5587,17 @@ static int update_stack_depth(struct bpf_verifier_env *env,
  * Since recursion is prevented by check_cfg() this algorithm
  * only needs a local stack of MAX_CALL_FRAMES to remember callsites
  */
-static int check_max_stack_depth(struct bpf_verifier_env *env)
+static int check_max_stack_depth_subprog(struct bpf_verifier_env *env, int idx)
 {
-       int depth = 0, frame = 0, idx = 0, i = 0, subprog_end;
        struct bpf_subprog_info *subprog = env->subprog_info;
        struct bpf_insn *insn = env->prog->insnsi;
+       int depth = 0, frame = 0, i, subprog_end;
        bool tail_call_reachable = false;
        int ret_insn[MAX_CALL_FRAMES];
        int ret_prog[MAX_CALL_FRAMES];
        int j;
 
+       i = subprog[idx].start;
 process_func:
        /* protect against potential stack overflow that might happen when
         * bpf2bpf calls get combined with tailcalls. Limit the caller's stack
@@ -5622,7 +5636,7 @@ process_func:
 continue_func:
        subprog_end = subprog[idx + 1].start;
        for (; i < subprog_end; i++) {
-               int next_insn;
+               int next_insn, sidx;
 
                if (!bpf_pseudo_call(insn + i) && !bpf_pseudo_func(insn + i))
                        continue;
@@ -5632,21 +5646,23 @@ continue_func:
 
                /* find the callee */
                next_insn = i + insn[i].imm + 1;
-               idx = find_subprog(env, next_insn);
-               if (idx < 0) {
+               sidx = find_subprog(env, next_insn);
+               if (sidx < 0) {
                        WARN_ONCE(1, "verifier bug. No program starts at insn %d\n",
                                  next_insn);
                        return -EFAULT;
                }
-               if (subprog[idx].is_async_cb) {
-                       if (subprog[idx].has_tail_call) {
+               if (subprog[sidx].is_async_cb) {
+                       if (subprog[sidx].has_tail_call) {
                                verbose(env, "verifier bug. subprog has tail_call and async cb\n");
                                return -EFAULT;
                        }
-                        /* async callbacks don't increase bpf prog stack size */
-                       continue;
+                       /* async callbacks don't increase bpf prog stack size unless called directly */
+                       if (!bpf_pseudo_call(insn + i))
+                               continue;
                }
                i = next_insn;
+               idx = sidx;
 
                if (subprog[idx].has_tail_call)
                        tail_call_reachable = true;
@@ -5682,6 +5698,22 @@ continue_func:
        goto continue_func;
 }
 
+static int check_max_stack_depth(struct bpf_verifier_env *env)
+{
+       struct bpf_subprog_info *si = env->subprog_info;
+       int ret;
+
+       for (int i = 0; i < env->subprog_cnt; i++) {
+               if (!i || si[i].is_async_cb) {
+                       ret = check_max_stack_depth_subprog(env, i);
+                       if (ret < 0)
+                               return ret;
+               }
+               continue;
+       }
+       return 0;
+}
+
 #ifndef CONFIG_BPF_JIT_ALWAYS_ON
 static int get_callee_stack_depth(struct bpf_verifier_env *env,
                                  const struct bpf_insn *insn, int idx)
@@ -5795,6 +5827,147 @@ static void coerce_reg_to_size(struct bpf_reg_state *reg, int size)
        __reg_combine_64_into_32(reg);
 }
 
+static void set_sext64_default_val(struct bpf_reg_state *reg, int size)
+{
+       if (size == 1) {
+               reg->smin_value = reg->s32_min_value = S8_MIN;
+               reg->smax_value = reg->s32_max_value = S8_MAX;
+       } else if (size == 2) {
+               reg->smin_value = reg->s32_min_value = S16_MIN;
+               reg->smax_value = reg->s32_max_value = S16_MAX;
+       } else {
+               /* size == 4 */
+               reg->smin_value = reg->s32_min_value = S32_MIN;
+               reg->smax_value = reg->s32_max_value = S32_MAX;
+       }
+       reg->umin_value = reg->u32_min_value = 0;
+       reg->umax_value = U64_MAX;
+       reg->u32_max_value = U32_MAX;
+       reg->var_off = tnum_unknown;
+}
+
+static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size)
+{
+       s64 init_s64_max, init_s64_min, s64_max, s64_min, u64_cval;
+       u64 top_smax_value, top_smin_value;
+       u64 num_bits = size * 8;
+
+       if (tnum_is_const(reg->var_off)) {
+               u64_cval = reg->var_off.value;
+               if (size == 1)
+                       reg->var_off = tnum_const((s8)u64_cval);
+               else if (size == 2)
+                       reg->var_off = tnum_const((s16)u64_cval);
+               else
+                       /* size == 4 */
+                       reg->var_off = tnum_const((s32)u64_cval);
+
+               u64_cval = reg->var_off.value;
+               reg->smax_value = reg->smin_value = u64_cval;
+               reg->umax_value = reg->umin_value = u64_cval;
+               reg->s32_max_value = reg->s32_min_value = u64_cval;
+               reg->u32_max_value = reg->u32_min_value = u64_cval;
+               return;
+       }
+
+       top_smax_value = ((u64)reg->smax_value >> num_bits) << num_bits;
+       top_smin_value = ((u64)reg->smin_value >> num_bits) << num_bits;
+
+       if (top_smax_value != top_smin_value)
+               goto out;
+
+       /* find the s64_min and s64_min after sign extension */
+       if (size == 1) {
+               init_s64_max = (s8)reg->smax_value;
+               init_s64_min = (s8)reg->smin_value;
+       } else if (size == 2) {
+               init_s64_max = (s16)reg->smax_value;
+               init_s64_min = (s16)reg->smin_value;
+       } else {
+               init_s64_max = (s32)reg->smax_value;
+               init_s64_min = (s32)reg->smin_value;
+       }
+
+       s64_max = max(init_s64_max, init_s64_min);
+       s64_min = min(init_s64_max, init_s64_min);
+
+       /* both of s64_max/s64_min positive or negative */
+       if (s64_max >= 0 == s64_min >= 0) {
+               reg->smin_value = reg->s32_min_value = s64_min;
+               reg->smax_value = reg->s32_max_value = s64_max;
+               reg->umin_value = reg->u32_min_value = s64_min;
+               reg->umax_value = reg->u32_max_value = s64_max;
+               reg->var_off = tnum_range(s64_min, s64_max);
+               return;
+       }
+
+out:
+       set_sext64_default_val(reg, size);
+}
+
+static void set_sext32_default_val(struct bpf_reg_state *reg, int size)
+{
+       if (size == 1) {
+               reg->s32_min_value = S8_MIN;
+               reg->s32_max_value = S8_MAX;
+       } else {
+               /* size == 2 */
+               reg->s32_min_value = S16_MIN;
+               reg->s32_max_value = S16_MAX;
+       }
+       reg->u32_min_value = 0;
+       reg->u32_max_value = U32_MAX;
+}
+
+static void coerce_subreg_to_size_sx(struct bpf_reg_state *reg, int size)
+{
+       s32 init_s32_max, init_s32_min, s32_max, s32_min, u32_val;
+       u32 top_smax_value, top_smin_value;
+       u32 num_bits = size * 8;
+
+       if (tnum_is_const(reg->var_off)) {
+               u32_val = reg->var_off.value;
+               if (size == 1)
+                       reg->var_off = tnum_const((s8)u32_val);
+               else
+                       reg->var_off = tnum_const((s16)u32_val);
+
+               u32_val = reg->var_off.value;
+               reg->s32_min_value = reg->s32_max_value = u32_val;
+               reg->u32_min_value = reg->u32_max_value = u32_val;
+               return;
+       }
+
+       top_smax_value = ((u32)reg->s32_max_value >> num_bits) << num_bits;
+       top_smin_value = ((u32)reg->s32_min_value >> num_bits) << num_bits;
+
+       if (top_smax_value != top_smin_value)
+               goto out;
+
+       /* find the s32_min and s32_min after sign extension */
+       if (size == 1) {
+               init_s32_max = (s8)reg->s32_max_value;
+               init_s32_min = (s8)reg->s32_min_value;
+       } else {
+               /* size == 2 */
+               init_s32_max = (s16)reg->s32_max_value;
+               init_s32_min = (s16)reg->s32_min_value;
+       }
+       s32_max = max(init_s32_max, init_s32_min);
+       s32_min = min(init_s32_max, init_s32_min);
+
+       if (s32_min >= 0 == s32_max >= 0) {
+               reg->s32_min_value = s32_min;
+               reg->s32_max_value = s32_max;
+               reg->u32_min_value = (u32)s32_min;
+               reg->u32_max_value = (u32)s32_max;
+               return;
+       }
+
+out:
+       set_sext32_default_val(reg, size);
+}
+
 static bool bpf_map_is_rdonly(const struct bpf_map *map)
 {
        /* A map is considered read-only if the following condition are true:
@@ -5815,7 +5988,8 @@ static bool bpf_map_is_rdonly(const struct bpf_map *map)
               !bpf_map_write_active(map);
 }
 
-static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val)
+static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val,
+                              bool is_ldsx)
 {
        void *ptr;
        u64 addr;
@@ -5828,13 +6002,13 @@ static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val)
 
        switch (size) {
        case sizeof(u8):
-               *val = (u64)*(u8 *)ptr;
+               *val = is_ldsx ? (s64)*(s8 *)ptr : (u64)*(u8 *)ptr;
                break;
        case sizeof(u16):
-               *val = (u64)*(u16 *)ptr;
+               *val = is_ldsx ? (s64)*(s16 *)ptr : (u64)*(u16 *)ptr;
                break;
        case sizeof(u32):
-               *val = (u64)*(u32 *)ptr;
+               *val = is_ldsx ? (s64)*(s32 *)ptr : (u64)*(u32 *)ptr;
                break;
        case sizeof(u64):
                *val = *(u64 *)ptr;
@@ -6253,7 +6427,7 @@ static int check_stack_access_within_bounds(
  */
 static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regno,
                            int off, int bpf_size, enum bpf_access_type t,
-                           int value_regno, bool strict_alignment_once)
+                           int value_regno, bool strict_alignment_once, bool is_ldsx)
 {
        struct bpf_reg_state *regs = cur_regs(env);
        struct bpf_reg_state *reg = regs + regno;
@@ -6314,7 +6488,7 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn
                                u64 val = 0;
 
                                err = bpf_map_direct_read(map, map_off, size,
-                                                         &val);
+                                                         &val, is_ldsx);
                                if (err)
                                        return err;
 
@@ -6484,8 +6658,11 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn
 
        if (!err && size < BPF_REG_SIZE && value_regno >= 0 && t == BPF_READ &&
            regs[value_regno].type == SCALAR_VALUE) {
-               /* b/h/w load zero-extends, mark upper bits as known 0 */
-               coerce_reg_to_size(&regs[value_regno], size);
+               if (!is_ldsx)
+                       /* b/h/w load zero-extends, mark upper bits as known 0 */
+                       coerce_reg_to_size(&regs[value_regno], size);
+               else
+                       coerce_reg_to_size_sx(&regs[value_regno], size);
        }
        return err;
 }
@@ -6577,17 +6754,17 @@ static int check_atomic(struct bpf_verifier_env *env, int insn_idx, struct bpf_i
         * case to simulate the register fill.
         */
        err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off,
-                              BPF_SIZE(insn->code), BPF_READ, -1, true);
+                              BPF_SIZE(insn->code), BPF_READ, -1, true, false);
        if (!err && load_reg >= 0)
                err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off,
                                       BPF_SIZE(insn->code), BPF_READ, load_reg,
-                                      true);
+                                      true, false);
        if (err)
                return err;
 
        /* Check whether we can write into the same memory. */
        err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off,
-                              BPF_SIZE(insn->code), BPF_WRITE, -1, true);
+                              BPF_SIZE(insn->code), BPF_WRITE, -1, true, false);
        if (err)
                return err;
 
@@ -6833,7 +7010,7 @@ static int check_helper_mem_access(struct bpf_verifier_env *env, int regno,
                                return zero_size_allowed ? 0 : -EACCES;
 
                        return check_mem_access(env, env->insn_idx, regno, offset, BPF_B,
-                                               atype, -1, false);
+                                               atype, -1, false, false);
                }
 
                fallthrough;
@@ -7205,7 +7382,7 @@ static int process_dynptr_func(struct bpf_verifier_env *env, int regno, int insn
                /* we write BPF_DW bits (8 bytes) at a time */
                for (i = 0; i < BPF_DYNPTR_SIZE; i += 8) {
                        err = check_mem_access(env, insn_idx, regno,
-                                              i, BPF_DW, BPF_WRITE, -1, false);
+                                              i, BPF_DW, BPF_WRITE, -1, false, false);
                        if (err)
                                return err;
                }
@@ -7298,7 +7475,7 @@ static int process_iter_arg(struct bpf_verifier_env *env, int regno, int insn_id
 
                for (i = 0; i < nr_slots * 8; i += BPF_REG_SIZE) {
                        err = check_mem_access(env, insn_idx, regno,
-                                              i, BPF_DW, BPF_WRITE, -1, false);
+                                              i, BPF_DW, BPF_WRITE, -1, false, false);
                        if (err)
                                return err;
                }
@@ -9442,7 +9619,7 @@ static int check_helper_call(struct bpf_verifier_env *env, struct bpf_insn *insn
         */
        for (i = 0; i < meta.access_size; i++) {
                err = check_mem_access(env, insn_idx, meta.regno, i, BPF_B,
-                                      BPF_WRITE, -1, false);
+                                      BPF_WRITE, -1, false, false);
                if (err)
                        return err;
        }
@@ -10051,15 +10228,6 @@ static bool __btf_type_is_scalar_struct(struct bpf_verifier_env *env,
        return true;
 }
 
-
-static u32 *reg2btf_ids[__BPF_REG_TYPE_MAX] = {
-#ifdef CONFIG_NET
-       [PTR_TO_SOCKET] = &btf_sock_ids[BTF_SOCK_TYPE_SOCK],
-       [PTR_TO_SOCK_COMMON] = &btf_sock_ids[BTF_SOCK_TYPE_SOCK_COMMON],
-       [PTR_TO_TCP_SOCK] = &btf_sock_ids[BTF_SOCK_TYPE_TCP],
-#endif
-};
-
 enum kfunc_ptr_arg_type {
        KF_ARG_PTR_TO_CTX,
        KF_ARG_PTR_TO_ALLOC_BTF_ID,    /* Allocated object */
@@ -12933,11 +13101,24 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)
        } else if (opcode == BPF_MOV) {
 
                if (BPF_SRC(insn->code) == BPF_X) {
-                       if (insn->imm != 0 || insn->off != 0) {
+                       if (insn->imm != 0) {
                                verbose(env, "BPF_MOV uses reserved fields\n");
                                return -EINVAL;
                        }
 
+                       if (BPF_CLASS(insn->code) == BPF_ALU) {
+                               if (insn->off != 0 && insn->off != 8 && insn->off != 16) {
+                                       verbose(env, "BPF_MOV uses reserved fields\n");
+                                       return -EINVAL;
+                               }
+                       } else {
+                               if (insn->off != 0 && insn->off != 8 && insn->off != 16 &&
+                                   insn->off != 32) {
+                                       verbose(env, "BPF_MOV uses reserved fields\n");
+                                       return -EINVAL;
+                               }
+                       }
+
                        /* check src operand */
                        err = check_reg_arg(env, insn->src_reg, SRC_OP);
                        if (err)
@@ -12961,18 +13142,33 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)
                                       !tnum_is_const(src_reg->var_off);
 
                        if (BPF_CLASS(insn->code) == BPF_ALU64) {
-                               /* case: R1 = R2
-                                * copy register state to dest reg
-                                */
-                               if (need_id)
-                                       /* Assign src and dst registers the same ID
-                                        * that will be used by find_equal_scalars()
-                                        * to propagate min/max range.
+                               if (insn->off == 0) {
+                                       /* case: R1 = R2
+                                        * copy register state to dest reg
                                         */
-                                       src_reg->id = ++env->id_gen;
-                               copy_register_state(dst_reg, src_reg);
-                               dst_reg->live |= REG_LIVE_WRITTEN;
-                               dst_reg->subreg_def = DEF_NOT_SUBREG;
+                                       if (need_id)
+                                               /* Assign src and dst registers the same ID
+                                                * that will be used by find_equal_scalars()
+                                                * to propagate min/max range.
+                                                */
+                                               src_reg->id = ++env->id_gen;
+                                       copy_register_state(dst_reg, src_reg);
+                                       dst_reg->live |= REG_LIVE_WRITTEN;
+                                       dst_reg->subreg_def = DEF_NOT_SUBREG;
+                               } else {
+                                       /* case: R1 = (s8, s16 s32)R2 */
+                                       bool no_sext;
+
+                                       no_sext = src_reg->umax_value < (1ULL << (insn->off - 1));
+                                       if (no_sext && need_id)
+                                               src_reg->id = ++env->id_gen;
+                                       copy_register_state(dst_reg, src_reg);
+                                       if (!no_sext)
+                                               dst_reg->id = 0;
+                                       coerce_reg_to_size_sx(dst_reg, insn->off >> 3);
+                                       dst_reg->live |= REG_LIVE_WRITTEN;
+                                       dst_reg->subreg_def = DEF_NOT_SUBREG;
+                               }
                        } else {
                                /* R1 = (u32) R2 */
                                if (is_pointer_value(env, insn->src_reg)) {
@@ -12981,19 +13177,33 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)
                                                insn->src_reg);
                                        return -EACCES;
                                } else if (src_reg->type == SCALAR_VALUE) {
-                                       bool is_src_reg_u32 = src_reg->umax_value <= U32_MAX;
-
-                                       if (is_src_reg_u32 && need_id)
-                                               src_reg->id = ++env->id_gen;
-                                       copy_register_state(dst_reg, src_reg);
-                                       /* Make sure ID is cleared if src_reg is not in u32 range otherwise
-                                        * dst_reg min/max could be incorrectly
-                                        * propagated into src_reg by find_equal_scalars()
-                                        */
-                                       if (!is_src_reg_u32)
-                                               dst_reg->id = 0;
-                                       dst_reg->live |= REG_LIVE_WRITTEN;
-                                       dst_reg->subreg_def = env->insn_idx + 1;
+                                       if (insn->off == 0) {
+                                               bool is_src_reg_u32 = src_reg->umax_value <= U32_MAX;
+
+                                               if (is_src_reg_u32 && need_id)
+                                                       src_reg->id = ++env->id_gen;
+                                               copy_register_state(dst_reg, src_reg);
+                                               /* Make sure ID is cleared if src_reg is not in u32
+                                                * range otherwise dst_reg min/max could be incorrectly
+                                                * propagated into src_reg by find_equal_scalars()
+                                                */
+                                               if (!is_src_reg_u32)
+                                                       dst_reg->id = 0;
+                                               dst_reg->live |= REG_LIVE_WRITTEN;
+                                               dst_reg->subreg_def = env->insn_idx + 1;
+                                       } else {
+                                               /* case: W1 = (s8, s16)W2 */
+                                               bool no_sext = src_reg->umax_value < (1ULL << (insn->off - 1));
+
+                                               if (no_sext && need_id)
+                                                       src_reg->id = ++env->id_gen;
+                                               copy_register_state(dst_reg, src_reg);
+                                               if (!no_sext)
+                                                       dst_reg->id = 0;
+                                               dst_reg->live |= REG_LIVE_WRITTEN;
+                                               dst_reg->subreg_def = env->insn_idx + 1;
+                                               coerce_subreg_to_size_sx(dst_reg, insn->off >> 3);
+                                       }
                                } else {
                                        mark_reg_unknown(env, regs,
                                                         insn->dst_reg);
@@ -16179,7 +16389,7 @@ static int save_aux_ptr_type(struct bpf_verifier_env *env, enum bpf_reg_type typ
                         * Have to support a use case when one path through
                         * the program yields TRUSTED pointer while another
                         * is UNTRUSTED. Fallback to UNTRUSTED to generate
-                        * BPF_PROBE_MEM.
+                        * BPF_PROBE_MEM/BPF_PROBE_MEMSX.
                         */
                        *prev_type = PTR_TO_BTF_ID | PTR_UNTRUSTED;
                } else {
@@ -16320,7 +16530,8 @@ static int do_check(struct bpf_verifier_env *env)
                         */
                        err = check_mem_access(env, env->insn_idx, insn->src_reg,
                                               insn->off, BPF_SIZE(insn->code),
-                                              BPF_READ, insn->dst_reg, false);
+                                              BPF_READ, insn->dst_reg, false,
+                                              BPF_MODE(insn->code) == BPF_MEMSX);
                        if (err)
                                return err;
 
@@ -16357,7 +16568,7 @@ static int do_check(struct bpf_verifier_env *env)
                        /* check that memory (dst_reg + off) is writeable */
                        err = check_mem_access(env, env->insn_idx, insn->dst_reg,
                                               insn->off, BPF_SIZE(insn->code),
-                                              BPF_WRITE, insn->src_reg, false);
+                                              BPF_WRITE, insn->src_reg, false, false);
                        if (err)
                                return err;
 
@@ -16382,7 +16593,7 @@ static int do_check(struct bpf_verifier_env *env)
                        /* check that memory (dst_reg + off) is writeable */
                        err = check_mem_access(env, env->insn_idx, insn->dst_reg,
                                               insn->off, BPF_SIZE(insn->code),
-                                              BPF_WRITE, -1, false);
+                                              BPF_WRITE, -1, false, false);
                        if (err)
                                return err;
 
@@ -16810,7 +17021,8 @@ static int resolve_pseudo_ldimm64(struct bpf_verifier_env *env)
 
        for (i = 0; i < insn_cnt; i++, insn++) {
                if (BPF_CLASS(insn->code) == BPF_LDX &&
-                   (BPF_MODE(insn->code) != BPF_MEM || insn->imm != 0)) {
+                   ((BPF_MODE(insn->code) != BPF_MEM && BPF_MODE(insn->code) != BPF_MEMSX) ||
+                   insn->imm != 0)) {
                        verbose(env, "BPF_LDX uses reserved fields\n");
                        return -EINVAL;
                }
@@ -17504,11 +17716,15 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
 
        for (i = 0; i < insn_cnt; i++, insn++) {
                bpf_convert_ctx_access_t convert_ctx_access;
+               u8 mode;
 
                if (insn->code == (BPF_LDX | BPF_MEM | BPF_B) ||
                    insn->code == (BPF_LDX | BPF_MEM | BPF_H) ||
                    insn->code == (BPF_LDX | BPF_MEM | BPF_W) ||
-                   insn->code == (BPF_LDX | BPF_MEM | BPF_DW)) {
+                   insn->code == (BPF_LDX | BPF_MEM | BPF_DW) ||
+                   insn->code == (BPF_LDX | BPF_MEMSX | BPF_B) ||
+                   insn->code == (BPF_LDX | BPF_MEMSX | BPF_H) ||
+                   insn->code == (BPF_LDX | BPF_MEMSX | BPF_W)) {
                        type = BPF_READ;
                } else if (insn->code == (BPF_STX | BPF_MEM | BPF_B) ||
                           insn->code == (BPF_STX | BPF_MEM | BPF_H) ||
@@ -17567,8 +17783,12 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
                 */
                case PTR_TO_BTF_ID | MEM_ALLOC | PTR_UNTRUSTED:
                        if (type == BPF_READ) {
-                               insn->code = BPF_LDX | BPF_PROBE_MEM |
-                                       BPF_SIZE((insn)->code);
+                               if (BPF_MODE(insn->code) == BPF_MEM)
+                                       insn->code = BPF_LDX | BPF_PROBE_MEM |
+                                                    BPF_SIZE((insn)->code);
+                               else
+                                       insn->code = BPF_LDX | BPF_PROBE_MEMSX |
+                                                    BPF_SIZE((insn)->code);
                                env->prog->aux->num_exentries++;
                        }
                        continue;
@@ -17578,6 +17798,7 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
 
                ctx_field_size = env->insn_aux_data[i + delta].ctx_field_size;
                size = BPF_LDST_BYTES(insn);
+               mode = BPF_MODE(insn->code);
 
                /* If the read access is a narrower load of the field,
                 * convert to a 4/8-byte load, to minimum program type specific
@@ -17637,6 +17858,10 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
                                                                (1ULL << size * 8) - 1);
                        }
                }
+               if (mode == BPF_MEMSX)
+                       insn_buf[cnt++] = BPF_RAW_INSN(BPF_ALU64 | BPF_MOV | BPF_X,
+                                                      insn->dst_reg, insn->dst_reg,
+                                                      size * 8, 0);
 
                new_prog = bpf_patch_insn_data(env, i + delta, insn_buf, cnt);
                if (!new_prog)
@@ -17756,7 +17981,8 @@ static int jit_subprogs(struct bpf_verifier_env *env)
                insn = func[i]->insnsi;
                for (j = 0; j < func[i]->len; j++, insn++) {
                        if (BPF_CLASS(insn->code) == BPF_LDX &&
-                           BPF_MODE(insn->code) == BPF_PROBE_MEM)
+                           (BPF_MODE(insn->code) == BPF_PROBE_MEM ||
+                            BPF_MODE(insn->code) == BPF_PROBE_MEMSX))
                                num_exentries++;
                }
                func[i]->aux->num_exentries = num_exentries;