ASoC: qdsp6: Suggest more generic node names
[linux-2.6-microblaze.git] / net / core / bpf_sk_storage.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2019 Facebook  */
3 #include <linux/rculist.h>
4 #include <linux/list.h>
5 #include <linux/hash.h>
6 #include <linux/types.h>
7 #include <linux/spinlock.h>
8 #include <linux/bpf.h>
9 #include <net/bpf_sk_storage.h>
10 #include <net/sock.h>
11 #include <uapi/linux/sock_diag.h>
12 #include <uapi/linux/btf.h>
13
14 static atomic_t cache_idx;
15
16 #define SK_STORAGE_CREATE_FLAG_MASK                                     \
17         (BPF_F_NO_PREALLOC | BPF_F_CLONE)
18
19 struct bucket {
20         struct hlist_head list;
21         raw_spinlock_t lock;
22 };
23
24 /* Thp map is not the primary owner of a bpf_sk_storage_elem.
25  * Instead, the sk->sk_bpf_storage is.
26  *
27  * The map (bpf_sk_storage_map) is for two purposes
28  * 1. Define the size of the "sk local storage".  It is
29  *    the map's value_size.
30  *
31  * 2. Maintain a list to keep track of all elems such
32  *    that they can be cleaned up during the map destruction.
33  *
34  * When a bpf local storage is being looked up for a
35  * particular sk,  the "bpf_map" pointer is actually used
36  * as the "key" to search in the list of elem in
37  * sk->sk_bpf_storage.
38  *
39  * Hence, consider sk->sk_bpf_storage is the mini-map
40  * with the "bpf_map" pointer as the searching key.
41  */
42 struct bpf_sk_storage_map {
43         struct bpf_map map;
44         /* Lookup elem does not require accessing the map.
45          *
46          * Updating/Deleting requires a bucket lock to
47          * link/unlink the elem from the map.  Having
48          * multiple buckets to improve contention.
49          */
50         struct bucket *buckets;
51         u32 bucket_log;
52         u16 elem_size;
53         u16 cache_idx;
54 };
55
56 struct bpf_sk_storage_data {
57         /* smap is used as the searching key when looking up
58          * from sk->sk_bpf_storage.
59          *
60          * Put it in the same cacheline as the data to minimize
61          * the number of cachelines access during the cache hit case.
62          */
63         struct bpf_sk_storage_map __rcu *smap;
64         u8 data[] __aligned(8);
65 };
66
67 /* Linked to bpf_sk_storage and bpf_sk_storage_map */
68 struct bpf_sk_storage_elem {
69         struct hlist_node map_node;     /* Linked to bpf_sk_storage_map */
70         struct hlist_node snode;        /* Linked to bpf_sk_storage */
71         struct bpf_sk_storage __rcu *sk_storage;
72         struct rcu_head rcu;
73         /* 8 bytes hole */
74         /* The data is stored in aother cacheline to minimize
75          * the number of cachelines access during a cache hit.
76          */
77         struct bpf_sk_storage_data sdata ____cacheline_aligned;
78 };
79
80 #define SELEM(_SDATA) container_of((_SDATA), struct bpf_sk_storage_elem, sdata)
81 #define SDATA(_SELEM) (&(_SELEM)->sdata)
82 #define BPF_SK_STORAGE_CACHE_SIZE       16
83
84 struct bpf_sk_storage {
85         struct bpf_sk_storage_data __rcu *cache[BPF_SK_STORAGE_CACHE_SIZE];
86         struct hlist_head list; /* List of bpf_sk_storage_elem */
87         struct sock *sk;        /* The sk that owns the the above "list" of
88                                  * bpf_sk_storage_elem.
89                                  */
90         struct rcu_head rcu;
91         raw_spinlock_t lock;    /* Protect adding/removing from the "list" */
92 };
93
94 static struct bucket *select_bucket(struct bpf_sk_storage_map *smap,
95                                     struct bpf_sk_storage_elem *selem)
96 {
97         return &smap->buckets[hash_ptr(selem, smap->bucket_log)];
98 }
99
100 static int omem_charge(struct sock *sk, unsigned int size)
101 {
102         /* same check as in sock_kmalloc() */
103         if (size <= sysctl_optmem_max &&
104             atomic_read(&sk->sk_omem_alloc) + size < sysctl_optmem_max) {
105                 atomic_add(size, &sk->sk_omem_alloc);
106                 return 0;
107         }
108
109         return -ENOMEM;
110 }
111
112 static bool selem_linked_to_sk(const struct bpf_sk_storage_elem *selem)
113 {
114         return !hlist_unhashed(&selem->snode);
115 }
116
117 static bool selem_linked_to_map(const struct bpf_sk_storage_elem *selem)
118 {
119         return !hlist_unhashed(&selem->map_node);
120 }
121
122 static struct bpf_sk_storage_elem *selem_alloc(struct bpf_sk_storage_map *smap,
123                                                struct sock *sk, void *value,
124                                                bool charge_omem)
125 {
126         struct bpf_sk_storage_elem *selem;
127
128         if (charge_omem && omem_charge(sk, smap->elem_size))
129                 return NULL;
130
131         selem = kzalloc(smap->elem_size, GFP_ATOMIC | __GFP_NOWARN);
132         if (selem) {
133                 if (value)
134                         memcpy(SDATA(selem)->data, value, smap->map.value_size);
135                 return selem;
136         }
137
138         if (charge_omem)
139                 atomic_sub(smap->elem_size, &sk->sk_omem_alloc);
140
141         return NULL;
142 }
143
144 /* sk_storage->lock must be held and selem->sk_storage == sk_storage.
145  * The caller must ensure selem->smap is still valid to be
146  * dereferenced for its smap->elem_size and smap->cache_idx.
147  */
148 static bool __selem_unlink_sk(struct bpf_sk_storage *sk_storage,
149                               struct bpf_sk_storage_elem *selem,
150                               bool uncharge_omem)
151 {
152         struct bpf_sk_storage_map *smap;
153         bool free_sk_storage;
154         struct sock *sk;
155
156         smap = rcu_dereference(SDATA(selem)->smap);
157         sk = sk_storage->sk;
158
159         /* All uncharging on sk->sk_omem_alloc must be done first.
160          * sk may be freed once the last selem is unlinked from sk_storage.
161          */
162         if (uncharge_omem)
163                 atomic_sub(smap->elem_size, &sk->sk_omem_alloc);
164
165         free_sk_storage = hlist_is_singular_node(&selem->snode,
166                                                  &sk_storage->list);
167         if (free_sk_storage) {
168                 atomic_sub(sizeof(struct bpf_sk_storage), &sk->sk_omem_alloc);
169                 sk_storage->sk = NULL;
170                 /* After this RCU_INIT, sk may be freed and cannot be used */
171                 RCU_INIT_POINTER(sk->sk_bpf_storage, NULL);
172
173                 /* sk_storage is not freed now.  sk_storage->lock is
174                  * still held and raw_spin_unlock_bh(&sk_storage->lock)
175                  * will be done by the caller.
176                  *
177                  * Although the unlock will be done under
178                  * rcu_read_lock(),  it is more intutivie to
179                  * read if kfree_rcu(sk_storage, rcu) is done
180                  * after the raw_spin_unlock_bh(&sk_storage->lock).
181                  *
182                  * Hence, a "bool free_sk_storage" is returned
183                  * to the caller which then calls the kfree_rcu()
184                  * after unlock.
185                  */
186         }
187         hlist_del_init_rcu(&selem->snode);
188         if (rcu_access_pointer(sk_storage->cache[smap->cache_idx]) ==
189             SDATA(selem))
190                 RCU_INIT_POINTER(sk_storage->cache[smap->cache_idx], NULL);
191
192         kfree_rcu(selem, rcu);
193
194         return free_sk_storage;
195 }
196
197 static void selem_unlink_sk(struct bpf_sk_storage_elem *selem)
198 {
199         struct bpf_sk_storage *sk_storage;
200         bool free_sk_storage = false;
201
202         if (unlikely(!selem_linked_to_sk(selem)))
203                 /* selem has already been unlinked from sk */
204                 return;
205
206         sk_storage = rcu_dereference(selem->sk_storage);
207         raw_spin_lock_bh(&sk_storage->lock);
208         if (likely(selem_linked_to_sk(selem)))
209                 free_sk_storage = __selem_unlink_sk(sk_storage, selem, true);
210         raw_spin_unlock_bh(&sk_storage->lock);
211
212         if (free_sk_storage)
213                 kfree_rcu(sk_storage, rcu);
214 }
215
216 static void __selem_link_sk(struct bpf_sk_storage *sk_storage,
217                             struct bpf_sk_storage_elem *selem)
218 {
219         RCU_INIT_POINTER(selem->sk_storage, sk_storage);
220         hlist_add_head(&selem->snode, &sk_storage->list);
221 }
222
223 static void selem_unlink_map(struct bpf_sk_storage_elem *selem)
224 {
225         struct bpf_sk_storage_map *smap;
226         struct bucket *b;
227
228         if (unlikely(!selem_linked_to_map(selem)))
229                 /* selem has already be unlinked from smap */
230                 return;
231
232         smap = rcu_dereference(SDATA(selem)->smap);
233         b = select_bucket(smap, selem);
234         raw_spin_lock_bh(&b->lock);
235         if (likely(selem_linked_to_map(selem)))
236                 hlist_del_init_rcu(&selem->map_node);
237         raw_spin_unlock_bh(&b->lock);
238 }
239
240 static void selem_link_map(struct bpf_sk_storage_map *smap,
241                            struct bpf_sk_storage_elem *selem)
242 {
243         struct bucket *b = select_bucket(smap, selem);
244
245         raw_spin_lock_bh(&b->lock);
246         RCU_INIT_POINTER(SDATA(selem)->smap, smap);
247         hlist_add_head_rcu(&selem->map_node, &b->list);
248         raw_spin_unlock_bh(&b->lock);
249 }
250
251 static void selem_unlink(struct bpf_sk_storage_elem *selem)
252 {
253         /* Always unlink from map before unlinking from sk_storage
254          * because selem will be freed after successfully unlinked from
255          * the sk_storage.
256          */
257         selem_unlink_map(selem);
258         selem_unlink_sk(selem);
259 }
260
261 static struct bpf_sk_storage_data *
262 __sk_storage_lookup(struct bpf_sk_storage *sk_storage,
263                     struct bpf_sk_storage_map *smap,
264                     bool cacheit_lockit)
265 {
266         struct bpf_sk_storage_data *sdata;
267         struct bpf_sk_storage_elem *selem;
268
269         /* Fast path (cache hit) */
270         sdata = rcu_dereference(sk_storage->cache[smap->cache_idx]);
271         if (sdata && rcu_access_pointer(sdata->smap) == smap)
272                 return sdata;
273
274         /* Slow path (cache miss) */
275         hlist_for_each_entry_rcu(selem, &sk_storage->list, snode)
276                 if (rcu_access_pointer(SDATA(selem)->smap) == smap)
277                         break;
278
279         if (!selem)
280                 return NULL;
281
282         sdata = SDATA(selem);
283         if (cacheit_lockit) {
284                 /* spinlock is needed to avoid racing with the
285                  * parallel delete.  Otherwise, publishing an already
286                  * deleted sdata to the cache will become a use-after-free
287                  * problem in the next __sk_storage_lookup().
288                  */
289                 raw_spin_lock_bh(&sk_storage->lock);
290                 if (selem_linked_to_sk(selem))
291                         rcu_assign_pointer(sk_storage->cache[smap->cache_idx],
292                                            sdata);
293                 raw_spin_unlock_bh(&sk_storage->lock);
294         }
295
296         return sdata;
297 }
298
299 static struct bpf_sk_storage_data *
300 sk_storage_lookup(struct sock *sk, struct bpf_map *map, bool cacheit_lockit)
301 {
302         struct bpf_sk_storage *sk_storage;
303         struct bpf_sk_storage_map *smap;
304
305         sk_storage = rcu_dereference(sk->sk_bpf_storage);
306         if (!sk_storage)
307                 return NULL;
308
309         smap = (struct bpf_sk_storage_map *)map;
310         return __sk_storage_lookup(sk_storage, smap, cacheit_lockit);
311 }
312
313 static int check_flags(const struct bpf_sk_storage_data *old_sdata,
314                        u64 map_flags)
315 {
316         if (old_sdata && (map_flags & ~BPF_F_LOCK) == BPF_NOEXIST)
317                 /* elem already exists */
318                 return -EEXIST;
319
320         if (!old_sdata && (map_flags & ~BPF_F_LOCK) == BPF_EXIST)
321                 /* elem doesn't exist, cannot update it */
322                 return -ENOENT;
323
324         return 0;
325 }
326
327 static int sk_storage_alloc(struct sock *sk,
328                             struct bpf_sk_storage_map *smap,
329                             struct bpf_sk_storage_elem *first_selem)
330 {
331         struct bpf_sk_storage *prev_sk_storage, *sk_storage;
332         int err;
333
334         err = omem_charge(sk, sizeof(*sk_storage));
335         if (err)
336                 return err;
337
338         sk_storage = kzalloc(sizeof(*sk_storage), GFP_ATOMIC | __GFP_NOWARN);
339         if (!sk_storage) {
340                 err = -ENOMEM;
341                 goto uncharge;
342         }
343         INIT_HLIST_HEAD(&sk_storage->list);
344         raw_spin_lock_init(&sk_storage->lock);
345         sk_storage->sk = sk;
346
347         __selem_link_sk(sk_storage, first_selem);
348         selem_link_map(smap, first_selem);
349         /* Publish sk_storage to sk.  sk->sk_lock cannot be acquired.
350          * Hence, atomic ops is used to set sk->sk_bpf_storage
351          * from NULL to the newly allocated sk_storage ptr.
352          *
353          * From now on, the sk->sk_bpf_storage pointer is protected
354          * by the sk_storage->lock.  Hence,  when freeing
355          * the sk->sk_bpf_storage, the sk_storage->lock must
356          * be held before setting sk->sk_bpf_storage to NULL.
357          */
358         prev_sk_storage = cmpxchg((struct bpf_sk_storage **)&sk->sk_bpf_storage,
359                                   NULL, sk_storage);
360         if (unlikely(prev_sk_storage)) {
361                 selem_unlink_map(first_selem);
362                 err = -EAGAIN;
363                 goto uncharge;
364
365                 /* Note that even first_selem was linked to smap's
366                  * bucket->list, first_selem can be freed immediately
367                  * (instead of kfree_rcu) because
368                  * bpf_sk_storage_map_free() does a
369                  * synchronize_rcu() before walking the bucket->list.
370                  * Hence, no one is accessing selem from the
371                  * bucket->list under rcu_read_lock().
372                  */
373         }
374
375         return 0;
376
377 uncharge:
378         kfree(sk_storage);
379         atomic_sub(sizeof(*sk_storage), &sk->sk_omem_alloc);
380         return err;
381 }
382
383 /* sk cannot be going away because it is linking new elem
384  * to sk->sk_bpf_storage. (i.e. sk->sk_refcnt cannot be 0).
385  * Otherwise, it will become a leak (and other memory issues
386  * during map destruction).
387  */
388 static struct bpf_sk_storage_data *sk_storage_update(struct sock *sk,
389                                                      struct bpf_map *map,
390                                                      void *value,
391                                                      u64 map_flags)
392 {
393         struct bpf_sk_storage_data *old_sdata = NULL;
394         struct bpf_sk_storage_elem *selem;
395         struct bpf_sk_storage *sk_storage;
396         struct bpf_sk_storage_map *smap;
397         int err;
398
399         /* BPF_EXIST and BPF_NOEXIST cannot be both set */
400         if (unlikely((map_flags & ~BPF_F_LOCK) > BPF_EXIST) ||
401             /* BPF_F_LOCK can only be used in a value with spin_lock */
402             unlikely((map_flags & BPF_F_LOCK) && !map_value_has_spin_lock(map)))
403                 return ERR_PTR(-EINVAL);
404
405         smap = (struct bpf_sk_storage_map *)map;
406         sk_storage = rcu_dereference(sk->sk_bpf_storage);
407         if (!sk_storage || hlist_empty(&sk_storage->list)) {
408                 /* Very first elem for this sk */
409                 err = check_flags(NULL, map_flags);
410                 if (err)
411                         return ERR_PTR(err);
412
413                 selem = selem_alloc(smap, sk, value, true);
414                 if (!selem)
415                         return ERR_PTR(-ENOMEM);
416
417                 err = sk_storage_alloc(sk, smap, selem);
418                 if (err) {
419                         kfree(selem);
420                         atomic_sub(smap->elem_size, &sk->sk_omem_alloc);
421                         return ERR_PTR(err);
422                 }
423
424                 return SDATA(selem);
425         }
426
427         if ((map_flags & BPF_F_LOCK) && !(map_flags & BPF_NOEXIST)) {
428                 /* Hoping to find an old_sdata to do inline update
429                  * such that it can avoid taking the sk_storage->lock
430                  * and changing the lists.
431                  */
432                 old_sdata = __sk_storage_lookup(sk_storage, smap, false);
433                 err = check_flags(old_sdata, map_flags);
434                 if (err)
435                         return ERR_PTR(err);
436                 if (old_sdata && selem_linked_to_sk(SELEM(old_sdata))) {
437                         copy_map_value_locked(map, old_sdata->data,
438                                               value, false);
439                         return old_sdata;
440                 }
441         }
442
443         raw_spin_lock_bh(&sk_storage->lock);
444
445         /* Recheck sk_storage->list under sk_storage->lock */
446         if (unlikely(hlist_empty(&sk_storage->list))) {
447                 /* A parallel del is happening and sk_storage is going
448                  * away.  It has just been checked before, so very
449                  * unlikely.  Return instead of retry to keep things
450                  * simple.
451                  */
452                 err = -EAGAIN;
453                 goto unlock_err;
454         }
455
456         old_sdata = __sk_storage_lookup(sk_storage, smap, false);
457         err = check_flags(old_sdata, map_flags);
458         if (err)
459                 goto unlock_err;
460
461         if (old_sdata && (map_flags & BPF_F_LOCK)) {
462                 copy_map_value_locked(map, old_sdata->data, value, false);
463                 selem = SELEM(old_sdata);
464                 goto unlock;
465         }
466
467         /* sk_storage->lock is held.  Hence, we are sure
468          * we can unlink and uncharge the old_sdata successfully
469          * later.  Hence, instead of charging the new selem now
470          * and then uncharge the old selem later (which may cause
471          * a potential but unnecessary charge failure),  avoid taking
472          * a charge at all here (the "!old_sdata" check) and the
473          * old_sdata will not be uncharged later during __selem_unlink_sk().
474          */
475         selem = selem_alloc(smap, sk, value, !old_sdata);
476         if (!selem) {
477                 err = -ENOMEM;
478                 goto unlock_err;
479         }
480
481         /* First, link the new selem to the map */
482         selem_link_map(smap, selem);
483
484         /* Second, link (and publish) the new selem to sk_storage */
485         __selem_link_sk(sk_storage, selem);
486
487         /* Third, remove old selem, SELEM(old_sdata) */
488         if (old_sdata) {
489                 selem_unlink_map(SELEM(old_sdata));
490                 __selem_unlink_sk(sk_storage, SELEM(old_sdata), false);
491         }
492
493 unlock:
494         raw_spin_unlock_bh(&sk_storage->lock);
495         return SDATA(selem);
496
497 unlock_err:
498         raw_spin_unlock_bh(&sk_storage->lock);
499         return ERR_PTR(err);
500 }
501
502 static int sk_storage_delete(struct sock *sk, struct bpf_map *map)
503 {
504         struct bpf_sk_storage_data *sdata;
505
506         sdata = sk_storage_lookup(sk, map, false);
507         if (!sdata)
508                 return -ENOENT;
509
510         selem_unlink(SELEM(sdata));
511
512         return 0;
513 }
514
515 /* Called by __sk_destruct() & bpf_sk_storage_clone() */
516 void bpf_sk_storage_free(struct sock *sk)
517 {
518         struct bpf_sk_storage_elem *selem;
519         struct bpf_sk_storage *sk_storage;
520         bool free_sk_storage = false;
521         struct hlist_node *n;
522
523         rcu_read_lock();
524         sk_storage = rcu_dereference(sk->sk_bpf_storage);
525         if (!sk_storage) {
526                 rcu_read_unlock();
527                 return;
528         }
529
530         /* Netiher the bpf_prog nor the bpf-map's syscall
531          * could be modifying the sk_storage->list now.
532          * Thus, no elem can be added-to or deleted-from the
533          * sk_storage->list by the bpf_prog or by the bpf-map's syscall.
534          *
535          * It is racing with bpf_sk_storage_map_free() alone
536          * when unlinking elem from the sk_storage->list and
537          * the map's bucket->list.
538          */
539         raw_spin_lock_bh(&sk_storage->lock);
540         hlist_for_each_entry_safe(selem, n, &sk_storage->list, snode) {
541                 /* Always unlink from map before unlinking from
542                  * sk_storage.
543                  */
544                 selem_unlink_map(selem);
545                 free_sk_storage = __selem_unlink_sk(sk_storage, selem, true);
546         }
547         raw_spin_unlock_bh(&sk_storage->lock);
548         rcu_read_unlock();
549
550         if (free_sk_storage)
551                 kfree_rcu(sk_storage, rcu);
552 }
553
554 static void bpf_sk_storage_map_free(struct bpf_map *map)
555 {
556         struct bpf_sk_storage_elem *selem;
557         struct bpf_sk_storage_map *smap;
558         struct bucket *b;
559         unsigned int i;
560
561         smap = (struct bpf_sk_storage_map *)map;
562
563         /* Note that this map might be concurrently cloned from
564          * bpf_sk_storage_clone. Wait for any existing bpf_sk_storage_clone
565          * RCU read section to finish before proceeding. New RCU
566          * read sections should be prevented via bpf_map_inc_not_zero.
567          */
568         synchronize_rcu();
569
570         /* bpf prog and the userspace can no longer access this map
571          * now.  No new selem (of this map) can be added
572          * to the sk->sk_bpf_storage or to the map bucket's list.
573          *
574          * The elem of this map can be cleaned up here
575          * or
576          * by bpf_sk_storage_free() during __sk_destruct().
577          */
578         for (i = 0; i < (1U << smap->bucket_log); i++) {
579                 b = &smap->buckets[i];
580
581                 rcu_read_lock();
582                 /* No one is adding to b->list now */
583                 while ((selem = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(&b->list)),
584                                                  struct bpf_sk_storage_elem,
585                                                  map_node))) {
586                         selem_unlink(selem);
587                         cond_resched_rcu();
588                 }
589                 rcu_read_unlock();
590         }
591
592         /* bpf_sk_storage_free() may still need to access the map.
593          * e.g. bpf_sk_storage_free() has unlinked selem from the map
594          * which then made the above while((selem = ...)) loop
595          * exited immediately.
596          *
597          * However, the bpf_sk_storage_free() still needs to access
598          * the smap->elem_size to do the uncharging in
599          * __selem_unlink_sk().
600          *
601          * Hence, wait another rcu grace period for the
602          * bpf_sk_storage_free() to finish.
603          */
604         synchronize_rcu();
605
606         kvfree(smap->buckets);
607         kfree(map);
608 }
609
610 /* U16_MAX is much more than enough for sk local storage
611  * considering a tcp_sock is ~2k.
612  */
613 #define MAX_VALUE_SIZE                                                  \
614         min_t(u32,                                                      \
615               (KMALLOC_MAX_SIZE - MAX_BPF_STACK - sizeof(struct bpf_sk_storage_elem)), \
616               (U16_MAX - sizeof(struct bpf_sk_storage_elem)))
617
618 static int bpf_sk_storage_map_alloc_check(union bpf_attr *attr)
619 {
620         if (attr->map_flags & ~SK_STORAGE_CREATE_FLAG_MASK ||
621             !(attr->map_flags & BPF_F_NO_PREALLOC) ||
622             attr->max_entries ||
623             attr->key_size != sizeof(int) || !attr->value_size ||
624             /* Enforce BTF for userspace sk dumping */
625             !attr->btf_key_type_id || !attr->btf_value_type_id)
626                 return -EINVAL;
627
628         if (!capable(CAP_SYS_ADMIN))
629                 return -EPERM;
630
631         if (attr->value_size > MAX_VALUE_SIZE)
632                 return -E2BIG;
633
634         return 0;
635 }
636
637 static struct bpf_map *bpf_sk_storage_map_alloc(union bpf_attr *attr)
638 {
639         struct bpf_sk_storage_map *smap;
640         unsigned int i;
641         u32 nbuckets;
642         u64 cost;
643         int ret;
644
645         smap = kzalloc(sizeof(*smap), GFP_USER | __GFP_NOWARN);
646         if (!smap)
647                 return ERR_PTR(-ENOMEM);
648         bpf_map_init_from_attr(&smap->map, attr);
649
650         nbuckets = roundup_pow_of_two(num_possible_cpus());
651         /* Use at least 2 buckets, select_bucket() is undefined behavior with 1 bucket */
652         nbuckets = max_t(u32, 2, nbuckets);
653         smap->bucket_log = ilog2(nbuckets);
654         cost = sizeof(*smap->buckets) * nbuckets + sizeof(*smap);
655
656         ret = bpf_map_charge_init(&smap->map.memory, cost);
657         if (ret < 0) {
658                 kfree(smap);
659                 return ERR_PTR(ret);
660         }
661
662         smap->buckets = kvcalloc(sizeof(*smap->buckets), nbuckets,
663                                  GFP_USER | __GFP_NOWARN);
664         if (!smap->buckets) {
665                 bpf_map_charge_finish(&smap->map.memory);
666                 kfree(smap);
667                 return ERR_PTR(-ENOMEM);
668         }
669
670         for (i = 0; i < nbuckets; i++) {
671                 INIT_HLIST_HEAD(&smap->buckets[i].list);
672                 raw_spin_lock_init(&smap->buckets[i].lock);
673         }
674
675         smap->elem_size = sizeof(struct bpf_sk_storage_elem) + attr->value_size;
676         smap->cache_idx = (unsigned int)atomic_inc_return(&cache_idx) %
677                 BPF_SK_STORAGE_CACHE_SIZE;
678
679         return &smap->map;
680 }
681
682 static int notsupp_get_next_key(struct bpf_map *map, void *key,
683                                 void *next_key)
684 {
685         return -ENOTSUPP;
686 }
687
688 static int bpf_sk_storage_map_check_btf(const struct bpf_map *map,
689                                         const struct btf *btf,
690                                         const struct btf_type *key_type,
691                                         const struct btf_type *value_type)
692 {
693         u32 int_data;
694
695         if (BTF_INFO_KIND(key_type->info) != BTF_KIND_INT)
696                 return -EINVAL;
697
698         int_data = *(u32 *)(key_type + 1);
699         if (BTF_INT_BITS(int_data) != 32 || BTF_INT_OFFSET(int_data))
700                 return -EINVAL;
701
702         return 0;
703 }
704
705 static void *bpf_fd_sk_storage_lookup_elem(struct bpf_map *map, void *key)
706 {
707         struct bpf_sk_storage_data *sdata;
708         struct socket *sock;
709         int fd, err;
710
711         fd = *(int *)key;
712         sock = sockfd_lookup(fd, &err);
713         if (sock) {
714                 sdata = sk_storage_lookup(sock->sk, map, true);
715                 sockfd_put(sock);
716                 return sdata ? sdata->data : NULL;
717         }
718
719         return ERR_PTR(err);
720 }
721
722 static int bpf_fd_sk_storage_update_elem(struct bpf_map *map, void *key,
723                                          void *value, u64 map_flags)
724 {
725         struct bpf_sk_storage_data *sdata;
726         struct socket *sock;
727         int fd, err;
728
729         fd = *(int *)key;
730         sock = sockfd_lookup(fd, &err);
731         if (sock) {
732                 sdata = sk_storage_update(sock->sk, map, value, map_flags);
733                 sockfd_put(sock);
734                 return PTR_ERR_OR_ZERO(sdata);
735         }
736
737         return err;
738 }
739
740 static int bpf_fd_sk_storage_delete_elem(struct bpf_map *map, void *key)
741 {
742         struct socket *sock;
743         int fd, err;
744
745         fd = *(int *)key;
746         sock = sockfd_lookup(fd, &err);
747         if (sock) {
748                 err = sk_storage_delete(sock->sk, map);
749                 sockfd_put(sock);
750                 return err;
751         }
752
753         return err;
754 }
755
756 static struct bpf_sk_storage_elem *
757 bpf_sk_storage_clone_elem(struct sock *newsk,
758                           struct bpf_sk_storage_map *smap,
759                           struct bpf_sk_storage_elem *selem)
760 {
761         struct bpf_sk_storage_elem *copy_selem;
762
763         copy_selem = selem_alloc(smap, newsk, NULL, true);
764         if (!copy_selem)
765                 return NULL;
766
767         if (map_value_has_spin_lock(&smap->map))
768                 copy_map_value_locked(&smap->map, SDATA(copy_selem)->data,
769                                       SDATA(selem)->data, true);
770         else
771                 copy_map_value(&smap->map, SDATA(copy_selem)->data,
772                                SDATA(selem)->data);
773
774         return copy_selem;
775 }
776
777 int bpf_sk_storage_clone(const struct sock *sk, struct sock *newsk)
778 {
779         struct bpf_sk_storage *new_sk_storage = NULL;
780         struct bpf_sk_storage *sk_storage;
781         struct bpf_sk_storage_elem *selem;
782         int ret = 0;
783
784         RCU_INIT_POINTER(newsk->sk_bpf_storage, NULL);
785
786         rcu_read_lock();
787         sk_storage = rcu_dereference(sk->sk_bpf_storage);
788
789         if (!sk_storage || hlist_empty(&sk_storage->list))
790                 goto out;
791
792         hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
793                 struct bpf_sk_storage_elem *copy_selem;
794                 struct bpf_sk_storage_map *smap;
795                 struct bpf_map *map;
796
797                 smap = rcu_dereference(SDATA(selem)->smap);
798                 if (!(smap->map.map_flags & BPF_F_CLONE))
799                         continue;
800
801                 /* Note that for lockless listeners adding new element
802                  * here can race with cleanup in bpf_sk_storage_map_free.
803                  * Try to grab map refcnt to make sure that it's still
804                  * alive and prevent concurrent removal.
805                  */
806                 map = bpf_map_inc_not_zero(&smap->map);
807                 if (IS_ERR(map))
808                         continue;
809
810                 copy_selem = bpf_sk_storage_clone_elem(newsk, smap, selem);
811                 if (!copy_selem) {
812                         ret = -ENOMEM;
813                         bpf_map_put(map);
814                         goto out;
815                 }
816
817                 if (new_sk_storage) {
818                         selem_link_map(smap, copy_selem);
819                         __selem_link_sk(new_sk_storage, copy_selem);
820                 } else {
821                         ret = sk_storage_alloc(newsk, smap, copy_selem);
822                         if (ret) {
823                                 kfree(copy_selem);
824                                 atomic_sub(smap->elem_size,
825                                            &newsk->sk_omem_alloc);
826                                 bpf_map_put(map);
827                                 goto out;
828                         }
829
830                         new_sk_storage = rcu_dereference(copy_selem->sk_storage);
831                 }
832                 bpf_map_put(map);
833         }
834
835 out:
836         rcu_read_unlock();
837
838         /* In case of an error, don't free anything explicitly here, the
839          * caller is responsible to call bpf_sk_storage_free.
840          */
841
842         return ret;
843 }
844
845 BPF_CALL_4(bpf_sk_storage_get, struct bpf_map *, map, struct sock *, sk,
846            void *, value, u64, flags)
847 {
848         struct bpf_sk_storage_data *sdata;
849
850         if (flags > BPF_SK_STORAGE_GET_F_CREATE)
851                 return (unsigned long)NULL;
852
853         sdata = sk_storage_lookup(sk, map, true);
854         if (sdata)
855                 return (unsigned long)sdata->data;
856
857         if (flags == BPF_SK_STORAGE_GET_F_CREATE &&
858             /* Cannot add new elem to a going away sk.
859              * Otherwise, the new elem may become a leak
860              * (and also other memory issues during map
861              *  destruction).
862              */
863             refcount_inc_not_zero(&sk->sk_refcnt)) {
864                 sdata = sk_storage_update(sk, map, value, BPF_NOEXIST);
865                 /* sk must be a fullsock (guaranteed by verifier),
866                  * so sock_gen_put() is unnecessary.
867                  */
868                 sock_put(sk);
869                 return IS_ERR(sdata) ?
870                         (unsigned long)NULL : (unsigned long)sdata->data;
871         }
872
873         return (unsigned long)NULL;
874 }
875
876 BPF_CALL_2(bpf_sk_storage_delete, struct bpf_map *, map, struct sock *, sk)
877 {
878         if (refcount_inc_not_zero(&sk->sk_refcnt)) {
879                 int err;
880
881                 err = sk_storage_delete(sk, map);
882                 sock_put(sk);
883                 return err;
884         }
885
886         return -ENOENT;
887 }
888
889 const struct bpf_map_ops sk_storage_map_ops = {
890         .map_alloc_check = bpf_sk_storage_map_alloc_check,
891         .map_alloc = bpf_sk_storage_map_alloc,
892         .map_free = bpf_sk_storage_map_free,
893         .map_get_next_key = notsupp_get_next_key,
894         .map_lookup_elem = bpf_fd_sk_storage_lookup_elem,
895         .map_update_elem = bpf_fd_sk_storage_update_elem,
896         .map_delete_elem = bpf_fd_sk_storage_delete_elem,
897         .map_check_btf = bpf_sk_storage_map_check_btf,
898 };
899
900 const struct bpf_func_proto bpf_sk_storage_get_proto = {
901         .func           = bpf_sk_storage_get,
902         .gpl_only       = false,
903         .ret_type       = RET_PTR_TO_MAP_VALUE_OR_NULL,
904         .arg1_type      = ARG_CONST_MAP_PTR,
905         .arg2_type      = ARG_PTR_TO_SOCKET,
906         .arg3_type      = ARG_PTR_TO_MAP_VALUE_OR_NULL,
907         .arg4_type      = ARG_ANYTHING,
908 };
909
910 const struct bpf_func_proto bpf_sk_storage_delete_proto = {
911         .func           = bpf_sk_storage_delete,
912         .gpl_only       = false,
913         .ret_type       = RET_INTEGER,
914         .arg1_type      = ARG_CONST_MAP_PTR,
915         .arg2_type      = ARG_PTR_TO_SOCKET,
916 };
917
918 struct bpf_sk_storage_diag {
919         u32 nr_maps;
920         struct bpf_map *maps[];
921 };
922
923 /* The reply will be like:
924  * INET_DIAG_BPF_SK_STORAGES (nla_nest)
925  *      SK_DIAG_BPF_STORAGE (nla_nest)
926  *              SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
927  *              SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
928  *      SK_DIAG_BPF_STORAGE (nla_nest)
929  *              SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
930  *              SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
931  *      ....
932  */
933 static int nla_value_size(u32 value_size)
934 {
935         /* SK_DIAG_BPF_STORAGE (nla_nest)
936          *      SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
937          *      SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
938          */
939         return nla_total_size(0) + nla_total_size(sizeof(u32)) +
940                 nla_total_size_64bit(value_size);
941 }
942
943 void bpf_sk_storage_diag_free(struct bpf_sk_storage_diag *diag)
944 {
945         u32 i;
946
947         if (!diag)
948                 return;
949
950         for (i = 0; i < diag->nr_maps; i++)
951                 bpf_map_put(diag->maps[i]);
952
953         kfree(diag);
954 }
955 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_free);
956
957 static bool diag_check_dup(const struct bpf_sk_storage_diag *diag,
958                            const struct bpf_map *map)
959 {
960         u32 i;
961
962         for (i = 0; i < diag->nr_maps; i++) {
963                 if (diag->maps[i] == map)
964                         return true;
965         }
966
967         return false;
968 }
969
970 struct bpf_sk_storage_diag *
971 bpf_sk_storage_diag_alloc(const struct nlattr *nla_stgs)
972 {
973         struct bpf_sk_storage_diag *diag;
974         struct nlattr *nla;
975         u32 nr_maps = 0;
976         int rem, err;
977
978         /* bpf_sk_storage_map is currently limited to CAP_SYS_ADMIN as
979          * the map_alloc_check() side also does.
980          */
981         if (!capable(CAP_SYS_ADMIN))
982                 return ERR_PTR(-EPERM);
983
984         nla_for_each_nested(nla, nla_stgs, rem) {
985                 if (nla_type(nla) == SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
986                         nr_maps++;
987         }
988
989         diag = kzalloc(sizeof(*diag) + sizeof(diag->maps[0]) * nr_maps,
990                        GFP_KERNEL);
991         if (!diag)
992                 return ERR_PTR(-ENOMEM);
993
994         nla_for_each_nested(nla, nla_stgs, rem) {
995                 struct bpf_map *map;
996                 int map_fd;
997
998                 if (nla_type(nla) != SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
999                         continue;
1000
1001                 map_fd = nla_get_u32(nla);
1002                 map = bpf_map_get(map_fd);
1003                 if (IS_ERR(map)) {
1004                         err = PTR_ERR(map);
1005                         goto err_free;
1006                 }
1007                 if (map->map_type != BPF_MAP_TYPE_SK_STORAGE) {
1008                         bpf_map_put(map);
1009                         err = -EINVAL;
1010                         goto err_free;
1011                 }
1012                 if (diag_check_dup(diag, map)) {
1013                         bpf_map_put(map);
1014                         err = -EEXIST;
1015                         goto err_free;
1016                 }
1017                 diag->maps[diag->nr_maps++] = map;
1018         }
1019
1020         return diag;
1021
1022 err_free:
1023         bpf_sk_storage_diag_free(diag);
1024         return ERR_PTR(err);
1025 }
1026 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_alloc);
1027
1028 static int diag_get(struct bpf_sk_storage_data *sdata, struct sk_buff *skb)
1029 {
1030         struct nlattr *nla_stg, *nla_value;
1031         struct bpf_sk_storage_map *smap;
1032
1033         /* It cannot exceed max nlattr's payload */
1034         BUILD_BUG_ON(U16_MAX - NLA_HDRLEN < MAX_VALUE_SIZE);
1035
1036         nla_stg = nla_nest_start(skb, SK_DIAG_BPF_STORAGE);
1037         if (!nla_stg)
1038                 return -EMSGSIZE;
1039
1040         smap = rcu_dereference(sdata->smap);
1041         if (nla_put_u32(skb, SK_DIAG_BPF_STORAGE_MAP_ID, smap->map.id))
1042                 goto errout;
1043
1044         nla_value = nla_reserve_64bit(skb, SK_DIAG_BPF_STORAGE_MAP_VALUE,
1045                                       smap->map.value_size,
1046                                       SK_DIAG_BPF_STORAGE_PAD);
1047         if (!nla_value)
1048                 goto errout;
1049
1050         if (map_value_has_spin_lock(&smap->map))
1051                 copy_map_value_locked(&smap->map, nla_data(nla_value),
1052                                       sdata->data, true);
1053         else
1054                 copy_map_value(&smap->map, nla_data(nla_value), sdata->data);
1055
1056         nla_nest_end(skb, nla_stg);
1057         return 0;
1058
1059 errout:
1060         nla_nest_cancel(skb, nla_stg);
1061         return -EMSGSIZE;
1062 }
1063
1064 static int bpf_sk_storage_diag_put_all(struct sock *sk, struct sk_buff *skb,
1065                                        int stg_array_type,
1066                                        unsigned int *res_diag_size)
1067 {
1068         /* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
1069         unsigned int diag_size = nla_total_size(0);
1070         struct bpf_sk_storage *sk_storage;
1071         struct bpf_sk_storage_elem *selem;
1072         struct bpf_sk_storage_map *smap;
1073         struct nlattr *nla_stgs;
1074         unsigned int saved_len;
1075         int err = 0;
1076
1077         rcu_read_lock();
1078
1079         sk_storage = rcu_dereference(sk->sk_bpf_storage);
1080         if (!sk_storage || hlist_empty(&sk_storage->list)) {
1081                 rcu_read_unlock();
1082                 return 0;
1083         }
1084
1085         nla_stgs = nla_nest_start(skb, stg_array_type);
1086         if (!nla_stgs)
1087                 /* Continue to learn diag_size */
1088                 err = -EMSGSIZE;
1089
1090         saved_len = skb->len;
1091         hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
1092                 smap = rcu_dereference(SDATA(selem)->smap);
1093                 diag_size += nla_value_size(smap->map.value_size);
1094
1095                 if (nla_stgs && diag_get(SDATA(selem), skb))
1096                         /* Continue to learn diag_size */
1097                         err = -EMSGSIZE;
1098         }
1099
1100         rcu_read_unlock();
1101
1102         if (nla_stgs) {
1103                 if (saved_len == skb->len)
1104                         nla_nest_cancel(skb, nla_stgs);
1105                 else
1106                         nla_nest_end(skb, nla_stgs);
1107         }
1108
1109         if (diag_size == nla_total_size(0)) {
1110                 *res_diag_size = 0;
1111                 return 0;
1112         }
1113
1114         *res_diag_size = diag_size;
1115         return err;
1116 }
1117
1118 int bpf_sk_storage_diag_put(struct bpf_sk_storage_diag *diag,
1119                             struct sock *sk, struct sk_buff *skb,
1120                             int stg_array_type,
1121                             unsigned int *res_diag_size)
1122 {
1123         /* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
1124         unsigned int diag_size = nla_total_size(0);
1125         struct bpf_sk_storage *sk_storage;
1126         struct bpf_sk_storage_data *sdata;
1127         struct nlattr *nla_stgs;
1128         unsigned int saved_len;
1129         int err = 0;
1130         u32 i;
1131
1132         *res_diag_size = 0;
1133
1134         /* No map has been specified.  Dump all. */
1135         if (!diag->nr_maps)
1136                 return bpf_sk_storage_diag_put_all(sk, skb, stg_array_type,
1137                                                    res_diag_size);
1138
1139         rcu_read_lock();
1140         sk_storage = rcu_dereference(sk->sk_bpf_storage);
1141         if (!sk_storage || hlist_empty(&sk_storage->list)) {
1142                 rcu_read_unlock();
1143                 return 0;
1144         }
1145
1146         nla_stgs = nla_nest_start(skb, stg_array_type);
1147         if (!nla_stgs)
1148                 /* Continue to learn diag_size */
1149                 err = -EMSGSIZE;
1150
1151         saved_len = skb->len;
1152         for (i = 0; i < diag->nr_maps; i++) {
1153                 sdata = __sk_storage_lookup(sk_storage,
1154                                 (struct bpf_sk_storage_map *)diag->maps[i],
1155                                 false);
1156
1157                 if (!sdata)
1158                         continue;
1159
1160                 diag_size += nla_value_size(diag->maps[i]->value_size);
1161
1162                 if (nla_stgs && diag_get(sdata, skb))
1163                         /* Continue to learn diag_size */
1164                         err = -EMSGSIZE;
1165         }
1166         rcu_read_unlock();
1167
1168         if (nla_stgs) {
1169                 if (saved_len == skb->len)
1170                         nla_nest_cancel(skb, nla_stgs);
1171                 else
1172                         nla_nest_end(skb, nla_stgs);
1173         }
1174
1175         if (diag_size == nla_total_size(0)) {
1176                 *res_diag_size = 0;
1177                 return 0;
1178         }
1179
1180         *res_diag_size = diag_size;
1181         return err;
1182 }
1183 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_put);