bpf: Undo incorrect __reg_bound_offset32 handling
[linux-2.6-microblaze.git] / kernel / bpf / verifier.c
index 1cc945d..2a84f73 100644 (file)
@@ -19,6 +19,8 @@
 #include <linux/sort.h>
 #include <linux/perf_event.h>
 #include <linux/ctype.h>
+#include <linux/error-injection.h>
+#include <linux/bpf_lsm.h>
 
 #include "disasm.h"
 
@@ -1034,17 +1036,6 @@ static void __reg_bound_offset(struct bpf_reg_state *reg)
                                                 reg->umax_value));
 }
 
-static void __reg_bound_offset32(struct bpf_reg_state *reg)
-{
-       u64 mask = 0xffffFFFF;
-       struct tnum range = tnum_range(reg->umin_value & mask,
-                                      reg->umax_value & mask);
-       struct tnum lo32 = tnum_cast(reg->var_off, 4);
-       struct tnum hi32 = tnum_lshift(tnum_rshift(reg->var_off, 32), 32);
-
-       reg->var_off = tnum_or(hi32, tnum_intersect(lo32, range));
-}
-
 /* Reset the min/max bounds of a register */
 static void __mark_reg_unbounded(struct bpf_reg_state *reg)
 {
@@ -3460,13 +3451,17 @@ static int check_func_arg(struct bpf_verifier_env *env, u32 regno,
                expected_type = CONST_PTR_TO_MAP;
                if (type != expected_type)
                        goto err_type;
-       } else if (arg_type == ARG_PTR_TO_CTX) {
+       } else if (arg_type == ARG_PTR_TO_CTX ||
+                  arg_type == ARG_PTR_TO_CTX_OR_NULL) {
                expected_type = PTR_TO_CTX;
-               if (type != expected_type)
-                       goto err_type;
-               err = check_ctx_reg(env, reg, regno);
-               if (err < 0)
-                       return err;
+               if (!(register_is_null(reg) &&
+                     arg_type == ARG_PTR_TO_CTX_OR_NULL)) {
+                       if (type != expected_type)
+                               goto err_type;
+                       err = check_ctx_reg(env, reg, regno);
+                       if (err < 0)
+                               return err;
+               }
        } else if (arg_type == ARG_PTR_TO_SOCK_COMMON) {
                expected_type = PTR_TO_SOCK_COMMON;
                /* Any sk pointer can be ARG_PTR_TO_SOCK_COMMON */
@@ -3649,7 +3644,8 @@ static int check_map_func_compatibility(struct bpf_verifier_env *env,
                if (func_id != BPF_FUNC_perf_event_read &&
                    func_id != BPF_FUNC_perf_event_output &&
                    func_id != BPF_FUNC_skb_output &&
-                   func_id != BPF_FUNC_perf_event_read_value)
+                   func_id != BPF_FUNC_perf_event_read_value &&
+                   func_id != BPF_FUNC_xdp_output)
                        goto error;
                break;
        case BPF_MAP_TYPE_STACK_TRACE:
@@ -3693,14 +3689,16 @@ static int check_map_func_compatibility(struct bpf_verifier_env *env,
                if (func_id != BPF_FUNC_sk_redirect_map &&
                    func_id != BPF_FUNC_sock_map_update &&
                    func_id != BPF_FUNC_map_delete_elem &&
-                   func_id != BPF_FUNC_msg_redirect_map)
+                   func_id != BPF_FUNC_msg_redirect_map &&
+                   func_id != BPF_FUNC_sk_select_reuseport)
                        goto error;
                break;
        case BPF_MAP_TYPE_SOCKHASH:
                if (func_id != BPF_FUNC_sk_redirect_hash &&
                    func_id != BPF_FUNC_sock_hash_update &&
                    func_id != BPF_FUNC_map_delete_elem &&
-                   func_id != BPF_FUNC_msg_redirect_hash)
+                   func_id != BPF_FUNC_msg_redirect_hash &&
+                   func_id != BPF_FUNC_sk_select_reuseport)
                        goto error;
                break;
        case BPF_MAP_TYPE_REUSEPORT_SOCKARRAY:
@@ -3737,6 +3735,7 @@ static int check_map_func_compatibility(struct bpf_verifier_env *env,
        case BPF_FUNC_perf_event_output:
        case BPF_FUNC_perf_event_read_value:
        case BPF_FUNC_skb_output:
+       case BPF_FUNC_xdp_output:
                if (map->map_type != BPF_MAP_TYPE_PERF_EVENT_ARRAY)
                        goto error;
                break;
@@ -3774,7 +3773,9 @@ static int check_map_func_compatibility(struct bpf_verifier_env *env,
                        goto error;
                break;
        case BPF_FUNC_sk_select_reuseport:
-               if (map->map_type != BPF_MAP_TYPE_REUSEPORT_SOCKARRAY)
+               if (map->map_type != BPF_MAP_TYPE_REUSEPORT_SOCKARRAY &&
+                   map->map_type != BPF_MAP_TYPE_SOCKMAP &&
+                   map->map_type != BPF_MAP_TYPE_SOCKHASH)
                        goto error;
                break;
        case BPF_FUNC_map_peek_elem:
@@ -4836,6 +4837,237 @@ static int adjust_ptr_min_max_vals(struct bpf_verifier_env *env,
        return 0;
 }
 
+static void scalar_min_max_add(struct bpf_reg_state *dst_reg,
+                              struct bpf_reg_state *src_reg)
+{
+       s64 smin_val = src_reg->smin_value;
+       s64 smax_val = src_reg->smax_value;
+       u64 umin_val = src_reg->umin_value;
+       u64 umax_val = src_reg->umax_value;
+
+       if (signed_add_overflows(dst_reg->smin_value, smin_val) ||
+           signed_add_overflows(dst_reg->smax_value, smax_val)) {
+               dst_reg->smin_value = S64_MIN;
+               dst_reg->smax_value = S64_MAX;
+       } else {
+               dst_reg->smin_value += smin_val;
+               dst_reg->smax_value += smax_val;
+       }
+       if (dst_reg->umin_value + umin_val < umin_val ||
+           dst_reg->umax_value + umax_val < umax_val) {
+               dst_reg->umin_value = 0;
+               dst_reg->umax_value = U64_MAX;
+       } else {
+               dst_reg->umin_value += umin_val;
+               dst_reg->umax_value += umax_val;
+       }
+       dst_reg->var_off = tnum_add(dst_reg->var_off, src_reg->var_off);
+}
+
+static void scalar_min_max_sub(struct bpf_reg_state *dst_reg,
+                              struct bpf_reg_state *src_reg)
+{
+       s64 smin_val = src_reg->smin_value;
+       s64 smax_val = src_reg->smax_value;
+       u64 umin_val = src_reg->umin_value;
+       u64 umax_val = src_reg->umax_value;
+
+       if (signed_sub_overflows(dst_reg->smin_value, smax_val) ||
+           signed_sub_overflows(dst_reg->smax_value, smin_val)) {
+               /* Overflow possible, we know nothing */
+               dst_reg->smin_value = S64_MIN;
+               dst_reg->smax_value = S64_MAX;
+       } else {
+               dst_reg->smin_value -= smax_val;
+               dst_reg->smax_value -= smin_val;
+       }
+       if (dst_reg->umin_value < umax_val) {
+               /* Overflow possible, we know nothing */
+               dst_reg->umin_value = 0;
+               dst_reg->umax_value = U64_MAX;
+       } else {
+               /* Cannot overflow (as long as bounds are consistent) */
+               dst_reg->umin_value -= umax_val;
+               dst_reg->umax_value -= umin_val;
+       }
+       dst_reg->var_off = tnum_sub(dst_reg->var_off, src_reg->var_off);
+}
+
+static void scalar_min_max_mul(struct bpf_reg_state *dst_reg,
+                              struct bpf_reg_state *src_reg)
+{
+       s64 smin_val = src_reg->smin_value;
+       u64 umin_val = src_reg->umin_value;
+       u64 umax_val = src_reg->umax_value;
+
+       dst_reg->var_off = tnum_mul(dst_reg->var_off, src_reg->var_off);
+       if (smin_val < 0 || dst_reg->smin_value < 0) {
+               /* Ain't nobody got time to multiply that sign */
+               __mark_reg_unbounded(dst_reg);
+               __update_reg_bounds(dst_reg);
+               return;
+       }
+       /* Both values are positive, so we can work with unsigned and
+        * copy the result to signed (unless it exceeds S64_MAX).
+        */
+       if (umax_val > U32_MAX || dst_reg->umax_value > U32_MAX) {
+               /* Potential overflow, we know nothing */
+               __mark_reg_unbounded(dst_reg);
+               /* (except what we can learn from the var_off) */
+               __update_reg_bounds(dst_reg);
+               return;
+       }
+       dst_reg->umin_value *= umin_val;
+       dst_reg->umax_value *= umax_val;
+       if (dst_reg->umax_value > S64_MAX) {
+               /* Overflow possible, we know nothing */
+               dst_reg->smin_value = S64_MIN;
+               dst_reg->smax_value = S64_MAX;
+       } else {
+               dst_reg->smin_value = dst_reg->umin_value;
+               dst_reg->smax_value = dst_reg->umax_value;
+       }
+}
+
+static void scalar_min_max_and(struct bpf_reg_state *dst_reg,
+                              struct bpf_reg_state *src_reg)
+{
+       s64 smin_val = src_reg->smin_value;
+       u64 umax_val = src_reg->umax_value;
+
+       /* We get our minimum from the var_off, since that's inherently
+        * bitwise.  Our maximum is the minimum of the operands' maxima.
+        */
+       dst_reg->var_off = tnum_and(dst_reg->var_off, src_reg->var_off);
+       dst_reg->umin_value = dst_reg->var_off.value;
+       dst_reg->umax_value = min(dst_reg->umax_value, umax_val);
+       if (dst_reg->smin_value < 0 || smin_val < 0) {
+               /* Lose signed bounds when ANDing negative numbers,
+                * ain't nobody got time for that.
+                */
+               dst_reg->smin_value = S64_MIN;
+               dst_reg->smax_value = S64_MAX;
+       } else {
+               /* ANDing two positives gives a positive, so safe to
+                * cast result into s64.
+                */
+               dst_reg->smin_value = dst_reg->umin_value;
+               dst_reg->smax_value = dst_reg->umax_value;
+       }
+       /* We may learn something more from the var_off */
+       __update_reg_bounds(dst_reg);
+}
+
+static void scalar_min_max_or(struct bpf_reg_state *dst_reg,
+                             struct bpf_reg_state *src_reg)
+{
+       s64 smin_val = src_reg->smin_value;
+       u64 umin_val = src_reg->umin_value;
+
+       /* We get our maximum from the var_off, and our minimum is the
+        * maximum of the operands' minima
+        */
+       dst_reg->var_off = tnum_or(dst_reg->var_off, src_reg->var_off);
+       dst_reg->umin_value = max(dst_reg->umin_value, umin_val);
+       dst_reg->umax_value = dst_reg->var_off.value | dst_reg->var_off.mask;
+       if (dst_reg->smin_value < 0 || smin_val < 0) {
+               /* Lose signed bounds when ORing negative numbers,
+                * ain't nobody got time for that.
+                */
+               dst_reg->smin_value = S64_MIN;
+               dst_reg->smax_value = S64_MAX;
+       } else {
+               /* ORing two positives gives a positive, so safe to
+                * cast result into s64.
+                */
+               dst_reg->smin_value = dst_reg->umin_value;
+               dst_reg->smax_value = dst_reg->umax_value;
+       }
+       /* We may learn something more from the var_off */
+       __update_reg_bounds(dst_reg);
+}
+
+static void scalar_min_max_lsh(struct bpf_reg_state *dst_reg,
+                              struct bpf_reg_state *src_reg)
+{
+       u64 umax_val = src_reg->umax_value;
+       u64 umin_val = src_reg->umin_value;
+
+       /* We lose all sign bit information (except what we can pick
+        * up from var_off)
+        */
+       dst_reg->smin_value = S64_MIN;
+       dst_reg->smax_value = S64_MAX;
+       /* If we might shift our top bit out, then we know nothing */
+       if (dst_reg->umax_value > 1ULL << (63 - umax_val)) {
+               dst_reg->umin_value = 0;
+               dst_reg->umax_value = U64_MAX;
+       } else {
+               dst_reg->umin_value <<= umin_val;
+               dst_reg->umax_value <<= umax_val;
+       }
+       dst_reg->var_off = tnum_lshift(dst_reg->var_off, umin_val);
+       /* We may learn something more from the var_off */
+       __update_reg_bounds(dst_reg);
+}
+
+static void scalar_min_max_rsh(struct bpf_reg_state *dst_reg,
+                              struct bpf_reg_state *src_reg)
+{
+       u64 umax_val = src_reg->umax_value;
+       u64 umin_val = src_reg->umin_value;
+
+       /* BPF_RSH is an unsigned shift.  If the value in dst_reg might
+        * be negative, then either:
+        * 1) src_reg might be zero, so the sign bit of the result is
+        *    unknown, so we lose our signed bounds
+        * 2) it's known negative, thus the unsigned bounds capture the
+        *    signed bounds
+        * 3) the signed bounds cross zero, so they tell us nothing
+        *    about the result
+        * If the value in dst_reg is known nonnegative, then again the
+        * unsigned bounts capture the signed bounds.
+        * Thus, in all cases it suffices to blow away our signed bounds
+        * and rely on inferring new ones from the unsigned bounds and
+        * var_off of the result.
+        */
+       dst_reg->smin_value = S64_MIN;
+       dst_reg->smax_value = S64_MAX;
+       dst_reg->var_off = tnum_rshift(dst_reg->var_off, umin_val);
+       dst_reg->umin_value >>= umax_val;
+       dst_reg->umax_value >>= umin_val;
+       /* We may learn something more from the var_off */
+       __update_reg_bounds(dst_reg);
+}
+
+static void scalar_min_max_arsh(struct bpf_reg_state *dst_reg,
+                               struct bpf_reg_state *src_reg,
+                               u64 insn_bitness)
+{
+       u64 umin_val = src_reg->umin_value;
+
+       /* Upon reaching here, src_known is true and
+        * umax_val is equal to umin_val.
+        */
+       if (insn_bitness == 32) {
+               dst_reg->smin_value = (u32)(((s32)dst_reg->smin_value) >> umin_val);
+               dst_reg->smax_value = (u32)(((s32)dst_reg->smax_value) >> umin_val);
+       } else {
+               dst_reg->smin_value >>= umin_val;
+               dst_reg->smax_value >>= umin_val;
+       }
+
+       dst_reg->var_off = tnum_arshift(dst_reg->var_off, umin_val,
+                                       insn_bitness);
+
+       /* blow away the dst_reg umin_value/umax_value and rely on
+        * dst_reg var_off to refine the result.
+        */
+       dst_reg->umin_value = 0;
+       dst_reg->umax_value = U64_MAX;
+       __update_reg_bounds(dst_reg);
+}
+
 /* WARNING: This function does calculations on 64-bit values, but the actual
  * execution may occur on 32-bit values. Therefore, things like bitshifts
  * need extra checks in the 32-bit case.
@@ -4892,23 +5124,7 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
                        verbose(env, "R%d tried to add from different pointers or scalars\n", dst);
                        return ret;
                }
-               if (signed_add_overflows(dst_reg->smin_value, smin_val) ||
-                   signed_add_overflows(dst_reg->smax_value, smax_val)) {
-                       dst_reg->smin_value = S64_MIN;
-                       dst_reg->smax_value = S64_MAX;
-               } else {
-                       dst_reg->smin_value += smin_val;
-                       dst_reg->smax_value += smax_val;
-               }
-               if (dst_reg->umin_value + umin_val < umin_val ||
-                   dst_reg->umax_value + umax_val < umax_val) {
-                       dst_reg->umin_value = 0;
-                       dst_reg->umax_value = U64_MAX;
-               } else {
-                       dst_reg->umin_value += umin_val;
-                       dst_reg->umax_value += umax_val;
-               }
-               dst_reg->var_off = tnum_add(dst_reg->var_off, src_reg.var_off);
+               scalar_min_max_add(dst_reg, &src_reg);
                break;
        case BPF_SUB:
                ret = sanitize_val_alu(env, insn);
@@ -4916,54 +5132,10 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
                        verbose(env, "R%d tried to sub from different pointers or scalars\n", dst);
                        return ret;
                }
-               if (signed_sub_overflows(dst_reg->smin_value, smax_val) ||
-                   signed_sub_overflows(dst_reg->smax_value, smin_val)) {
-                       /* Overflow possible, we know nothing */
-                       dst_reg->smin_value = S64_MIN;
-                       dst_reg->smax_value = S64_MAX;
-               } else {
-                       dst_reg->smin_value -= smax_val;
-                       dst_reg->smax_value -= smin_val;
-               }
-               if (dst_reg->umin_value < umax_val) {
-                       /* Overflow possible, we know nothing */
-                       dst_reg->umin_value = 0;
-                       dst_reg->umax_value = U64_MAX;
-               } else {
-                       /* Cannot overflow (as long as bounds are consistent) */
-                       dst_reg->umin_value -= umax_val;
-                       dst_reg->umax_value -= umin_val;
-               }
-               dst_reg->var_off = tnum_sub(dst_reg->var_off, src_reg.var_off);
+               scalar_min_max_sub(dst_reg, &src_reg);
                break;
        case BPF_MUL:
-               dst_reg->var_off = tnum_mul(dst_reg->var_off, src_reg.var_off);
-               if (smin_val < 0 || dst_reg->smin_value < 0) {
-                       /* Ain't nobody got time to multiply that sign */
-                       __mark_reg_unbounded(dst_reg);
-                       __update_reg_bounds(dst_reg);
-                       break;
-               }
-               /* Both values are positive, so we can work with unsigned and
-                * copy the result to signed (unless it exceeds S64_MAX).
-                */
-               if (umax_val > U32_MAX || dst_reg->umax_value > U32_MAX) {
-                       /* Potential overflow, we know nothing */
-                       __mark_reg_unbounded(dst_reg);
-                       /* (except what we can learn from the var_off) */
-                       __update_reg_bounds(dst_reg);
-                       break;
-               }
-               dst_reg->umin_value *= umin_val;
-               dst_reg->umax_value *= umax_val;
-               if (dst_reg->umax_value > S64_MAX) {
-                       /* Overflow possible, we know nothing */
-                       dst_reg->smin_value = S64_MIN;
-                       dst_reg->smax_value = S64_MAX;
-               } else {
-                       dst_reg->smin_value = dst_reg->umin_value;
-                       dst_reg->smax_value = dst_reg->umax_value;
-               }
+               scalar_min_max_mul(dst_reg, &src_reg);
                break;
        case BPF_AND:
                if (src_known && dst_known) {
@@ -4971,27 +5143,7 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
                                                  src_reg.var_off.value);
                        break;
                }
-               /* We get our minimum from the var_off, since that's inherently
-                * bitwise.  Our maximum is the minimum of the operands' maxima.
-                */
-               dst_reg->var_off = tnum_and(dst_reg->var_off, src_reg.var_off);
-               dst_reg->umin_value = dst_reg->var_off.value;
-               dst_reg->umax_value = min(dst_reg->umax_value, umax_val);
-               if (dst_reg->smin_value < 0 || smin_val < 0) {
-                       /* Lose signed bounds when ANDing negative numbers,
-                        * ain't nobody got time for that.
-                        */
-                       dst_reg->smin_value = S64_MIN;
-                       dst_reg->smax_value = S64_MAX;
-               } else {
-                       /* ANDing two positives gives a positive, so safe to
-                        * cast result into s64.
-                        */
-                       dst_reg->smin_value = dst_reg->umin_value;
-                       dst_reg->smax_value = dst_reg->umax_value;
-               }
-               /* We may learn something more from the var_off */
-               __update_reg_bounds(dst_reg);
+               scalar_min_max_and(dst_reg, &src_reg);
                break;
        case BPF_OR:
                if (src_known && dst_known) {
@@ -4999,28 +5151,7 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
                                                  src_reg.var_off.value);
                        break;
                }
