bpf, sockmap: Allow skipping sk_skb parser program
[linux-2.6-microblaze.git] / net / core / sock_map.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3
4 #include <linux/bpf.h>
5 #include <linux/btf_ids.h>
6 #include <linux/filter.h>
7 #include <linux/errno.h>
8 #include <linux/file.h>
9 #include <linux/net.h>
10 #include <linux/workqueue.h>
11 #include <linux/skmsg.h>
12 #include <linux/list.h>
13 #include <linux/jhash.h>
14 #include <linux/sock_diag.h>
15 #include <net/udp.h>
16
17 struct bpf_stab {
18         struct bpf_map map;
19         struct sock **sks;
20         struct sk_psock_progs progs;
21         raw_spinlock_t lock;
22 };
23
24 #define SOCK_CREATE_FLAG_MASK                           \
25         (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
26
27 static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
28 {
29         struct bpf_stab *stab;
30         u64 cost;
31         int err;
32
33         if (!capable(CAP_NET_ADMIN))
34                 return ERR_PTR(-EPERM);
35         if (attr->max_entries == 0 ||
36             attr->key_size    != 4 ||
37             (attr->value_size != sizeof(u32) &&
38              attr->value_size != sizeof(u64)) ||
39             attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
40                 return ERR_PTR(-EINVAL);
41
42         stab = kzalloc(sizeof(*stab), GFP_USER);
43         if (!stab)
44                 return ERR_PTR(-ENOMEM);
45
46         bpf_map_init_from_attr(&stab->map, attr);
47         raw_spin_lock_init(&stab->lock);
48
49         /* Make sure page count doesn't overflow. */
50         cost = (u64) stab->map.max_entries * sizeof(struct sock *);
51         err = bpf_map_charge_init(&stab->map.memory, cost);
52         if (err)
53                 goto free_stab;
54
55         stab->sks = bpf_map_area_alloc(stab->map.max_entries *
56                                        sizeof(struct sock *),
57                                        stab->map.numa_node);
58         if (stab->sks)
59                 return &stab->map;
60         err = -ENOMEM;
61         bpf_map_charge_finish(&stab->map.memory);
62 free_stab:
63         kfree(stab);
64         return ERR_PTR(err);
65 }
66
67 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog)
68 {
69         u32 ufd = attr->target_fd;
70         struct bpf_map *map;
71         struct fd f;
72         int ret;
73
74         if (attr->attach_flags || attr->replace_bpf_fd)
75                 return -EINVAL;
76
77         f = fdget(ufd);
78         map = __bpf_map_get(f);
79         if (IS_ERR(map))
80                 return PTR_ERR(map);
81         ret = sock_map_prog_update(map, prog, NULL, attr->attach_type);
82         fdput(f);
83         return ret;
84 }
85
86 int sock_map_prog_detach(const union bpf_attr *attr, enum bpf_prog_type ptype)
87 {
88         u32 ufd = attr->target_fd;
89         struct bpf_prog *prog;
90         struct bpf_map *map;
91         struct fd f;
92         int ret;
93
94         if (attr->attach_flags || attr->replace_bpf_fd)
95                 return -EINVAL;
96
97         f = fdget(ufd);
98         map = __bpf_map_get(f);
99         if (IS_ERR(map))
100                 return PTR_ERR(map);
101
102         prog = bpf_prog_get(attr->attach_bpf_fd);
103         if (IS_ERR(prog)) {
104                 ret = PTR_ERR(prog);
105                 goto put_map;
106         }
107
108         if (prog->type != ptype) {
109                 ret = -EINVAL;
110                 goto put_prog;
111         }
112
113         ret = sock_map_prog_update(map, NULL, prog, attr->attach_type);
114 put_prog:
115         bpf_prog_put(prog);
116 put_map:
117         fdput(f);
118         return ret;
119 }
120
121 static void sock_map_sk_acquire(struct sock *sk)
122         __acquires(&sk->sk_lock.slock)
123 {
124         lock_sock(sk);
125         preempt_disable();
126         rcu_read_lock();
127 }
128
129 static void sock_map_sk_release(struct sock *sk)
130         __releases(&sk->sk_lock.slock)
131 {
132         rcu_read_unlock();
133         preempt_enable();
134         release_sock(sk);
135 }
136
137 static void sock_map_add_link(struct sk_psock *psock,
138                               struct sk_psock_link *link,
139                               struct bpf_map *map, void *link_raw)
140 {
141         link->link_raw = link_raw;
142         link->map = map;
143         spin_lock_bh(&psock->link_lock);
144         list_add_tail(&link->list, &psock->link);
145         spin_unlock_bh(&psock->link_lock);
146 }
147
148 static void sock_map_del_link(struct sock *sk,
149                               struct sk_psock *psock, void *link_raw)
150 {
151         bool strp_stop = false, verdict_stop = false;
152         struct sk_psock_link *link, *tmp;
153
154         spin_lock_bh(&psock->link_lock);
155         list_for_each_entry_safe(link, tmp, &psock->link, list) {
156                 if (link->link_raw == link_raw) {
157                         struct bpf_map *map = link->map;
158                         struct bpf_stab *stab = container_of(map, struct bpf_stab,
159                                                              map);
160                         if (psock->parser.enabled && stab->progs.skb_parser)
161                                 strp_stop = true;
162                         if (psock->parser.enabled && stab->progs.skb_verdict)
163                                 verdict_stop = true;
164                         list_del(&link->list);
165                         sk_psock_free_link(link);
166                 }
167         }
168         spin_unlock_bh(&psock->link_lock);
169         if (strp_stop || verdict_stop) {
170                 write_lock_bh(&sk->sk_callback_lock);
171                 if (strp_stop)
172                         sk_psock_stop_strp(sk, psock);
173                 else
174                         sk_psock_stop_verdict(sk, psock);
175                 write_unlock_bh(&sk->sk_callback_lock);
176         }
177 }
178
179 static void sock_map_unref(struct sock *sk, void *link_raw)
180 {
181         struct sk_psock *psock = sk_psock(sk);
182
183         if (likely(psock)) {
184                 sock_map_del_link(sk, psock, link_raw);
185                 sk_psock_put(sk, psock);
186         }
187 }
188
189 static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
190 {
191         struct proto *prot;
192
193         switch (sk->sk_type) {
194         case SOCK_STREAM:
195                 prot = tcp_bpf_get_proto(sk, psock);
196                 break;
197
198         case SOCK_DGRAM:
199                 prot = udp_bpf_get_proto(sk, psock);
200                 break;
201
202         default:
203                 return -EINVAL;
204         }
205
206         if (IS_ERR(prot))
207                 return PTR_ERR(prot);
208
209         sk_psock_update_proto(sk, psock, prot);
210         return 0;
211 }
212
213 static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
214 {
215         struct sk_psock *psock;
216
217         rcu_read_lock();
218         psock = sk_psock(sk);
219         if (psock) {
220                 if (sk->sk_prot->close != sock_map_close) {
221                         psock = ERR_PTR(-EBUSY);
222                         goto out;
223                 }
224
225                 if (!refcount_inc_not_zero(&psock->refcnt))
226                         psock = ERR_PTR(-EBUSY);
227         }
228 out:
229         rcu_read_unlock();
230         return psock;
231 }
232
233 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
234                          struct sock *sk)
235 {
236         struct bpf_prog *msg_parser, *skb_parser, *skb_verdict;
237         struct sk_psock *psock;
238         int ret;
239
240         skb_verdict = READ_ONCE(progs->skb_verdict);
241         skb_parser = READ_ONCE(progs->skb_parser);
242         if (skb_verdict) {
243                 skb_verdict = bpf_prog_inc_not_zero(skb_verdict);
244                 if (IS_ERR(skb_verdict))
245                         return PTR_ERR(skb_verdict);
246         }
247         if (skb_parser) {
248                 skb_parser = bpf_prog_inc_not_zero(skb_parser);
249                 if (IS_ERR(skb_parser)) {
250                         bpf_prog_put(skb_verdict);
251                         return PTR_ERR(skb_parser);
252                 }
253         }
254
255         msg_parser = READ_ONCE(progs->msg_parser);
256         if (msg_parser) {
257                 msg_parser = bpf_prog_inc_not_zero(msg_parser);
258                 if (IS_ERR(msg_parser)) {
259                         ret = PTR_ERR(msg_parser);
260                         goto out;
261                 }
262         }
263
264         psock = sock_map_psock_get_checked(sk);
265         if (IS_ERR(psock)) {
266                 ret = PTR_ERR(psock);
267                 goto out_progs;
268         }
269
270         if (psock) {
271                 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
272                     (skb_parser  && READ_ONCE(psock->progs.skb_parser)) ||
273                     (skb_verdict && READ_ONCE(psock->progs.skb_verdict))) {
274                         sk_psock_put(sk, psock);
275                         ret = -EBUSY;
276                         goto out_progs;
277                 }
278         } else {
279                 psock = sk_psock_init(sk, map->numa_node);
280                 if (IS_ERR(psock)) {
281                         ret = PTR_ERR(psock);
282                         goto out_progs;
283                 }
284         }
285
286         if (msg_parser)
287                 psock_set_prog(&psock->progs.msg_parser, msg_parser);
288
289         ret = sock_map_init_proto(sk, psock);
290         if (ret < 0)
291                 goto out_drop;
292
293         write_lock_bh(&sk->sk_callback_lock);
294         if (skb_parser && skb_verdict && !psock->parser.enabled) {
295                 ret = sk_psock_init_strp(sk, psock);
296                 if (ret)
297                         goto out_unlock_drop;
298                 psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
299                 psock_set_prog(&psock->progs.skb_parser, skb_parser);
300                 sk_psock_start_strp(sk, psock);
301         } else if (!skb_parser && skb_verdict && !psock->parser.enabled) {
302                 psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
303                 sk_psock_start_verdict(sk,psock);
304         }
305         write_unlock_bh(&sk->sk_callback_lock);
306         return 0;
307 out_unlock_drop:
308         write_unlock_bh(&sk->sk_callback_lock);
309 out_drop:
310         sk_psock_put(sk, psock);
311 out_progs:
312         if (msg_parser)
313                 bpf_prog_put(msg_parser);
314 out:
315         if (skb_verdict)
316                 bpf_prog_put(skb_verdict);
317         if (skb_parser)
318                 bpf_prog_put(skb_parser);
319         return ret;
320 }
321
322 static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
323 {
324         struct sk_psock *psock;
325         int ret;
326
327         psock = sock_map_psock_get_checked(sk);
328         if (IS_ERR(psock))
329                 return PTR_ERR(psock);
330
331         if (!psock) {
332                 psock = sk_psock_init(sk, map->numa_node);
333                 if (IS_ERR(psock))
334                         return PTR_ERR(psock);
335         }
336
337         ret = sock_map_init_proto(sk, psock);
338         if (ret < 0)
339                 sk_psock_put(sk, psock);
340         return ret;
341 }
342
343 static void sock_map_free(struct bpf_map *map)
344 {
345         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
346         int i;
347
348         /* After the sync no updates or deletes will be in-flight so it
349          * is safe to walk map and remove entries without risking a race
350          * in EEXIST update case.
351          */
352         synchronize_rcu();
353         for (i = 0; i < stab->map.max_entries; i++) {
354                 struct sock **psk = &stab->sks[i];
355                 struct sock *sk;
356
357                 sk = xchg(psk, NULL);
358                 if (sk) {
359                         lock_sock(sk);
360                         rcu_read_lock();
361                         sock_map_unref(sk, psk);
362                         rcu_read_unlock();
363                         release_sock(sk);
364                 }
365         }
366
367         /* wait for psock readers accessing its map link */
368         synchronize_rcu();
369
370         bpf_map_area_free(stab->sks);
371         kfree(stab);
372 }
373
374 static void sock_map_release_progs(struct bpf_map *map)
375 {
376         psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs);
377 }
378
379 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
380 {
381         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
382
383         WARN_ON_ONCE(!rcu_read_lock_held());
384
385         if (unlikely(key >= map->max_entries))
386                 return NULL;
387         return READ_ONCE(stab->sks[key]);
388 }
389
390 static void *sock_map_lookup(struct bpf_map *map, void *key)
391 {
392         struct sock *sk;
393
394         sk = __sock_map_lookup_elem(map, *(u32 *)key);
395         if (!sk)
396                 return NULL;
397         if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt))
398                 return NULL;
399         return sk;
400 }
401
402 static void *sock_map_lookup_sys(struct bpf_map *map, void *key)
403 {
404         struct sock *sk;
405
406         if (map->value_size != sizeof(u64))
407                 return ERR_PTR(-ENOSPC);
408
409         sk = __sock_map_lookup_elem(map, *(u32 *)key);
410         if (!sk)
411                 return ERR_PTR(-ENOENT);
412
413         __sock_gen_cookie(sk);
414         return &sk->sk_cookie;
415 }
416
417 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
418                              struct sock **psk)
419 {
420         struct sock *sk;
421         int err = 0;
422
423         raw_spin_lock_bh(&stab->lock);
424         sk = *psk;
425         if (!sk_test || sk_test == sk)
426                 sk = xchg(psk, NULL);
427
428         if (likely(sk))
429                 sock_map_unref(sk, psk);
430         else
431                 err = -EINVAL;
432
433         raw_spin_unlock_bh(&stab->lock);
434         return err;
435 }
436
437 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk,
438                                       void *link_raw)
439 {
440         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
441
442         __sock_map_delete(stab, sk, link_raw);
443 }
444
445 static int sock_map_delete_elem(struct bpf_map *map, void *key)
446 {
447         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
448         u32 i = *(u32 *)key;
449         struct sock **psk;
450
451         if (unlikely(i >= map->max_entries))
452                 return -EINVAL;
453
454         psk = &stab->sks[i];
455         return __sock_map_delete(stab, NULL, psk);
456 }
457
458 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
459 {
460         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
461         u32 i = key ? *(u32 *)key : U32_MAX;
462         u32 *key_next = next;
463
464         if (i == stab->map.max_entries - 1)
465                 return -ENOENT;
466         if (i >= stab->map.max_entries)
467                 *key_next = 0;
468         else
469                 *key_next = i + 1;
470         return 0;
471 }
472
473 static bool sock_map_redirect_allowed(const struct sock *sk);
474
475 static int sock_map_update_common(struct bpf_map *map, u32 idx,
476                                   struct sock *sk, u64 flags)
477 {
478         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
479         struct sk_psock_link *link;
480         struct sk_psock *psock;
481         struct sock *osk;
482         int ret;
483
484         WARN_ON_ONCE(!rcu_read_lock_held());
485         if (unlikely(flags > BPF_EXIST))
486                 return -EINVAL;
487         if (unlikely(idx >= map->max_entries))
488                 return -E2BIG;
489
490         link = sk_psock_init_link();
491         if (!link)
492                 return -ENOMEM;
493
494         /* Only sockets we can redirect into/from in BPF need to hold
495          * refs to parser/verdict progs and have their sk_data_ready
496          * and sk_write_space callbacks overridden.
497          */
498         if (sock_map_redirect_allowed(sk))
499                 ret = sock_map_link(map, &stab->progs, sk);
500         else
501                 ret = sock_map_link_no_progs(map, sk);
502         if (ret < 0)
503                 goto out_free;
504
505         psock = sk_psock(sk);
506         WARN_ON_ONCE(!psock);
507
508         raw_spin_lock_bh(&stab->lock);
509         osk = stab->sks[idx];
510         if (osk && flags == BPF_NOEXIST) {
511                 ret = -EEXIST;
512                 goto out_unlock;
513         } else if (!osk && flags == BPF_EXIST) {
514                 ret = -ENOENT;
515                 goto out_unlock;
516         }
517
518         sock_map_add_link(psock, link, map, &stab->sks[idx]);
519         stab->sks[idx] = sk;
520         if (osk)
521                 sock_map_unref(osk, &stab->sks[idx]);
522         raw_spin_unlock_bh(&stab->lock);
523         return 0;
524 out_unlock:
525         raw_spin_unlock_bh(&stab->lock);
526         if (psock)
527                 sk_psock_put(sk, psock);
528 out_free:
529         sk_psock_free_link(link);
530         return ret;
531 }
532
533 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
534 {
535         return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
536                ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB ||
537                ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB;
538 }
539
540 static bool sk_is_tcp(const struct sock *sk)
541 {
542         return sk->sk_type == SOCK_STREAM &&
543                sk->sk_protocol == IPPROTO_TCP;
544 }
545
546 static bool sk_is_udp(const struct sock *sk)
547 {
548         return sk->sk_type == SOCK_DGRAM &&
549                sk->sk_protocol == IPPROTO_UDP;
550 }
551
552 static bool sock_map_redirect_allowed(const struct sock *sk)
553 {
554         return sk_is_tcp(sk) && sk->sk_state != TCP_LISTEN;
555 }
556
557 static bool sock_map_sk_is_suitable(const struct sock *sk)
558 {
559         return sk_is_tcp(sk) || sk_is_udp(sk);
560 }
561
562 static bool sock_map_sk_state_allowed(const struct sock *sk)
563 {
564         if (sk_is_tcp(sk))
565                 return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN);
566         else if (sk_is_udp(sk))
567                 return sk_hashed(sk);
568
569         return false;
570 }
571
572 static int sock_hash_update_common(struct bpf_map *map, void *key,
573                                    struct sock *sk, u64 flags);
574
575 int sock_map_update_elem_sys(struct bpf_map *map, void *key, void *value,
576                              u64 flags)
577 {
578         struct socket *sock;
579         struct sock *sk;
580         int ret;
581         u64 ufd;
582
583         if (map->value_size == sizeof(u64))
584                 ufd = *(u64 *)value;
585         else
586                 ufd = *(u32 *)value;
587         if (ufd > S32_MAX)
588                 return -EINVAL;
589
590         sock = sockfd_lookup(ufd, &ret);
591         if (!sock)
592                 return ret;
593         sk = sock->sk;
594         if (!sk) {
595                 ret = -EINVAL;
596                 goto out;
597         }
598         if (!sock_map_sk_is_suitable(sk)) {
599                 ret = -EOPNOTSUPP;
600                 goto out;
601         }
602
603         sock_map_sk_acquire(sk);
604         if (!sock_map_sk_state_allowed(sk))
605                 ret = -EOPNOTSUPP;
606         else if (map->map_type == BPF_MAP_TYPE_SOCKMAP)
607                 ret = sock_map_update_common(map, *(u32 *)key, sk, flags);
608         else
609                 ret = sock_hash_update_common(map, key, sk, flags);
610         sock_map_sk_release(sk);
611 out:
612         fput(sock->file);
613         return ret;
614 }
615
616 static int sock_map_update_elem(struct bpf_map *map, void *key,
617                                 void *value, u64 flags)
618 {
619         struct sock *sk = (struct sock *)value;
620         int ret;
621
622         if (unlikely(!sk || !sk_fullsock(sk)))
623                 return -EINVAL;
624
625         if (!sock_map_sk_is_suitable(sk))
626                 return -EOPNOTSUPP;
627
628         local_bh_disable();
629         bh_lock_sock(sk);
630         if (!sock_map_sk_state_allowed(sk))
631                 ret = -EOPNOTSUPP;
632         else if (map->map_type == BPF_MAP_TYPE_SOCKMAP)
633                 ret = sock_map_update_common(map, *(u32 *)key, sk, flags);
634         else
635                 ret = sock_hash_update_common(map, key, sk, flags);
636         bh_unlock_sock(sk);
637         local_bh_enable();
638         return ret;
639 }
640
641 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops,
642            struct bpf_map *, map, void *, key, u64, flags)
643 {
644         WARN_ON_ONCE(!rcu_read_lock_held());
645
646         if (likely(sock_map_sk_is_suitable(sops->sk) &&
647                    sock_map_op_okay(sops)))
648                 return sock_map_update_common(map, *(u32 *)key, sops->sk,
649                                               flags);
650         return -EOPNOTSUPP;
651 }
652
653 const struct bpf_func_proto bpf_sock_map_update_proto = {
654         .func           = bpf_sock_map_update,
655         .gpl_only       = false,
656         .pkt_access     = true,
657         .ret_type       = RET_INTEGER,
658         .arg1_type      = ARG_PTR_TO_CTX,
659         .arg2_type      = ARG_CONST_MAP_PTR,
660         .arg3_type      = ARG_PTR_TO_MAP_KEY,
661         .arg4_type      = ARG_ANYTHING,
662 };
663
664 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
665            struct bpf_map *, map, u32, key, u64, flags)
666 {
667         struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
668         struct sock *sk;
669
670         if (unlikely(flags & ~(BPF_F_INGRESS)))
671                 return SK_DROP;
672
673         sk = __sock_map_lookup_elem(map, key);
674         if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
675                 return SK_DROP;
676
677         tcb->bpf.flags = flags;
678         tcb->bpf.sk_redir = sk;
679         return SK_PASS;
680 }
681
682 const struct bpf_func_proto bpf_sk_redirect_map_proto = {
683         .func           = bpf_sk_redirect_map,
684         .gpl_only       = false,
685         .ret_type       = RET_INTEGER,
686         .arg1_type      = ARG_PTR_TO_CTX,
687         .arg2_type      = ARG_CONST_MAP_PTR,
688         .arg3_type      = ARG_ANYTHING,
689         .arg4_type      = ARG_ANYTHING,
690 };
691
692 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg,
693            struct bpf_map *, map, u32, key, u64, flags)
694 {
695         struct sock *sk;
696
697         if (unlikely(flags & ~(BPF_F_INGRESS)))
698                 return SK_DROP;
699
700         sk = __sock_map_lookup_elem(map, key);
701         if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
702                 return SK_DROP;
703
704         msg->flags = flags;
705         msg->sk_redir = sk;
706         return SK_PASS;
707 }
708
709 const struct bpf_func_proto bpf_msg_redirect_map_proto = {
710         .func           = bpf_msg_redirect_map,
711         .gpl_only       = false,
712         .ret_type       = RET_INTEGER,
713         .arg1_type      = ARG_PTR_TO_CTX,
714         .arg2_type      = ARG_CONST_MAP_PTR,
715         .arg3_type      = ARG_ANYTHING,
716         .arg4_type      = ARG_ANYTHING,
717 };
718
719 struct sock_map_seq_info {
720         struct bpf_map *map;
721         struct sock *sk;
722         u32 index;
723 };
724
725 struct bpf_iter__sockmap {
726         __bpf_md_ptr(struct bpf_iter_meta *, meta);
727         __bpf_md_ptr(struct bpf_map *, map);
728         __bpf_md_ptr(void *, key);
729         __bpf_md_ptr(struct sock *, sk);
730 };
731
732 DEFINE_BPF_ITER_FUNC(sockmap, struct bpf_iter_meta *meta,
733                      struct bpf_map *map, void *key,
734                      struct sock *sk)
735
736 static void *sock_map_seq_lookup_elem(struct sock_map_seq_info *info)
737 {
738         if (unlikely(info->index >= info->map->max_entries))
739                 return NULL;
740
741         info->sk = __sock_map_lookup_elem(info->map, info->index);
742
743         /* can't return sk directly, since that might be NULL */
744         return info;
745 }
746
747 static void *sock_map_seq_start(struct seq_file *seq, loff_t *pos)
748 {
749         struct sock_map_seq_info *info = seq->private;
750
751         if (*pos == 0)
752                 ++*pos;
753
754         /* pairs with sock_map_seq_stop */
755         rcu_read_lock();
756         return sock_map_seq_lookup_elem(info);
757 }
758
759 static void *sock_map_seq_next(struct seq_file *seq, void *v, loff_t *pos)
760 {
761         struct sock_map_seq_info *info = seq->private;
762
763         ++*pos;
764         ++info->index;
765
766         return sock_map_seq_lookup_elem(info);
767 }
768
769 static int sock_map_seq_show(struct seq_file *seq, void *v)
770 {
771         struct sock_map_seq_info *info = seq->private;
772         struct bpf_iter__sockmap ctx = {};
773         struct bpf_iter_meta meta;
774         struct bpf_prog *prog;
775
776         meta.seq = seq;
777         prog = bpf_iter_get_info(&meta, !v);
778         if (!prog)
779                 return 0;
780
781         ctx.meta = &meta;
782         ctx.map = info->map;
783         if (v) {
784                 ctx.key = &info->index;
785                 ctx.sk = info->sk;
786         }
787
788         return bpf_iter_run_prog(prog, &ctx);
789 }
790
791 static void sock_map_seq_stop(struct seq_file *seq, void *v)
792 {
793         if (!v)
794                 (void)sock_map_seq_show(seq, NULL);
795
796         /* pairs with sock_map_seq_start */
797         rcu_read_unlock();
798 }
799
800 static const struct seq_operations sock_map_seq_ops = {
801         .start  = sock_map_seq_start,
802         .next   = sock_map_seq_next,
803         .stop   = sock_map_seq_stop,
804         .show   = sock_map_seq_show,
805 };
806
807 static int sock_map_init_seq_private(void *priv_data,
808                                      struct bpf_iter_aux_info *aux)
809 {
810         struct sock_map_seq_info *info = priv_data;
811
812         info->map = aux->map;
813         return 0;
814 }
815
816 static const struct bpf_iter_seq_info sock_map_iter_seq_info = {
817         .seq_ops                = &sock_map_seq_ops,
818         .init_seq_private       = sock_map_init_seq_private,
819         .seq_priv_size          = sizeof(struct sock_map_seq_info),
820 };
821
822 static int sock_map_btf_id;
823 const struct bpf_map_ops sock_map_ops = {
824         .map_meta_equal         = bpf_map_meta_equal,
825         .map_alloc              = sock_map_alloc,
826         .map_free               = sock_map_free,
827         .map_get_next_key       = sock_map_get_next_key,
828         .map_lookup_elem_sys_only = sock_map_lookup_sys,
829         .map_update_elem        = sock_map_update_elem,
830         .map_delete_elem        = sock_map_delete_elem,
831         .map_lookup_elem        = sock_map_lookup,
832         .map_release_uref       = sock_map_release_progs,
833         .map_check_btf          = map_check_no_btf,
834         .map_btf_name           = "bpf_stab",
835         .map_btf_id             = &sock_map_btf_id,
836         .iter_seq_info          = &sock_map_iter_seq_info,
837 };
838
839 struct bpf_shtab_elem {
840         struct rcu_head rcu;
841         u32 hash;
842         struct sock *sk;
843         struct hlist_node node;
844         u8 key[];
845 };
846
847 struct bpf_shtab_bucket {
848         struct hlist_head head;
849         raw_spinlock_t lock;
850 };
851
852 struct bpf_shtab {
853         struct bpf_map map;
854         struct bpf_shtab_bucket *buckets;
855         u32 buckets_num;
856         u32 elem_size;
857         struct sk_psock_progs progs;
858         atomic_t count;
859 };
860
861 static inline u32 sock_hash_bucket_hash(const void *key, u32 len)
862 {
863         return jhash(key, len, 0);
864 }
865
866 static struct bpf_shtab_bucket *sock_hash_select_bucket(struct bpf_shtab *htab,
867                                                         u32 hash)
868 {
869         return &htab->buckets[hash & (htab->buckets_num - 1)];
870 }
871
872 static struct bpf_shtab_elem *
873 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key,
874                           u32 key_size)
875 {
876         struct bpf_shtab_elem *elem;
877
878         hlist_for_each_entry_rcu(elem, head, node) {
879                 if (elem->hash == hash &&
880                     !memcmp(&elem->key, key, key_size))
881                         return elem;
882         }
883
884         return NULL;
885 }
886
887 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
888 {
889         struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
890         u32 key_size = map->key_size, hash;
891         struct bpf_shtab_bucket *bucket;
892         struct bpf_shtab_elem *elem;
893
894         WARN_ON_ONCE(!rcu_read_lock_held());
895
896         hash = sock_hash_bucket_hash(key, key_size);
897         bucket = sock_hash_select_bucket(htab, hash);
898         elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
899
900         return elem ? elem->sk : NULL;
901 }
902
903 static void sock_hash_free_elem(struct bpf_shtab *htab,
904                                 struct bpf_shtab_elem *elem)
905 {
906         atomic_dec(&htab->count);
907         kfree_rcu(elem, rcu);
908 }
909
910 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk,
911                                        void *link_raw)
912 {
913         struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
914         struct bpf_shtab_elem *elem_probe, *elem = link_raw;
915         struct bpf_shtab_bucket *bucket;
916
917         WARN_ON_ONCE(!rcu_read_lock_held());
918         bucket = sock_hash_select_bucket(htab, elem->hash);
919
920         /* elem may be deleted in parallel from the map, but access here
921          * is okay since it's going away only after RCU grace period.
922          * However, we need to check whether it's still present.
923          */
924         raw_spin_lock_bh(&bucket->lock);
925         elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash,
926                                                elem->key, map->key_size);
927         if (elem_probe && elem_probe == elem) {
928                 hlist_del_rcu(&elem->node);
929                 sock_map_unref(elem->sk, elem);
930                 sock_hash_free_elem(htab, elem);
931         }
932         raw_spin_unlock_bh(&bucket->lock);
933 }
934
935 static int sock_hash_delete_elem(struct bpf_map *map, void *key)
936 {
937         struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
938         u32 hash, key_size = map->key_size;
939         struct bpf_shtab_bucket *bucket;
940         struct bpf_shtab_elem *elem;
941         int ret = -ENOENT;
942
943         hash = sock_hash_bucket_hash(key, key_size);
944         bucket = sock_hash_select_bucket(htab, hash);
945
946         raw_spin_lock_bh(&bucket->lock);
947         elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
948         if (elem) {
949                 hlist_del_rcu(&elem->node);
950                 sock_map_unref(elem->sk, elem);
951                 sock_hash_free_elem(htab, elem);
952                 ret = 0;
953         }
954         raw_spin_unlock_bh(&bucket->lock);
955         return ret;
956 }
957
958 static struct bpf_shtab_elem *sock_hash_alloc_elem(struct bpf_shtab *htab,
959                                                    void *key, u32 key_size,
960                                                    u32 hash, struct sock *sk,
961                                                    struct bpf_shtab_elem *old)
962 {
963         struct bpf_shtab_elem *new;
964
965         if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
966                 if (!old) {
967                         atomic_dec(&htab->count);
968                         return ERR_PTR(-E2BIG);
969                 }
970         }
971
972         new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
973                            htab->map.numa_node);
974         if (!new) {
975                 atomic_dec(&htab->count);
976                 return ERR_PTR(-ENOMEM);
977         }
978         memcpy(new->key, key, key_size);
979         new->sk = sk;
980         new->hash = hash;
981         return new;
982 }
983
984 static int sock_hash_update_common(struct bpf_map *map, void *key,
985                                    struct sock *sk, u64 flags)
986 {
987         struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
988         u32 key_size = map->key_size, hash;
989         struct bpf_shtab_elem *elem, *elem_new;
990         struct bpf_shtab_bucket *bucket;
991         struct sk_psock_link *link;
992         struct sk_psock *psock;
993         int ret;
994
995         WARN_ON_ONCE(!rcu_read_lock_held());
996         if (unlikely(flags > BPF_EXIST))
997                 return -EINVAL;
998
999         link = sk_psock_init_link();
1000         if (!link)
1001                 return -ENOMEM;
1002
1003         /* Only sockets we can redirect into/from in BPF need to hold
1004          * refs to parser/verdict progs and have their sk_data_ready
1005          * and sk_write_space callbacks overridden.
1006          */
1007         if (sock_map_redirect_allowed(sk))
1008                 ret = sock_map_link(map, &htab->progs, sk);
1009         else
1010                 ret = sock_map_link_no_progs(map, sk);
1011         if (ret < 0)
1012                 goto out_free;
1013
1014         psock = sk_psock(sk);
1015         WARN_ON_ONCE(!psock);
1016
1017         hash = sock_hash_bucket_hash(key, key_size);
1018         bucket = sock_hash_select_bucket(htab, hash);
1019
1020         raw_spin_lock_bh(&bucket->lock);
1021         elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
1022         if (elem && flags == BPF_NOEXIST) {
1023                 ret = -EEXIST;
1024                 goto out_unlock;
1025         } else if (!elem && flags == BPF_EXIST) {
1026                 ret = -ENOENT;
1027                 goto out_unlock;
1028         }
1029
1030         elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem);
1031         if (IS_ERR(elem_new)) {
1032                 ret = PTR_ERR(elem_new);
1033                 goto out_unlock;
1034         }
1035
1036         sock_map_add_link(psock, link, map, elem_new);
1037         /* Add new element to the head of the list, so that
1038          * concurrent search will find it before old elem.
1039          */
1040         hlist_add_head_rcu(&elem_new->node, &bucket->head);
1041         if (elem) {
1042                 hlist_del_rcu(&elem->node);
1043                 sock_map_unref(elem->sk, elem);
1044                 sock_hash_free_elem(htab, elem);
1045         }
1046         raw_spin_unlock_bh(&bucket->lock);
1047         return 0;
1048 out_unlock:
1049         raw_spin_unlock_bh(&bucket->lock);
1050         sk_psock_put(sk, psock);
1051 out_free:
1052         sk_psock_free_link(link);
1053         return ret;
1054 }
1055
1056 static int sock_hash_get_next_key(struct bpf_map *map, void *key,
1057                                   void *key_next)
1058 {
1059         struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
1060         struct bpf_shtab_elem *elem, *elem_next;
1061         u32 hash, key_size = map->key_size;
1062         struct hlist_head *head;
1063         int i = 0;
1064
1065         if (!key)
1066                 goto find_first_elem;
1067         hash = sock_hash_bucket_hash(key, key_size);
1068         head = &sock_hash_select_bucket(htab, hash)->head;
1069         elem = sock_hash_lookup_elem_raw(head, hash, key, key_size);
1070         if (!elem)
1071                 goto find_first_elem;
1072
1073         elem_next = hlist_entry_safe(rcu_dereference(hlist_next_rcu(&elem->node)),
1074                                      struct bpf_shtab_elem, node);
1075         if (elem_next) {
1076                 memcpy(key_next, elem_next->key, key_size);
1077                 return 0;
1078         }
1079
1080         i = hash & (htab->buckets_num - 1);
1081         i++;
1082 find_first_elem:
1083         for (; i < htab->buckets_num; i++) {
1084                 head = &sock_hash_select_bucket(htab, i)->head;
1085                 elem_next = hlist_entry_safe(rcu_dereference(hlist_first_rcu(head)),
1086                                              struct bpf_shtab_elem, node);
1087                 if (elem_next) {
1088                         memcpy(key_next, elem_next->key, key_size);
1089                         return 0;
1090                 }
1091         }
1092
1093         return -ENOENT;
1094 }
1095
1096 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
1097 {
1098         struct bpf_shtab *htab;
1099         int i, err;
1100         u64 cost;
1101
1102         if (!capable(CAP_NET_ADMIN))
1103                 return ERR_PTR(-EPERM);
1104         if (attr->max_entries == 0 ||
1105             attr->key_size    == 0 ||
1106             (attr->value_size != sizeof(u32) &&
1107              attr->value_size != sizeof(u64)) ||
1108             attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
1109                 return ERR_PTR(-EINVAL);
1110         if (attr->key_size > MAX_BPF_STACK)
1111                 return ERR_PTR(-E2BIG);
1112
1113         htab = kzalloc(sizeof(*htab), GFP_USER);
1114         if (!htab)
1115                 return ERR_PTR(-ENOMEM);
1116
1117         bpf_map_init_from_attr(&htab->map, attr);
1118
1119         htab->buckets_num = roundup_pow_of_two(htab->map.max_entries);
1120         htab->elem_size = sizeof(struct bpf_shtab_elem) +
1121                           round_up(htab->map.key_size, 8);
1122         if (htab->buckets_num == 0 ||
1123             htab->buckets_num > U32_MAX / sizeof(struct bpf_shtab_bucket)) {
1124                 err = -EINVAL;
1125                 goto free_htab;
1126         }
1127
1128         cost = (u64) htab->buckets_num * sizeof(struct bpf_shtab_bucket) +
1129                (u64) htab->elem_size * htab->map.max_entries;
1130         if (cost >= U32_MAX - PAGE_SIZE) {
1131                 err = -EINVAL;
1132                 goto free_htab;
1133         }
1134         err = bpf_map_charge_init(&htab->map.memory, cost);
1135         if (err)
1136                 goto free_htab;
1137
1138         htab->buckets = bpf_map_area_alloc(htab->buckets_num *
1139                                            sizeof(struct bpf_shtab_bucket),
1140                                            htab->map.numa_node);
1141         if (!htab->buckets) {
1142                 bpf_map_charge_finish(&htab->map.memory);
1143                 err = -ENOMEM;
1144                 goto free_htab;
1145         }
1146
1147         for (i = 0; i < htab->buckets_num; i++) {
1148                 INIT_HLIST_HEAD(&htab->buckets[i].head);
1149                 raw_spin_lock_init(&htab->buckets[i].lock);
1150         }
1151
1152         return &htab->map;
1153 free_htab:
1154         kfree(htab);
1155         return ERR_PTR(err);
1156 }
1157
1158 static void sock_hash_free(struct bpf_map *map)
1159 {
1160         struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
1161         struct bpf_shtab_bucket *bucket;
1162         struct hlist_head unlink_list;
1163         struct bpf_shtab_elem *elem;
1164         struct hlist_node *node;
1165         int i;
1166
1167         /* After the sync no updates or deletes will be in-flight so it
1168          * is safe to walk map and remove entries without risking a race
1169          * in EEXIST update case.
1170          */
1171         synchronize_rcu();
1172         for (i = 0; i < htab->buckets_num; i++) {
1173                 bucket = sock_hash_select_bucket(htab, i);
1174
1175                 /* We are racing with sock_hash_delete_from_link to
1176                  * enter the spin-lock critical section. Every socket on
1177                  * the list is still linked to sockhash. Since link
1178                  * exists, psock exists and holds a ref to socket. That
1179                  * lets us to grab a socket ref too.
1180                  */
1181                 raw_spin_lock_bh(&bucket->lock);
1182                 hlist_for_each_entry(elem, &bucket->head, node)
1183                         sock_hold(elem->sk);
1184                 hlist_move_list(&bucket->head, &unlink_list);
1185                 raw_spin_unlock_bh(&bucket->lock);
1186
1187                 /* Process removed entries out of atomic context to
1188                  * block for socket lock before deleting the psock's
1189                  * link to sockhash.
1190                  */
1191                 hlist_for_each_entry_safe(elem, node, &unlink_list, node) {
1192                         hlist_del(&elem->node);
1193                         lock_sock(elem->sk);
1194                         rcu_read_lock();
1195                         sock_map_unref(elem->sk, elem);
1196                         rcu_read_unlock();
1197                         release_sock(elem->sk);
1198                         sock_put(elem->sk);
1199                         sock_hash_free_elem(htab, elem);
1200                 }
1201         }
1202
1203         /* wait for psock readers accessing its map link */
1204         synchronize_rcu();
1205
1206         bpf_map_area_free(htab->buckets);
1207         kfree(htab);
1208 }
1209
1210 static void *sock_hash_lookup_sys(struct bpf_map *map, void *key)
1211 {
1212         struct sock *sk;
1213
1214         if (map->value_size != sizeof(u64))
1215                 return ERR_PTR(-ENOSPC);
1216
1217         sk = __sock_hash_lookup_elem(map, key);
1218         if (!sk)
1219                 return ERR_PTR(-ENOENT);
1220
1221         __sock_gen_cookie(sk);
1222         return &sk->sk_cookie;
1223 }
1224
1225 static void *sock_hash_lookup(struct bpf_map *map, void *key)
1226 {
1227         struct sock *sk;
1228
1229         sk = __sock_hash_lookup_elem(map, key);
1230         if (!sk)
1231                 return NULL;
1232         if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt))
1233                 return NULL;
1234         return sk;
1235 }
1236
1237 static void sock_hash_release_progs(struct bpf_map *map)
1238 {
1239         psock_progs_drop(&container_of(map, struct bpf_shtab, map)->progs);
1240 }
1241
1242 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops,
1243            struct bpf_map *, map, void *, key, u64, flags)
1244 {
1245         WARN_ON_ONCE(!rcu_read_lock_held());
1246
1247         if (likely(sock_map_sk_is_suitable(sops->sk) &&
1248                    sock_map_op_okay(sops)))
1249                 return sock_hash_update_common(map, key, sops->sk, flags);
1250         return -EOPNOTSUPP;
1251 }
1252
1253 const struct bpf_func_proto bpf_sock_hash_update_proto = {
1254         .func           = bpf_sock_hash_update,
1255         .gpl_only       = false,
1256         .pkt_access     = true,
1257         .ret_type       = RET_INTEGER,
1258         .arg1_type      = ARG_PTR_TO_CTX,
1259         .arg2_type      = ARG_CONST_MAP_PTR,
1260         .arg3_type      = ARG_PTR_TO_MAP_KEY,
1261         .arg4_type      = ARG_ANYTHING,
1262 };
1263
1264 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
1265            struct bpf_map *, map, void *, key, u64, flags)
1266 {
1267         struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
1268         struct sock *sk;
1269
1270         if (unlikely(flags & ~(BPF_F_INGRESS)))
1271                 return SK_DROP;
1272
1273         sk = __sock_hash_lookup_elem(map, key);
1274         if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
1275                 return SK_DROP;
1276
1277         tcb->bpf.flags = flags;
1278         tcb->bpf.sk_redir = sk;
1279         return SK_PASS;
1280 }
1281
1282 const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
1283         .func           = bpf_sk_redirect_hash,
1284         .gpl_only       = false,
1285         .ret_type       = RET_INTEGER,
1286         .arg1_type      = ARG_PTR_TO_CTX,
1287         .arg2_type      = ARG_CONST_MAP_PTR,
1288         .arg3_type      = ARG_PTR_TO_MAP_KEY,
1289         .arg4_type      = ARG_ANYTHING,
1290 };
1291
1292 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg,
1293            struct bpf_map *, map, void *, key, u64, flags)
1294 {
1295         struct sock *sk;
1296
1297         if (unlikely(flags & ~(BPF_F_INGRESS)))
1298                 return SK_DROP;
1299
1300         sk = __sock_hash_lookup_elem(map, key);
1301         if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
1302                 return SK_DROP;
1303
1304         msg->flags = flags;
1305         msg->sk_redir = sk;
1306         return SK_PASS;
1307 }
1308
1309 const struct bpf_func_proto bpf_msg_redirect_hash_proto = {
1310         .func           = bpf_msg_redirect_hash,
1311         .gpl_only       = false,
1312         .ret_type       = RET_INTEGER,
1313         .arg1_type      = ARG_PTR_TO_CTX,
1314         .arg2_type      = ARG_CONST_MAP_PTR,
1315         .arg3_type      = ARG_PTR_TO_MAP_KEY,
1316         .arg4_type      = ARG_ANYTHING,
1317 };
1318
1319 struct sock_hash_seq_info {
1320         struct bpf_map *map;
1321         struct bpf_shtab *htab;
1322         u32 bucket_id;
1323 };
1324
1325 static void *sock_hash_seq_find_next(struct sock_hash_seq_info *info,
1326                                      struct bpf_shtab_elem *prev_elem)
1327 {
1328         const struct bpf_shtab *htab = info->htab;
1329         struct bpf_shtab_bucket *bucket;
1330         struct bpf_shtab_elem *elem;
1331         struct hlist_node *node;
1332
1333         /* try to find next elem in the same bucket */
1334         if (prev_elem) {
1335                 node = rcu_dereference(hlist_next_rcu(&prev_elem->node));
1336                 elem = hlist_entry_safe(node, struct bpf_shtab_elem, node);
1337                 if (elem)
1338                         return elem;
1339
1340                 /* no more elements, continue in the next bucket */
1341                 info->bucket_id++;
1342         }
1343
1344         for (; info->bucket_id < htab->buckets_num; info->bucket_id++) {
1345                 bucket = &htab->buckets[info->bucket_id];
1346                 node = rcu_dereference(hlist_first_rcu(&bucket->head));
1347                 elem = hlist_entry_safe(node, struct bpf_shtab_elem, node);
1348                 if (elem)
1349                         return elem;
1350         }
1351
1352         return NULL;
1353 }
1354
1355 static void *sock_hash_seq_start(struct seq_file *seq, loff_t *pos)
1356 {
1357         struct sock_hash_seq_info *info = seq->private;
1358
1359         if (*pos == 0)
1360                 ++*pos;
1361
1362         /* pairs with sock_hash_seq_stop */
1363         rcu_read_lock();
1364         return sock_hash_seq_find_next(info, NULL);
1365 }
1366
1367 static void *sock_hash_seq_next(struct seq_file *seq, void *v, loff_t *pos)
1368 {
1369         struct sock_hash_seq_info *info = seq->private;
1370
1371         ++*pos;
1372         return sock_hash_seq_find_next(info, v);
1373 }
1374
1375 static int sock_hash_seq_show(struct seq_file *seq, void *v)
1376 {
1377         struct sock_hash_seq_info *info = seq->private;
1378         struct bpf_iter__sockmap ctx = {};
1379         struct bpf_shtab_elem *elem = v;
1380         struct bpf_iter_meta meta;
1381         struct bpf_prog *prog;
1382
1383         meta.seq = seq;
1384         prog = bpf_iter_get_info(&meta, !elem);
1385         if (!prog)
1386                 return 0;
1387
1388         ctx.meta = &meta;
1389         ctx.map = info->map;
1390         if (elem) {
1391                 ctx.key = elem->key;
1392                 ctx.sk = elem->sk;
1393         }
1394
1395         return bpf_iter_run_prog(prog, &ctx);
1396 }
1397
1398 static void sock_hash_seq_stop(struct seq_file *seq, void *v)
1399 {
1400         if (!v)
1401                 (void)sock_hash_seq_show(seq, NULL);
1402
1403         /* pairs with sock_hash_seq_start */
1404         rcu_read_unlock();
1405 }
1406
1407 static const struct seq_operations sock_hash_seq_ops = {
1408         .start  = sock_hash_seq_start,
1409         .next   = sock_hash_seq_next,
1410         .stop   = sock_hash_seq_stop,
1411         .show   = sock_hash_seq_show,
1412 };
1413
1414 static int sock_hash_init_seq_private(void *priv_data,
1415                                      struct bpf_iter_aux_info *aux)
1416 {
1417         struct sock_hash_seq_info *info = priv_data;
1418
1419         info->map = aux->map;
1420         info->htab = container_of(aux->map, struct bpf_shtab, map);
1421         return 0;
1422 }
1423
1424 static const struct bpf_iter_seq_info sock_hash_iter_seq_info = {
1425         .seq_ops                = &sock_hash_seq_ops,
1426         .init_seq_private       = sock_hash_init_seq_private,
1427         .seq_priv_size          = sizeof(struct sock_hash_seq_info),
1428 };
1429
1430 static int sock_hash_map_btf_id;
1431 const struct bpf_map_ops sock_hash_ops = {
1432         .map_meta_equal         = bpf_map_meta_equal,
1433         .map_alloc              = sock_hash_alloc,
1434         .map_free               = sock_hash_free,
1435         .map_get_next_key       = sock_hash_get_next_key,
1436         .map_update_elem        = sock_map_update_elem,
1437         .map_delete_elem        = sock_hash_delete_elem,
1438         .map_lookup_elem        = sock_hash_lookup,
1439         .map_lookup_elem_sys_only = sock_hash_lookup_sys,
1440         .map_release_uref       = sock_hash_release_progs,
1441         .map_check_btf          = map_check_no_btf,
1442         .map_btf_name           = "bpf_shtab",
1443         .map_btf_id             = &sock_hash_map_btf_id,
1444         .iter_seq_info          = &sock_hash_iter_seq_info,
1445 };
1446
1447 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map)
1448 {
1449         switch (map->map_type) {
1450         case BPF_MAP_TYPE_SOCKMAP:
1451                 return &container_of(map, struct bpf_stab, map)->progs;
1452         case BPF_MAP_TYPE_SOCKHASH:
1453                 return &container_of(map, struct bpf_shtab, map)->progs;
1454         default:
1455                 break;
1456         }
1457
1458         return NULL;
1459 }
1460
1461 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
1462                          struct bpf_prog *old, u32 which)
1463 {
1464         struct sk_psock_progs *progs = sock_map_progs(map);
1465         struct bpf_prog **pprog;
1466
1467         if (!progs)
1468                 return -EOPNOTSUPP;
1469
1470         switch (which) {
1471         case BPF_SK_MSG_VERDICT:
1472                 pprog = &progs->msg_parser;
1473                 break;
1474         case BPF_SK_SKB_STREAM_PARSER:
1475                 pprog = &progs->skb_parser;
1476                 break;
1477         case BPF_SK_SKB_STREAM_VERDICT:
1478                 pprog = &progs->skb_verdict;
1479                 break;
1480         default:
1481                 return -EOPNOTSUPP;
1482         }
1483
1484         if (old)
1485                 return psock_replace_prog(pprog, prog, old);
1486
1487         psock_set_prog(pprog, prog);
1488         return 0;
1489 }
1490
1491 static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link)
1492 {
1493         switch (link->map->map_type) {
1494         case BPF_MAP_TYPE_SOCKMAP:
1495                 return sock_map_delete_from_link(link->map, sk,
1496                                                  link->link_raw);
1497         case BPF_MAP_TYPE_SOCKHASH:
1498                 return sock_hash_delete_from_link(link->map, sk,
1499                                                   link->link_raw);
1500         default:
1501                 break;
1502         }
1503 }
1504
1505 static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock)
1506 {
1507         struct sk_psock_link *link;
1508
1509         while ((link = sk_psock_link_pop(psock))) {
1510                 sock_map_unlink(sk, link);
1511                 sk_psock_free_link(link);
1512         }
1513 }
1514
1515 void sock_map_unhash(struct sock *sk)
1516 {
1517         void (*saved_unhash)(struct sock *sk);
1518         struct sk_psock *psock;
1519
1520         rcu_read_lock();
1521         psock = sk_psock(sk);
1522         if (unlikely(!psock)) {
1523                 rcu_read_unlock();
1524                 if (sk->sk_prot->unhash)
1525                         sk->sk_prot->unhash(sk);
1526                 return;
1527         }
1528
1529         saved_unhash = psock->saved_unhash;
1530         sock_map_remove_links(sk, psock);
1531         rcu_read_unlock();
1532         saved_unhash(sk);
1533 }
1534
1535 void sock_map_close(struct sock *sk, long timeout)
1536 {
1537         void (*saved_close)(struct sock *sk, long timeout);
1538         struct sk_psock *psock;
1539
1540         lock_sock(sk);
1541         rcu_read_lock();
1542         psock = sk_psock(sk);
1543         if (unlikely(!psock)) {
1544                 rcu_read_unlock();
1545                 release_sock(sk);
1546                 return sk->sk_prot->close(sk, timeout);
1547         }
1548
1549         saved_close = psock->saved_close;
1550         sock_map_remove_links(sk, psock);
1551         rcu_read_unlock();
1552         release_sock(sk);
1553         saved_close(sk, timeout);
1554 }
1555
1556 static int sock_map_iter_attach_target(struct bpf_prog *prog,
1557                                        union bpf_iter_link_info *linfo,
1558                                        struct bpf_iter_aux_info *aux)
1559 {
1560         struct bpf_map *map;
1561         int err = -EINVAL;
1562
1563         if (!linfo->map.map_fd)
1564                 return -EBADF;
1565
1566         map = bpf_map_get_with_uref(linfo->map.map_fd);
1567         if (IS_ERR(map))
1568                 return PTR_ERR(map);
1569
1570         if (map->map_type != BPF_MAP_TYPE_SOCKMAP &&
1571             map->map_type != BPF_MAP_TYPE_SOCKHASH)
1572                 goto put_map;
1573
1574         if (prog->aux->max_rdonly_access > map->key_size) {
1575                 err = -EACCES;
1576                 goto put_map;
1577         }
1578
1579         aux->map = map;
1580         return 0;
1581
1582 put_map:
1583         bpf_map_put_with_uref(map);
1584         return err;
1585 }
1586
1587 static void sock_map_iter_detach_target(struct bpf_iter_aux_info *aux)
1588 {
1589         bpf_map_put_with_uref(aux->map);
1590 }
1591
1592 static struct bpf_iter_reg sock_map_iter_reg = {
1593         .target                 = "sockmap",
1594         .attach_target          = sock_map_iter_attach_target,
1595         .detach_target          = sock_map_iter_detach_target,
1596         .show_fdinfo            = bpf_iter_map_show_fdinfo,
1597         .fill_link_info         = bpf_iter_map_fill_link_info,
1598         .ctx_arg_info_size      = 2,
1599         .ctx_arg_info           = {
1600                 { offsetof(struct bpf_iter__sockmap, key),
1601                   PTR_TO_RDONLY_BUF_OR_NULL },
1602                 { offsetof(struct bpf_iter__sockmap, sk),
1603                   PTR_TO_BTF_ID_OR_NULL },
1604         },
1605 };
1606
1607 static int __init bpf_sockmap_iter_init(void)
1608 {
1609         sock_map_iter_reg.ctx_arg_info[1].btf_id =
1610                 btf_sock_ids[BTF_SOCK_TYPE_SOCK];
1611         return bpf_iter_reg_target(&sock_map_iter_reg);
1612 }
1613 late_initcall(bpf_sockmap_iter_init);