bpf: validate value_type
authorKui-Feng Lee <thinker.li@gmail.com>
Fri, 19 Jan 2024 22:50:01 +0000 (14:50 -0800)
committerMartin KaFai Lau <martin.lau@kernel.org>
Wed, 24 Jan 2024 00:37:45 +0000 (16:37 -0800)
A value_type should consist of three components: refcnt, state, and data.
refcnt and state has been move to struct bpf_struct_ops_common_value to
make it easier to check the value type.

Signed-off-by: Kui-Feng Lee <thinker.li@gmail.com>
Link: https://lore.kernel.org/r/20240119225005.668602-11-thinker.li@gmail.com
Signed-off-by: Martin KaFai Lau <martin.lau@kernel.org>
include/linux/bpf.h
kernel/bpf/bpf_struct_ops.c

index a5b4258..7c17817 100644 (file)
@@ -1688,6 +1688,18 @@ struct bpf_struct_ops_desc {
        u32 value_id;
 };
 
+enum bpf_struct_ops_state {
+       BPF_STRUCT_OPS_STATE_INIT,
+       BPF_STRUCT_OPS_STATE_INUSE,
+       BPF_STRUCT_OPS_STATE_TOBEFREE,
+       BPF_STRUCT_OPS_STATE_READY,
+};
+
+struct bpf_struct_ops_common_value {
+       refcount_t refcnt;
+       enum bpf_struct_ops_state state;
+};
+
 #if defined(CONFIG_BPF_JIT) && defined(CONFIG_BPF_SYSCALL)
 #define BPF_MODULE_OWNER ((void *)((0xeB9FUL << 2) + POISON_POINTER_DELTA))
 const struct bpf_struct_ops_desc *bpf_struct_ops_find(struct btf *btf, u32 type_id);
index 02216a8..30ab34f 100644 (file)
 #include <linux/btf_ids.h>
 #include <linux/rcupdate_wait.h>
 
-enum bpf_struct_ops_state {
-       BPF_STRUCT_OPS_STATE_INIT,
-       BPF_STRUCT_OPS_STATE_INUSE,
-       BPF_STRUCT_OPS_STATE_TOBEFREE,
-       BPF_STRUCT_OPS_STATE_READY,
-};
-
-#define BPF_STRUCT_OPS_COMMON_VALUE                    \
-       refcount_t refcnt;                              \
-       enum bpf_struct_ops_state state
-
 struct bpf_struct_ops_value {
-       BPF_STRUCT_OPS_COMMON_VALUE;
+       struct bpf_struct_ops_common_value common;
        char data[] ____cacheline_aligned_in_smp;
 };
 
@@ -81,8 +70,8 @@ static DEFINE_MUTEX(update_mutex);
 #define BPF_STRUCT_OPS_TYPE(_name)                             \
 extern struct bpf_struct_ops bpf_##_name;                      \
                                                                \
-struct bpf_struct_ops_##_name {                                                \
-       BPF_STRUCT_OPS_COMMON_VALUE;                            \
+struct bpf_struct_ops_##_name {                                        \
+       struct bpf_struct_ops_common_value common;              \
        struct _name data ____cacheline_aligned_in_smp;         \
 };
 #include "bpf_struct_ops_types.h"
