bpf: fix subprog verifier bypass by div/mod by 0 exception
[linux-2.6-microblaze.git] / net / core / filter.c
index d339ef1..08ab4c6 100644 (file)
@@ -401,8 +401,8 @@ do_pass:
                /* Classic BPF expects A and X to be reset first. These need
                 * to be guaranteed to be the first two instructions.
                 */
-               *new_insn++ = BPF_ALU64_REG(BPF_XOR, BPF_REG_A, BPF_REG_A);
-               *new_insn++ = BPF_ALU64_REG(BPF_XOR, BPF_REG_X, BPF_REG_X);
+               *new_insn++ = BPF_ALU32_REG(BPF_XOR, BPF_REG_A, BPF_REG_A);
+               *new_insn++ = BPF_ALU32_REG(BPF_XOR, BPF_REG_X, BPF_REG_X);
 
                /* All programs must keep CTX in callee saved BPF_REG_CTX.
                 * In eBPF case it's done by the compiler, here we need to
@@ -458,6 +458,17 @@ do_pass:
                            convert_bpf_extensions(fp, &insn))
                                break;
 
+                       if (fp->code == (BPF_ALU | BPF_DIV | BPF_X) ||
+                           fp->code == (BPF_ALU | BPF_MOD | BPF_X)) {
+                               *insn++ = BPF_MOV32_REG(BPF_REG_X, BPF_REG_X);
+                               /* Error with exception code on div/mod by 0.
+                                * For cBPF programs, this was always return 0.
+                                */
+                               *insn++ = BPF_JMP_IMM(BPF_JNE, BPF_REG_X, 0, 2);
+                               *insn++ = BPF_ALU32_REG(BPF_XOR, BPF_REG_A, BPF_REG_A);
+                               *insn++ = BPF_EXIT_INSN();
+                       }
+
                        *insn = BPF_RAW_INSN(fp->code, BPF_REG_A, BPF_REG_X, 0, fp->k);
                        break;
 
@@ -2682,8 +2693,9 @@ static int __xdp_generic_ok_fwd_dev(struct sk_buff *skb, struct net_device *fwd)
        return 0;
 }
 