-               /* We get our maximum from the var_off, and our minimum is the
-                * maximum of the operands' minima
-                */
-               dst_reg->var_off = tnum_or(dst_reg->var_off, src_reg.var_off);
-               dst_reg->umin_value = max(dst_reg->umin_value, umin_val);
-               dst_reg->umax_value = dst_reg->var_off.value |
-                                     dst_reg->var_off.mask;
-               if (dst_reg->smin_value < 0 || smin_val < 0) {
-                       /* Lose signed bounds when ORing negative numbers,
-                        * ain't nobody got time for that.
-                        */
-                       dst_reg->smin_value = S64_MIN;
-                       dst_reg->smax_value = S64_MAX;
-               } else {
-                       /* ORing two positives gives a positive, so safe to
-                        * cast result into s64.
-                        */
-                       dst_reg->smin_value = dst_reg->umin_value;
-                       dst_reg->smax_value = dst_reg->umax_value;
-               }
-               /* We may learn something more from the var_off */
-               __update_reg_bounds(dst_reg);
+               scalar_min_max_or(dst_reg, &src_reg);
                break;
        case BPF_LSH:
                if (umax_val >= insn_bitness) {
@@ -5030,22 +5161,7 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
                        mark_reg_unknown(env, regs, insn->dst_reg);
                        break;
                }
