Merge tag 'nds32-for-linus-4.18' of git://git.kernel.org/pub/scm/linux/kernel/git...
[linux-2.6-microblaze.git] / kernel / bpf / sockmap.c
1 /* Copyright (c) 2017 Covalent IO, Inc. http://covalent.io
2  *
3  * This program is free software; you can redistribute it and/or
4  * modify it under the terms of version 2 of the GNU General Public
5  * License as published by the Free Software Foundation.
6  *
7  * This program is distributed in the hope that it will be useful, but
8  * WITHOUT ANY WARRANTY; without even the implied warranty of
9  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10  * General Public License for more details.
11  */
12
13 /* A BPF sock_map is used to store sock objects. This is primarly used
14  * for doing socket redirect with BPF helper routines.
15  *
16  * A sock map may have BPF programs attached to it, currently a program
17  * used to parse packets and a program to provide a verdict and redirect
18  * decision on the packet are supported. Any programs attached to a sock
19  * map are inherited by sock objects when they are added to the map. If
20  * no BPF programs are attached the sock object may only be used for sock
21  * redirect.
22  *
23  * A sock object may be in multiple maps, but can only inherit a single
24  * parse or verdict program. If adding a sock object to a map would result
25  * in having multiple parsing programs the update will return an EBUSY error.
26  *
27  * For reference this program is similar to devmap used in XDP context
28  * reviewing these together may be useful. For an example please review
29  * ./samples/bpf/sockmap/.
30  */
31 #include <linux/bpf.h>
32 #include <net/sock.h>
33 #include <linux/filter.h>
34 #include <linux/errno.h>
35 #include <linux/file.h>
36 #include <linux/kernel.h>
37 #include <linux/net.h>
38 #include <linux/skbuff.h>
39 #include <linux/workqueue.h>
40 #include <linux/list.h>
41 #include <linux/mm.h>
42 #include <net/strparser.h>
43 #include <net/tcp.h>
44 #include <linux/ptr_ring.h>
45 #include <net/inet_common.h>
46 #include <linux/sched/signal.h>
47
48 #define SOCK_CREATE_FLAG_MASK \
49         (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
50
51 struct bpf_sock_progs {
52         struct bpf_prog *bpf_tx_msg;
53         struct bpf_prog *bpf_parse;
54         struct bpf_prog *bpf_verdict;
55 };
56
57 struct bpf_stab {
58         struct bpf_map map;
59         struct sock **sock_map;
60         struct bpf_sock_progs progs;
61 };
62
63 struct bucket {
64         struct hlist_head head;
65         raw_spinlock_t lock;
66 };
67
68 struct bpf_htab {
69         struct bpf_map map;
70         struct bucket *buckets;
71         atomic_t count;
72         u32 n_buckets;
73         u32 elem_size;
74         struct bpf_sock_progs progs;
75         struct rcu_head rcu;
76 };
77
78 struct htab_elem {
79         struct rcu_head rcu;
80         struct hlist_node hash_node;
81         u32 hash;
82         struct sock *sk;
83         char key[0];
84 };
85
86 enum smap_psock_state {
87         SMAP_TX_RUNNING,
88 };
89
90 struct smap_psock_map_entry {
91         struct list_head list;
92         struct sock **entry;
93         struct htab_elem __rcu *hash_link;
94         struct bpf_htab __rcu *htab;
95 };
96
97 struct smap_psock {
98         struct rcu_head rcu;
99         refcount_t refcnt;
100
101         /* datapath variables */
102         struct sk_buff_head rxqueue;
103         bool strp_enabled;
104
105         /* datapath error path cache across tx work invocations */
106         int save_rem;
107         int save_off;
108         struct sk_buff *save_skb;
109
110         /* datapath variables for tx_msg ULP */
111         struct sock *sk_redir;
112         int apply_bytes;
113         int cork_bytes;
114         int sg_size;
115         int eval;
116         struct sk_msg_buff *cork;
117         struct list_head ingress;
118
119         struct strparser strp;
120         struct bpf_prog *bpf_tx_msg;
121         struct bpf_prog *bpf_parse;
122         struct bpf_prog *bpf_verdict;
123         struct list_head maps;
124         spinlock_t maps_lock;
125
126         /* Back reference used when sock callback trigger sockmap operations */
127         struct sock *sock;
128         unsigned long state;
129
130         struct work_struct tx_work;
131         struct work_struct gc_work;
132
133         struct proto *sk_proto;
134         void (*save_close)(struct sock *sk, long timeout);
135         void (*save_data_ready)(struct sock *sk);
136         void (*save_write_space)(struct sock *sk);
137 };
138
139 static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
140 static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
141                            int nonblock, int flags, int *addr_len);
142 static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
143 static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
144                             int offset, size_t size, int flags);
145 static void bpf_tcp_close(struct sock *sk, long timeout);
146
147 static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
148 {
149         return rcu_dereference_sk_user_data(sk);
150 }
151
152 static bool bpf_tcp_stream_read(const struct sock *sk)
153 {
154         struct smap_psock *psock;
155         bool empty = true;
156
157         rcu_read_lock();
158         psock = smap_psock_sk(sk);
159         if (unlikely(!psock))
160                 goto out;
161         empty = list_empty(&psock->ingress);
162 out:
163         rcu_read_unlock();
164         return !empty;
165 }
166
167 enum {
168         SOCKMAP_IPV4,
169         SOCKMAP_IPV6,
170         SOCKMAP_NUM_PROTS,
171 };
172
173 enum {
174         SOCKMAP_BASE,
175         SOCKMAP_TX,
176         SOCKMAP_NUM_CONFIGS,
177 };
178
179 static struct proto *saved_tcpv6_prot __read_mostly;
180 static DEFINE_SPINLOCK(tcpv6_prot_lock);
181 static struct proto bpf_tcp_prots[SOCKMAP_NUM_PROTS][SOCKMAP_NUM_CONFIGS];
182 static void build_protos(struct proto prot[SOCKMAP_NUM_CONFIGS],
183                          struct proto *base)
184 {
185         prot[SOCKMAP_BASE]                      = *base;
186         prot[SOCKMAP_BASE].close                = bpf_tcp_close;
187         prot[SOCKMAP_BASE].recvmsg              = bpf_tcp_recvmsg;
188         prot[SOCKMAP_BASE].stream_memory_read   = bpf_tcp_stream_read;
189
190         prot[SOCKMAP_TX]                        = prot[SOCKMAP_BASE];
191         prot[SOCKMAP_TX].sendmsg                = bpf_tcp_sendmsg;
192         prot[SOCKMAP_TX].sendpage               = bpf_tcp_sendpage;
193 }
194
195 static void update_sk_prot(struct sock *sk, struct smap_psock *psock)
196 {
197         int family = sk->sk_family == AF_INET6 ? SOCKMAP_IPV6 : SOCKMAP_IPV4;
198         int conf = psock->bpf_tx_msg ? SOCKMAP_TX : SOCKMAP_BASE;
199
200         sk->sk_prot = &bpf_tcp_prots[family][conf];
201 }
202
203 static int bpf_tcp_init(struct sock *sk)
204 {
205         struct smap_psock *psock;
206
207         rcu_read_lock();
208         psock = smap_psock_sk(sk);
209         if (unlikely(!psock)) {
210                 rcu_read_unlock();
211                 return -EINVAL;
212         }
213
214         if (unlikely(psock->sk_proto)) {
215                 rcu_read_unlock();
216                 return -EBUSY;
217         }
218
219         psock->save_close = sk->sk_prot->close;
220         psock->sk_proto = sk->sk_prot;
221
222         /* Build IPv6 sockmap whenever the address of tcpv6_prot changes */
223         if (sk->sk_family == AF_INET6 &&
224             unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {
225                 spin_lock_bh(&tcpv6_prot_lock);
226                 if (likely(sk->sk_prot != saved_tcpv6_prot)) {
227                         build_protos(bpf_tcp_prots[SOCKMAP_IPV6], sk->sk_prot);
228                         smp_store_release(&saved_tcpv6_prot, sk->sk_prot);
229                 }
230                 spin_unlock_bh(&tcpv6_prot_lock);
231         }
232         update_sk_prot(sk, psock);
233         rcu_read_unlock();
234         return 0;
235 }
236
237 static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
238 static int free_start_sg(struct sock *sk, struct sk_msg_buff *md);
239
240 static void bpf_tcp_release(struct sock *sk)
241 {
242         struct smap_psock *psock;
243
244         rcu_read_lock();
245         psock = smap_psock_sk(sk);
246         if (unlikely(!psock))
247                 goto out;
248
249         if (psock->cork) {
250                 free_start_sg(psock->sock, psock->cork);
251                 kfree(psock->cork);
252                 psock->cork = NULL;
253         }
254
255         if (psock->sk_proto) {
256                 sk->sk_prot = psock->sk_proto;
257                 psock->sk_proto = NULL;
258         }
259 out:
260         rcu_read_unlock();
261 }
262
263 static struct htab_elem *lookup_elem_raw(struct hlist_head *head,
264                                          u32 hash, void *key, u32 key_size)
265 {
266         struct htab_elem *l;
267
268         hlist_for_each_entry_rcu(l, head, hash_node) {
269                 if (l->hash == hash && !memcmp(&l->key, key, key_size))
270                         return l;
271         }
272
273         return NULL;
274 }
275
276 static inline struct bucket *__select_bucket(struct bpf_htab *htab, u32 hash)
277 {
278         return &htab->buckets[hash & (htab->n_buckets - 1)];
279 }
280
281 static inline struct hlist_head *select_bucket(struct bpf_htab *htab, u32 hash)
282 {
283         return &__select_bucket(htab, hash)->head;
284 }
285
286 static void free_htab_elem(struct bpf_htab *htab, struct htab_elem *l)
287 {
288         atomic_dec(&htab->count);
289         kfree_rcu(l, rcu);
290 }
291
292 static struct smap_psock_map_entry *psock_map_pop(struct sock *sk,
293                                                   struct smap_psock *psock)
294 {
295         struct smap_psock_map_entry *e;
296
297         spin_lock_bh(&psock->maps_lock);
298         e = list_first_entry_or_null(&psock->maps,
299                                      struct smap_psock_map_entry,
300                                      list);
301         if (e)
302                 list_del(&e->list);
303         spin_unlock_bh(&psock->maps_lock);
304         return e;
305 }
306
307 static void bpf_tcp_close(struct sock *sk, long timeout)
308 {
309         void (*close_fun)(struct sock *sk, long timeout);
310         struct smap_psock_map_entry *e;
311         struct sk_msg_buff *md, *mtmp;
312         struct smap_psock *psock;
313         struct sock *osk;
314
315         lock_sock(sk);
316         rcu_read_lock();
317         psock = smap_psock_sk(sk);
318         if (unlikely(!psock)) {
319                 rcu_read_unlock();
320                 release_sock(sk);
321                 return sk->sk_prot->close(sk, timeout);
322         }
323
324         /* The psock may be destroyed anytime after exiting the RCU critial
325          * section so by the time we use close_fun the psock may no longer
326          * be valid. However, bpf_tcp_close is called with the sock lock
327          * held so the close hook and sk are still valid.
328          */
329         close_fun = psock->save_close;
330
331         if (psock->cork) {
332                 free_start_sg(psock->sock, psock->cork);
333                 kfree(psock->cork);
334                 psock->cork = NULL;
335         }
336
337         list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
338                 list_del(&md->list);
339                 free_start_sg(psock->sock, md);
340                 kfree(md);
341         }
342
343         e = psock_map_pop(sk, psock);
344         while (e) {
345                 if (e->entry) {
346                         osk = cmpxchg(e->entry, sk, NULL);
347                         if (osk == sk) {
348                                 smap_release_sock(psock, sk);
349                         }
350                 } else {
351                         struct htab_elem *link = rcu_dereference(e->hash_link);
352                         struct bpf_htab *htab = rcu_dereference(e->htab);
353                         struct hlist_head *head;
354                         struct htab_elem *l;
355                         struct bucket *b;
356
357                         b = __select_bucket(htab, link->hash);
358                         head = &b->head;
359                         raw_spin_lock_bh(&b->lock);
360                         l = lookup_elem_raw(head,
361                                             link->hash, link->key,
362                                             htab->map.key_size);
363                         /* If another thread deleted this object skip deletion.
364                          * The refcnt on psock may or may not be zero.
365                          */
366                         if (l) {
367                                 hlist_del_rcu(&link->hash_node);
368                                 smap_release_sock(psock, link->sk);
369                                 free_htab_elem(htab, link);
370                         }
371                         raw_spin_unlock_bh(&b->lock);
372                 }
373                 e = psock_map_pop(sk, psock);
374         }
375         rcu_read_unlock();
376         release_sock(sk);
377         close_fun(sk, timeout);
378 }
379
380 enum __sk_action {
381         __SK_DROP = 0,
382         __SK_PASS,
383         __SK_REDIRECT,
384         __SK_NONE,
385 };
386
387 static struct tcp_ulp_ops bpf_tcp_ulp_ops __read_mostly = {
388         .name           = "bpf_tcp",
389         .uid            = TCP_ULP_BPF,
390         .user_visible   = false,
391         .owner          = NULL,
392         .init           = bpf_tcp_init,
393         .release        = bpf_tcp_release,
394 };
395
396 static int memcopy_from_iter(struct sock *sk,
397                              struct sk_msg_buff *md,
398                              struct iov_iter *from, int bytes)
399 {
400         struct scatterlist *sg = md->sg_data;
401         int i = md->sg_curr, rc = -ENOSPC;
402
403         do {
404                 int copy;
405                 char *to;
406
407                 if (md->sg_copybreak >= sg[i].length) {
408                         md->sg_copybreak = 0;
409
410                         if (++i == MAX_SKB_FRAGS)
411                                 i = 0;
412
413                         if (i == md->sg_end)
414                                 break;
415                 }
416
417                 copy = sg[i].length - md->sg_copybreak;
418                 to = sg_virt(&sg[i]) + md->sg_copybreak;
419                 md->sg_copybreak += copy;
420
421                 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
422                         rc = copy_from_iter_nocache(to, copy, from);
423                 else
424                         rc = copy_from_iter(to, copy, from);
425
426                 if (rc != copy) {
427                         rc = -EFAULT;
428                         goto out;
429                 }
430
431                 bytes -= copy;
432                 if (!bytes)
433                         break;
434
435                 md->sg_copybreak = 0;
436                 if (++i == MAX_SKB_FRAGS)
437                         i = 0;
438         } while (i != md->sg_end);
439 out:
440         md->sg_curr = i;
441         return rc;
442 }
443
444 static int bpf_tcp_push(struct sock *sk, int apply_bytes,
445                         struct sk_msg_buff *md,
446                         int flags, bool uncharge)
447 {
448         bool apply = apply_bytes;
449         struct scatterlist *sg;
450         int offset, ret = 0;
451         struct page *p;
452         size_t size;
453
454         while (1) {
455                 sg = md->sg_data + md->sg_start;
456                 size = (apply && apply_bytes < sg->length) ?
457                         apply_bytes : sg->length;
458                 offset = sg->offset;
459
460                 tcp_rate_check_app_limited(sk);
461                 p = sg_page(sg);
462 retry:
463                 ret = do_tcp_sendpages(sk, p, offset, size, flags);
464                 if (ret != size) {
465                         if (ret > 0) {
466                                 if (apply)
467                                         apply_bytes -= ret;
468
469                                 sg->offset += ret;
470                                 sg->length -= ret;
471                                 size -= ret;
472                                 offset += ret;
473                                 if (uncharge)
474                                         sk_mem_uncharge(sk, ret);
475                                 goto retry;
476                         }
477
478                         return ret;
479                 }
480
481                 if (apply)
482                         apply_bytes -= ret;
483                 sg->offset += ret;
484                 sg->length -= ret;
485                 if (uncharge)
486                         sk_mem_uncharge(sk, ret);
487
488                 if (!sg->length) {
489                         put_page(p);
490                         md->sg_start++;
491                         if (md->sg_start == MAX_SKB_FRAGS)
492                                 md->sg_start = 0;
493                         sg_init_table(sg, 1);
494
495                         if (md->sg_start == md->sg_end)
496                                 break;
497                 }
498
499                 if (apply && !apply_bytes)
500                         break;
501         }
502         return 0;
503 }
504
505 static inline void bpf_compute_data_pointers_sg(struct sk_msg_buff *md)
506 {
507         struct scatterlist *sg = md->sg_data + md->sg_start;
508
509         if (md->sg_copy[md->sg_start]) {
510                 md->data = md->data_end = 0;
511         } else {
512                 md->data = sg_virt(sg);
513                 md->data_end = md->data + sg->length;
514         }
515 }
516
517 static void return_mem_sg(struct sock *sk, int bytes, struct sk_msg_buff *md)
518 {
519         struct scatterlist *sg = md->sg_data;
520         int i = md->sg_start;
521
522         do {
523                 int uncharge = (bytes < sg[i].length) ? bytes : sg[i].length;
524
525                 sk_mem_uncharge(sk, uncharge);
526                 bytes -= uncharge;
527                 if (!bytes)
528                         break;
529                 i++;
530                 if (i == MAX_SKB_FRAGS)
531                         i = 0;
532         } while (i != md->sg_end);
533 }
534
535 static void free_bytes_sg(struct sock *sk, int bytes,
536                           struct sk_msg_buff *md, bool charge)
537 {
538         struct scatterlist *sg = md->sg_data;
539         int i = md->sg_start, free;
540
541         while (bytes && sg[i].length) {
542                 free = sg[i].length;
543                 if (bytes < free) {
544                         sg[i].length -= bytes;
545                         sg[i].offset += bytes;
546                         if (charge)
547                                 sk_mem_uncharge(sk, bytes);
548                         break;
549                 }
550
551                 if (charge)
552                         sk_mem_uncharge(sk, sg[i].length);
553                 put_page(sg_page(&sg[i]));
554                 bytes -= sg[i].length;
555                 sg[i].length = 0;
556                 sg[i].page_link = 0;
557                 sg[i].offset = 0;
558                 i++;
559
560                 if (i == MAX_SKB_FRAGS)
561                         i = 0;
562         }
563         md->sg_start = i;
564 }
565
566 static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)
567 {
568         struct scatterlist *sg = md->sg_data;
569         int i = start, free = 0;
570
571         while (sg[i].length) {
572                 free += sg[i].length;
573                 sk_mem_uncharge(sk, sg[i].length);
574                 if (!md->skb)
575                         put_page(sg_page(&sg[i]));
576                 sg[i].length = 0;
577                 sg[i].page_link = 0;
578                 sg[i].offset = 0;
579                 i++;
580
581                 if (i == MAX_SKB_FRAGS)
582                         i = 0;
583         }
584         if (md->skb)
585                 consume_skb(md->skb);
586
587         return free;
588 }
589
590 static int free_start_sg(struct sock *sk, struct sk_msg_buff *md)
591 {
592         int free = free_sg(sk, md->sg_start, md);
593
594         md->sg_start = md->sg_end;
595         return free;
596 }
597
598 static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md)
599 {
600         return free_sg(sk, md->sg_curr, md);
601 }
602
603 static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md)
604 {
605         return ((_rc == SK_PASS) ?
606                (md->sk_redir ? __SK_REDIRECT : __SK_PASS) :
607                __SK_DROP);
608 }
609
610 static unsigned int smap_do_tx_msg(struct sock *sk,
611                                    struct smap_psock *psock,
612                                    struct sk_msg_buff *md)
613 {
614         struct bpf_prog *prog;
615         unsigned int rc, _rc;
616
617         preempt_disable();
618         rcu_read_lock();
619
620         /* If the policy was removed mid-send then default to 'accept' */
621         prog = READ_ONCE(psock->bpf_tx_msg);
622         if (unlikely(!prog)) {
623                 _rc = SK_PASS;
624                 goto verdict;
625         }
626
627         bpf_compute_data_pointers_sg(md);
628         md->sk = sk;
629         rc = (*prog->bpf_func)(md, prog->insnsi);
630         psock->apply_bytes = md->apply_bytes;
631
632         /* Moving return codes from UAPI namespace into internal namespace */
633         _rc = bpf_map_msg_verdict(rc, md);
634
635         /* The psock has a refcount on the sock but not on the map and because
636          * we need to drop rcu read lock here its possible the map could be
637          * removed between here and when we need it to execute the sock
638          * redirect. So do the map lookup now for future use.
639          */
640         if (_rc == __SK_REDIRECT) {
641                 if (psock->sk_redir)
642                         sock_put(psock->sk_redir);
643                 psock->sk_redir = do_msg_redirect_map(md);
644                 if (!psock->sk_redir) {
645                         _rc = __SK_DROP;
646                         goto verdict;
647                 }
648                 sock_hold(psock->sk_redir);
649         }
650 verdict:
651         rcu_read_unlock();
652         preempt_enable();
653
654         return _rc;
655 }
656
657 static int bpf_tcp_ingress(struct sock *sk, int apply_bytes,
658                            struct smap_psock *psock,
659                            struct sk_msg_buff *md, int flags)
660 {
661         bool apply = apply_bytes;
662         size_t size, copied = 0;
663         struct sk_msg_buff *r;
664         int err = 0, i;
665
666         r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_KERNEL);
667         if (unlikely(!r))
668                 return -ENOMEM;
669
670         lock_sock(sk);
671         r->sg_start = md->sg_start;
672         i = md->sg_start;
673
674         do {
675                 size = (apply && apply_bytes < md->sg_data[i].length) ?
676                         apply_bytes : md->sg_data[i].length;
677
678                 if (!sk_wmem_schedule(sk, size)) {
679                         if (!copied)
680                                 err = -ENOMEM;
681                         break;
682                 }
683
684                 sk_mem_charge(sk, size);
685                 r->sg_data[i] = md->sg_data[i];
686                 r->sg_data[i].length = size;
687                 md->sg_data[i].length -= size;
688                 md->sg_data[i].offset += size;
689                 copied += size;
690
691                 if (md->sg_data[i].length) {
692                         get_page(sg_page(&r->sg_data[i]));
693                         r->sg_end = (i + 1) == MAX_SKB_FRAGS ? 0 : i + 1;
694                 } else {
695                         i++;
696                         if (i == MAX_SKB_FRAGS)
697                                 i = 0;
698                         r->sg_end = i;
699                 }
700
701                 if (apply) {
702                         apply_bytes -= size;
703                         if (!apply_bytes)
704                                 break;
705                 }
706         } while (i != md->sg_end);
707
708         md->sg_start = i;
709
710         if (!err) {
711                 list_add_tail(&r->list, &psock->ingress);
712                 sk->sk_data_ready(sk);
713         } else {
714                 free_start_sg(sk, r);
715                 kfree(r);
716         }
717
718         release_sock(sk);
719         return err;
720 }
721
722 static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
723                                        struct sk_msg_buff *md,
724                                        int flags)
725 {
726         bool ingress = !!(md->flags & BPF_F_INGRESS);
727         struct smap_psock *psock;
728         struct scatterlist *sg;
729         int err = 0;
730
731         sg = md->sg_data;
732
733         rcu_read_lock();
734         psock = smap_psock_sk(sk);
735         if (unlikely(!psock))
736                 goto out_rcu;
737
738         if (!refcount_inc_not_zero(&psock->refcnt))
739                 goto out_rcu;
740
741         rcu_read_unlock();
742
743         if (ingress) {
744                 err = bpf_tcp_ingress(sk, send, psock, md, flags);
745         } else {
746                 lock_sock(sk);
747                 err = bpf_tcp_push(sk, send, md, flags, false);
748                 release_sock(sk);
749         }
750         smap_release_sock(psock, sk);
751         if (unlikely(err))
752                 goto out;
753         return 0;
754 out_rcu:
755         rcu_read_unlock();
756 out:
757         free_bytes_sg(NULL, send, md, false);
758         return err;
759 }
760
761 static inline void bpf_md_init(struct smap_psock *psock)
762 {
763         if (!psock->apply_bytes) {
764                 psock->eval =  __SK_NONE;
765                 if (psock->sk_redir) {
766                         sock_put(psock->sk_redir);
767                         psock->sk_redir = NULL;
768                 }
769         }
770 }
771
772 static void apply_bytes_dec(struct smap_psock *psock, int i)
773 {
774         if (psock->apply_bytes) {
775                 if (psock->apply_bytes < i)
776                         psock->apply_bytes = 0;
777                 else
778                         psock->apply_bytes -= i;
779         }
780 }
781
782 static int bpf_exec_tx_verdict(struct smap_psock *psock,
783                                struct sk_msg_buff *m,
784                                struct sock *sk,
785                                int *copied, int flags)
786 {
787         bool cork = false, enospc = (m->sg_start == m->sg_end);
788         struct sock *redir;
789         int err = 0;
790         int send;
791
792 more_data:
793         if (psock->eval == __SK_NONE)
794                 psock->eval = smap_do_tx_msg(sk, psock, m);
795
796         if (m->cork_bytes &&
797             m->cork_bytes > psock->sg_size && !enospc) {
798                 psock->cork_bytes = m->cork_bytes - psock->sg_size;
799                 if (!psock->cork) {
800                         psock->cork = kcalloc(1,
801                                         sizeof(struct sk_msg_buff),
802                                         GFP_ATOMIC | __GFP_NOWARN);
803
804                         if (!psock->cork) {
805                                 err = -ENOMEM;
806                                 goto out_err;
807                         }
808                 }
809                 memcpy(psock->cork, m, sizeof(*m));
810                 goto out_err;
811         }
812
813         send = psock->sg_size;
814         if (psock->apply_bytes && psock->apply_bytes < send)
815                 send = psock->apply_bytes;
816
817         switch (psock->eval) {
818         case __SK_PASS:
819                 err = bpf_tcp_push(sk, send, m, flags, true);
820                 if (unlikely(err)) {
821                         *copied -= free_start_sg(sk, m);
822                         break;
823                 }
824
825                 apply_bytes_dec(psock, send);
826                 psock->sg_size -= send;
827                 break;
828         case __SK_REDIRECT:
829                 redir = psock->sk_redir;
830                 apply_bytes_dec(psock, send);
831
832                 if (psock->cork) {
833                         cork = true;
834                         psock->cork = NULL;
835                 }
836
837                 return_mem_sg(sk, send, m);
838                 release_sock(sk);
839
840                 err = bpf_tcp_sendmsg_do_redirect(redir, send, m, flags);
841                 lock_sock(sk);
842
843                 if (unlikely(err < 0)) {
844                         free_start_sg(sk, m);
845                         psock->sg_size = 0;
846                         if (!cork)
847                                 *copied -= send;
848                 } else {
849                         psock->sg_size -= send;
850                 }
851
852                 if (cork) {
853                         free_start_sg(sk, m);
854                         psock->sg_size = 0;
855                         kfree(m);
856                         m = NULL;
857                         err = 0;
858                 }
859                 break;
860         case __SK_DROP:
861         default:
862                 free_bytes_sg(sk, send, m, true);
863                 apply_bytes_dec(psock, send);
864                 *copied -= send;
865                 psock->sg_size -= send;
866                 err = -EACCES;
867                 break;
868         }
869
870         if (likely(!err)) {
871                 bpf_md_init(psock);
872                 if (m &&
873                     m->sg_data[m->sg_start].page_link &&
874                     m->sg_data[m->sg_start].length)
875                         goto more_data;
876         }
877
878 out_err:
879         return err;
880 }
881
882 static int bpf_wait_data(struct sock *sk,
883                          struct smap_psock *psk, int flags,
884                          long timeo, int *err)
885 {
886         int rc;
887
888         DEFINE_WAIT_FUNC(wait, woken_wake_function);
889
890         add_wait_queue(sk_sleep(sk), &wait);
891         sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
892         rc = sk_wait_event(sk, &timeo,
893                            !list_empty(&psk->ingress) ||
894                            !skb_queue_empty(&sk->sk_receive_queue),
895                            &wait);
896         sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
897         remove_wait_queue(sk_sleep(sk), &wait);
898
899         return rc;
900 }
901
902 static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
903                            int nonblock, int flags, int *addr_len)
904 {
905         struct iov_iter *iter = &msg->msg_iter;
906         struct smap_psock *psock;
907         int copied = 0;
908
909         if (unlikely(flags & MSG_ERRQUEUE))
910                 return inet_recv_error(sk, msg, len, addr_len);
911
912         rcu_read_lock();
913         psock = smap_psock_sk(sk);
914         if (unlikely(!psock))
915                 goto out;
916
917         if (unlikely(!refcount_inc_not_zero(&psock->refcnt)))
918                 goto out;
919         rcu_read_unlock();
920
921         if (!skb_queue_empty(&sk->sk_receive_queue))
922                 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
923
924         lock_sock(sk);
925 bytes_ready:
926         while (copied != len) {
927                 struct scatterlist *sg;
928                 struct sk_msg_buff *md;
929                 int i;
930
931                 md = list_first_entry_or_null(&psock->ingress,
932                                               struct sk_msg_buff, list);
933                 if (unlikely(!md))
934                         break;
935                 i = md->sg_start;
936                 do {
937                         struct page *page;
938                         int n, copy;
939
940                         sg = &md->sg_data[i];
941                         copy = sg->length;
942                         page = sg_page(sg);
943
944                         if (copied + copy > len)
945                                 copy = len - copied;
946
947                         n = copy_page_to_iter(page, sg->offset, copy, iter);
948                         if (n != copy) {
949                                 md->sg_start = i;
950                                 release_sock(sk);
951                                 smap_release_sock(psock, sk);
952                                 return -EFAULT;
953                         }
954
955                         copied += copy;
956                         sg->offset += copy;
957                         sg->length -= copy;
958                         sk_mem_uncharge(sk, copy);
959
960                         if (!sg->length) {
961                                 i++;
962                                 if (i == MAX_SKB_FRAGS)
963                                         i = 0;
964                                 if (!md->skb)
965                                         put_page(page);
966                         }
967                         if (copied == len)
968                                 break;
969                 } while (i != md->sg_end);
970                 md->sg_start = i;
971
972                 if (!sg->length && md->sg_start == md->sg_end) {
973                         list_del(&md->list);
974                         if (md->skb)
975                                 consume_skb(md->skb);
976                         kfree(md);
977                 }
978         }
979
980         if (!copied) {
981                 long timeo;
982                 int data;
983                 int err = 0;
984
985                 timeo = sock_rcvtimeo(sk, nonblock);
986                 data = bpf_wait_data(sk, psock, flags, timeo, &err);
987
988                 if (data) {
989                         if (!skb_queue_empty(&sk->sk_receive_queue)) {
990                                 release_sock(sk);
991                                 smap_release_sock(psock, sk);
992                                 copied = tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
993                                 return copied;
994                         }
995                         goto bytes_ready;
996                 }
997
998                 if (err)
999                         copied = err;
1000         }
1001
1002         release_sock(sk);
1003         smap_release_sock(psock, sk);
1004         return copied;
1005 out:
1006         rcu_read_unlock();
1007         return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
1008 }
1009
1010
1011 static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
1012 {
1013         int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
1014         struct sk_msg_buff md = {0};
1015         unsigned int sg_copy = 0;
1016         struct smap_psock *psock;
1017         int copied = 0, err = 0;
1018         struct scatterlist *sg;
1019         long timeo;
1020
1021         /* Its possible a sock event or user removed the psock _but_ the ops
1022          * have not been reprogrammed yet so we get here. In this case fallback
1023          * to tcp_sendmsg. Note this only works because we _only_ ever allow
1024          * a single ULP there is no hierarchy here.
1025          */
1026         rcu_read_lock();
1027         psock = smap_psock_sk(sk);
1028         if (unlikely(!psock)) {
1029                 rcu_read_unlock();
1030                 return tcp_sendmsg(sk, msg, size);
1031         }
1032
1033         /* Increment the psock refcnt to ensure its not released while sending a
1034          * message. Required because sk lookup and bpf programs are used in
1035          * separate rcu critical sections. Its OK if we lose the map entry
1036          * but we can't lose the sock reference.
1037          */
1038         if (!refcount_inc_not_zero(&psock->refcnt)) {
1039                 rcu_read_unlock();
1040                 return tcp_sendmsg(sk, msg, size);
1041         }
1042
1043         sg = md.sg_data;
1044         sg_init_marker(sg, MAX_SKB_FRAGS);
1045         rcu_read_unlock();
1046
1047         lock_sock(sk);
1048         timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
1049
1050         while (msg_data_left(msg)) {
1051                 struct sk_msg_buff *m;
1052                 bool enospc = false;
1053                 int copy;
1054
1055                 if (sk->sk_err) {
1056                         err = sk->sk_err;
1057                         goto out_err;
1058                 }
1059
1060                 copy = msg_data_left(msg);
1061                 if (!sk_stream_memory_free(sk))
1062                         goto wait_for_sndbuf;
1063
1064                 m = psock->cork_bytes ? psock->cork : &md;
1065                 m->sg_curr = m->sg_copybreak ? m->sg_curr : m->sg_end;
1066                 err = sk_alloc_sg(sk, copy, m->sg_data,
1067                                   m->sg_start, &m->sg_end, &sg_copy,
1068                                   m->sg_end - 1);
1069                 if (err) {
1070                         if (err != -ENOSPC)
1071                                 goto wait_for_memory;
1072                         enospc = true;
1073                         copy = sg_copy;
1074                 }
1075
1076                 err = memcopy_from_iter(sk, m, &msg->msg_iter, copy);
1077                 if (err < 0) {
1078                         free_curr_sg(sk, m);
1079                         goto out_err;
1080                 }
1081
1082                 psock->sg_size += copy;
1083                 copied += copy;
1084                 sg_copy = 0;
1085
1086                 /* When bytes are being corked skip running BPF program and
1087                  * applying verdict unless there is no more buffer space. In
1088                  * the ENOSPC case simply run BPF prorgram with currently
1089                  * accumulated data. We don't have much choice at this point
1090                  * we could try extending the page frags or chaining complex
1091                  * frags but even in these cases _eventually_ we will hit an
1092                  * OOM scenario. More complex recovery schemes may be
1093                  * implemented in the future, but BPF programs must handle
1094                  * the case where apply_cork requests are not honored. The
1095                  * canonical method to verify this is to check data length.
1096                  */
1097                 if (psock->cork_bytes) {
1098                         if (copy > psock->cork_bytes)
1099                                 psock->cork_bytes = 0;
1100                         else
1101                                 psock->cork_bytes -= copy;
1102
1103                         if (psock->cork_bytes && !enospc)
1104                                 goto out_cork;
1105
1106                         /* All cork bytes accounted for re-run filter */
1107                         psock->eval = __SK_NONE;
1108                         psock->cork_bytes = 0;
1109                 }
1110
1111                 err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
1112                 if (unlikely(err < 0))
1113                         goto out_err;
1114                 continue;
1115 wait_for_sndbuf:
1116                 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1117 wait_for_memory:
1118                 err = sk_stream_wait_memory(sk, &timeo);
1119                 if (err)
1120                         goto out_err;
1121         }
1122 out_err:
1123         if (err < 0)
1124                 err = sk_stream_error(sk, msg->msg_flags, err);
1125 out_cork:
1126         release_sock(sk);
1127         smap_release_sock(psock, sk);
1128         return copied ? copied : err;
1129 }
1130
1131 static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
1132                             int offset, size_t size, int flags)
1133 {
1134         struct sk_msg_buff md = {0}, *m = NULL;
1135         int err = 0, copied = 0;
1136         struct smap_psock *psock;
1137         struct scatterlist *sg;
1138         bool enospc = false;
1139
1140         rcu_read_lock();
1141         psock = smap_psock_sk(sk);
1142         if (unlikely(!psock))
1143                 goto accept;
1144
1145         if (!refcount_inc_not_zero(&psock->refcnt))
1146                 goto accept;
1147         rcu_read_unlock();
1148
1149         lock_sock(sk);
1150
1151         if (psock->cork_bytes) {
1152                 m = psock->cork;
1153                 sg = &m->sg_data[m->sg_end];
1154         } else {
1155                 m = &md;
1156                 sg = m->sg_data;
1157                 sg_init_marker(sg, MAX_SKB_FRAGS);
1158         }
1159
1160         /* Catch case where ring is full and sendpage is stalled. */
1161         if (unlikely(m->sg_end == m->sg_start &&
1162             m->sg_data[m->sg_end].length))
1163                 goto out_err;
1164
1165         psock->sg_size += size;
1166         sg_set_page(sg, page, size, offset);
1167         get_page(page);
1168         m->sg_copy[m->sg_end] = true;
1169         sk_mem_charge(sk, size);
1170         m->sg_end++;
1171         copied = size;
1172
1173         if (m->sg_end == MAX_SKB_FRAGS)
1174                 m->sg_end = 0;
1175
1176         if (m->sg_end == m->sg_start)
1177                 enospc = true;
1178
1179         if (psock->cork_bytes) {
1180                 if (size > psock->cork_bytes)
1181                         psock->cork_bytes = 0;
1182                 else
1183                         psock->cork_bytes -= size;
1184
1185                 if (psock->cork_bytes && !enospc)
1186                         goto out_err;
1187
1188                 /* All cork bytes accounted for re-run filter */
1189                 psock->eval = __SK_NONE;
1190                 psock->cork_bytes = 0;
1191         }
1192
1193         err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
1194 out_err:
1195         release_sock(sk);
1196         smap_release_sock(psock, sk);
1197         return copied ? copied : err;
1198 accept:
1199         rcu_read_unlock();
1200         return tcp_sendpage(sk, page, offset, size, flags);
1201 }
1202
1203 static void bpf_tcp_msg_add(struct smap_psock *psock,
1204                             struct sock *sk,
1205                             struct bpf_prog *tx_msg)
1206 {
1207         struct bpf_prog *orig_tx_msg;
1208
1209         orig_tx_msg = xchg(&psock->bpf_tx_msg, tx_msg);
1210         if (orig_tx_msg)
1211                 bpf_prog_put(orig_tx_msg);
1212 }
1213
1214 static int bpf_tcp_ulp_register(void)
1215 {
1216         build_protos(bpf_tcp_prots[SOCKMAP_IPV4], &tcp_prot);
1217         /* Once BPF TX ULP is registered it is never unregistered. It
1218          * will be in the ULP list for the lifetime of the system. Doing
1219          * duplicate registers is not a problem.
1220          */
1221         return tcp_register_ulp(&bpf_tcp_ulp_ops);
1222 }
1223
1224 static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb)
1225 {
1226         struct bpf_prog *prog = READ_ONCE(psock->bpf_verdict);
1227         int rc;
1228
1229         if (unlikely(!prog))
1230                 return __SK_DROP;
1231
1232         skb_orphan(skb);
1233         /* We need to ensure that BPF metadata for maps is also cleared
1234          * when we orphan the skb so that we don't have the possibility
1235          * to reference a stale map.
1236          */
1237         TCP_SKB_CB(skb)->bpf.sk_redir = NULL;
1238         skb->sk = psock->sock;
1239         bpf_compute_data_end_sk_skb(skb);
1240         preempt_disable();
1241         rc = (*prog->bpf_func)(skb, prog->insnsi);
1242         preempt_enable();
1243         skb->sk = NULL;
1244
1245         /* Moving return codes from UAPI namespace into internal namespace */
1246         return rc == SK_PASS ?
1247                 (TCP_SKB_CB(skb)->bpf.sk_redir ? __SK_REDIRECT : __SK_PASS) :
1248                 __SK_DROP;
1249 }
1250
1251 static int smap_do_ingress(struct smap_psock *psock, struct sk_buff *skb)
1252 {
1253         struct sock *sk = psock->sock;
1254         int copied = 0, num_sg;
1255         struct sk_msg_buff *r;
1256
1257         r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_ATOMIC);
1258         if (unlikely(!r))
1259                 return -EAGAIN;
1260
1261         if (!sk_rmem_schedule(sk, skb, skb->len)) {
1262                 kfree(r);
1263                 return -EAGAIN;
1264         }
1265
1266         sg_init_table(r->sg_data, MAX_SKB_FRAGS);
1267         num_sg = skb_to_sgvec(skb, r->sg_data, 0, skb->len);
1268         if (unlikely(num_sg < 0)) {
1269                 kfree(r);
1270                 return num_sg;
1271         }
1272         sk_mem_charge(sk, skb->len);
1273         copied = skb->len;
1274         r->sg_start = 0;
1275         r->sg_end = num_sg == MAX_SKB_FRAGS ? 0 : num_sg;
1276         r->skb = skb;
1277         list_add_tail(&r->list, &psock->ingress);
1278         sk->sk_data_ready(sk);
1279         return copied;
1280 }
1281
1282 static void smap_do_verdict(struct smap_psock *psock, struct sk_buff *skb)
1283 {
1284         struct smap_psock *peer;
1285         struct sock *sk;
1286         __u32 in;
1287         int rc;
1288
1289         rc = smap_verdict_func(psock, skb);
1290         switch (rc) {
1291         case __SK_REDIRECT:
1292                 sk = do_sk_redirect_map(skb);
1293                 if (!sk) {
1294                         kfree_skb(skb);
1295                         break;
1296                 }
1297
1298                 peer = smap_psock_sk(sk);
1299                 in = (TCP_SKB_CB(skb)->bpf.flags) & BPF_F_INGRESS;
1300
1301                 if (unlikely(!peer || sock_flag(sk, SOCK_DEAD) ||
1302                              !test_bit(SMAP_TX_RUNNING, &peer->state))) {
1303                         kfree_skb(skb);
1304                         break;
1305                 }
1306
1307                 if (!in && sock_writeable(sk)) {
1308                         skb_set_owner_w(skb, sk);
1309                         skb_queue_tail(&peer->rxqueue, skb);
1310                         schedule_work(&peer->tx_work);
1311                         break;
1312                 } else if (in &&
1313                            atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf) {
1314                         skb_queue_tail(&peer->rxqueue, skb);
1315                         schedule_work(&peer->tx_work);
1316                         break;
1317                 }
1318         /* Fall through and free skb otherwise */
1319         case __SK_DROP:
1320         default:
1321                 kfree_skb(skb);
1322         }
1323 }
1324
1325 static void smap_report_sk_error(struct smap_psock *psock, int err)
1326 {
1327         struct sock *sk = psock->sock;
1328
1329         sk->sk_err = err;
1330         sk->sk_error_report(sk);
1331 }
1332
1333 static void smap_read_sock_strparser(struct strparser *strp,
1334                                      struct sk_buff *skb)
1335 {
1336         struct smap_psock *psock;
1337
1338         rcu_read_lock();
1339         psock = container_of(strp, struct smap_psock, strp);
1340         smap_do_verdict(psock, skb);
1341         rcu_read_unlock();
1342 }
1343
1344 /* Called with lock held on socket */
1345 static void smap_data_ready(struct sock *sk)
1346 {
1347         struct smap_psock *psock;
1348
1349         rcu_read_lock();
1350         psock = smap_psock_sk(sk);
1351         if (likely(psock)) {
1352                 write_lock_bh(&sk->sk_callback_lock);
1353                 strp_data_ready(&psock->strp);
1354                 write_unlock_bh(&sk->sk_callback_lock);
1355         }
1356         rcu_read_unlock();
1357 }
1358
1359 static void smap_tx_work(struct work_struct *w)
1360 {
1361         struct smap_psock *psock;
1362         struct sk_buff *skb;
1363         int rem, off, n;
1364
1365         psock = container_of(w, struct smap_psock, tx_work);
1366
1367         /* lock sock to avoid losing sk_socket at some point during loop */
1368         lock_sock(psock->sock);
1369         if (psock->save_skb) {
1370                 skb = psock->save_skb;
1371                 rem = psock->save_rem;
1372                 off = psock->save_off;
1373                 psock->save_skb = NULL;
1374                 goto start;
1375         }
1376
1377         while ((skb = skb_dequeue(&psock->rxqueue))) {
1378                 __u32 flags;
1379
1380                 rem = skb->len;
1381                 off = 0;
1382 start:
1383                 flags = (TCP_SKB_CB(skb)->bpf.flags) & BPF_F_INGRESS;
1384                 do {
1385                         if (likely(psock->sock->sk_socket)) {
1386                                 if (flags)
1387                                         n = smap_do_ingress(psock, skb);
1388                                 else
1389                                         n = skb_send_sock_locked(psock->sock,
1390                                                                  skb, off, rem);
1391                         } else {
1392                                 n = -EINVAL;
1393                         }
1394
1395                         if (n <= 0) {
1396                                 if (n == -EAGAIN) {
1397                                         /* Retry when space is available */
1398                                         psock->save_skb = skb;
1399                                         psock->save_rem = rem;
1400                                         psock->save_off = off;
1401                                         goto out;
1402                                 }
1403                                 /* Hard errors break pipe and stop xmit */
1404                                 smap_report_sk_error(psock, n ? -n : EPIPE);
1405                                 clear_bit(SMAP_TX_RUNNING, &psock->state);
1406                                 kfree_skb(skb);
1407                                 goto out;
1408                         }
1409                         rem -= n;
1410                         off += n;
1411                 } while (rem);
1412
1413                 if (!flags)
1414                         kfree_skb(skb);
1415         }
1416 out:
1417         release_sock(psock->sock);
1418 }
1419
1420 static void smap_write_space(struct sock *sk)
1421 {
1422         struct smap_psock *psock;
1423
1424         rcu_read_lock();
1425         psock = smap_psock_sk(sk);
1426         if (likely(psock && test_bit(SMAP_TX_RUNNING, &psock->state)))
1427                 schedule_work(&psock->tx_work);
1428         rcu_read_unlock();
1429 }
1430
1431 static void smap_stop_sock(struct smap_psock *psock, struct sock *sk)
1432 {
1433         if (!psock->strp_enabled)
1434                 return;
1435         sk->sk_data_ready = psock->save_data_ready;
1436         sk->sk_write_space = psock->save_write_space;
1437         psock->save_data_ready = NULL;
1438         psock->save_write_space = NULL;
1439         strp_stop(&psock->strp);
1440         psock->strp_enabled = false;
1441 }
1442
1443 static void smap_destroy_psock(struct rcu_head *rcu)
1444 {
1445         struct smap_psock *psock = container_of(rcu,
1446                                                   struct smap_psock, rcu);
1447
1448         /* Now that a grace period has passed there is no longer
1449          * any reference to this sock in the sockmap so we can
1450          * destroy the psock, strparser, and bpf programs. But,
1451          * because we use workqueue sync operations we can not
1452          * do it in rcu context
1453          */
1454         schedule_work(&psock->gc_work);
1455 }
1456
1457 static void smap_release_sock(struct smap_psock *psock, struct sock *sock)
1458 {
1459         if (refcount_dec_and_test(&psock->refcnt)) {
1460                 tcp_cleanup_ulp(sock);
1461                 write_lock_bh(&sock->sk_callback_lock);
1462                 smap_stop_sock(psock, sock);
1463                 write_unlock_bh(&sock->sk_callback_lock);
1464                 clear_bit(SMAP_TX_RUNNING, &psock->state);
1465                 rcu_assign_sk_user_data(sock, NULL);
1466                 call_rcu_sched(&psock->rcu, smap_destroy_psock);
1467         }
1468 }
1469
1470 static int smap_parse_func_strparser(struct strparser *strp,
1471                                        struct sk_buff *skb)
1472 {
1473         struct smap_psock *psock;
1474         struct bpf_prog *prog;
1475         int rc;
1476
1477         rcu_read_lock();
1478         psock = container_of(strp, struct smap_psock, strp);
1479         prog = READ_ONCE(psock->bpf_parse);
1480
1481         if (unlikely(!prog)) {
1482                 rcu_read_unlock();
1483                 return skb->len;
1484         }
1485
1486         /* Attach socket for bpf program to use if needed we can do this
1487          * because strparser clones the skb before handing it to a upper
1488          * layer, meaning skb_orphan has been called. We NULL sk on the
1489          * way out to ensure we don't trigger a BUG_ON in skb/sk operations
1490          * later and because we are not charging the memory of this skb to
1491          * any socket yet.
1492          */
1493         skb->sk = psock->sock;
1494         bpf_compute_data_end_sk_skb(skb);
1495         rc = (*prog->bpf_func)(skb, prog->insnsi);
1496         skb->sk = NULL;
1497         rcu_read_unlock();
1498         return rc;
1499 }
1500
1501 static int smap_read_sock_done(struct strparser *strp, int err)
1502 {
1503         return err;
1504 }
1505
1506 static int smap_init_sock(struct smap_psock *psock,
1507                           struct sock *sk)
1508 {
1509         static const struct strp_callbacks cb = {
1510                 .rcv_msg = smap_read_sock_strparser,
1511                 .parse_msg = smap_parse_func_strparser,
1512                 .read_sock_done = smap_read_sock_done,
1513         };
1514
1515         return strp_init(&psock->strp, sk, &cb);
1516 }
1517
1518 static void smap_init_progs(struct smap_psock *psock,
1519                             struct bpf_prog *verdict,
1520                             struct bpf_prog *parse)
1521 {
1522         struct bpf_prog *orig_parse, *orig_verdict;
1523
1524         orig_parse = xchg(&psock->bpf_parse, parse);
1525         orig_verdict = xchg(&psock->bpf_verdict, verdict);
1526
1527         if (orig_verdict)
1528                 bpf_prog_put(orig_verdict);
1529         if (orig_parse)
1530                 bpf_prog_put(orig_parse);
1531 }
1532
1533 static void smap_start_sock(struct smap_psock *psock, struct sock *sk)
1534 {
1535         if (sk->sk_data_ready == smap_data_ready)
1536                 return;
1537         psock->save_data_ready = sk->sk_data_ready;
1538         psock->save_write_space = sk->sk_write_space;
1539         sk->sk_data_ready = smap_data_ready;
1540         sk->sk_write_space = smap_write_space;
1541         psock->strp_enabled = true;
1542 }
1543
1544 static void sock_map_remove_complete(struct bpf_stab *stab)
1545 {
1546         bpf_map_area_free(stab->sock_map);
1547         kfree(stab);
1548 }
1549
1550 static void smap_gc_work(struct work_struct *w)
1551 {
1552         struct smap_psock_map_entry *e, *tmp;
1553         struct sk_msg_buff *md, *mtmp;
1554         struct smap_psock *psock;
1555
1556         psock = container_of(w, struct smap_psock, gc_work);
1557
1558         /* no callback lock needed because we already detached sockmap ops */
1559         if (psock->strp_enabled)
1560                 strp_done(&psock->strp);
1561
1562         cancel_work_sync(&psock->tx_work);
1563         __skb_queue_purge(&psock->rxqueue);
1564
1565         /* At this point all strparser and xmit work must be complete */
1566         if (psock->bpf_parse)
1567                 bpf_prog_put(psock->bpf_parse);
1568         if (psock->bpf_verdict)
1569                 bpf_prog_put(psock->bpf_verdict);
1570         if (psock->bpf_tx_msg)
1571                 bpf_prog_put(psock->bpf_tx_msg);
1572
1573         if (psock->cork) {
1574                 free_start_sg(psock->sock, psock->cork);
1575                 kfree(psock->cork);
1576         }
1577
1578         list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
1579                 list_del(&md->list);
1580                 free_start_sg(psock->sock, md);
1581                 kfree(md);
1582         }
1583
1584         list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1585                 list_del(&e->list);
1586                 kfree(e);
1587         }
1588
1589         if (psock->sk_redir)
1590                 sock_put(psock->sk_redir);
1591
1592         sock_put(psock->sock);
1593         kfree(psock);
1594 }
1595
1596 static struct smap_psock *smap_init_psock(struct sock *sock, int node)
1597 {
1598         struct smap_psock *psock;
1599
1600         psock = kzalloc_node(sizeof(struct smap_psock),
1601                              GFP_ATOMIC | __GFP_NOWARN,
1602                              node);
1603         if (!psock)
1604                 return ERR_PTR(-ENOMEM);
1605
1606         psock->eval =  __SK_NONE;
1607         psock->sock = sock;
1608         skb_queue_head_init(&psock->rxqueue);
1609         INIT_WORK(&psock->tx_work, smap_tx_work);
1610         INIT_WORK(&psock->gc_work, smap_gc_work);
1611         INIT_LIST_HEAD(&psock->maps);
1612         INIT_LIST_HEAD(&psock->ingress);
1613         refcount_set(&psock->refcnt, 1);
1614         spin_lock_init(&psock->maps_lock);
1615
1616         rcu_assign_sk_user_data(sock, psock);
1617         sock_hold(sock);
1618         return psock;
1619 }
1620
1621 static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
1622 {
1623         struct bpf_stab *stab;
1624         u64 cost;
1625         int err;
1626
1627         if (!capable(CAP_NET_ADMIN))
1628                 return ERR_PTR(-EPERM);
1629
1630         /* check sanity of attributes */
1631         if (attr->max_entries == 0 || attr->key_size != 4 ||
1632             attr->value_size != 4 || attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
1633                 return ERR_PTR(-EINVAL);
1634
1635         err = bpf_tcp_ulp_register();
1636         if (err && err != -EEXIST)
1637                 return ERR_PTR(err);
1638
1639         stab = kzalloc(sizeof(*stab), GFP_USER);
1640         if (!stab)
1641                 return ERR_PTR(-ENOMEM);
1642
1643         bpf_map_init_from_attr(&stab->map, attr);
1644
1645         /* make sure page count doesn't overflow */
1646         cost = (u64) stab->map.max_entries * sizeof(struct sock *);
1647         err = -EINVAL;
1648         if (cost >= U32_MAX - PAGE_SIZE)
1649                 goto free_stab;
1650
1651         stab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
1652
1653         /* if map size is larger than memlock limit, reject it early */
1654         err = bpf_map_precharge_memlock(stab->map.pages);
1655         if (err)
1656                 goto free_stab;
1657
1658         err = -ENOMEM;
1659         stab->sock_map = bpf_map_area_alloc(stab->map.max_entries *
1660                                             sizeof(struct sock *),
1661                                             stab->map.numa_node);
1662         if (!stab->sock_map)
1663                 goto free_stab;
1664
1665         return &stab->map;
1666 free_stab:
1667         kfree(stab);
1668         return ERR_PTR(err);
1669 }
1670
1671 static void smap_list_map_remove(struct smap_psock *psock,
1672                                  struct sock **entry)
1673 {
1674         struct smap_psock_map_entry *e, *tmp;
1675
1676         spin_lock_bh(&psock->maps_lock);
1677         list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1678                 if (e->entry == entry)
1679                         list_del(&e->list);
1680         }
1681         spin_unlock_bh(&psock->maps_lock);
1682 }
1683
1684 static void smap_list_hash_remove(struct smap_psock *psock,
1685                                   struct htab_elem *hash_link)
1686 {
1687         struct smap_psock_map_entry *e, *tmp;
1688
1689         spin_lock_bh(&psock->maps_lock);
1690         list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1691                 struct htab_elem *c = rcu_dereference(e->hash_link);
1692
1693                 if (c == hash_link)
1694                         list_del(&e->list);
1695         }
1696         spin_unlock_bh(&psock->maps_lock);
1697 }
1698
1699 static void sock_map_free(struct bpf_map *map)
1700 {
1701         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1702         int i;
1703
1704         synchronize_rcu();
1705
1706         /* At this point no update, lookup or delete operations can happen.
1707          * However, be aware we can still get a socket state event updates,
1708          * and data ready callabacks that reference the psock from sk_user_data
1709          * Also psock worker threads are still in-flight. So smap_release_sock
1710          * will only free the psock after cancel_sync on the worker threads
1711          * and a grace period expire to ensure psock is really safe to remove.
1712          */
1713         rcu_read_lock();
1714         for (i = 0; i < stab->map.max_entries; i++) {
1715                 struct smap_psock *psock;
1716                 struct sock *sock;
1717
1718                 sock = xchg(&stab->sock_map[i], NULL);
1719                 if (!sock)
1720                         continue;
1721
1722                 psock = smap_psock_sk(sock);
1723                 /* This check handles a racing sock event that can get the
1724                  * sk_callback_lock before this case but after xchg happens
1725                  * causing the refcnt to hit zero and sock user data (psock)
1726                  * to be null and queued for garbage collection.
1727                  */
1728                 if (likely(psock)) {
1729                         smap_list_map_remove(psock, &stab->sock_map[i]);
1730                         smap_release_sock(psock, sock);
1731                 }
1732         }
1733         rcu_read_unlock();
1734
1735         sock_map_remove_complete(stab);
1736 }
1737
1738 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next_key)
1739 {
1740         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1741         u32 i = key ? *(u32 *)key : U32_MAX;
1742         u32 *next = (u32 *)next_key;
1743
1744         if (i >= stab->map.max_entries) {
1745                 *next = 0;
1746                 return 0;
1747         }
1748
1749         if (i == stab->map.max_entries - 1)
1750                 return -ENOENT;
1751
1752         *next = i + 1;
1753         return 0;
1754 }
1755
1756 struct sock  *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
1757 {
1758         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1759
1760         if (key >= map->max_entries)
1761                 return NULL;
1762
1763         return READ_ONCE(stab->sock_map[key]);
1764 }
1765
1766 static int sock_map_delete_elem(struct bpf_map *map, void *key)
1767 {
1768         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1769         struct smap_psock *psock;
1770         int k = *(u32 *)key;
1771         struct sock *sock;
1772
1773         if (k >= map->max_entries)
1774                 return -EINVAL;
1775
1776         sock = xchg(&stab->sock_map[k], NULL);
1777         if (!sock)
1778                 return -EINVAL;
1779
1780         psock = smap_psock_sk(sock);
1781         if (!psock)
1782                 goto out;
1783
1784         if (psock->bpf_parse)
1785                 smap_stop_sock(psock, sock);
1786         smap_list_map_remove(psock, &stab->sock_map[k]);
1787         smap_release_sock(psock, sock);
1788 out:
1789         return 0;
1790 }
1791
1792 /* Locking notes: Concurrent updates, deletes, and lookups are allowed and are
1793  * done inside rcu critical sections. This ensures on updates that the psock
1794  * will not be released via smap_release_sock() until concurrent updates/deletes
1795  * complete. All operations operate on sock_map using cmpxchg and xchg
1796  * operations to ensure we do not get stale references. Any reads into the
1797  * map must be done with READ_ONCE() because of this.
1798  *
1799  * A psock is destroyed via call_rcu and after any worker threads are cancelled
1800  * and syncd so we are certain all references from the update/lookup/delete
1801  * operations as well as references in the data path are no longer in use.
1802  *
1803  * Psocks may exist in multiple maps, but only a single set of parse/verdict
1804  * programs may be inherited from the maps it belongs to. A reference count
1805  * is kept with the total number of references to the psock from all maps. The
1806  * psock will not be released until this reaches zero. The psock and sock
1807  * user data data use the sk_callback_lock to protect critical data structures
1808  * from concurrent access. This allows us to avoid two updates from modifying
1809  * the user data in sock and the lock is required anyways for modifying
1810  * callbacks, we simply increase its scope slightly.
1811  *
1812  * Rules to follow,
1813  *  - psock must always be read inside RCU critical section
1814  *  - sk_user_data must only be modified inside sk_callback_lock and read
1815  *    inside RCU critical section.
1816  *  - psock->maps list must only be read & modified inside sk_callback_lock
1817  *  - sock_map must use READ_ONCE and (cmp)xchg operations
1818  *  - BPF verdict/parse programs must use READ_ONCE and xchg operations
1819  */
1820
1821 static int __sock_map_ctx_update_elem(struct bpf_map *map,
1822                                       struct bpf_sock_progs *progs,
1823                                       struct sock *sock,
1824                                       struct sock **map_link,
1825                                       void *key)
1826 {
1827         struct bpf_prog *verdict, *parse, *tx_msg;
1828         struct smap_psock_map_entry *e = NULL;
1829         struct smap_psock *psock;
1830         bool new = false;
1831         int err = 0;
1832
1833         /* 1. If sock map has BPF programs those will be inherited by the
1834          * sock being added. If the sock is already attached to BPF programs
1835          * this results in an error.
1836          */
1837         verdict = READ_ONCE(progs->bpf_verdict);
1838         parse = READ_ONCE(progs->bpf_parse);
1839         tx_msg = READ_ONCE(progs->bpf_tx_msg);
1840
1841         if (parse && verdict) {
1842                 /* bpf prog refcnt may be zero if a concurrent attach operation
1843                  * removes the program after the above READ_ONCE() but before
1844                  * we increment the refcnt. If this is the case abort with an
1845                  * error.
1846                  */
1847                 verdict = bpf_prog_inc_not_zero(verdict);
1848                 if (IS_ERR(verdict))
1849                         return PTR_ERR(verdict);
1850
1851                 parse = bpf_prog_inc_not_zero(parse);
1852                 if (IS_ERR(parse)) {
1853                         bpf_prog_put(verdict);
1854                         return PTR_ERR(parse);
1855                 }
1856         }
1857
1858         if (tx_msg) {
1859                 tx_msg = bpf_prog_inc_not_zero(tx_msg);
1860                 if (IS_ERR(tx_msg)) {
1861                         if (parse && verdict) {
1862                                 bpf_prog_put(parse);
1863                                 bpf_prog_put(verdict);
1864                         }
1865                         return PTR_ERR(tx_msg);
1866                 }
1867         }
1868
1869         psock = smap_psock_sk(sock);
1870
1871         /* 2. Do not allow inheriting programs if psock exists and has
1872          * already inherited programs. This would create confusion on
1873          * which parser/verdict program is running. If no psock exists
1874          * create one. Inside sk_callback_lock to ensure concurrent create
1875          * doesn't update user data.
1876          */
1877         if (psock) {
1878                 if (READ_ONCE(psock->bpf_parse) && parse) {
1879                         err = -EBUSY;
1880                         goto out_progs;
1881                 }
1882                 if (READ_ONCE(psock->bpf_tx_msg) && tx_msg) {
1883                         err = -EBUSY;
1884                         goto out_progs;
1885                 }
1886                 if (!refcount_inc_not_zero(&psock->refcnt)) {
1887                         err = -EAGAIN;
1888                         goto out_progs;
1889                 }
1890         } else {
1891                 psock = smap_init_psock(sock, map->numa_node);
1892                 if (IS_ERR(psock)) {
1893                         err = PTR_ERR(psock);
1894                         goto out_progs;
1895                 }
1896
1897                 set_bit(SMAP_TX_RUNNING, &psock->state);
1898                 new = true;
1899         }
1900
1901         if (map_link) {
1902                 e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
1903                 if (!e) {
1904                         err = -ENOMEM;
1905                         goto out_free;
1906                 }
1907         }
1908
1909         /* 3. At this point we have a reference to a valid psock that is
1910          * running. Attach any BPF programs needed.
1911          */
1912         if (tx_msg)
1913                 bpf_tcp_msg_add(psock, sock, tx_msg);
1914         if (new) {
1915                 err = tcp_set_ulp_id(sock, TCP_ULP_BPF);
1916                 if (err)
1917                         goto out_free;
1918         }
1919
1920         if (parse && verdict && !psock->strp_enabled) {
1921                 err = smap_init_sock(psock, sock);
1922                 if (err)
1923                         goto out_free;
1924                 smap_init_progs(psock, verdict, parse);
1925                 write_lock_bh(&sock->sk_callback_lock);
1926                 smap_start_sock(psock, sock);
1927                 write_unlock_bh(&sock->sk_callback_lock);
1928         }
1929
1930         /* 4. Place psock in sockmap for use and stop any programs on
1931          * the old sock assuming its not the same sock we are replacing
1932          * it with. Because we can only have a single set of programs if
1933          * old_sock has a strp we can stop it.
1934          */
1935         if (map_link) {
1936                 e->entry = map_link;
1937                 spin_lock_bh(&psock->maps_lock);
1938                 list_add_tail(&e->list, &psock->maps);
1939                 spin_unlock_bh(&psock->maps_lock);
1940         }
1941         return err;
1942 out_free:
1943         smap_release_sock(psock, sock);
1944 out_progs:
1945         if (parse && verdict) {
1946                 bpf_prog_put(parse);
1947                 bpf_prog_put(verdict);
1948         }
1949         if (tx_msg)
1950                 bpf_prog_put(tx_msg);
1951         kfree(e);
1952         return err;
1953 }
1954
1955 static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
1956                                     struct bpf_map *map,
1957                                     void *key, u64 flags)
1958 {
1959         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1960         struct bpf_sock_progs *progs = &stab->progs;
1961         struct sock *osock, *sock;
1962         u32 i = *(u32 *)key;
1963         int err;
1964
1965         if (unlikely(flags > BPF_EXIST))
1966                 return -EINVAL;
1967
1968         if (unlikely(i >= stab->map.max_entries))
1969                 return -E2BIG;
1970
1971         sock = READ_ONCE(stab->sock_map[i]);
1972         if (flags == BPF_EXIST && !sock)
1973                 return -ENOENT;
1974         else if (flags == BPF_NOEXIST && sock)
1975                 return -EEXIST;
1976
1977         sock = skops->sk;
1978         err = __sock_map_ctx_update_elem(map, progs, sock, &stab->sock_map[i],
1979                                          key);
1980         if (err)
1981                 goto out;
1982
1983         osock = xchg(&stab->sock_map[i], sock);
1984         if (osock) {
1985                 struct smap_psock *opsock = smap_psock_sk(osock);
1986
1987                 smap_list_map_remove(opsock, &stab->sock_map[i]);
1988                 smap_release_sock(opsock, osock);
1989         }
1990 out:
1991         return err;
1992 }
1993
1994 int sock_map_prog(struct bpf_map *map, struct bpf_prog *prog, u32 type)
1995 {
1996         struct bpf_sock_progs *progs;
1997         struct bpf_prog *orig;
1998
1999         if (map->map_type == BPF_MAP_TYPE_SOCKMAP) {
2000                 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
2001
2002                 progs = &stab->progs;
2003         } else if (map->map_type == BPF_MAP_TYPE_SOCKHASH) {
2004                 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2005
2006                 progs = &htab->progs;
2007         } else {
2008                 return -EINVAL;
2009         }
2010
2011         switch (type) {
2012         case BPF_SK_MSG_VERDICT:
2013                 orig = xchg(&progs->bpf_tx_msg, prog);
2014                 break;
2015         case BPF_SK_SKB_STREAM_PARSER:
2016                 orig = xchg(&progs->bpf_parse, prog);
2017                 break;
2018         case BPF_SK_SKB_STREAM_VERDICT:
2019                 orig = xchg(&progs->bpf_verdict, prog);
2020                 break;
2021         default:
2022                 return -EOPNOTSUPP;
2023         }
2024
2025         if (orig)
2026                 bpf_prog_put(orig);
2027
2028         return 0;
2029 }
2030
2031 int sockmap_get_from_fd(const union bpf_attr *attr, int type,
2032                         struct bpf_prog *prog)
2033 {
2034         int ufd = attr->target_fd;
2035         struct bpf_map *map;
2036         struct fd f;
2037         int err;
2038
2039         f = fdget(ufd);
2040         map = __bpf_map_get(f);
2041         if (IS_ERR(map))
2042                 return PTR_ERR(map);
2043
2044         err = sock_map_prog(map, prog, attr->attach_type);
2045         fdput(f);
2046         return err;
2047 }
2048
2049 static void *sock_map_lookup(struct bpf_map *map, void *key)
2050 {
2051         return NULL;
2052 }
2053
2054 static int sock_map_update_elem(struct bpf_map *map,
2055                                 void *key, void *value, u64 flags)
2056 {
2057         struct bpf_sock_ops_kern skops;
2058         u32 fd = *(u32 *)value;
2059         struct socket *socket;
2060         int err;
2061
2062         socket = sockfd_lookup(fd, &err);
2063         if (!socket)
2064                 return err;
2065
2066         skops.sk = socket->sk;
2067         if (!skops.sk) {
2068                 fput(socket->file);
2069                 return -EINVAL;
2070         }
2071
2072         if (skops.sk->sk_type != SOCK_STREAM ||
2073             skops.sk->sk_protocol != IPPROTO_TCP) {
2074                 fput(socket->file);
2075                 return -EOPNOTSUPP;
2076         }
2077
2078         lock_sock(skops.sk);
2079         preempt_disable();
2080         rcu_read_lock();
2081         err = sock_map_ctx_update_elem(&skops, map, key, flags);
2082         rcu_read_unlock();
2083         preempt_enable();
2084         release_sock(skops.sk);
2085         fput(socket->file);
2086         return err;
2087 }
2088
2089 static void sock_map_release(struct bpf_map *map)
2090 {
2091         struct bpf_sock_progs *progs;
2092         struct bpf_prog *orig;
2093
2094         if (map->map_type == BPF_MAP_TYPE_SOCKMAP) {
2095                 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
2096
2097                 progs = &stab->progs;
2098         } else {
2099                 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2100
2101                 progs = &htab->progs;
2102         }
2103
2104         orig = xchg(&progs->bpf_parse, NULL);
2105         if (orig)
2106                 bpf_prog_put(orig);
2107         orig = xchg(&progs->bpf_verdict, NULL);
2108         if (orig)
2109                 bpf_prog_put(orig);
2110
2111         orig = xchg(&progs->bpf_tx_msg, NULL);
2112         if (orig)
2113                 bpf_prog_put(orig);
2114 }
2115
2116 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
2117 {
2118         struct bpf_htab *htab;
2119         int i, err;
2120         u64 cost;
2121
2122         if (!capable(CAP_NET_ADMIN))
2123                 return ERR_PTR(-EPERM);
2124
2125         /* check sanity of attributes */
2126         if (attr->max_entries == 0 || attr->value_size != 4 ||
2127             attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
2128                 return ERR_PTR(-EINVAL);
2129
2130         if (attr->key_size > MAX_BPF_STACK)
2131                 /* eBPF programs initialize keys on stack, so they cannot be
2132                  * larger than max stack size
2133                  */
2134                 return ERR_PTR(-E2BIG);
2135
2136         err = bpf_tcp_ulp_register();
2137         if (err && err != -EEXIST)
2138                 return ERR_PTR(err);
2139
2140         htab = kzalloc(sizeof(*htab), GFP_USER);
2141         if (!htab)
2142                 return ERR_PTR(-ENOMEM);
2143
2144         bpf_map_init_from_attr(&htab->map, attr);
2145
2146         htab->n_buckets = roundup_pow_of_two(htab->map.max_entries);
2147         htab->elem_size = sizeof(struct htab_elem) +
2148                           round_up(htab->map.key_size, 8);
2149         err = -EINVAL;
2150         if (htab->n_buckets == 0 ||
2151             htab->n_buckets > U32_MAX / sizeof(struct bucket))
2152                 goto free_htab;
2153
2154         cost = (u64) htab->n_buckets * sizeof(struct bucket) +
2155                (u64) htab->elem_size * htab->map.max_entries;
2156
2157         if (cost >= U32_MAX - PAGE_SIZE)
2158                 goto free_htab;
2159
2160         htab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
2161         err = bpf_map_precharge_memlock(htab->map.pages);
2162         if (err)
2163                 goto free_htab;
2164
2165         err = -ENOMEM;
2166         htab->buckets = bpf_map_area_alloc(
2167                                 htab->n_buckets * sizeof(struct bucket),
2168                                 htab->map.numa_node);
2169         if (!htab->buckets)
2170                 goto free_htab;
2171
2172         for (i = 0; i < htab->n_buckets; i++) {
2173                 INIT_HLIST_HEAD(&htab->buckets[i].head);
2174                 raw_spin_lock_init(&htab->buckets[i].lock);
2175         }
2176
2177         return &htab->map;
2178 free_htab:
2179         kfree(htab);
2180         return ERR_PTR(err);
2181 }
2182
2183 static void __bpf_htab_free(struct rcu_head *rcu)
2184 {
2185         struct bpf_htab *htab;
2186
2187         htab = container_of(rcu, struct bpf_htab, rcu);
2188         bpf_map_area_free(htab->buckets);
2189         kfree(htab);
2190 }
2191
2192 static void sock_hash_free(struct bpf_map *map)
2193 {
2194         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2195         int i;
2196
2197         synchronize_rcu();
2198
2199         /* At this point no update, lookup or delete operations can happen.
2200          * However, be aware we can still get a socket state event updates,
2201          * and data ready callabacks that reference the psock from sk_user_data
2202          * Also psock worker threads are still in-flight. So smap_release_sock
2203          * will only free the psock after cancel_sync on the worker threads
2204          * and a grace period expire to ensure psock is really safe to remove.
2205          */
2206         rcu_read_lock();
2207         for (i = 0; i < htab->n_buckets; i++) {
2208                 struct bucket *b = __select_bucket(htab, i);
2209                 struct hlist_head *head;
2210                 struct hlist_node *n;
2211                 struct htab_elem *l;
2212
2213                 raw_spin_lock_bh(&b->lock);
2214                 head = &b->head;
2215                 hlist_for_each_entry_safe(l, n, head, hash_node) {
2216                         struct sock *sock = l->sk;
2217                         struct smap_psock *psock;
2218
2219                         hlist_del_rcu(&l->hash_node);
2220                         psock = smap_psock_sk(sock);
2221                         /* This check handles a racing sock event that can get
2222                          * the sk_callback_lock before this case but after xchg
2223                          * causing the refcnt to hit zero and sock user data
2224                          * (psock) to be null and queued for garbage collection.
2225                          */
2226                         if (likely(psock)) {
2227                                 smap_list_hash_remove(psock, l);
2228                                 smap_release_sock(psock, sock);
2229                         }
2230                         free_htab_elem(htab, l);
2231                 }
2232                 raw_spin_unlock_bh(&b->lock);
2233         }
2234         rcu_read_unlock();
2235         call_rcu(&htab->rcu, __bpf_htab_free);
2236 }
2237
2238 static struct htab_elem *alloc_sock_hash_elem(struct bpf_htab *htab,
2239                                               void *key, u32 key_size, u32 hash,
2240                                               struct sock *sk,
2241                                               struct htab_elem *old_elem)
2242 {
2243         struct htab_elem *l_new;
2244
2245         if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
2246                 if (!old_elem) {
2247                         atomic_dec(&htab->count);
2248                         return ERR_PTR(-E2BIG);
2249                 }
2250         }
2251         l_new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
2252                              htab->map.numa_node);
2253         if (!l_new)
2254                 return ERR_PTR(-ENOMEM);
2255
2256         memcpy(l_new->key, key, key_size);
2257         l_new->sk = sk;
2258         l_new->hash = hash;
2259         return l_new;
2260 }
2261
2262 static inline u32 htab_map_hash(const void *key, u32 key_len)
2263 {
2264         return jhash(key, key_len, 0);
2265 }
2266
2267 static int sock_hash_get_next_key(struct bpf_map *map,
2268                                   void *key, void *next_key)
2269 {
2270         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2271         struct htab_elem *l, *next_l;
2272         struct hlist_head *h;
2273         u32 hash, key_size;
2274         int i = 0;
2275
2276         WARN_ON_ONCE(!rcu_read_lock_held());
2277
2278         key_size = map->key_size;
2279         if (!key)
2280                 goto find_first_elem;
2281         hash = htab_map_hash(key, key_size);
2282         h = select_bucket(htab, hash);
2283
2284         l = lookup_elem_raw(h, hash, key, key_size);
2285         if (!l)
2286                 goto find_first_elem;
2287         next_l = hlist_entry_safe(
2288                      rcu_dereference_raw(hlist_next_rcu(&l->hash_node)),
2289                      struct htab_elem, hash_node);
2290         if (next_l) {
2291                 memcpy(next_key, next_l->key, key_size);
2292                 return 0;
2293         }
2294
2295         /* no more elements in this hash list, go to the next bucket */
2296         i = hash & (htab->n_buckets - 1);
2297         i++;
2298
2299 find_first_elem:
2300         /* iterate over buckets */
2301         for (; i < htab->n_buckets; i++) {
2302                 h = select_bucket(htab, i);
2303
2304                 /* pick first element in the bucket */
2305                 next_l = hlist_entry_safe(
2306                                 rcu_dereference_raw(hlist_first_rcu(h)),
2307                                 struct htab_elem, hash_node);
2308                 if (next_l) {
2309                         /* if it's not empty, just return it */
2310                         memcpy(next_key, next_l->key, key_size);
2311                         return 0;
2312                 }
2313         }
2314
2315         /* iterated over all buckets and all elements */
2316         return -ENOENT;
2317 }
2318
2319 static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops,
2320                                      struct bpf_map *map,
2321                                      void *key, u64 map_flags)
2322 {
2323         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2324         struct bpf_sock_progs *progs = &htab->progs;
2325         struct htab_elem *l_new = NULL, *l_old;
2326         struct smap_psock_map_entry *e = NULL;
2327         struct hlist_head *head;
2328         struct smap_psock *psock;
2329         u32 key_size, hash;
2330         struct sock *sock;
2331         struct bucket *b;
2332         int err;
2333
2334         sock = skops->sk;
2335
2336         if (sock->sk_type != SOCK_STREAM ||
2337             sock->sk_protocol != IPPROTO_TCP)
2338                 return -EOPNOTSUPP;
2339
2340         if (unlikely(map_flags > BPF_EXIST))
2341                 return -EINVAL;
2342
2343         e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
2344         if (!e)
2345                 return -ENOMEM;
2346
2347         WARN_ON_ONCE(!rcu_read_lock_held());
2348         key_size = map->key_size;
2349         hash = htab_map_hash(key, key_size);
2350         b = __select_bucket(htab, hash);
2351         head = &b->head;
2352
2353         err = __sock_map_ctx_update_elem(map, progs, sock, NULL, key);
2354         if (err)
2355                 goto err;
2356
2357         /* psock is valid here because otherwise above *ctx_update_elem would
2358          * have thrown an error. It is safe to skip error check.
2359          */
2360         psock = smap_psock_sk(sock);
2361         raw_spin_lock_bh(&b->lock);
2362         l_old = lookup_elem_raw(head, hash, key, key_size);
2363         if (l_old && map_flags == BPF_NOEXIST) {
2364                 err = -EEXIST;
2365                 goto bucket_err;
2366         }
2367         if (!l_old && map_flags == BPF_EXIST) {
2368                 err = -ENOENT;
2369                 goto bucket_err;
2370         }
2371
2372         l_new = alloc_sock_hash_elem(htab, key, key_size, hash, sock, l_old);
2373         if (IS_ERR(l_new)) {
2374                 err = PTR_ERR(l_new);
2375                 goto bucket_err;
2376         }
2377
2378         rcu_assign_pointer(e->hash_link, l_new);
2379         rcu_assign_pointer(e->htab,
2380                            container_of(map, struct bpf_htab, map));
2381         spin_lock_bh(&psock->maps_lock);
2382         list_add_tail(&e->list, &psock->maps);
2383         spin_unlock_bh(&psock->maps_lock);
2384
2385         /* add new element to the head of the list, so that
2386          * concurrent search will find it before old elem
2387          */
2388         hlist_add_head_rcu(&l_new->hash_node, head);
2389         if (l_old) {
2390                 psock = smap_psock_sk(l_old->sk);
2391
2392                 hlist_del_rcu(&l_old->hash_node);
2393                 smap_list_hash_remove(psock, l_old);
2394                 smap_release_sock(psock, l_old->sk);
2395                 free_htab_elem(htab, l_old);
2396         }
2397         raw_spin_unlock_bh(&b->lock);
2398         return 0;
2399 bucket_err:
2400         smap_release_sock(psock, sock);
2401         raw_spin_unlock_bh(&b->lock);
2402 err:
2403         kfree(e);
2404         return err;
2405 }
2406
2407 static int sock_hash_update_elem(struct bpf_map *map,
2408                                 void *key, void *value, u64 flags)
2409 {
2410         struct bpf_sock_ops_kern skops;
2411         u32 fd = *(u32 *)value;
2412         struct socket *socket;
2413         int err;
2414
2415         socket = sockfd_lookup(fd, &err);
2416         if (!socket)
2417                 return err;
2418
2419         skops.sk = socket->sk;
2420         if (!skops.sk) {
2421                 fput(socket->file);
2422                 return -EINVAL;
2423         }
2424
2425         lock_sock(skops.sk);
2426         preempt_disable();
2427         rcu_read_lock();
2428         err = sock_hash_ctx_update_elem(&skops, map, key, flags);
2429         rcu_read_unlock();
2430         preempt_enable();
2431         release_sock(skops.sk);
2432         fput(socket->file);
2433         return err;
2434 }
2435
2436 static int sock_hash_delete_elem(struct bpf_map *map, void *key)
2437 {
2438         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2439         struct hlist_head *head;
2440         struct bucket *b;
2441         struct htab_elem *l;
2442         u32 hash, key_size;
2443         int ret = -ENOENT;
2444
2445         key_size = map->key_size;
2446         hash = htab_map_hash(key, key_size);
2447         b = __select_bucket(htab, hash);
2448         head = &b->head;
2449
2450         raw_spin_lock_bh(&b->lock);
2451         l = lookup_elem_raw(head, hash, key, key_size);
2452         if (l) {
2453                 struct sock *sock = l->sk;
2454                 struct smap_psock *psock;
2455
2456                 hlist_del_rcu(&l->hash_node);
2457                 psock = smap_psock_sk(sock);
2458                 /* This check handles a racing sock event that can get the
2459                  * sk_callback_lock before this case but after xchg happens
2460                  * causing the refcnt to hit zero and sock user data (psock)
2461                  * to be null and queued for garbage collection.
2462                  */
2463                 if (likely(psock)) {
2464                         smap_list_hash_remove(psock, l);
2465                         smap_release_sock(psock, sock);
2466                 }
2467                 free_htab_elem(htab, l);
2468                 ret = 0;
2469         }
2470         raw_spin_unlock_bh(&b->lock);
2471         return ret;
2472 }
2473
2474 struct sock  *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
2475 {
2476         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2477         struct hlist_head *head;
2478         struct htab_elem *l;
2479         u32 key_size, hash;
2480         struct bucket *b;
2481         struct sock *sk;
2482
2483         key_size = map->key_size;
2484         hash = htab_map_hash(key, key_size);
2485         b = __select_bucket(htab, hash);
2486         head = &b->head;
2487
2488         l = lookup_elem_raw(head, hash, key, key_size);
2489         sk = l ? l->sk : NULL;
2490         return sk;
2491 }
2492
2493 const struct bpf_map_ops sock_map_ops = {
2494         .map_alloc = sock_map_alloc,
2495         .map_free = sock_map_free,
2496         .map_lookup_elem = sock_map_lookup,
2497         .map_get_next_key = sock_map_get_next_key,
2498         .map_update_elem = sock_map_update_elem,
2499         .map_delete_elem = sock_map_delete_elem,
2500         .map_release_uref = sock_map_release,
2501 };
2502
2503 const struct bpf_map_ops sock_hash_ops = {
2504         .map_alloc = sock_hash_alloc,
2505         .map_free = sock_hash_free,
2506         .map_lookup_elem = sock_map_lookup,
2507         .map_get_next_key = sock_hash_get_next_key,
2508         .map_update_elem = sock_hash_update_elem,
2509         .map_delete_elem = sock_hash_delete_elem,
2510         .map_release_uref = sock_map_release,
2511 };
2512
2513 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, bpf_sock,
2514            struct bpf_map *, map, void *, key, u64, flags)
2515 {
2516         WARN_ON_ONCE(!rcu_read_lock_held());
2517         return sock_map_ctx_update_elem(bpf_sock, map, key, flags);
2518 }
2519
2520 const struct bpf_func_proto bpf_sock_map_update_proto = {
2521         .func           = bpf_sock_map_update,
2522         .gpl_only       = false,
2523         .pkt_access     = true,
2524         .ret_type       = RET_INTEGER,
2525         .arg1_type      = ARG_PTR_TO_CTX,
2526         .arg2_type      = ARG_CONST_MAP_PTR,
2527         .arg3_type      = ARG_PTR_TO_MAP_KEY,
2528         .arg4_type      = ARG_ANYTHING,
2529 };
2530
2531 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, bpf_sock,
2532            struct bpf_map *, map, void *, key, u64, flags)
2533 {
2534         WARN_ON_ONCE(!rcu_read_lock_held());
2535         return sock_hash_ctx_update_elem(bpf_sock, map, key, flags);
2536 }
2537
2538 const struct bpf_func_proto bpf_sock_hash_update_proto = {
2539         .func           = bpf_sock_hash_update,
2540         .gpl_only       = false,
2541         .pkt_access     = true,
2542         .ret_type       = RET_INTEGER,
2543         .arg1_type      = ARG_PTR_TO_CTX,
2544         .arg2_type      = ARG_CONST_MAP_PTR,
2545         .arg3_type      = ARG_PTR_TO_MAP_KEY,
2546         .arg4_type      = ARG_ANYTHING,
2547 };