-int xdp_do_generic_redirect_map(struct net_device *dev, struct sk_buff *skb,
-                               struct bpf_prog *xdp_prog)
+static int xdp_do_generic_redirect_map(struct net_device *dev,
+                                      struct sk_buff *skb,
+                                      struct bpf_prog *xdp_prog)
 {
        struct redirect_info *ri = this_cpu_ptr(&redirect_info);
        unsigned long map_owner = ri->map_owner;
@@ -2860,7 +2872,7 @@ static const struct bpf_func_proto bpf_skb_event_output_proto = {
        .arg2_type      = ARG_CONST_MAP_PTR,
        .arg3_type      = ARG_ANYTHING,
        .arg4_type      = ARG_PTR_TO_MEM,
-       .arg5_type      = ARG_CONST_SIZE,
+       .arg5_type      = ARG_CONST_SIZE_OR_ZERO,
 };
 
 static unsigned short bpf_tunnel_key_af(u64 flags)
@@ -3011,6 +3023,8 @@ BPF_CALL_4(bpf_skb_set_tunnel_key, struct sk_buff *, skb,
        info->key.tun_flags = TUNNEL_KEY | TUNNEL_CSUM | TUNNEL_NOCACHE;
        if (flags & BPF_F_DONT_FRAGMENT)
                info->key.tun_flags |= TUNNEL_DONT_FRAGMENT;
+       if (flags & BPF_F_ZERO_CSUM_TX)
+               info->key.tun_flags &= ~TUNNEL_CSUM;
 
        info->key.tun_id = cpu_to_be64(from->tunnel_id);
        info->key.tos = from->tunnel_tos;
@@ -3024,8 +3038,6 @@ BPF_CALL_4(bpf_skb_set_tunnel_key, struct sk_buff *, skb,
                                  IPV6_FLOWLABEL_MASK;
        } else {
                info->key.u.ipv4.dst = cpu_to_be32(from->remote_ipv4);
-               if (flags & BPF_F_ZERO_CSUM_TX)
-                       info->key.tun_flags &= ~TUNNEL_CSUM;
        }
 
        return 0;
@@ -3149,7 +3161,7 @@ static const struct bpf_func_proto bpf_xdp_event_output_proto = {
        .arg2_type      = ARG_CONST_MAP_PTR,
        .arg3_type      = ARG_ANYTHING,
        .arg4_type      = ARG_PTR_TO_MEM,
-       .arg5_type      = ARG_CONST_SIZE,
+       .arg5_type      = ARG_CONST_SIZE_OR_ZERO,
 };
 
 BPF_CALL_1(bpf_get_socket_cookie, struct sk_buff *, skb)
@@ -3227,6 +3239,29 @@ BPF_CALL_5(bpf_setsockopt, struct bpf_sock_ops_kern *, bpf_sock,
                        ret = -EINVAL;
                }
 #ifdef CONFIG_INET
+#if IS_ENABLED(CONFIG_IPV6)
+       } else if (level == SOL_IPV6) {
+               if (optlen != sizeof(int) || sk->sk_family != AF_INET6)
+                       return -EINVAL;
+
+               val = *((int *)optval);
+               /* Only some options are supported */
+               switch (optname) {
+               case IPV6_TCLASS:
+                       if (val < -1 || val > 0xff) {
+                               ret = -EINVAL;
+                       } else {
+                               struct ipv6_pinfo *np = inet6_sk(sk);
+
+                               if (val == -1)
+                                       val = 0;
+                               np->tclass = val;
+                       }
+                       break;
+               default:
+                       ret = -EINVAL;
+               }
+#endif
        } else if (level == SOL_TCP &&
                   sk->sk_prot->setsockopt == tcp_setsockopt) {
                if (optname == TCP_CONGESTION) {
@@ -3236,7 +3271,8 @@ BPF_CALL_5(bpf_setsockopt, struct bpf_sock_ops_kern *, bpf_sock,
                        strncpy(name, optval, min_t(long, optlen,
                                                    TCP_CA_NAME_MAX-1));
                        name[TCP_CA_NAME_MAX-1] = 0;
-                       ret = tcp_set_congestion_control(sk, name, false, reinit);
+                       ret = tcp_set_congestion_control(sk, name, false,
+                                                        reinit);
                } else {
                        struct tcp_sock *tp = tcp_sk(sk);
 
@@ -3302,6 +3338,22 @@ BPF_CALL_5(bpf_getsockopt, struct bpf_sock_ops_kern *, bpf_sock,
                } else {
                        goto err_clear;
                }
+#if IS_ENABLED(CONFIG_IPV6)
+       } else if (level == SOL_IPV6) {
+               struct ipv6_pinfo *np = inet6_sk(sk);
+
+               if (optlen != sizeof(int) || sk->sk_family != AF_INET6)
+                       goto err_clear;
+
+               /* Only some options are supported */
+               switch (optname) {
+               case IPV6_TCLASS:
+                       *((int *)optval) = (int)np->tclass;
+                       break;
+               default:
+                       goto err_clear;
+               }
+#endif
        } else {
                goto err_clear;
        }
@@ -3323,6 +3375,33 @@ static const struct bpf_func_proto bpf_getsockopt_proto = {
        .arg5_type      = ARG_CONST_SIZE,
 };
 
+BPF_CALL_2(bpf_sock_ops_cb_flags_set, struct bpf_sock_ops_kern *, bpf_sock,
+          int, argval)
+{
+       struct sock *sk = bpf_sock->sk;
+       int val = argval & BPF_SOCK_OPS_ALL_CB_FLAGS;
+
+       if (!sk_fullsock(sk))
+               return -EINVAL;
+
+#ifdef CONFIG_INET
+       if (val)
+               tcp_sk(sk)->bpf_sock_ops_cb_flags = val;
+
+       return argval & (~BPF_SOCK_OPS_ALL_CB_FLAGS);
+#else
+       return -EINVAL;
+#endif
+}
+
+static const struct bpf_func_proto bpf_sock_ops_cb_flags_set_proto = {
+       .func           = bpf_sock_ops_cb_flags_set,
+       .gpl_only       = false,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_ANYTHING,
+};
+
 static const struct bpf_func_proto *
 bpf_base_func_proto(enum bpf_func_id func_id)
 {
@@ -3455,6 +3534,8 @@ xdp_func_proto(enum bpf_func_id func_id)
                return &bpf_xdp_event_output_proto;
        case BPF_FUNC_get_smp_processor_id:
                return &bpf_get_smp_processor_id_proto;
+       case BPF_FUNC_csum_diff:
+               return &bpf_csum_diff_proto;
        case BPF_FUNC_xdp_adjust_head:
                return &bpf_xdp_adjust_head_proto;
        case BPF_FUNC_xdp_adjust_meta:
@@ -3503,6 +3584,8 @@ static const struct bpf_func_proto *
                return &bpf_setsockopt_proto;
        case BPF_FUNC_getsockopt:
                return &bpf_getsockopt_proto;
+       case BPF_FUNC_sock_ops_cb_flags_set:
+               return &bpf_sock_ops_cb_flags_set_proto;
        case BPF_FUNC_sock_map_update:
                return &bpf_sock_map_update_proto;
        default:
@@ -3819,34 +3902,44 @@ void bpf_warn_invalid_xdp_action(u32 act)
 }
 EXPORT_SYMBOL_GPL(bpf_warn_invalid_xdp_action);
 
-static bool __is_valid_sock_ops_access(int off, int size)
+static bool sock_ops_is_valid_access(int off, int size,
+                                    enum bpf_access_type type,
+                                    struct bpf_insn_access_aux *info)
 {
+       const int size_default = sizeof(__u32);
+
        if (off < 0 || off >= sizeof(struct bpf_sock_ops))
                return false;
+
        /* The verifier guarantees that size > 0. */
        if (off % size != 0)
                return false;
-       if (size != sizeof(__u32))
-               return false;
 
-       return true;
-}
-
-static bool sock_ops_is_valid_access(int off, int size,
-                                    enum bpf_access_type type,
-                                    struct bpf_insn_access_aux *info)
-{
        if (type == BPF_WRITE) {
                switch (off) {
-               case offsetof(struct bpf_sock_ops, op) ...
-                    offsetof(struct bpf_sock_ops, replylong[3]):
+               case offsetof(struct bpf_sock_ops, reply):
+               case offsetof(struct bpf_sock_ops, sk_txhash):
+                       if (size != size_default)
+                               return false;
                        break;
                default:
                        return false;
                }
+       } else {
+               switch (off) {
+               case bpf_ctx_range_till(struct bpf_sock_ops, bytes_received,
+                                       bytes_acked):
+                       if (size != sizeof(__u64))
+                               return false;
+                       break;
+               default:
+                       if (size != size_default)
+                               return false;
+                       break;
+               }
        }
 
-       return __is_valid_sock_ops_access(off, size);
+       return true;
 }
 
 static int sk_skb_prologue(struct bpf_insn *insn_buf, bool direct_write,
@@ -4301,6 +4394,24 @@ static u32 xdp_convert_ctx_access(enum bpf_access_type type,
                                      si->dst_reg, si->src_reg,
                                      offsetof(struct xdp_buff, data_end));
                break;
+       case offsetof(struct xdp_md, ingress_ifindex):
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct xdp_buff, rxq),
+                                     si->dst_reg, si->src_reg,
+                                     offsetof(struct xdp_buff, rxq));
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct xdp_rxq_info, dev),
+                                     si->dst_reg, si->dst_reg,
+                                     offsetof(struct xdp_rxq_info, dev));
+               *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
+                                     offsetof(struct net_device, ifindex));
+               break;
+       case offsetof(struct xdp_md, rx_queue_index):
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct xdp_buff, rxq),
+                                     si->dst_reg, si->src_reg,
+                                     offsetof(struct xdp_buff, rxq));
+               *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
+                                     offsetof(struct xdp_rxq_info,
+                                              queue_index));
+               break;
        }
 
        return insn - insn_buf;