-               /* We lose all sign bit information (except what we can pick
-                * up from var_off)
-                */
-               dst_reg->smin_value = S64_MIN;
-               dst_reg->smax_value = S64_MAX;
-               /* If we might shift our top bit out, then we know nothing */
-               if (dst_reg->umax_value > 1ULL << (63 - umax_val)) {
-                       dst_reg->umin_value = 0;
-                       dst_reg->umax_value = U64_MAX;
-               } else {
-                       dst_reg->umin_value <<= umin_val;
-                       dst_reg->umax_value <<= umax_val;
-               }
-               dst_reg->var_off = tnum_lshift(dst_reg->var_off, umin_val);
-               /* We may learn something more from the var_off */
-               __update_reg_bounds(dst_reg);
+               scalar_min_max_lsh(dst_reg, &src_reg);
                break;
        case BPF_RSH:
                if (umax_val >= insn_bitness) {
@@ -5055,27 +5171,7 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
                        mark_reg_unknown(env, regs, insn->dst_reg);
                        break;
                }
-               /* BPF_RSH is an unsigned shift.  If the value in dst_reg might
-                * be negative, then either:
-                * 1) src_reg might be zero, so the sign bit of the result is
-                *    unknown, so we lose our signed bounds
-                * 2) it's known negative, thus the unsigned bounds capture the
-                *    signed bounds
-                * 3) the signed bounds cross zero, so they tell us nothing
-                *    about the result
-                * If the value in dst_reg is known nonnegative, then again the
-                * unsigned bounts capture the signed bounds.
-                * Thus, in all cases it suffices to blow away our signed bounds
-                * and rely on inferring new ones from the unsigned bounds and
-                * var_off of the result.
-                */
-               dst_reg->smin_value = S64_MIN;
-               dst_reg->smax_value = S64_MAX;
-               dst_reg->var_off = tnum_rshift(dst_reg->var_off, umin_val);
-               dst_reg->umin_value >>= umax_val;
-               dst_reg->umax_value >>= umin_val;
-               /* We may learn something more from the var_off */
-               __update_reg_bounds(dst_reg);
+               scalar_min_max_rsh(dst_reg, &src_reg);
                break;
        case BPF_ARSH:
                if (umax_val >= insn_bitness) {
@@ -5085,27 +5181,7 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
                        mark_reg_unknown(env, regs, insn->dst_reg);
                        break;
                }