@@ -113,11 +102,49 @@ const struct bpf_prog_ops bpf_struct_ops_prog_ops = {
 
 BTF_ID_LIST(st_ops_ids)
 BTF_ID(struct, module)
+BTF_ID(struct, bpf_struct_ops_common_value)
 
 enum {
        IDX_MODULE_ID,
+       IDX_ST_OPS_COMMON_VALUE_ID,
 };
 
+extern struct btf *btf_vmlinux;
+
+static bool is_valid_value_type(struct btf *btf, s32 value_id,
+                               const struct btf_type *type,
+                               const char *value_name)
+{
+       const struct btf_type *common_value_type;
+       const struct btf_member *member;
+       const struct btf_type *vt, *mt;
+
+       vt = btf_type_by_id(btf, value_id);
+       if (btf_vlen(vt) != 2) {
+               pr_warn("The number of %s's members should be 2, but we get %d\n",
+                       value_name, btf_vlen(vt));
+               return false;
+       }
+       member = btf_type_member(vt);
+       mt = btf_type_by_id(btf, member->type);
+       common_value_type = btf_type_by_id(btf_vmlinux,
+                                          st_ops_ids[IDX_ST_OPS_COMMON_VALUE_ID]);
+       if (mt != common_value_type) {
+               pr_warn("The first member of %s should be bpf_struct_ops_common_value\n",
+                       value_name);
+               return false;
+       }
+       member++;
+       mt = btf_type_by_id(btf, member->type);
+       if (mt != type) {
+               pr_warn("The second member of %s should be %s\n",
+                       value_name, btf_name_by_offset(btf, type->name_off));
+               return false;
+       }
+
+       return true;
+}
+
 static void bpf_struct_ops_desc_init(struct bpf_struct_ops_desc *st_ops_desc,
                                     struct btf *btf,
                                     struct bpf_verifier_log *log)
@@ -138,14 +165,6 @@ static void bpf_struct_ops_desc_init(struct bpf_struct_ops_desc *st_ops_desc,
        }
        sprintf(value_name, "%s%s", VALUE_PREFIX, st_ops->name);
 
-       value_id = btf_find_by_name_kind(btf, value_name,
-                                        BTF_KIND_STRUCT);
-       if (value_id < 0) {
-               pr_warn("Cannot find struct %s in %s\n",
-                       value_name, btf_get_name(btf));
-               return;
-       }
-
        type_id = btf_find_by_name_kind(btf, st_ops->name,
                                        BTF_KIND_STRUCT);
        if (type_id < 0) {
@@ -160,6 +179,16 @@ static void bpf_struct_ops_desc_init(struct bpf_struct_ops_desc *st_ops_desc,
                return;
        }
 
+       value_id = btf_find_by_name_kind(btf, value_name,
+                                        BTF_KIND_STRUCT);
+       if (value_id < 0) {
+               pr_warn("Cannot find struct %s in %s\n",
+                       value_name, btf_get_name(btf));
+               return;
+       }
+       if (!is_valid_value_type(btf, value_id, t, value_name))
+               return;
+
        for_each_member(i, t, member) {
                const struct btf_type *func_proto;
 
@@ -219,8 +248,6 @@ void bpf_struct_ops_init(struct btf *btf, struct bpf_verifier_log *log)
        }
 }
 
-extern struct btf *btf_vmlinux;
-
 static const struct bpf_struct_ops_desc *
 bpf_struct_ops_find_value(struct btf *btf, u32 value_id)
 {
@@ -276,7 +303,7 @@ int bpf_struct_ops_map_sys_lookup_elem(struct bpf_map *map, void *key,
 
        kvalue = &st_map->kvalue;
        /* Pair with smp_store_release() during map_update */
-       state = smp_load_acquire(&kvalue->state);
+       state = smp_load_acquire(&kvalue->common.state);
        if (state == BPF_STRUCT_OPS_STATE_INIT) {
                memset(value, 0, map->value_size);
                return 0;
@@ -287,7 +314,7 @@ int bpf_struct_ops_map_sys_lookup_elem(struct bpf_map *map, void *key,
         */
        uvalue = value;
        memcpy(uvalue, st_map->uvalue, map->value_size);
-       uvalue->state = state;
+       uvalue->common.state = state;
 
        /* This value offers the user space a general estimate of how
         * many sockets are still utilizing this struct_ops for TCP
@@ -295,7 +322,7 @@ int bpf_struct_ops_map_sys_lookup_elem(struct bpf_map *map, void *key,
         * should sufficiently meet our present goals.
         */
        refcnt = atomic64_read(&map->refcnt) - atomic64_read(&map->usercnt);
-       refcount_set(&uvalue->refcnt, max_t(s64, refcnt, 0));
+       refcount_set(&uvalue->common.refcnt, max_t(s64, refcnt, 0));
 
        return 0;
 }
@@ -413,7 +440,7 @@ static long bpf_struct_ops_map_update_elem(struct bpf_map *map, void *key,
        if (err)
                return err;
 
-       if (uvalue->state || refcount_read(&uvalue->refcnt))
+       if (uvalue->common.state || refcount_read(&uvalue->common.refcnt))
                return -EINVAL;
 
        tlinks = kcalloc(BPF_TRAMP_MAX, sizeof(*tlinks), GFP_KERNEL);
@@ -425,7 +452,7 @@ static long bpf_struct_ops_map_update_elem(struct bpf_map *map, void *key,
 
        mutex_lock(&st_map->lock);
 
-       if (kvalue->state != BPF_STRUCT_OPS_STATE_INIT) {
+       if (kvalue->common.state != BPF_STRUCT_OPS_STATE_INIT) {
                err = -EBUSY;
                goto unlock;
        }
@@ -540,7 +567,7 @@ static long bpf_struct_ops_map_update_elem(struct bpf_map *map, void *key,
                 *
                 * Pair with smp_load_acquire() during lookup_elem().
                 */
-               smp_store_release(&kvalue->state, BPF_STRUCT_OPS_STATE_READY);
+               smp_store_release(&kvalue->common.state, BPF_STRUCT_OPS_STATE_READY);
                goto unlock;
        }
 
@@ -558,7 +585,7 @@ static long bpf_struct_ops_map_update_elem(struct bpf_map *map, void *key,
                 * It ensures the above udata updates (e.g. prog->aux->id)
                 * can be seen once BPF_STRUCT_OPS_STATE_INUSE is set.
                 */
-               smp_store_release(&kvalue->state, BPF_STRUCT_OPS_STATE_INUSE);
+               smp_store_release(&kvalue->common.state, BPF_STRUCT_OPS_STATE_INUSE);
                goto unlock;
        }
 
@@ -588,7 +615,7 @@ static long bpf_struct_ops_map_delete_elem(struct bpf_map *map, void *key)
        if (st_map->map.map_flags & BPF_F_LINK)
                return -EOPNOTSUPP;
 
-       prev_state = cmpxchg(&st_map->kvalue.state,
+       prev_state = cmpxchg(&st_map->kvalue.common.state,
                             BPF_STRUCT_OPS_STATE_INUSE,
                             BPF_STRUCT_OPS_STATE_TOBEFREE);
        switch (prev_state) {
@@ -848,7 +875,7 @@ static bool bpf_struct_ops_valid_to_reg(struct bpf_map *map)
        return map->map_type == BPF_MAP_TYPE_STRUCT_OPS &&
                map->map_flags & BPF_F_LINK &&
                /* Pair with smp_store_release() during map_update */
-               smp_load_acquire(&st_map->kvalue.state) == BPF_STRUCT_OPS_STATE_READY;
+               smp_load_acquire(&st_map->kvalue.common.state) == BPF_STRUCT_OPS_STATE_READY;
 }
 
 static void bpf_struct_ops_map_link_dealloc(struct bpf_link *link)