@@ -4435,6 +4546,211 @@ static u32 sock_ops_convert_ctx_access(enum bpf_access_type type,
                *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg,
                                      offsetof(struct sock_common, skc_num));
                break;
+
+       case offsetof(struct bpf_sock_ops, is_fullsock):
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
+                                               struct bpf_sock_ops_kern,
+                                               is_fullsock),
+                                     si->dst_reg, si->src_reg,
+                                     offsetof(struct bpf_sock_ops_kern,
+                                              is_fullsock));
+               break;
+
+       case offsetof(struct bpf_sock_ops, state):
+               BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_state) != 1);
+
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
+                                               struct bpf_sock_ops_kern, sk),
+                                     si->dst_reg, si->src_reg,
+                                     offsetof(struct bpf_sock_ops_kern, sk));
+               *insn++ = BPF_LDX_MEM(BPF_B, si->dst_reg, si->dst_reg,
+                                     offsetof(struct sock_common, skc_state));
+               break;
+
+       case offsetof(struct bpf_sock_ops, rtt_min):
+               BUILD_BUG_ON(FIELD_SIZEOF(struct tcp_sock, rtt_min) !=
+                            sizeof(struct minmax));
+               BUILD_BUG_ON(sizeof(struct minmax) <
+                            sizeof(struct minmax_sample));
+
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
+                                               struct bpf_sock_ops_kern, sk),
+                                     si->dst_reg, si->src_reg,
+                                     offsetof(struct bpf_sock_ops_kern, sk));
+               *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
+                                     offsetof(struct tcp_sock, rtt_min) +
+                                     FIELD_SIZEOF(struct minmax_sample, t));
+               break;
+
+/* Helper macro for adding read access to tcp_sock or sock fields. */
+#define SOCK_OPS_GET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ)                        \
+       do {                                                                  \
+               BUILD_BUG_ON(FIELD_SIZEOF(OBJ, OBJ_FIELD) >                   \
+                            FIELD_SIZEOF(struct bpf_sock_ops, BPF_FIELD));   \
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(                       \
+                                               struct bpf_sock_ops_kern,     \
+                                               is_fullsock),                 \
+                                     si->dst_reg, si->src_reg,               \
+                                     offsetof(struct bpf_sock_ops_kern,      \
+                                              is_fullsock));                 \
+               *insn++ = BPF_JMP_IMM(BPF_JEQ, si->dst_reg, 0, 2);            \
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(                       \
+                                               struct bpf_sock_ops_kern, sk),\
+                                     si->dst_reg, si->src_reg,               \
+                                     offsetof(struct bpf_sock_ops_kern, sk));\
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(OBJ,                   \
+                                                      OBJ_FIELD),            \
+                                     si->dst_reg, si->dst_reg,               \
+                                     offsetof(OBJ, OBJ_FIELD));              \
+       } while (0)
+
+/* Helper macro for adding write access to tcp_sock or sock fields.
+ * The macro is called with two registers, dst_reg which contains a pointer
+ * to ctx (context) and src_reg which contains the value that should be
+ * stored. However, we need an additional register since we cannot overwrite
+ * dst_reg because it may be used later in the program.
+ * Instead we "borrow" one of the other register. We first save its value
+ * into a new (temp) field in bpf_sock_ops_kern, use it, and then restore
+ * it at the end of the macro.
+ */
+#define SOCK_OPS_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ)                        \
+       do {                                                                  \
+               int reg = BPF_REG_9;                                          \
+               BUILD_BUG_ON(FIELD_SIZEOF(OBJ, OBJ_FIELD) >                   \
+                            FIELD_SIZEOF(struct bpf_sock_ops, BPF_FIELD));   \
+               if (si->dst_reg == reg || si->src_reg == reg)                 \
+                       reg--;                                                \
+               if (si->dst_reg == reg || si->src_reg == reg)                 \
+                       reg--;                                                \
+               *insn++ = BPF_STX_MEM(BPF_DW, si->dst_reg, reg,               \
+                                     offsetof(struct bpf_sock_ops_kern,      \
+                                              temp));                        \
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(                       \
+                                               struct bpf_sock_ops_kern,     \
+                                               is_fullsock),                 \
+                                     reg, si->dst_reg,                       \
+                                     offsetof(struct bpf_sock_ops_kern,      \
+                                              is_fullsock));                 \
+               *insn++ = BPF_JMP_IMM(BPF_JEQ, reg, 0, 2);                    \
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(                       \
+                                               struct bpf_sock_ops_kern, sk),\
+                                     reg, si->dst_reg,                       \
+                                     offsetof(struct bpf_sock_ops_kern, sk));\
+               *insn++ = BPF_STX_MEM(BPF_FIELD_SIZEOF(OBJ, OBJ_FIELD),       \
+                                     reg, si->src_reg,                       \
+                                     offsetof(OBJ, OBJ_FIELD));              \
+               *insn++ = BPF_LDX_MEM(BPF_DW, reg, si->dst_reg,               \
+                                     offsetof(struct bpf_sock_ops_kern,      \
+                                              temp));                        \
+       } while (0)
+
+#define SOCK_OPS_GET_OR_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ, TYPE)           \
+       do {                                                                  \
+               if (TYPE == BPF_WRITE)                                        \
+                       SOCK_OPS_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ);        \
+               else                                                          \
+                       SOCK_OPS_GET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ);        \
+       } while (0)
+
+       case offsetof(struct bpf_sock_ops, snd_cwnd):
+               SOCK_OPS_GET_FIELD(snd_cwnd, snd_cwnd, struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, srtt_us):
+               SOCK_OPS_GET_FIELD(srtt_us, srtt_us, struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, bpf_sock_ops_cb_flags):
+               SOCK_OPS_GET_FIELD(bpf_sock_ops_cb_flags, bpf_sock_ops_cb_flags,
+                                  struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, snd_ssthresh):
+               SOCK_OPS_GET_FIELD(snd_ssthresh, snd_ssthresh, struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, rcv_nxt):
+               SOCK_OPS_GET_FIELD(rcv_nxt, rcv_nxt, struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, snd_nxt):
+               SOCK_OPS_GET_FIELD(snd_nxt, snd_nxt, struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, snd_una):
+               SOCK_OPS_GET_FIELD(snd_una, snd_una, struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, mss_cache):
+               SOCK_OPS_GET_FIELD(mss_cache, mss_cache, struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, ecn_flags):
+               SOCK_OPS_GET_FIELD(ecn_flags, ecn_flags, struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, rate_delivered):
+               SOCK_OPS_GET_FIELD(rate_delivered, rate_delivered,
+                                  struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, rate_interval_us):
+               SOCK_OPS_GET_FIELD(rate_interval_us, rate_interval_us,
+                                  struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, packets_out):
+               SOCK_OPS_GET_FIELD(packets_out, packets_out, struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, retrans_out):
+               SOCK_OPS_GET_FIELD(retrans_out, retrans_out, struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, total_retrans):
+               SOCK_OPS_GET_FIELD(total_retrans, total_retrans,
+                                  struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, segs_in):
+               SOCK_OPS_GET_FIELD(segs_in, segs_in, struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, data_segs_in):
+               SOCK_OPS_GET_FIELD(data_segs_in, data_segs_in, struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, segs_out):
+               SOCK_OPS_GET_FIELD(segs_out, segs_out, struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, data_segs_out):
+               SOCK_OPS_GET_FIELD(data_segs_out, data_segs_out,
+                                  struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, lost_out):
+               SOCK_OPS_GET_FIELD(lost_out, lost_out, struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, sacked_out):
+               SOCK_OPS_GET_FIELD(sacked_out, sacked_out, struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, sk_txhash):
+               SOCK_OPS_GET_OR_SET_FIELD(sk_txhash, sk_txhash,
+                                         struct sock, type);
+               break;
+
+       case offsetof(struct bpf_sock_ops, bytes_received):
+               SOCK_OPS_GET_FIELD(bytes_received, bytes_received,
+                                  struct tcp_sock);
+               break;
+
+       case offsetof(struct bpf_sock_ops, bytes_acked):
+               SOCK_OPS_GET_FIELD(bytes_acked, bytes_acked, struct tcp_sock);
+               break;
+
        }
        return insn - insn_buf;
 }
@@ -4471,6 +4787,7 @@ const struct bpf_verifier_ops sk_filter_verifier_ops = {
 };
 
 const struct bpf_prog_ops sk_filter_prog_ops = {
+       .test_run               = bpf_prog_test_run_skb,
 };
 
 const struct bpf_verifier_ops tc_cls_act_verifier_ops = {