-
-               /* Upon reaching here, src_known is true and
-                * umax_val is equal to umin_val.
-                */
-               if (insn_bitness == 32) {
-                       dst_reg->smin_value = (u32)(((s32)dst_reg->smin_value) >> umin_val);
-                       dst_reg->smax_value = (u32)(((s32)dst_reg->smax_value) >> umin_val);
-               } else {
-                       dst_reg->smin_value >>= umin_val;
-                       dst_reg->smax_value >>= umin_val;
-               }
-
-               dst_reg->var_off = tnum_arshift(dst_reg->var_off, umin_val,
-                                               insn_bitness);
-
-               /* blow away the dst_reg umin_value/umax_value and rely on
-                * dst_reg var_off to refine the result.
-                */
-               dst_reg->umin_value = 0;
-               dst_reg->umax_value = U64_MAX;
-               __update_reg_bounds(dst_reg);
+               scalar_min_max_arsh(dst_reg, &src_reg, insn_bitness);
                break;
        default:
                mark_reg_unknown(env, regs, insn->dst_reg);
@@ -5117,6 +5193,7 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
                coerce_reg_to_size(dst_reg, 4);
        }
 
+       __update_reg_bounds(dst_reg);
        __reg_deduce_bounds(dst_reg);
        __reg_bound_offset(dst_reg);
        return 0;
