Merge tag 'for-linus' of git://git.armlinux.org.uk/~rmk/linux-arm
[linux-2.6-microblaze.git] / net / core / bpf_sk_storage.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2019 Facebook  */
3 #include <linux/rculist.h>
4 #include <linux/list.h>
5 #include <linux/hash.h>
6 #include <linux/types.h>
7 #include <linux/spinlock.h>
8 #include <linux/bpf.h>
9 #include <linux/btf.h>
10 #include <linux/btf_ids.h>
11 #include <linux/bpf_local_storage.h>
12 #include <net/bpf_sk_storage.h>
13 #include <net/sock.h>
14 #include <uapi/linux/sock_diag.h>
15 #include <uapi/linux/btf.h>
16 #include <linux/rcupdate_trace.h>
17
18 DEFINE_BPF_STORAGE_CACHE(sk_cache);
19
20 static struct bpf_local_storage_data *
21 bpf_sk_storage_lookup(struct sock *sk, struct bpf_map *map, bool cacheit_lockit)
22 {
23         struct bpf_local_storage *sk_storage;
24         struct bpf_local_storage_map *smap;
25
26         sk_storage =
27                 rcu_dereference_check(sk->sk_bpf_storage, bpf_rcu_lock_held());
28         if (!sk_storage)
29                 return NULL;
30
31         smap = (struct bpf_local_storage_map *)map;
32         return bpf_local_storage_lookup(sk_storage, smap, cacheit_lockit);
33 }
34
35 static int bpf_sk_storage_del(struct sock *sk, struct bpf_map *map)
36 {
37         struct bpf_local_storage_data *sdata;
38
39         sdata = bpf_sk_storage_lookup(sk, map, false);
40         if (!sdata)
41                 return -ENOENT;
42
43         bpf_selem_unlink(SELEM(sdata), true);
44
45         return 0;
46 }
47
48 /* Called by __sk_destruct() & bpf_sk_storage_clone() */
49 void bpf_sk_storage_free(struct sock *sk)
50 {
51         struct bpf_local_storage_elem *selem;
52         struct bpf_local_storage *sk_storage;
53         bool free_sk_storage = false;
54         struct hlist_node *n;
55
56         rcu_read_lock();
57         sk_storage = rcu_dereference(sk->sk_bpf_storage);
58         if (!sk_storage) {
59                 rcu_read_unlock();
60                 return;
61         }
62
63         /* Netiher the bpf_prog nor the bpf-map's syscall
64          * could be modifying the sk_storage->list now.
65          * Thus, no elem can be added-to or deleted-from the
66          * sk_storage->list by the bpf_prog or by the bpf-map's syscall.
67          *
68          * It is racing with bpf_local_storage_map_free() alone
69          * when unlinking elem from the sk_storage->list and
70          * the map's bucket->list.
71          */
72         raw_spin_lock_bh(&sk_storage->lock);
73         hlist_for_each_entry_safe(selem, n, &sk_storage->list, snode) {
74                 /* Always unlink from map before unlinking from
75                  * sk_storage.
76                  */
77                 bpf_selem_unlink_map(selem);
78                 free_sk_storage = bpf_selem_unlink_storage_nolock(
79                         sk_storage, selem, true, false);
80         }
81         raw_spin_unlock_bh(&sk_storage->lock);
82         rcu_read_unlock();
83
84         if (free_sk_storage)
85                 kfree_rcu(sk_storage, rcu);
86 }
87
88 static void bpf_sk_storage_map_free(struct bpf_map *map)
89 {
90         struct bpf_local_storage_map *smap;
91
92         smap = (struct bpf_local_storage_map *)map;
93         bpf_local_storage_cache_idx_free(&sk_cache, smap->cache_idx);
94         bpf_local_storage_map_free(smap, NULL);
95 }
96
97 static struct bpf_map *bpf_sk_storage_map_alloc(union bpf_attr *attr)
98 {
99         struct bpf_local_storage_map *smap;
100
101         smap = bpf_local_storage_map_alloc(attr);
102         if (IS_ERR(smap))
103                 return ERR_CAST(smap);
104
105         smap->cache_idx = bpf_local_storage_cache_idx_get(&sk_cache);
106         return &smap->map;
107 }
108
109 static int notsupp_get_next_key(struct bpf_map *map, void *key,
110                                 void *next_key)
111 {
112         return -ENOTSUPP;
113 }
114
115 static void *bpf_fd_sk_storage_lookup_elem(struct bpf_map *map, void *key)
116 {
117         struct bpf_local_storage_data *sdata;
118         struct socket *sock;
119         int fd, err;
120
121         fd = *(int *)key;
122         sock = sockfd_lookup(fd, &err);
123         if (sock) {
124                 sdata = bpf_sk_storage_lookup(sock->sk, map, true);
125                 sockfd_put(sock);
126                 return sdata ? sdata->data : NULL;
127         }
128
129         return ERR_PTR(err);
130 }
131
132 static int bpf_fd_sk_storage_update_elem(struct bpf_map *map, void *key,
133                                          void *value, u64 map_flags)
134 {
135         struct bpf_local_storage_data *sdata;
136         struct socket *sock;
137         int fd, err;
138
139         fd = *(int *)key;
140         sock = sockfd_lookup(fd, &err);
141         if (sock) {
142                 sdata = bpf_local_storage_update(
143                         sock->sk, (struct bpf_local_storage_map *)map, value,
144                         map_flags, GFP_ATOMIC);
145                 sockfd_put(sock);
146                 return PTR_ERR_OR_ZERO(sdata);
147         }
148
149         return err;
150 }
151
152 static int bpf_fd_sk_storage_delete_elem(struct bpf_map *map, void *key)
153 {
154         struct socket *sock;
155         int fd, err;
156
157         fd = *(int *)key;
158         sock = sockfd_lookup(fd, &err);
159         if (sock) {
160                 err = bpf_sk_storage_del(sock->sk, map);
161                 sockfd_put(sock);
162                 return err;
163         }
164
165         return err;
166 }
167
168 static struct bpf_local_storage_elem *
169 bpf_sk_storage_clone_elem(struct sock *newsk,
170                           struct bpf_local_storage_map *smap,
171                           struct bpf_local_storage_elem *selem)
172 {
173         struct bpf_local_storage_elem *copy_selem;
174
175         copy_selem = bpf_selem_alloc(smap, newsk, NULL, true, GFP_ATOMIC);
176         if (!copy_selem)
177                 return NULL;
178
179         if (map_value_has_spin_lock(&smap->map))
180                 copy_map_value_locked(&smap->map, SDATA(copy_selem)->data,
181                                       SDATA(selem)->data, true);
182         else
183                 copy_map_value(&smap->map, SDATA(copy_selem)->data,
184                                SDATA(selem)->data);
185
186         return copy_selem;
187 }
188
189 int bpf_sk_storage_clone(const struct sock *sk, struct sock *newsk)
190 {
191         struct bpf_local_storage *new_sk_storage = NULL;
192         struct bpf_local_storage *sk_storage;
193         struct bpf_local_storage_elem *selem;
194         int ret = 0;
195
196         RCU_INIT_POINTER(newsk->sk_bpf_storage, NULL);
197
198         rcu_read_lock();
199         sk_storage = rcu_dereference(sk->sk_bpf_storage);
200
201         if (!sk_storage || hlist_empty(&sk_storage->list))
202                 goto out;
203
204         hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
205                 struct bpf_local_storage_elem *copy_selem;
206                 struct bpf_local_storage_map *smap;
207                 struct bpf_map *map;
208
209                 smap = rcu_dereference(SDATA(selem)->smap);
210                 if (!(smap->map.map_flags & BPF_F_CLONE))
211                         continue;
212
213                 /* Note that for lockless listeners adding new element
214                  * here can race with cleanup in bpf_local_storage_map_free.
215                  * Try to grab map refcnt to make sure that it's still
216                  * alive and prevent concurrent removal.
217                  */
218                 map = bpf_map_inc_not_zero(&smap->map);
219                 if (IS_ERR(map))
220                         continue;
221
222                 copy_selem = bpf_sk_storage_clone_elem(newsk, smap, selem);
223                 if (!copy_selem) {
224                         ret = -ENOMEM;
225                         bpf_map_put(map);
226                         goto out;
227                 }
228
229                 if (new_sk_storage) {
230                         bpf_selem_link_map(smap, copy_selem);
231                         bpf_selem_link_storage_nolock(new_sk_storage, copy_selem);
232                 } else {
233                         ret = bpf_local_storage_alloc(newsk, smap, copy_selem, GFP_ATOMIC);
234                         if (ret) {
235                                 kfree(copy_selem);
236                                 atomic_sub(smap->elem_size,
237                                            &newsk->sk_omem_alloc);
238                                 bpf_map_put(map);
239                                 goto out;
240                         }
241
242                         new_sk_storage =
243                                 rcu_dereference(copy_selem->local_storage);
244                 }
245                 bpf_map_put(map);
246         }
247
248 out:
249         rcu_read_unlock();
250
251         /* In case of an error, don't free anything explicitly here, the
252          * caller is responsible to call bpf_sk_storage_free.
253          */
254
255         return ret;
256 }
257
258 /* *gfp_flags* is a hidden argument provided by the verifier */
259 BPF_CALL_5(bpf_sk_storage_get, struct bpf_map *, map, struct sock *, sk,
260            void *, value, u64, flags, gfp_t, gfp_flags)
261 {
262         struct bpf_local_storage_data *sdata;
263
264         WARN_ON_ONCE(!bpf_rcu_lock_held());
265         if (!sk || !sk_fullsock(sk) || flags > BPF_SK_STORAGE_GET_F_CREATE)
266                 return (unsigned long)NULL;
267
268         sdata = bpf_sk_storage_lookup(sk, map, true);
269         if (sdata)
270                 return (unsigned long)sdata->data;
271
272         if (flags == BPF_SK_STORAGE_GET_F_CREATE &&
273             /* Cannot add new elem to a going away sk.
274              * Otherwise, the new elem may become a leak
275              * (and also other memory issues during map
276              *  destruction).
277              */
278             refcount_inc_not_zero(&sk->sk_refcnt)) {
279                 sdata = bpf_local_storage_update(
280                         sk, (struct bpf_local_storage_map *)map, value,
281                         BPF_NOEXIST, gfp_flags);
282                 /* sk must be a fullsock (guaranteed by verifier),
283                  * so sock_gen_put() is unnecessary.
284                  */
285                 sock_put(sk);
286                 return IS_ERR(sdata) ?
287                         (unsigned long)NULL : (unsigned long)sdata->data;
288         }
289
290         return (unsigned long)NULL;
291 }
292
293 BPF_CALL_2(bpf_sk_storage_delete, struct bpf_map *, map, struct sock *, sk)
294 {
295         WARN_ON_ONCE(!bpf_rcu_lock_held());
296         if (!sk || !sk_fullsock(sk))
297                 return -EINVAL;
298
299         if (refcount_inc_not_zero(&sk->sk_refcnt)) {
300                 int err;
301
302                 err = bpf_sk_storage_del(sk, map);
303                 sock_put(sk);
304                 return err;
305         }
306
307         return -ENOENT;
308 }
309
310 static int bpf_sk_storage_charge(struct bpf_local_storage_map *smap,
311                                  void *owner, u32 size)
312 {
313         struct sock *sk = (struct sock *)owner;
314
315         /* same check as in sock_kmalloc() */
316         if (size <= sysctl_optmem_max &&
317             atomic_read(&sk->sk_omem_alloc) + size < sysctl_optmem_max) {
318                 atomic_add(size, &sk->sk_omem_alloc);
319                 return 0;
320         }
321
322         return -ENOMEM;
323 }
324
325 static void bpf_sk_storage_uncharge(struct bpf_local_storage_map *smap,
326                                     void *owner, u32 size)
327 {
328         struct sock *sk = owner;
329
330         atomic_sub(size, &sk->sk_omem_alloc);
331 }
332
333 static struct bpf_local_storage __rcu **
334 bpf_sk_storage_ptr(void *owner)
335 {
336         struct sock *sk = owner;
337
338         return &sk->sk_bpf_storage;
339 }
340
341 BTF_ID_LIST_SINGLE(sk_storage_map_btf_ids, struct, bpf_local_storage_map)
342 const struct bpf_map_ops sk_storage_map_ops = {
343         .map_meta_equal = bpf_map_meta_equal,
344         .map_alloc_check = bpf_local_storage_map_alloc_check,
345         .map_alloc = bpf_sk_storage_map_alloc,
346         .map_free = bpf_sk_storage_map_free,
347         .map_get_next_key = notsupp_get_next_key,
348         .map_lookup_elem = bpf_fd_sk_storage_lookup_elem,
349         .map_update_elem = bpf_fd_sk_storage_update_elem,
350         .map_delete_elem = bpf_fd_sk_storage_delete_elem,
351         .map_check_btf = bpf_local_storage_map_check_btf,
352         .map_btf_id = &sk_storage_map_btf_ids[0],
353         .map_local_storage_charge = bpf_sk_storage_charge,
354         .map_local_storage_uncharge = bpf_sk_storage_uncharge,
355         .map_owner_storage_ptr = bpf_sk_storage_ptr,
356 };
357
358 const struct bpf_func_proto bpf_sk_storage_get_proto = {
359         .func           = bpf_sk_storage_get,
360         .gpl_only       = false,
361         .ret_type       = RET_PTR_TO_MAP_VALUE_OR_NULL,
362         .arg1_type      = ARG_CONST_MAP_PTR,
363         .arg2_type      = ARG_PTR_TO_BTF_ID_SOCK_COMMON,
364         .arg3_type      = ARG_PTR_TO_MAP_VALUE_OR_NULL,
365         .arg4_type      = ARG_ANYTHING,
366 };
367
368 const struct bpf_func_proto bpf_sk_storage_get_cg_sock_proto = {
369         .func           = bpf_sk_storage_get,
370         .gpl_only       = false,
371         .ret_type       = RET_PTR_TO_MAP_VALUE_OR_NULL,
372         .arg1_type      = ARG_CONST_MAP_PTR,
373         .arg2_type      = ARG_PTR_TO_CTX, /* context is 'struct sock' */
374         .arg3_type      = ARG_PTR_TO_MAP_VALUE_OR_NULL,
375         .arg4_type      = ARG_ANYTHING,
376 };
377
378 const struct bpf_func_proto bpf_sk_storage_delete_proto = {
379         .func           = bpf_sk_storage_delete,
380         .gpl_only       = false,
381         .ret_type       = RET_INTEGER,
382         .arg1_type      = ARG_CONST_MAP_PTR,
383         .arg2_type      = ARG_PTR_TO_BTF_ID_SOCK_COMMON,
384 };
385
386 static bool bpf_sk_storage_tracing_allowed(const struct bpf_prog *prog)
387 {
388         const struct btf *btf_vmlinux;
389         const struct btf_type *t;
390         const char *tname;
391         u32 btf_id;
392
393         if (prog->aux->dst_prog)
394                 return false;
395
396         /* Ensure the tracing program is not tracing
397          * any bpf_sk_storage*() function and also
398          * use the bpf_sk_storage_(get|delete) helper.
399          */
400         switch (prog->expected_attach_type) {
401         case BPF_TRACE_ITER:
402         case BPF_TRACE_RAW_TP:
403                 /* bpf_sk_storage has no trace point */
404                 return true;
405         case BPF_TRACE_FENTRY:
406         case BPF_TRACE_FEXIT:
407                 btf_vmlinux = bpf_get_btf_vmlinux();
408                 if (IS_ERR_OR_NULL(btf_vmlinux))
409                         return false;
410                 btf_id = prog->aux->attach_btf_id;
411                 t = btf_type_by_id(btf_vmlinux, btf_id);
412                 tname = btf_name_by_offset(btf_vmlinux, t->name_off);
413                 return !!strncmp(tname, "bpf_sk_storage",
414                                  strlen("bpf_sk_storage"));
415         default:
416                 return false;
417         }
418
419         return false;
420 }
421
422 /* *gfp_flags* is a hidden argument provided by the verifier */
423 BPF_CALL_5(bpf_sk_storage_get_tracing, struct bpf_map *, map, struct sock *, sk,
424            void *, value, u64, flags, gfp_t, gfp_flags)
425 {
426         WARN_ON_ONCE(!bpf_rcu_lock_held());
427         if (in_hardirq() || in_nmi())
428                 return (unsigned long)NULL;
429
430         return (unsigned long)____bpf_sk_storage_get(map, sk, value, flags,
431                                                      gfp_flags);
432 }
433
434 BPF_CALL_2(bpf_sk_storage_delete_tracing, struct bpf_map *, map,
435            struct sock *, sk)
436 {
437         WARN_ON_ONCE(!bpf_rcu_lock_held());
438         if (in_hardirq() || in_nmi())
439                 return -EPERM;
440
441         return ____bpf_sk_storage_delete(map, sk);
442 }
443
444 const struct bpf_func_proto bpf_sk_storage_get_tracing_proto = {
445         .func           = bpf_sk_storage_get_tracing,
446         .gpl_only       = false,
447         .ret_type       = RET_PTR_TO_MAP_VALUE_OR_NULL,
448         .arg1_type      = ARG_CONST_MAP_PTR,
449         .arg2_type      = ARG_PTR_TO_BTF_ID,
450         .arg2_btf_id    = &btf_sock_ids[BTF_SOCK_TYPE_SOCK_COMMON],
451         .arg3_type      = ARG_PTR_TO_MAP_VALUE_OR_NULL,
452         .arg4_type      = ARG_ANYTHING,
453         .allowed        = bpf_sk_storage_tracing_allowed,
454 };
455
456 const struct bpf_func_proto bpf_sk_storage_delete_tracing_proto = {
457         .func           = bpf_sk_storage_delete_tracing,
458         .gpl_only       = false,
459         .ret_type       = RET_INTEGER,
460         .arg1_type      = ARG_CONST_MAP_PTR,
461         .arg2_type      = ARG_PTR_TO_BTF_ID,
462         .arg2_btf_id    = &btf_sock_ids[BTF_SOCK_TYPE_SOCK_COMMON],
463         .allowed        = bpf_sk_storage_tracing_allowed,
464 };
465
466 struct bpf_sk_storage_diag {
467         u32 nr_maps;
468         struct bpf_map *maps[];
469 };
470
471 /* The reply will be like:
472  * INET_DIAG_BPF_SK_STORAGES (nla_nest)
473  *      SK_DIAG_BPF_STORAGE (nla_nest)
474  *              SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
475  *              SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
476  *      SK_DIAG_BPF_STORAGE (nla_nest)
477  *              SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
478  *              SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
479  *      ....
480  */
481 static int nla_value_size(u32 value_size)
482 {
483         /* SK_DIAG_BPF_STORAGE (nla_nest)
484          *      SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
485          *      SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
486          */
487         return nla_total_size(0) + nla_total_size(sizeof(u32)) +
488                 nla_total_size_64bit(value_size);
489 }
490
491 void bpf_sk_storage_diag_free(struct bpf_sk_storage_diag *diag)
492 {
493         u32 i;
494
495         if (!diag)
496                 return;
497
498         for (i = 0; i < diag->nr_maps; i++)
499                 bpf_map_put(diag->maps[i]);
500
501         kfree(diag);
502 }
503 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_free);
504
505 static bool diag_check_dup(const struct bpf_sk_storage_diag *diag,
506                            const struct bpf_map *map)
507 {
508         u32 i;
509
510         for (i = 0; i < diag->nr_maps; i++) {
511                 if (diag->maps[i] == map)
512                         return true;
513         }
514
515         return false;
516 }
517
518 struct bpf_sk_storage_diag *
519 bpf_sk_storage_diag_alloc(const struct nlattr *nla_stgs)
520 {
521         struct bpf_sk_storage_diag *diag;
522         struct nlattr *nla;
523         u32 nr_maps = 0;
524         int rem, err;
525
526         /* bpf_local_storage_map is currently limited to CAP_SYS_ADMIN as
527          * the map_alloc_check() side also does.
528          */
529         if (!bpf_capable())
530                 return ERR_PTR(-EPERM);
531
532         nla_for_each_nested(nla, nla_stgs, rem) {
533                 if (nla_type(nla) == SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
534                         nr_maps++;
535         }
536
537         diag = kzalloc(struct_size(diag, maps, nr_maps), GFP_KERNEL);
538         if (!diag)
539                 return ERR_PTR(-ENOMEM);
540
541         nla_for_each_nested(nla, nla_stgs, rem) {
542                 struct bpf_map *map;
543                 int map_fd;
544
545                 if (nla_type(nla) != SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
546                         continue;
547
548                 map_fd = nla_get_u32(nla);
549                 map = bpf_map_get(map_fd);
550                 if (IS_ERR(map)) {
551                         err = PTR_ERR(map);
552                         goto err_free;
553                 }
554                 if (map->map_type != BPF_MAP_TYPE_SK_STORAGE) {
555                         bpf_map_put(map);
556                         err = -EINVAL;
557                         goto err_free;
558                 }
559                 if (diag_check_dup(diag, map)) {
560                         bpf_map_put(map);
561                         err = -EEXIST;
562                         goto err_free;
563                 }
564                 diag->maps[diag->nr_maps++] = map;
565         }
566
567         return diag;
568
569 err_free:
570         bpf_sk_storage_diag_free(diag);
571         return ERR_PTR(err);
572 }
573 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_alloc);
574
575 static int diag_get(struct bpf_local_storage_data *sdata, struct sk_buff *skb)
576 {
577         struct nlattr *nla_stg, *nla_value;
578         struct bpf_local_storage_map *smap;
579
580         /* It cannot exceed max nlattr's payload */
581         BUILD_BUG_ON(U16_MAX - NLA_HDRLEN < BPF_LOCAL_STORAGE_MAX_VALUE_SIZE);
582
583         nla_stg = nla_nest_start(skb, SK_DIAG_BPF_STORAGE);
584         if (!nla_stg)
585                 return -EMSGSIZE;
586
587         smap = rcu_dereference(sdata->smap);
588         if (nla_put_u32(skb, SK_DIAG_BPF_STORAGE_MAP_ID, smap->map.id))
589                 goto errout;
590
591         nla_value = nla_reserve_64bit(skb, SK_DIAG_BPF_STORAGE_MAP_VALUE,
592                                       smap->map.value_size,
593                                       SK_DIAG_BPF_STORAGE_PAD);
594         if (!nla_value)
595                 goto errout;
596
597         if (map_value_has_spin_lock(&smap->map))
598                 copy_map_value_locked(&smap->map, nla_data(nla_value),
599                                       sdata->data, true);
600         else
601                 copy_map_value(&smap->map, nla_data(nla_value), sdata->data);
602
603         nla_nest_end(skb, nla_stg);
604         return 0;
605
606 errout:
607         nla_nest_cancel(skb, nla_stg);
608         return -EMSGSIZE;
609 }
610
611 static int bpf_sk_storage_diag_put_all(struct sock *sk, struct sk_buff *skb,
612                                        int stg_array_type,
613                                        unsigned int *res_diag_size)
614 {
615         /* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
616         unsigned int diag_size = nla_total_size(0);
617         struct bpf_local_storage *sk_storage;
618         struct bpf_local_storage_elem *selem;
619         struct bpf_local_storage_map *smap;
620         struct nlattr *nla_stgs;
621         unsigned int saved_len;
622         int err = 0;
623
624         rcu_read_lock();
625
626         sk_storage = rcu_dereference(sk->sk_bpf_storage);
627         if (!sk_storage || hlist_empty(&sk_storage->list)) {
628                 rcu_read_unlock();
629                 return 0;
630         }
631
632         nla_stgs = nla_nest_start(skb, stg_array_type);
633         if (!nla_stgs)
634                 /* Continue to learn diag_size */
635                 err = -EMSGSIZE;
636
637         saved_len = skb->len;
638         hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
639                 smap = rcu_dereference(SDATA(selem)->smap);
640                 diag_size += nla_value_size(smap->map.value_size);
641
642                 if (nla_stgs && diag_get(SDATA(selem), skb))
643                         /* Continue to learn diag_size */
644                         err = -EMSGSIZE;
645         }
646
647         rcu_read_unlock();
648
649         if (nla_stgs) {
650                 if (saved_len == skb->len)
651                         nla_nest_cancel(skb, nla_stgs);
652                 else
653                         nla_nest_end(skb, nla_stgs);
654         }
655
656         if (diag_size == nla_total_size(0)) {
657                 *res_diag_size = 0;
658                 return 0;
659         }
660
661         *res_diag_size = diag_size;
662         return err;
663 }
664
665 int bpf_sk_storage_diag_put(struct bpf_sk_storage_diag *diag,
666                             struct sock *sk, struct sk_buff *skb,
667                             int stg_array_type,
668                             unsigned int *res_diag_size)
669 {
670         /* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
671         unsigned int diag_size = nla_total_size(0);
672         struct bpf_local_storage *sk_storage;
673         struct bpf_local_storage_data *sdata;
674         struct nlattr *nla_stgs;
675         unsigned int saved_len;
676         int err = 0;
677         u32 i;
678
679         *res_diag_size = 0;
680
681         /* No map has been specified.  Dump all. */
682         if (!diag->nr_maps)
683                 return bpf_sk_storage_diag_put_all(sk, skb, stg_array_type,
684                                                    res_diag_size);
685
686         rcu_read_lock();
687         sk_storage = rcu_dereference(sk->sk_bpf_storage);
688         if (!sk_storage || hlist_empty(&sk_storage->list)) {
689                 rcu_read_unlock();
690                 return 0;
691         }
692
693         nla_stgs = nla_nest_start(skb, stg_array_type);
694         if (!nla_stgs)
695                 /* Continue to learn diag_size */
696                 err = -EMSGSIZE;
697
698         saved_len = skb->len;
699         for (i = 0; i < diag->nr_maps; i++) {
700                 sdata = bpf_local_storage_lookup(sk_storage,
701                                 (struct bpf_local_storage_map *)diag->maps[i],
702                                 false);
703
704                 if (!sdata)
705                         continue;
706
707                 diag_size += nla_value_size(diag->maps[i]->value_size);
708
709                 if (nla_stgs && diag_get(sdata, skb))
710                         /* Continue to learn diag_size */
711                         err = -EMSGSIZE;
712         }
713         rcu_read_unlock();
714
715         if (nla_stgs) {
716                 if (saved_len == skb->len)
717                         nla_nest_cancel(skb, nla_stgs);
718                 else
719                         nla_nest_end(skb, nla_stgs);
720         }
721
722         if (diag_size == nla_total_size(0)) {
723                 *res_diag_size = 0;
724                 return 0;
725         }
726
727         *res_diag_size = diag_size;
728         return err;
729 }
730 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_put);
731
732 struct bpf_iter_seq_sk_storage_map_info {
733         struct bpf_map *map;
734         unsigned int bucket_id;
735         unsigned skip_elems;
736 };
737
738 static struct bpf_local_storage_elem *
739 bpf_sk_storage_map_seq_find_next(struct bpf_iter_seq_sk_storage_map_info *info,
740                                  struct bpf_local_storage_elem *prev_selem)
741         __acquires(RCU) __releases(RCU)
742 {
743         struct bpf_local_storage *sk_storage;
744         struct bpf_local_storage_elem *selem;
745         u32 skip_elems = info->skip_elems;
746         struct bpf_local_storage_map *smap;
747         u32 bucket_id = info->bucket_id;
748         u32 i, count, n_buckets;
749         struct bpf_local_storage_map_bucket *b;
750
751         smap = (struct bpf_local_storage_map *)info->map;
752         n_buckets = 1U << smap->bucket_log;
753         if (bucket_id >= n_buckets)
754                 return NULL;
755
756         /* try to find next selem in the same bucket */
757         selem = prev_selem;
758         count = 0;
759         while (selem) {
760                 selem = hlist_entry_safe(rcu_dereference(hlist_next_rcu(&selem->map_node)),
761                                          struct bpf_local_storage_elem, map_node);
762                 if (!selem) {
763                         /* not found, unlock and go to the next bucket */
764                         b = &smap->buckets[bucket_id++];
765                         rcu_read_unlock();
766                         skip_elems = 0;
767                         break;
768                 }
769                 sk_storage = rcu_dereference(selem->local_storage);
770                 if (sk_storage) {
771                         info->skip_elems = skip_elems + count;
772                         return selem;
773                 }
774                 count++;
775         }
776
777         for (i = bucket_id; i < (1U << smap->bucket_log); i++) {
778                 b = &smap->buckets[i];
779                 rcu_read_lock();
780                 count = 0;
781                 hlist_for_each_entry_rcu(selem, &b->list, map_node) {
782                         sk_storage = rcu_dereference(selem->local_storage);
783                         if (sk_storage && count >= skip_elems) {
784                                 info->bucket_id = i;
785                                 info->skip_elems = count;
786                                 return selem;
787                         }
788                         count++;
789                 }
790                 rcu_read_unlock();
791                 skip_elems = 0;
792         }
793
794         info->bucket_id = i;
795         info->skip_elems = 0;
796         return NULL;
797 }
798
799 static void *bpf_sk_storage_map_seq_start(struct seq_file *seq, loff_t *pos)
800 {
801         struct bpf_local_storage_elem *selem;
802
803         selem = bpf_sk_storage_map_seq_find_next(seq->private, NULL);
804         if (!selem)
805                 return NULL;
806
807         if (*pos == 0)
808                 ++*pos;
809         return selem;
810 }
811
812 static void *bpf_sk_storage_map_seq_next(struct seq_file *seq, void *v,
813                                          loff_t *pos)
814 {
815         struct bpf_iter_seq_sk_storage_map_info *info = seq->private;
816
817         ++*pos;
818         ++info->skip_elems;
819         return bpf_sk_storage_map_seq_find_next(seq->private, v);
820 }
821
822 struct bpf_iter__bpf_sk_storage_map {
823         __bpf_md_ptr(struct bpf_iter_meta *, meta);
824         __bpf_md_ptr(struct bpf_map *, map);
825         __bpf_md_ptr(struct sock *, sk);
826         __bpf_md_ptr(void *, value);
827 };
828
829 DEFINE_BPF_ITER_FUNC(bpf_sk_storage_map, struct bpf_iter_meta *meta,
830                      struct bpf_map *map, struct sock *sk,
831                      void *value)
832
833 static int __bpf_sk_storage_map_seq_show(struct seq_file *seq,
834                                          struct bpf_local_storage_elem *selem)
835 {
836         struct bpf_iter_seq_sk_storage_map_info *info = seq->private;
837         struct bpf_iter__bpf_sk_storage_map ctx = {};
838         struct bpf_local_storage *sk_storage;
839         struct bpf_iter_meta meta;
840         struct bpf_prog *prog;
841         int ret = 0;
842
843         meta.seq = seq;
844         prog = bpf_iter_get_info(&meta, selem == NULL);
845         if (prog) {
846                 ctx.meta = &meta;
847                 ctx.map = info->map;
848                 if (selem) {
849                         sk_storage = rcu_dereference(selem->local_storage);
850                         ctx.sk = sk_storage->owner;
851                         ctx.value = SDATA(selem)->data;
852                 }
853                 ret = bpf_iter_run_prog(prog, &ctx);
854         }
855
856         return ret;
857 }
858
859 static int bpf_sk_storage_map_seq_show(struct seq_file *seq, void *v)
860 {
861         return __bpf_sk_storage_map_seq_show(seq, v);
862 }
863
864 static void bpf_sk_storage_map_seq_stop(struct seq_file *seq, void *v)
865         __releases(RCU)
866 {
867         if (!v)
868                 (void)__bpf_sk_storage_map_seq_show(seq, v);
869         else
870                 rcu_read_unlock();
871 }
872
873 static int bpf_iter_init_sk_storage_map(void *priv_data,
874                                         struct bpf_iter_aux_info *aux)
875 {
876         struct bpf_iter_seq_sk_storage_map_info *seq_info = priv_data;
877
878         seq_info->map = aux->map;
879         return 0;
880 }
881
882 static int bpf_iter_attach_map(struct bpf_prog *prog,
883                                union bpf_iter_link_info *linfo,
884                                struct bpf_iter_aux_info *aux)
885 {
886         struct bpf_map *map;
887         int err = -EINVAL;
888
889         if (!linfo->map.map_fd)
890                 return -EBADF;
891
892         map = bpf_map_get_with_uref(linfo->map.map_fd);
893         if (IS_ERR(map))
894                 return PTR_ERR(map);
895
896         if (map->map_type != BPF_MAP_TYPE_SK_STORAGE)
897                 goto put_map;
898
899         if (prog->aux->max_rdonly_access > map->value_size) {
900                 err = -EACCES;
901                 goto put_map;
902         }
903
904         aux->map = map;
905         return 0;
906
907 put_map:
908         bpf_map_put_with_uref(map);
909         return err;
910 }
911
912 static void bpf_iter_detach_map(struct bpf_iter_aux_info *aux)
913 {
914         bpf_map_put_with_uref(aux->map);
915 }
916
917 static const struct seq_operations bpf_sk_storage_map_seq_ops = {
918         .start  = bpf_sk_storage_map_seq_start,
919         .next   = bpf_sk_storage_map_seq_next,
920         .stop   = bpf_sk_storage_map_seq_stop,
921         .show   = bpf_sk_storage_map_seq_show,
922 };
923
924 static const struct bpf_iter_seq_info iter_seq_info = {
925         .seq_ops                = &bpf_sk_storage_map_seq_ops,
926         .init_seq_private       = bpf_iter_init_sk_storage_map,
927         .fini_seq_private       = NULL,
928         .seq_priv_size          = sizeof(struct bpf_iter_seq_sk_storage_map_info),
929 };
930
931 static struct bpf_iter_reg bpf_sk_storage_map_reg_info = {
932         .target                 = "bpf_sk_storage_map",
933         .attach_target          = bpf_iter_attach_map,
934         .detach_target          = bpf_iter_detach_map,
935         .show_fdinfo            = bpf_iter_map_show_fdinfo,
936         .fill_link_info         = bpf_iter_map_fill_link_info,
937         .ctx_arg_info_size      = 2,
938         .ctx_arg_info           = {
939                 { offsetof(struct bpf_iter__bpf_sk_storage_map, sk),
940                   PTR_TO_BTF_ID_OR_NULL },
941                 { offsetof(struct bpf_iter__bpf_sk_storage_map, value),
942                   PTR_TO_BUF | PTR_MAYBE_NULL },
943         },
944         .seq_info               = &iter_seq_info,
945 };
946
947 static int __init bpf_sk_storage_map_iter_init(void)
948 {
949         bpf_sk_storage_map_reg_info.ctx_arg_info[0].btf_id =
950                 btf_sock_ids[BTF_SOCK_TYPE_SOCK];
951         return bpf_iter_reg_target(&bpf_sk_storage_map_reg_info);
952 }
953 late_initcall(bpf_sk_storage_map_iter_init);