@@ -5717,10 +5794,6 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg,
        /* We might have learned some bits from the bounds. */
        __reg_bound_offset(false_reg);
        __reg_bound_offset(true_reg);
-       if (is_jmp32) {
-               __reg_bound_offset32(false_reg);
-               __reg_bound_offset32(true_reg);
-       }
        /* Intersecting with the old var_off might have improved our bounds
         * slightly.  e.g. if umax was 0x7f...f and var_off was (0; 0xf...fc),
         * then new var_off is (0; 0x7f...fc) which improves our umax.
@@ -5830,10 +5903,6 @@ static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
        /* We might have learned some bits from the bounds. */
        __reg_bound_offset(false_reg);
        __reg_bound_offset(true_reg);
-       if (is_jmp32) {
-               __reg_bound_offset32(false_reg);
-               __reg_bound_offset32(true_reg);
-       }
        /* Intersecting with the old var_off might have improved our bounds
         * slightly.  e.g. if umax was 0x7f...f and var_off was (0; 0xf...fc),
         * then new var_off is (0; 0x7f...fc) which improves our umax.
@@ -6405,8 +6474,9 @@ static int check_return_code(struct bpf_verifier_env *env)
        struct tnum range = tnum_range(0, 1);
        int err;
 
-       /* The struct_ops func-ptr's return type could be "void" */
-       if (env->prog->type == BPF_PROG_TYPE_STRUCT_OPS &&
+       /* LSM and struct_ops func-ptr's return type could be "void" */
+       if ((env->prog->type == BPF_PROG_TYPE_STRUCT_OPS ||
+            env->prog->type == BPF_PROG_TYPE_LSM) &&
            !prog->aux->attach_func_proto->type)
                return 0;
 
@@ -8139,26 +8209,48 @@ static bool is_tracing_prog_type(enum bpf_prog_type type)
        }
 }
 
+static bool is_preallocated_map(struct bpf_map *map)
+{
+       if (!check_map_prealloc(map))
+               return false;
+       if (map->inner_map_meta && !check_map_prealloc(map->inner_map_meta))
+               return false;
+       return true;
+}
+
 static int check_map_prog_compatibility(struct bpf_verifier_env *env,
                                        struct bpf_map *map,
                                        struct bpf_prog *prog)
 
 {
-       /* Make sure that BPF_PROG_TYPE_PERF_EVENT programs only use
-        * preallocated hash maps, since doing memory allocation
-        * in overflow_handler can crash depending on where nmi got
-        * triggered.
+       /*
+        * Validate that trace type programs use preallocated hash maps.
+        *
+        * For programs attached to PERF events this is mandatory as the
+        * perf NMI can hit any arbitrary code sequence.
+        *
+        * All other trace types using preallocated hash maps are unsafe as
+        * well because tracepoint or kprobes can be inside locked regions
+        * of the memory allocator or at a place where a recursion into the
+        * memory allocator would see inconsistent state.
+        *
+        * On RT enabled kernels run-time allocation of all trace type
+        * programs is strictly prohibited due to lock type constraints. On
+        * !RT kernels it is allowed for backwards compatibility reasons for
+        * now, but warnings are emitted so developers are made aware of
+        * the unsafety and can fix their programs before this is enforced.
         */
-       if (prog->type == BPF_PROG_TYPE_PERF_EVENT) {
-               if (!check_map_prealloc(map)) {
+       if (is_tracing_prog_type(prog->type) && !is_preallocated_map(map)) {
+               if (prog->type == BPF_PROG_TYPE_PERF_EVENT) {
                        verbose(env, "perf_event programs can only use preallocated hash map\n");
                        return -EINVAL;
                }
-               if (map->inner_map_meta &&
-                   !check_map_prealloc(map->inner_map_meta)) {
-                       verbose(env, "perf_event programs can only use preallocated inner hash map\n");
+               if (IS_ENABLED(CONFIG_PREEMPT_RT)) {
+                       verbose(env, "trace type programs can only use preallocated hash map\n");
                        return -EINVAL;
                }
+               WARN_ONCE(1, "trace type BPF program uses run-time allocation\n");
+               verbose(env, "trace type programs with run-time allocated hash maps are unsafe. Switch to preallocated hash maps.\n");
        }
 
        if ((is_tracing_prog_type(prog->type) ||
@@ -9774,6 +9866,26 @@ static int check_struct_ops_btf_id(struct bpf_verifier_env *env)
 
        return 0;
 }
+#define SECURITY_PREFIX "security_"
+
+static int check_attach_modify_return(struct bpf_verifier_env *env)
+{
+       struct bpf_prog *prog = env->prog;
+       unsigned long addr = (unsigned long) prog->aux->trampoline->func.addr;
+
+       /* This is expected to be cleaned up in the future with the KRSI effort
+        * introducing the LSM_HOOK macro for cleaning up lsm_hooks.h.
+        */
+       if (within_error_injection_list(addr) ||
+           !strncmp(SECURITY_PREFIX, prog->aux->attach_func_name,
+                    sizeof(SECURITY_PREFIX) - 1))
+               return 0;
+
+       verbose(env, "fmod_ret attach_btf_id %u (%s) is not modifiable\n",
+               prog->aux->attach_btf_id, prog->aux->attach_func_name);
+
+       return -EINVAL;
+}
 
 static int check_attach_btf_id(struct bpf_verifier_env *env)
 {
@@ -9794,7 +9906,9 @@ static int check_attach_btf_id(struct bpf_verifier_env *env)
        if (prog->type == BPF_PROG_TYPE_STRUCT_OPS)
                return check_struct_ops_btf_id(env);
 
-       if (prog->type != BPF_PROG_TYPE_TRACING && !prog_extension)
+       if (prog->type != BPF_PROG_TYPE_TRACING &&
+           prog->type != BPF_PROG_TYPE_LSM &&
+           !prog_extension)
                return 0;
 
        if (!btf_id) {
@@ -9924,8 +10038,17 @@ static int check_attach_btf_id(struct bpf_verifier_env *env)
                if (!prog_extension)
                        return -EINVAL;
                /* fallthrough */
+       case BPF_MODIFY_RETURN:
+       case BPF_LSM_MAC:
        case BPF_TRACE_FENTRY:
        case BPF_TRACE_FEXIT:
+               prog->aux->attach_func_name = tname;
+               if (prog->type == BPF_PROG_TYPE_LSM) {
+                       ret = bpf_lsm_verify_prog(&env->log, prog);
+                       if (ret < 0)
+                               return ret;
+               }
+
                if (!btf_type_is_func(t)) {
                        verbose(env, "attach_btf_id %u is not a function\n",
                                btf_id);
@@ -9940,7 +10063,6 @@ static int check_attach_btf_id(struct bpf_verifier_env *env)
                tr = bpf_trampoline_lookup(key);
                if (!tr)
                        return -ENOMEM;
-               prog->aux->attach_func_name = tname;
                /* t is either vmlinux type or another program's type */
                prog->aux->attach_func_proto = t;
                mutex_lock(&tr->mutex);
@@ -9973,6 +10095,9 @@ static int check_attach_btf_id(struct bpf_verifier_env *env)
                }
                tr->func.addr = (void *)addr;
                prog->aux->trampoline = tr;
+
+               if (prog->expected_attach_type == BPF_MODIFY_RETURN)
+                       ret = check_attach_modify_return(env);
 out:
                mutex_unlock(&tr->mutex);
                if (ret)