Merge tag 'drm-misc-fixes-2018-07-13' of git://anongit.freedesktop.org/drm/drm-misc...
[linux-2.6-microblaze.git] / kernel / bpf / sockmap.c
1 /* Copyright (c) 2017 Covalent IO, Inc. http://covalent.io
2  *
3  * This program is free software; you can redistribute it and/or
4  * modify it under the terms of version 2 of the GNU General Public
5  * License as published by the Free Software Foundation.
6  *
7  * This program is distributed in the hope that it will be useful, but
8  * WITHOUT ANY WARRANTY; without even the implied warranty of
9  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10  * General Public License for more details.
11  */
12
13 /* A BPF sock_map is used to store sock objects. This is primarly used
14  * for doing socket redirect with BPF helper routines.
15  *
16  * A sock map may have BPF programs attached to it, currently a program
17  * used to parse packets and a program to provide a verdict and redirect
18  * decision on the packet are supported. Any programs attached to a sock
19  * map are inherited by sock objects when they are added to the map. If
20  * no BPF programs are attached the sock object may only be used for sock
21  * redirect.
22  *
23  * A sock object may be in multiple maps, but can only inherit a single
24  * parse or verdict program. If adding a sock object to a map would result
25  * in having multiple parsing programs the update will return an EBUSY error.
26  *
27  * For reference this program is similar to devmap used in XDP context
28  * reviewing these together may be useful. For an example please review
29  * ./samples/bpf/sockmap/.
30  */
31 #include <linux/bpf.h>
32 #include <net/sock.h>
33 #include <linux/filter.h>
34 #include <linux/errno.h>
35 #include <linux/file.h>
36 #include <linux/kernel.h>
37 #include <linux/net.h>
38 #include <linux/skbuff.h>
39 #include <linux/workqueue.h>
40 #include <linux/list.h>
41 #include <linux/mm.h>
42 #include <net/strparser.h>
43 #include <net/tcp.h>
44 #include <linux/ptr_ring.h>
45 #include <net/inet_common.h>
46 #include <linux/sched/signal.h>
47
48 #define SOCK_CREATE_FLAG_MASK \
49         (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
50
51 struct bpf_sock_progs {
52         struct bpf_prog *bpf_tx_msg;
53         struct bpf_prog *bpf_parse;
54         struct bpf_prog *bpf_verdict;
55 };
56
57 struct bpf_stab {
58         struct bpf_map map;
59         struct sock **sock_map;
60         struct bpf_sock_progs progs;
61 };
62
63 struct bucket {
64         struct hlist_head head;
65         raw_spinlock_t lock;
66 };
67
68 struct bpf_htab {
69         struct bpf_map map;
70         struct bucket *buckets;
71         atomic_t count;
72         u32 n_buckets;
73         u32 elem_size;
74         struct bpf_sock_progs progs;
75         struct rcu_head rcu;
76 };
77
78 struct htab_elem {
79         struct rcu_head rcu;
80         struct hlist_node hash_node;
81         u32 hash;
82         struct sock *sk;
83         char key[0];
84 };
85
86 enum smap_psock_state {
87         SMAP_TX_RUNNING,
88 };
89
90 struct smap_psock_map_entry {
91         struct list_head list;
92         struct sock **entry;
93         struct htab_elem __rcu *hash_link;
94         struct bpf_htab __rcu *htab;
95 };
96
97 struct smap_psock {
98         struct rcu_head rcu;
99         refcount_t refcnt;
100
101         /* datapath variables */
102         struct sk_buff_head rxqueue;
103         bool strp_enabled;
104
105         /* datapath error path cache across tx work invocations */
106         int save_rem;
107         int save_off;
108         struct sk_buff *save_skb;
109
110         /* datapath variables for tx_msg ULP */
111         struct sock *sk_redir;
112         int apply_bytes;
113         int cork_bytes;
114         int sg_size;
115         int eval;
116         struct sk_msg_buff *cork;
117         struct list_head ingress;
118
119         struct strparser strp;
120         struct bpf_prog *bpf_tx_msg;
121         struct bpf_prog *bpf_parse;
122         struct bpf_prog *bpf_verdict;
123         struct list_head maps;
124         spinlock_t maps_lock;
125
126         /* Back reference used when sock callback trigger sockmap operations */
127         struct sock *sock;
128         unsigned long state;
129
130         struct work_struct tx_work;
131         struct work_struct gc_work;
132
133         struct proto *sk_proto;
134         void (*save_close)(struct sock *sk, long timeout);
135         void (*save_data_ready)(struct sock *sk);
136         void (*save_write_space)(struct sock *sk);
137 };
138
139 static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
140 static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
141                            int nonblock, int flags, int *addr_len);
142 static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
143 static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
144                             int offset, size_t size, int flags);
145 static void bpf_tcp_close(struct sock *sk, long timeout);
146
147 static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
148 {
149         return rcu_dereference_sk_user_data(sk);
150 }
151
152 static bool bpf_tcp_stream_read(const struct sock *sk)
153 {
154         struct smap_psock *psock;
155         bool empty = true;
156
157         rcu_read_lock();
158         psock = smap_psock_sk(sk);
159         if (unlikely(!psock))
160                 goto out;
161         empty = list_empty(&psock->ingress);
162 out:
163         rcu_read_unlock();
164         return !empty;
165 }
166
167 enum {
168         SOCKMAP_IPV4,
169         SOCKMAP_IPV6,
170         SOCKMAP_NUM_PROTS,
171 };
172
173 enum {
174         SOCKMAP_BASE,
175         SOCKMAP_TX,
176         SOCKMAP_NUM_CONFIGS,
177 };
178
179 static struct proto *saved_tcpv6_prot __read_mostly;
180 static DEFINE_SPINLOCK(tcpv6_prot_lock);
181 static struct proto bpf_tcp_prots[SOCKMAP_NUM_PROTS][SOCKMAP_NUM_CONFIGS];
182 static void build_protos(struct proto prot[SOCKMAP_NUM_CONFIGS],
183                          struct proto *base)
184 {
185         prot[SOCKMAP_BASE]                      = *base;
186         prot[SOCKMAP_BASE].close                = bpf_tcp_close;
187         prot[SOCKMAP_BASE].recvmsg              = bpf_tcp_recvmsg;
188         prot[SOCKMAP_BASE].stream_memory_read   = bpf_tcp_stream_read;
189
190         prot[SOCKMAP_TX]                        = prot[SOCKMAP_BASE];
191         prot[SOCKMAP_TX].sendmsg                = bpf_tcp_sendmsg;
192         prot[SOCKMAP_TX].sendpage               = bpf_tcp_sendpage;
193 }
194
195 static void update_sk_prot(struct sock *sk, struct smap_psock *psock)
196 {
197         int family = sk->sk_family == AF_INET6 ? SOCKMAP_IPV6 : SOCKMAP_IPV4;
198         int conf = psock->bpf_tx_msg ? SOCKMAP_TX : SOCKMAP_BASE;
199
200         sk->sk_prot = &bpf_tcp_prots[family][conf];
201 }
202
203 static int bpf_tcp_init(struct sock *sk)
204 {
205         struct smap_psock *psock;
206
207         rcu_read_lock();
208         psock = smap_psock_sk(sk);
209         if (unlikely(!psock)) {
210                 rcu_read_unlock();
211                 return -EINVAL;
212         }
213
214         if (unlikely(psock->sk_proto)) {
215                 rcu_read_unlock();
216                 return -EBUSY;
217         }
218
219         psock->save_close = sk->sk_prot->close;
220         psock->sk_proto = sk->sk_prot;
221
222         /* Build IPv6 sockmap whenever the address of tcpv6_prot changes */
223         if (sk->sk_family == AF_INET6 &&
224             unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {
225                 spin_lock_bh(&tcpv6_prot_lock);
226                 if (likely(sk->sk_prot != saved_tcpv6_prot)) {
227                         build_protos(bpf_tcp_prots[SOCKMAP_IPV6], sk->sk_prot);
228                         smp_store_release(&saved_tcpv6_prot, sk->sk_prot);
229                 }
230                 spin_unlock_bh(&tcpv6_prot_lock);
231         }
232         update_sk_prot(sk, psock);
233         rcu_read_unlock();
234         return 0;
235 }
236
237 static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
238 static int free_start_sg(struct sock *sk, struct sk_msg_buff *md);
239
240 static void bpf_tcp_release(struct sock *sk)
241 {
242         struct smap_psock *psock;
243
244         rcu_read_lock();
245         psock = smap_psock_sk(sk);
246         if (unlikely(!psock))
247                 goto out;
248
249         if (psock->cork) {
250                 free_start_sg(psock->sock, psock->cork);
251                 kfree(psock->cork);
252                 psock->cork = NULL;
253         }
254
255         if (psock->sk_proto) {
256                 sk->sk_prot = psock->sk_proto;
257                 psock->sk_proto = NULL;
258         }
259 out:
260         rcu_read_unlock();
261 }
262
263 static struct htab_elem *lookup_elem_raw(struct hlist_head *head,
264                                          u32 hash, void *key, u32 key_size)
265 {
266         struct htab_elem *l;
267
268         hlist_for_each_entry_rcu(l, head, hash_node) {
269                 if (l->hash == hash && !memcmp(&l->key, key, key_size))
270                         return l;
271         }
272
273         return NULL;
274 }
275
276 static inline struct bucket *__select_bucket(struct bpf_htab *htab, u32 hash)
277 {
278         return &htab->buckets[hash & (htab->n_buckets - 1)];
279 }
280
281 static inline struct hlist_head *select_bucket(struct bpf_htab *htab, u32 hash)
282 {
283         return &__select_bucket(htab, hash)->head;
284 }
285
286 static void free_htab_elem(struct bpf_htab *htab, struct htab_elem *l)
287 {
288         atomic_dec(&htab->count);
289         kfree_rcu(l, rcu);
290 }
291
292 static struct smap_psock_map_entry *psock_map_pop(struct sock *sk,
293                                                   struct smap_psock *psock)
294 {
295         struct smap_psock_map_entry *e;
296
297         spin_lock_bh(&psock->maps_lock);
298         e = list_first_entry_or_null(&psock->maps,
299                                      struct smap_psock_map_entry,
300                                      list);
301         if (e)
302                 list_del(&e->list);
303         spin_unlock_bh(&psock->maps_lock);
304         return e;
305 }
306
307 static void bpf_tcp_close(struct sock *sk, long timeout)
308 {
309         void (*close_fun)(struct sock *sk, long timeout);
310         struct smap_psock_map_entry *e;
311         struct sk_msg_buff *md, *mtmp;
312         struct smap_psock *psock;
313         struct sock *osk;
314
315         rcu_read_lock();
316         psock = smap_psock_sk(sk);
317         if (unlikely(!psock)) {
318                 rcu_read_unlock();
319                 return sk->sk_prot->close(sk, timeout);
320         }
321
322         /* The psock may be destroyed anytime after exiting the RCU critial
323          * section so by the time we use close_fun the psock may no longer
324          * be valid. However, bpf_tcp_close is called with the sock lock
325          * held so the close hook and sk are still valid.
326          */
327         close_fun = psock->save_close;
328
329         if (psock->cork) {
330                 free_start_sg(psock->sock, psock->cork);
331                 kfree(psock->cork);
332                 psock->cork = NULL;
333         }
334
335         list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
336                 list_del(&md->list);
337                 free_start_sg(psock->sock, md);
338                 kfree(md);
339         }
340
341         e = psock_map_pop(sk, psock);
342         while (e) {
343                 if (e->entry) {
344                         osk = cmpxchg(e->entry, sk, NULL);
345                         if (osk == sk) {
346                                 smap_release_sock(psock, sk);
347                         }
348                 } else {
349                         struct htab_elem *link = rcu_dereference(e->hash_link);
350                         struct bpf_htab *htab = rcu_dereference(e->htab);
351                         struct hlist_head *head;
352                         struct htab_elem *l;
353                         struct bucket *b;
354
355                         b = __select_bucket(htab, link->hash);
356                         head = &b->head;
357                         raw_spin_lock_bh(&b->lock);
358                         l = lookup_elem_raw(head,
359                                             link->hash, link->key,
360                                             htab->map.key_size);
361                         /* If another thread deleted this object skip deletion.
362                          * The refcnt on psock may or may not be zero.
363                          */
364                         if (l) {
365                                 hlist_del_rcu(&link->hash_node);
366                                 smap_release_sock(psock, link->sk);
367                                 free_htab_elem(htab, link);
368                         }
369                         raw_spin_unlock_bh(&b->lock);
370                 }
371                 e = psock_map_pop(sk, psock);
372         }
373         rcu_read_unlock();
374         close_fun(sk, timeout);
375 }
376
377 enum __sk_action {
378         __SK_DROP = 0,
379         __SK_PASS,
380         __SK_REDIRECT,
381         __SK_NONE,
382 };
383
384 static struct tcp_ulp_ops bpf_tcp_ulp_ops __read_mostly = {
385         .name           = "bpf_tcp",
386         .uid            = TCP_ULP_BPF,
387         .user_visible   = false,
388         .owner          = NULL,
389         .init           = bpf_tcp_init,
390         .release        = bpf_tcp_release,
391 };
392
393 static int memcopy_from_iter(struct sock *sk,
394                              struct sk_msg_buff *md,
395                              struct iov_iter *from, int bytes)
396 {
397         struct scatterlist *sg = md->sg_data;
398         int i = md->sg_curr, rc = -ENOSPC;
399
400         do {
401                 int copy;
402                 char *to;
403
404                 if (md->sg_copybreak >= sg[i].length) {
405                         md->sg_copybreak = 0;
406
407                         if (++i == MAX_SKB_FRAGS)
408                                 i = 0;
409
410                         if (i == md->sg_end)
411                                 break;
412                 }
413
414                 copy = sg[i].length - md->sg_copybreak;
415                 to = sg_virt(&sg[i]) + md->sg_copybreak;
416                 md->sg_copybreak += copy;
417
418                 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
419                         rc = copy_from_iter_nocache(to, copy, from);
420                 else
421                         rc = copy_from_iter(to, copy, from);
422
423                 if (rc != copy) {
424                         rc = -EFAULT;
425                         goto out;
426                 }
427
428                 bytes -= copy;
429                 if (!bytes)
430                         break;
431
432                 md->sg_copybreak = 0;
433                 if (++i == MAX_SKB_FRAGS)
434                         i = 0;
435         } while (i != md->sg_end);
436 out:
437         md->sg_curr = i;
438         return rc;
439 }
440
441 static int bpf_tcp_push(struct sock *sk, int apply_bytes,
442                         struct sk_msg_buff *md,
443                         int flags, bool uncharge)
444 {
445         bool apply = apply_bytes;
446         struct scatterlist *sg;
447         int offset, ret = 0;
448         struct page *p;
449         size_t size;
450
451         while (1) {
452                 sg = md->sg_data + md->sg_start;
453                 size = (apply && apply_bytes < sg->length) ?
454                         apply_bytes : sg->length;
455                 offset = sg->offset;
456
457                 tcp_rate_check_app_limited(sk);
458                 p = sg_page(sg);
459 retry:
460                 ret = do_tcp_sendpages(sk, p, offset, size, flags);
461                 if (ret != size) {
462                         if (ret > 0) {
463                                 if (apply)
464                                         apply_bytes -= ret;
465
466                                 sg->offset += ret;
467                                 sg->length -= ret;
468                                 size -= ret;
469                                 offset += ret;
470                                 if (uncharge)
471                                         sk_mem_uncharge(sk, ret);
472                                 goto retry;
473                         }
474
475                         return ret;
476                 }
477
478                 if (apply)
479                         apply_bytes -= ret;
480                 sg->offset += ret;
481                 sg->length -= ret;
482                 if (uncharge)
483                         sk_mem_uncharge(sk, ret);
484
485                 if (!sg->length) {
486                         put_page(p);
487                         md->sg_start++;
488                         if (md->sg_start == MAX_SKB_FRAGS)
489                                 md->sg_start = 0;
490                         sg_init_table(sg, 1);
491
492                         if (md->sg_start == md->sg_end)
493                                 break;
494                 }
495
496                 if (apply && !apply_bytes)
497                         break;
498         }
499         return 0;
500 }
501
502 static inline void bpf_compute_data_pointers_sg(struct sk_msg_buff *md)
503 {
504         struct scatterlist *sg = md->sg_data + md->sg_start;
505
506         if (md->sg_copy[md->sg_start]) {
507                 md->data = md->data_end = 0;
508         } else {
509                 md->data = sg_virt(sg);
510                 md->data_end = md->data + sg->length;
511         }
512 }
513
514 static void return_mem_sg(struct sock *sk, int bytes, struct sk_msg_buff *md)
515 {
516         struct scatterlist *sg = md->sg_data;
517         int i = md->sg_start;
518
519         do {
520                 int uncharge = (bytes < sg[i].length) ? bytes : sg[i].length;
521
522                 sk_mem_uncharge(sk, uncharge);
523                 bytes -= uncharge;
524                 if (!bytes)
525                         break;
526                 i++;
527                 if (i == MAX_SKB_FRAGS)
528                         i = 0;
529         } while (i != md->sg_end);
530 }
531
532 static void free_bytes_sg(struct sock *sk, int bytes,
533                           struct sk_msg_buff *md, bool charge)
534 {
535         struct scatterlist *sg = md->sg_data;
536         int i = md->sg_start, free;
537
538         while (bytes && sg[i].length) {
539                 free = sg[i].length;
540                 if (bytes < free) {
541                         sg[i].length -= bytes;
542                         sg[i].offset += bytes;
543                         if (charge)
544                                 sk_mem_uncharge(sk, bytes);
545                         break;
546                 }
547
548                 if (charge)
549                         sk_mem_uncharge(sk, sg[i].length);
550                 put_page(sg_page(&sg[i]));
551                 bytes -= sg[i].length;
552                 sg[i].length = 0;
553                 sg[i].page_link = 0;
554                 sg[i].offset = 0;
555                 i++;
556
557                 if (i == MAX_SKB_FRAGS)
558                         i = 0;
559         }
560         md->sg_start = i;
561 }
562
563 static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)
564 {
565         struct scatterlist *sg = md->sg_data;
566         int i = start, free = 0;
567
568         while (sg[i].length) {
569                 free += sg[i].length;
570                 sk_mem_uncharge(sk, sg[i].length);
571                 put_page(sg_page(&sg[i]));
572                 sg[i].length = 0;
573                 sg[i].page_link = 0;
574                 sg[i].offset = 0;
575                 i++;
576
577                 if (i == MAX_SKB_FRAGS)
578                         i = 0;
579         }
580
581         return free;
582 }
583
584 static int free_start_sg(struct sock *sk, struct sk_msg_buff *md)
585 {
586         int free = free_sg(sk, md->sg_start, md);
587
588         md->sg_start = md->sg_end;
589         return free;
590 }
591
592 static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md)
593 {
594         return free_sg(sk, md->sg_curr, md);
595 }
596
597 static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md)
598 {
599         return ((_rc == SK_PASS) ?
600                (md->sk_redir ? __SK_REDIRECT : __SK_PASS) :
601                __SK_DROP);
602 }
603
604 static unsigned int smap_do_tx_msg(struct sock *sk,
605                                    struct smap_psock *psock,
606                                    struct sk_msg_buff *md)
607 {
608         struct bpf_prog *prog;
609         unsigned int rc, _rc;
610
611         preempt_disable();
612         rcu_read_lock();
613
614         /* If the policy was removed mid-send then default to 'accept' */
615         prog = READ_ONCE(psock->bpf_tx_msg);
616         if (unlikely(!prog)) {
617                 _rc = SK_PASS;
618                 goto verdict;
619         }
620
621         bpf_compute_data_pointers_sg(md);
622         md->sk = sk;
623         rc = (*prog->bpf_func)(md, prog->insnsi);
624         psock->apply_bytes = md->apply_bytes;
625
626         /* Moving return codes from UAPI namespace into internal namespace */
627         _rc = bpf_map_msg_verdict(rc, md);
628
629         /* The psock has a refcount on the sock but not on the map and because
630          * we need to drop rcu read lock here its possible the map could be
631          * removed between here and when we need it to execute the sock
632          * redirect. So do the map lookup now for future use.
633          */
634         if (_rc == __SK_REDIRECT) {
635                 if (psock->sk_redir)
636                         sock_put(psock->sk_redir);
637                 psock->sk_redir = do_msg_redirect_map(md);
638                 if (!psock->sk_redir) {
639                         _rc = __SK_DROP;
640                         goto verdict;
641                 }
642                 sock_hold(psock->sk_redir);
643         }
644 verdict:
645         rcu_read_unlock();
646         preempt_enable();
647
648         return _rc;
649 }
650
651 static int bpf_tcp_ingress(struct sock *sk, int apply_bytes,
652                            struct smap_psock *psock,
653                            struct sk_msg_buff *md, int flags)
654 {
655         bool apply = apply_bytes;
656         size_t size, copied = 0;
657         struct sk_msg_buff *r;
658         int err = 0, i;
659
660         r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_KERNEL);
661         if (unlikely(!r))
662                 return -ENOMEM;
663
664         lock_sock(sk);
665         r->sg_start = md->sg_start;
666         i = md->sg_start;
667
668         do {
669                 size = (apply && apply_bytes < md->sg_data[i].length) ?
670                         apply_bytes : md->sg_data[i].length;
671
672                 if (!sk_wmem_schedule(sk, size)) {
673                         if (!copied)
674                                 err = -ENOMEM;
675                         break;
676                 }
677
678                 sk_mem_charge(sk, size);
679                 r->sg_data[i] = md->sg_data[i];
680                 r->sg_data[i].length = size;
681                 md->sg_data[i].length -= size;
682                 md->sg_data[i].offset += size;
683                 copied += size;
684
685                 if (md->sg_data[i].length) {
686                         get_page(sg_page(&r->sg_data[i]));
687                         r->sg_end = (i + 1) == MAX_SKB_FRAGS ? 0 : i + 1;
688                 } else {
689                         i++;
690                         if (i == MAX_SKB_FRAGS)
691                                 i = 0;
692                         r->sg_end = i;
693                 }
694
695                 if (apply) {
696                         apply_bytes -= size;
697                         if (!apply_bytes)
698                                 break;
699                 }
700         } while (i != md->sg_end);
701
702         md->sg_start = i;
703
704         if (!err) {
705                 list_add_tail(&r->list, &psock->ingress);
706                 sk->sk_data_ready(sk);
707         } else {
708                 free_start_sg(sk, r);
709                 kfree(r);
710         }
711
712         release_sock(sk);
713         return err;
714 }
715
716 static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
717                                        struct sk_msg_buff *md,
718                                        int flags)
719 {
720         bool ingress = !!(md->flags & BPF_F_INGRESS);
721         struct smap_psock *psock;
722         struct scatterlist *sg;
723         int err = 0;
724
725         sg = md->sg_data;
726
727         rcu_read_lock();
728         psock = smap_psock_sk(sk);
729         if (unlikely(!psock))
730                 goto out_rcu;
731
732         if (!refcount_inc_not_zero(&psock->refcnt))
733                 goto out_rcu;
734
735         rcu_read_unlock();
736
737         if (ingress) {
738                 err = bpf_tcp_ingress(sk, send, psock, md, flags);
739         } else {
740                 lock_sock(sk);
741                 err = bpf_tcp_push(sk, send, md, flags, false);
742                 release_sock(sk);
743         }
744         smap_release_sock(psock, sk);
745         if (unlikely(err))
746                 goto out;
747         return 0;
748 out_rcu:
749         rcu_read_unlock();
750 out:
751         free_bytes_sg(NULL, send, md, false);
752         return err;
753 }
754
755 static inline void bpf_md_init(struct smap_psock *psock)
756 {
757         if (!psock->apply_bytes) {
758                 psock->eval =  __SK_NONE;
759                 if (psock->sk_redir) {
760                         sock_put(psock->sk_redir);
761                         psock->sk_redir = NULL;
762                 }
763         }
764 }
765
766 static void apply_bytes_dec(struct smap_psock *psock, int i)
767 {
768         if (psock->apply_bytes) {
769                 if (psock->apply_bytes < i)
770                         psock->apply_bytes = 0;
771                 else
772                         psock->apply_bytes -= i;
773         }
774 }
775
776 static int bpf_exec_tx_verdict(struct smap_psock *psock,
777                                struct sk_msg_buff *m,
778                                struct sock *sk,
779                                int *copied, int flags)
780 {
781         bool cork = false, enospc = (m->sg_start == m->sg_end);
782         struct sock *redir;
783         int err = 0;
784         int send;
785
786 more_data:
787         if (psock->eval == __SK_NONE)
788                 psock->eval = smap_do_tx_msg(sk, psock, m);
789
790         if (m->cork_bytes &&
791             m->cork_bytes > psock->sg_size && !enospc) {
792                 psock->cork_bytes = m->cork_bytes - psock->sg_size;
793                 if (!psock->cork) {
794                         psock->cork = kcalloc(1,
795                                         sizeof(struct sk_msg_buff),
796                                         GFP_ATOMIC | __GFP_NOWARN);
797
798                         if (!psock->cork) {
799                                 err = -ENOMEM;
800                                 goto out_err;
801                         }
802                 }
803                 memcpy(psock->cork, m, sizeof(*m));
804                 goto out_err;
805         }
806
807         send = psock->sg_size;
808         if (psock->apply_bytes && psock->apply_bytes < send)
809                 send = psock->apply_bytes;
810
811         switch (psock->eval) {
812         case __SK_PASS:
813                 err = bpf_tcp_push(sk, send, m, flags, true);
814                 if (unlikely(err)) {
815                         *copied -= free_start_sg(sk, m);
816                         break;
817                 }
818
819                 apply_bytes_dec(psock, send);
820                 psock->sg_size -= send;
821                 break;
822         case __SK_REDIRECT:
823                 redir = psock->sk_redir;
824                 apply_bytes_dec(psock, send);
825
826                 if (psock->cork) {
827                         cork = true;
828                         psock->cork = NULL;
829                 }
830
831                 return_mem_sg(sk, send, m);
832                 release_sock(sk);
833
834                 err = bpf_tcp_sendmsg_do_redirect(redir, send, m, flags);
835                 lock_sock(sk);
836
837                 if (unlikely(err < 0)) {
838                         free_start_sg(sk, m);
839                         psock->sg_size = 0;
840                         if (!cork)
841                                 *copied -= send;
842                 } else {
843                         psock->sg_size -= send;
844                 }
845
846                 if (cork) {
847                         free_start_sg(sk, m);
848                         psock->sg_size = 0;
849                         kfree(m);
850                         m = NULL;
851                         err = 0;
852                 }
853                 break;
854         case __SK_DROP:
855         default:
856                 free_bytes_sg(sk, send, m, true);
857                 apply_bytes_dec(psock, send);
858                 *copied -= send;
859                 psock->sg_size -= send;
860                 err = -EACCES;
861                 break;
862         }
863
864         if (likely(!err)) {
865                 bpf_md_init(psock);
866                 if (m &&
867                     m->sg_data[m->sg_start].page_link &&
868                     m->sg_data[m->sg_start].length)
869                         goto more_data;
870         }
871
872 out_err:
873         return err;
874 }
875
876 static int bpf_wait_data(struct sock *sk,
877                          struct smap_psock *psk, int flags,
878                          long timeo, int *err)
879 {
880         int rc;
881
882         DEFINE_WAIT_FUNC(wait, woken_wake_function);
883
884         add_wait_queue(sk_sleep(sk), &wait);
885         sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
886         rc = sk_wait_event(sk, &timeo,
887                            !list_empty(&psk->ingress) ||
888                            !skb_queue_empty(&sk->sk_receive_queue),
889                            &wait);
890         sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
891         remove_wait_queue(sk_sleep(sk), &wait);
892
893         return rc;
894 }
895
896 static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
897                            int nonblock, int flags, int *addr_len)
898 {
899         struct iov_iter *iter = &msg->msg_iter;
900         struct smap_psock *psock;
901         int copied = 0;
902
903         if (unlikely(flags & MSG_ERRQUEUE))
904                 return inet_recv_error(sk, msg, len, addr_len);
905
906         rcu_read_lock();
907         psock = smap_psock_sk(sk);
908         if (unlikely(!psock))
909                 goto out;
910
911         if (unlikely(!refcount_inc_not_zero(&psock->refcnt)))
912                 goto out;
913         rcu_read_unlock();
914
915         if (!skb_queue_empty(&sk->sk_receive_queue))
916                 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
917
918         lock_sock(sk);
919 bytes_ready:
920         while (copied != len) {
921                 struct scatterlist *sg;
922                 struct sk_msg_buff *md;
923                 int i;
924
925                 md = list_first_entry_or_null(&psock->ingress,
926                                               struct sk_msg_buff, list);
927                 if (unlikely(!md))
928                         break;
929                 i = md->sg_start;
930                 do {
931                         struct page *page;
932                         int n, copy;
933
934                         sg = &md->sg_data[i];
935                         copy = sg->length;
936                         page = sg_page(sg);
937
938                         if (copied + copy > len)
939                                 copy = len - copied;
940
941                         n = copy_page_to_iter(page, sg->offset, copy, iter);
942                         if (n != copy) {
943                                 md->sg_start = i;
944                                 release_sock(sk);
945                                 smap_release_sock(psock, sk);
946                                 return -EFAULT;
947                         }
948
949                         copied += copy;
950                         sg->offset += copy;
951                         sg->length -= copy;
952                         sk_mem_uncharge(sk, copy);
953
954                         if (!sg->length) {
955                                 i++;
956                                 if (i == MAX_SKB_FRAGS)
957                                         i = 0;
958                                 if (!md->skb)
959                                         put_page(page);
960                         }
961                         if (copied == len)
962                                 break;
963                 } while (i != md->sg_end);
964                 md->sg_start = i;
965
966                 if (!sg->length && md->sg_start == md->sg_end) {
967                         list_del(&md->list);
968                         if (md->skb)
969                                 consume_skb(md->skb);
970                         kfree(md);
971                 }
972         }
973
974         if (!copied) {
975                 long timeo;
976                 int data;
977                 int err = 0;
978
979                 timeo = sock_rcvtimeo(sk, nonblock);
980                 data = bpf_wait_data(sk, psock, flags, timeo, &err);
981
982                 if (data) {
983                         if (!skb_queue_empty(&sk->sk_receive_queue)) {
984                                 release_sock(sk);
985                                 smap_release_sock(psock, sk);
986                                 copied = tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
987                                 return copied;
988                         }
989                         goto bytes_ready;
990                 }
991
992                 if (err)
993                         copied = err;
994         }
995
996         release_sock(sk);
997         smap_release_sock(psock, sk);
998         return copied;
999 out:
1000         rcu_read_unlock();
1001         return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
1002 }
1003
1004
1005 static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
1006 {
1007         int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
1008         struct sk_msg_buff md = {0};
1009         unsigned int sg_copy = 0;
1010         struct smap_psock *psock;
1011         int copied = 0, err = 0;
1012         struct scatterlist *sg;
1013         long timeo;
1014
1015         /* Its possible a sock event or user removed the psock _but_ the ops
1016          * have not been reprogrammed yet so we get here. In this case fallback
1017          * to tcp_sendmsg. Note this only works because we _only_ ever allow
1018          * a single ULP there is no hierarchy here.
1019          */
1020         rcu_read_lock();
1021         psock = smap_psock_sk(sk);
1022         if (unlikely(!psock)) {
1023                 rcu_read_unlock();
1024                 return tcp_sendmsg(sk, msg, size);
1025         }
1026
1027         /* Increment the psock refcnt to ensure its not released while sending a
1028          * message. Required because sk lookup and bpf programs are used in
1029          * separate rcu critical sections. Its OK if we lose the map entry
1030          * but we can't lose the sock reference.
1031          */
1032         if (!refcount_inc_not_zero(&psock->refcnt)) {
1033                 rcu_read_unlock();
1034                 return tcp_sendmsg(sk, msg, size);
1035         }
1036
1037         sg = md.sg_data;
1038         sg_init_marker(sg, MAX_SKB_FRAGS);
1039         rcu_read_unlock();
1040
1041         lock_sock(sk);
1042         timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
1043
1044         while (msg_data_left(msg)) {
1045                 struct sk_msg_buff *m;
1046                 bool enospc = false;
1047                 int copy;
1048
1049                 if (sk->sk_err) {
1050                         err = sk->sk_err;
1051                         goto out_err;
1052                 }
1053
1054                 copy = msg_data_left(msg);
1055                 if (!sk_stream_memory_free(sk))
1056                         goto wait_for_sndbuf;
1057
1058                 m = psock->cork_bytes ? psock->cork : &md;
1059                 m->sg_curr = m->sg_copybreak ? m->sg_curr : m->sg_end;
1060                 err = sk_alloc_sg(sk, copy, m->sg_data,
1061                                   m->sg_start, &m->sg_end, &sg_copy,
1062                                   m->sg_end - 1);
1063                 if (err) {
1064                         if (err != -ENOSPC)
1065                                 goto wait_for_memory;
1066                         enospc = true;
1067                         copy = sg_copy;
1068                 }
1069
1070                 err = memcopy_from_iter(sk, m, &msg->msg_iter, copy);
1071                 if (err < 0) {
1072                         free_curr_sg(sk, m);
1073                         goto out_err;
1074                 }
1075
1076                 psock->sg_size += copy;
1077                 copied += copy;
1078                 sg_copy = 0;
1079
1080                 /* When bytes are being corked skip running BPF program and
1081                  * applying verdict unless there is no more buffer space. In
1082                  * the ENOSPC case simply run BPF prorgram with currently
1083                  * accumulated data. We don't have much choice at this point
1084                  * we could try extending the page frags or chaining complex
1085                  * frags but even in these cases _eventually_ we will hit an
1086                  * OOM scenario. More complex recovery schemes may be
1087                  * implemented in the future, but BPF programs must handle
1088                  * the case where apply_cork requests are not honored. The
1089                  * canonical method to verify this is to check data length.
1090                  */
1091                 if (psock->cork_bytes) {
1092                         if (copy > psock->cork_bytes)
1093                                 psock->cork_bytes = 0;
1094                         else
1095                                 psock->cork_bytes -= copy;
1096
1097                         if (psock->cork_bytes && !enospc)
1098                                 goto out_cork;
1099
1100                         /* All cork bytes accounted for re-run filter */
1101                         psock->eval = __SK_NONE;
1102                         psock->cork_bytes = 0;
1103                 }
1104
1105                 err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
1106                 if (unlikely(err < 0))
1107                         goto out_err;
1108                 continue;
1109 wait_for_sndbuf:
1110                 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1111 wait_for_memory:
1112                 err = sk_stream_wait_memory(sk, &timeo);
1113                 if (err)
1114                         goto out_err;
1115         }
1116 out_err:
1117         if (err < 0)
1118                 err = sk_stream_error(sk, msg->msg_flags, err);
1119 out_cork:
1120         release_sock(sk);
1121         smap_release_sock(psock, sk);
1122         return copied ? copied : err;
1123 }
1124
1125 static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
1126                             int offset, size_t size, int flags)
1127 {
1128         struct sk_msg_buff md = {0}, *m = NULL;
1129         int err = 0, copied = 0;
1130         struct smap_psock *psock;
1131         struct scatterlist *sg;
1132         bool enospc = false;
1133
1134         rcu_read_lock();
1135         psock = smap_psock_sk(sk);
1136         if (unlikely(!psock))
1137                 goto accept;
1138
1139         if (!refcount_inc_not_zero(&psock->refcnt))
1140                 goto accept;
1141         rcu_read_unlock();
1142
1143         lock_sock(sk);
1144
1145         if (psock->cork_bytes) {
1146                 m = psock->cork;
1147                 sg = &m->sg_data[m->sg_end];
1148         } else {
1149                 m = &md;
1150                 sg = m->sg_data;
1151                 sg_init_marker(sg, MAX_SKB_FRAGS);
1152         }
1153
1154         /* Catch case where ring is full and sendpage is stalled. */
1155         if (unlikely(m->sg_end == m->sg_start &&
1156             m->sg_data[m->sg_end].length))
1157                 goto out_err;
1158
1159         psock->sg_size += size;
1160         sg_set_page(sg, page, size, offset);
1161         get_page(page);
1162         m->sg_copy[m->sg_end] = true;
1163         sk_mem_charge(sk, size);
1164         m->sg_end++;
1165         copied = size;
1166
1167         if (m->sg_end == MAX_SKB_FRAGS)
1168                 m->sg_end = 0;
1169
1170         if (m->sg_end == m->sg_start)
1171                 enospc = true;
1172
1173         if (psock->cork_bytes) {
1174                 if (size > psock->cork_bytes)
1175                         psock->cork_bytes = 0;
1176                 else
1177                         psock->cork_bytes -= size;
1178
1179                 if (psock->cork_bytes && !enospc)
1180                         goto out_err;
1181
1182                 /* All cork bytes accounted for re-run filter */
1183                 psock->eval = __SK_NONE;
1184                 psock->cork_bytes = 0;
1185         }
1186
1187         err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
1188 out_err:
1189         release_sock(sk);
1190         smap_release_sock(psock, sk);
1191         return copied ? copied : err;
1192 accept:
1193         rcu_read_unlock();
1194         return tcp_sendpage(sk, page, offset, size, flags);
1195 }
1196
1197 static void bpf_tcp_msg_add(struct smap_psock *psock,
1198                             struct sock *sk,
1199                             struct bpf_prog *tx_msg)
1200 {
1201         struct bpf_prog *orig_tx_msg;
1202
1203         orig_tx_msg = xchg(&psock->bpf_tx_msg, tx_msg);
1204         if (orig_tx_msg)
1205                 bpf_prog_put(orig_tx_msg);
1206 }
1207
1208 static int bpf_tcp_ulp_register(void)
1209 {
1210         build_protos(bpf_tcp_prots[SOCKMAP_IPV4], &tcp_prot);
1211         /* Once BPF TX ULP is registered it is never unregistered. It
1212          * will be in the ULP list for the lifetime of the system. Doing
1213          * duplicate registers is not a problem.
1214          */
1215         return tcp_register_ulp(&bpf_tcp_ulp_ops);
1216 }
1217
1218 static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb)
1219 {
1220         struct bpf_prog *prog = READ_ONCE(psock->bpf_verdict);
1221         int rc;
1222
1223         if (unlikely(!prog))
1224                 return __SK_DROP;
1225
1226         skb_orphan(skb);
1227         /* We need to ensure that BPF metadata for maps is also cleared
1228          * when we orphan the skb so that we don't have the possibility
1229          * to reference a stale map.
1230          */
1231         TCP_SKB_CB(skb)->bpf.sk_redir = NULL;
1232         skb->sk = psock->sock;
1233         bpf_compute_data_pointers(skb);
1234         preempt_disable();
1235         rc = (*prog->bpf_func)(skb, prog->insnsi);
1236         preempt_enable();
1237         skb->sk = NULL;
1238
1239         /* Moving return codes from UAPI namespace into internal namespace */
1240         return rc == SK_PASS ?
1241                 (TCP_SKB_CB(skb)->bpf.sk_redir ? __SK_REDIRECT : __SK_PASS) :
1242                 __SK_DROP;
1243 }
1244
1245 static int smap_do_ingress(struct smap_psock *psock, struct sk_buff *skb)
1246 {
1247         struct sock *sk = psock->sock;
1248         int copied = 0, num_sg;
1249         struct sk_msg_buff *r;
1250
1251         r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_ATOMIC);
1252         if (unlikely(!r))
1253                 return -EAGAIN;
1254
1255         if (!sk_rmem_schedule(sk, skb, skb->len)) {
1256                 kfree(r);
1257                 return -EAGAIN;
1258         }
1259
1260         sg_init_table(r->sg_data, MAX_SKB_FRAGS);
1261         num_sg = skb_to_sgvec(skb, r->sg_data, 0, skb->len);
1262         if (unlikely(num_sg < 0)) {
1263                 kfree(r);
1264                 return num_sg;
1265         }
1266         sk_mem_charge(sk, skb->len);
1267         copied = skb->len;
1268         r->sg_start = 0;
1269         r->sg_end = num_sg == MAX_SKB_FRAGS ? 0 : num_sg;
1270         r->skb = skb;
1271         list_add_tail(&r->list, &psock->ingress);
1272         sk->sk_data_ready(sk);
1273         return copied;
1274 }
1275
1276 static void smap_do_verdict(struct smap_psock *psock, struct sk_buff *skb)
1277 {
1278         struct smap_psock *peer;
1279         struct sock *sk;
1280         __u32 in;
1281         int rc;
1282
1283         rc = smap_verdict_func(psock, skb);
1284         switch (rc) {
1285         case __SK_REDIRECT:
1286                 sk = do_sk_redirect_map(skb);
1287                 if (!sk) {
1288                         kfree_skb(skb);
1289                         break;
1290                 }
1291
1292                 peer = smap_psock_sk(sk);
1293                 in = (TCP_SKB_CB(skb)->bpf.flags) & BPF_F_INGRESS;
1294
1295                 if (unlikely(!peer || sock_flag(sk, SOCK_DEAD) ||
1296                              !test_bit(SMAP_TX_RUNNING, &peer->state))) {
1297                         kfree_skb(skb);
1298                         break;
1299                 }
1300
1301                 if (!in && sock_writeable(sk)) {
1302                         skb_set_owner_w(skb, sk);
1303                         skb_queue_tail(&peer->rxqueue, skb);
1304                         schedule_work(&peer->tx_work);
1305                         break;
1306                 } else if (in &&
1307                            atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf) {
1308                         skb_queue_tail(&peer->rxqueue, skb);
1309                         schedule_work(&peer->tx_work);
1310                         break;
1311                 }
1312         /* Fall through and free skb otherwise */
1313         case __SK_DROP:
1314         default:
1315                 kfree_skb(skb);
1316         }
1317 }
1318
1319 static void smap_report_sk_error(struct smap_psock *psock, int err)
1320 {
1321         struct sock *sk = psock->sock;
1322
1323         sk->sk_err = err;
1324         sk->sk_error_report(sk);
1325 }
1326
1327 static void smap_read_sock_strparser(struct strparser *strp,
1328                                      struct sk_buff *skb)
1329 {
1330         struct smap_psock *psock;
1331
1332         rcu_read_lock();
1333         psock = container_of(strp, struct smap_psock, strp);
1334         smap_do_verdict(psock, skb);
1335         rcu_read_unlock();
1336 }
1337
1338 /* Called with lock held on socket */
1339 static void smap_data_ready(struct sock *sk)
1340 {
1341         struct smap_psock *psock;
1342
1343         rcu_read_lock();
1344         psock = smap_psock_sk(sk);
1345         if (likely(psock)) {
1346                 write_lock_bh(&sk->sk_callback_lock);
1347                 strp_data_ready(&psock->strp);
1348                 write_unlock_bh(&sk->sk_callback_lock);
1349         }
1350         rcu_read_unlock();
1351 }
1352
1353 static void smap_tx_work(struct work_struct *w)
1354 {
1355         struct smap_psock *psock;
1356         struct sk_buff *skb;
1357         int rem, off, n;
1358
1359         psock = container_of(w, struct smap_psock, tx_work);
1360
1361         /* lock sock to avoid losing sk_socket at some point during loop */
1362         lock_sock(psock->sock);
1363         if (psock->save_skb) {
1364                 skb = psock->save_skb;
1365                 rem = psock->save_rem;
1366                 off = psock->save_off;
1367                 psock->save_skb = NULL;
1368                 goto start;
1369         }
1370
1371         while ((skb = skb_dequeue(&psock->rxqueue))) {
1372                 __u32 flags;
1373
1374                 rem = skb->len;
1375                 off = 0;
1376 start:
1377                 flags = (TCP_SKB_CB(skb)->bpf.flags) & BPF_F_INGRESS;
1378                 do {
1379                         if (likely(psock->sock->sk_socket)) {
1380                                 if (flags)
1381                                         n = smap_do_ingress(psock, skb);
1382                                 else
1383                                         n = skb_send_sock_locked(psock->sock,
1384                                                                  skb, off, rem);
1385                         } else {
1386                                 n = -EINVAL;
1387                         }
1388
1389                         if (n <= 0) {
1390                                 if (n == -EAGAIN) {
1391                                         /* Retry when space is available */
1392                                         psock->save_skb = skb;
1393                                         psock->save_rem = rem;
1394                                         psock->save_off = off;
1395                                         goto out;
1396                                 }
1397                                 /* Hard errors break pipe and stop xmit */
1398                                 smap_report_sk_error(psock, n ? -n : EPIPE);
1399                                 clear_bit(SMAP_TX_RUNNING, &psock->state);
1400                                 kfree_skb(skb);
1401                                 goto out;
1402                         }
1403                         rem -= n;
1404                         off += n;
1405                 } while (rem);
1406
1407                 if (!flags)
1408                         kfree_skb(skb);
1409         }
1410 out:
1411         release_sock(psock->sock);
1412 }
1413
1414 static void smap_write_space(struct sock *sk)
1415 {
1416         struct smap_psock *psock;
1417
1418         rcu_read_lock();
1419         psock = smap_psock_sk(sk);
1420         if (likely(psock && test_bit(SMAP_TX_RUNNING, &psock->state)))
1421                 schedule_work(&psock->tx_work);
1422         rcu_read_unlock();
1423 }
1424
1425 static void smap_stop_sock(struct smap_psock *psock, struct sock *sk)
1426 {
1427         if (!psock->strp_enabled)
1428                 return;
1429         sk->sk_data_ready = psock->save_data_ready;
1430         sk->sk_write_space = psock->save_write_space;
1431         psock->save_data_ready = NULL;
1432         psock->save_write_space = NULL;
1433         strp_stop(&psock->strp);
1434         psock->strp_enabled = false;
1435 }
1436
1437 static void smap_destroy_psock(struct rcu_head *rcu)
1438 {
1439         struct smap_psock *psock = container_of(rcu,
1440                                                   struct smap_psock, rcu);
1441
1442         /* Now that a grace period has passed there is no longer
1443          * any reference to this sock in the sockmap so we can
1444          * destroy the psock, strparser, and bpf programs. But,
1445          * because we use workqueue sync operations we can not
1446          * do it in rcu context
1447          */
1448         schedule_work(&psock->gc_work);
1449 }
1450
1451 static void smap_release_sock(struct smap_psock *psock, struct sock *sock)
1452 {
1453         if (refcount_dec_and_test(&psock->refcnt)) {
1454                 tcp_cleanup_ulp(sock);
1455                 write_lock_bh(&sock->sk_callback_lock);
1456                 smap_stop_sock(psock, sock);
1457                 write_unlock_bh(&sock->sk_callback_lock);
1458                 clear_bit(SMAP_TX_RUNNING, &psock->state);
1459                 rcu_assign_sk_user_data(sock, NULL);
1460                 call_rcu_sched(&psock->rcu, smap_destroy_psock);
1461         }
1462 }
1463
1464 static int smap_parse_func_strparser(struct strparser *strp,
1465                                        struct sk_buff *skb)
1466 {
1467         struct smap_psock *psock;
1468         struct bpf_prog *prog;
1469         int rc;
1470
1471         rcu_read_lock();
1472         psock = container_of(strp, struct smap_psock, strp);
1473         prog = READ_ONCE(psock->bpf_parse);
1474
1475         if (unlikely(!prog)) {
1476                 rcu_read_unlock();
1477                 return skb->len;
1478         }
1479
1480         /* Attach socket for bpf program to use if needed we can do this
1481          * because strparser clones the skb before handing it to a upper
1482          * layer, meaning skb_orphan has been called. We NULL sk on the
1483          * way out to ensure we don't trigger a BUG_ON in skb/sk operations
1484          * later and because we are not charging the memory of this skb to
1485          * any socket yet.
1486          */
1487         skb->sk = psock->sock;
1488         bpf_compute_data_pointers(skb);
1489         rc = (*prog->bpf_func)(skb, prog->insnsi);
1490         skb->sk = NULL;
1491         rcu_read_unlock();
1492         return rc;
1493 }
1494
1495 static int smap_read_sock_done(struct strparser *strp, int err)
1496 {
1497         return err;
1498 }
1499
1500 static int smap_init_sock(struct smap_psock *psock,
1501                           struct sock *sk)
1502 {
1503         static const struct strp_callbacks cb = {
1504                 .rcv_msg = smap_read_sock_strparser,
1505                 .parse_msg = smap_parse_func_strparser,
1506                 .read_sock_done = smap_read_sock_done,
1507         };
1508
1509         return strp_init(&psock->strp, sk, &cb);
1510 }
1511
1512 static void smap_init_progs(struct smap_psock *psock,
1513                             struct bpf_prog *verdict,
1514                             struct bpf_prog *parse)
1515 {
1516         struct bpf_prog *orig_parse, *orig_verdict;
1517
1518         orig_parse = xchg(&psock->bpf_parse, parse);
1519         orig_verdict = xchg(&psock->bpf_verdict, verdict);
1520
1521         if (orig_verdict)
1522                 bpf_prog_put(orig_verdict);
1523         if (orig_parse)
1524                 bpf_prog_put(orig_parse);
1525 }
1526
1527 static void smap_start_sock(struct smap_psock *psock, struct sock *sk)
1528 {
1529         if (sk->sk_data_ready == smap_data_ready)
1530                 return;
1531         psock->save_data_ready = sk->sk_data_ready;
1532         psock->save_write_space = sk->sk_write_space;
1533         sk->sk_data_ready = smap_data_ready;
1534         sk->sk_write_space = smap_write_space;
1535         psock->strp_enabled = true;
1536 }
1537
1538 static void sock_map_remove_complete(struct bpf_stab *stab)
1539 {
1540         bpf_map_area_free(stab->sock_map);
1541         kfree(stab);
1542 }
1543
1544 static void smap_gc_work(struct work_struct *w)
1545 {
1546         struct smap_psock_map_entry *e, *tmp;
1547         struct sk_msg_buff *md, *mtmp;
1548         struct smap_psock *psock;
1549
1550         psock = container_of(w, struct smap_psock, gc_work);
1551
1552         /* no callback lock needed because we already detached sockmap ops */
1553         if (psock->strp_enabled)
1554                 strp_done(&psock->strp);
1555
1556         cancel_work_sync(&psock->tx_work);
1557         __skb_queue_purge(&psock->rxqueue);
1558
1559         /* At this point all strparser and xmit work must be complete */
1560         if (psock->bpf_parse)
1561                 bpf_prog_put(psock->bpf_parse);
1562         if (psock->bpf_verdict)
1563                 bpf_prog_put(psock->bpf_verdict);
1564         if (psock->bpf_tx_msg)
1565                 bpf_prog_put(psock->bpf_tx_msg);
1566
1567         if (psock->cork) {
1568                 free_start_sg(psock->sock, psock->cork);
1569                 kfree(psock->cork);
1570         }
1571
1572         list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
1573                 list_del(&md->list);
1574                 free_start_sg(psock->sock, md);
1575                 kfree(md);
1576         }
1577
1578         list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1579                 list_del(&e->list);
1580                 kfree(e);
1581         }
1582
1583         if (psock->sk_redir)
1584                 sock_put(psock->sk_redir);
1585
1586         sock_put(psock->sock);
1587         kfree(psock);
1588 }
1589
1590 static struct smap_psock *smap_init_psock(struct sock *sock, int node)
1591 {
1592         struct smap_psock *psock;
1593
1594         psock = kzalloc_node(sizeof(struct smap_psock),
1595                              GFP_ATOMIC | __GFP_NOWARN,
1596                              node);
1597         if (!psock)
1598                 return ERR_PTR(-ENOMEM);
1599
1600         psock->eval =  __SK_NONE;
1601         psock->sock = sock;
1602         skb_queue_head_init(&psock->rxqueue);
1603         INIT_WORK(&psock->tx_work, smap_tx_work);
1604         INIT_WORK(&psock->gc_work, smap_gc_work);
1605         INIT_LIST_HEAD(&psock->maps);
1606         INIT_LIST_HEAD(&psock->ingress);
1607         refcount_set(&psock->refcnt, 1);
1608         spin_lock_init(&psock->maps_lock);
1609
1610         rcu_assign_sk_user_data(sock, psock);
1611         sock_hold(sock);
1612         return psock;
1613 }
1614
1615 static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
1616 {
1617         struct bpf_stab *stab;
1618         u64 cost;
1619         int err;
1620
1621         if (!capable(CAP_NET_ADMIN))
1622                 return ERR_PTR(-EPERM);
1623
1624         /* check sanity of attributes */
1625         if (attr->max_entries == 0 || attr->key_size != 4 ||
1626             attr->value_size != 4 || attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
1627                 return ERR_PTR(-EINVAL);
1628
1629         err = bpf_tcp_ulp_register();
1630         if (err && err != -EEXIST)
1631                 return ERR_PTR(err);
1632
1633         stab = kzalloc(sizeof(*stab), GFP_USER);
1634         if (!stab)
1635                 return ERR_PTR(-ENOMEM);
1636
1637         bpf_map_init_from_attr(&stab->map, attr);
1638
1639         /* make sure page count doesn't overflow */
1640         cost = (u64) stab->map.max_entries * sizeof(struct sock *);
1641         err = -EINVAL;
1642         if (cost >= U32_MAX - PAGE_SIZE)
1643                 goto free_stab;
1644
1645         stab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
1646
1647         /* if map size is larger than memlock limit, reject it early */
1648         err = bpf_map_precharge_memlock(stab->map.pages);
1649         if (err)
1650                 goto free_stab;
1651
1652         err = -ENOMEM;
1653         stab->sock_map = bpf_map_area_alloc(stab->map.max_entries *
1654                                             sizeof(struct sock *),
1655                                             stab->map.numa_node);
1656         if (!stab->sock_map)
1657                 goto free_stab;
1658
1659         return &stab->map;
1660 free_stab:
1661         kfree(stab);
1662         return ERR_PTR(err);
1663 }
1664
1665 static void smap_list_map_remove(struct smap_psock *psock,
1666                                  struct sock **entry)
1667 {
1668         struct smap_psock_map_entry *e, *tmp;
1669
1670         spin_lock_bh(&psock->maps_lock);
1671         list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1672                 if (e->entry == entry)
1673                         list_del(&e->list);
1674         }
1675         spin_unlock_bh(&psock->maps_lock);
1676 }
1677
1678 static void smap_list_hash_remove(struct smap_psock *psock,
1679                                   struct htab_elem *hash_link)
1680 {
1681         struct smap_psock_map_entry *e, *tmp;
1682
1683         spin_lock_bh(&psock->maps_lock);
1684         list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1685                 struct htab_elem *c = rcu_dereference(e->hash_link);
1686
1687                 if (c == hash_link)
1688                         list_del(&e->list);
1689         }
1690         spin_unlock_bh(&psock->maps_lock);
1691 }
1692
1693 static void sock_map_free(struct bpf_map *map)
1694 {
1695         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1696         int i;
1697
1698         synchronize_rcu();
1699
1700         /* At this point no update, lookup or delete operations can happen.
1701          * However, be aware we can still get a socket state event updates,
1702          * and data ready callabacks that reference the psock from sk_user_data
1703          * Also psock worker threads are still in-flight. So smap_release_sock
1704          * will only free the psock after cancel_sync on the worker threads
1705          * and a grace period expire to ensure psock is really safe to remove.
1706          */
1707         rcu_read_lock();
1708         for (i = 0; i < stab->map.max_entries; i++) {
1709                 struct smap_psock *psock;
1710                 struct sock *sock;
1711
1712                 sock = xchg(&stab->sock_map[i], NULL);
1713                 if (!sock)
1714                         continue;
1715
1716                 psock = smap_psock_sk(sock);
1717                 /* This check handles a racing sock event that can get the
1718                  * sk_callback_lock before this case but after xchg happens
1719                  * causing the refcnt to hit zero and sock user data (psock)
1720                  * to be null and queued for garbage collection.
1721                  */
1722                 if (likely(psock)) {
1723                         smap_list_map_remove(psock, &stab->sock_map[i]);
1724                         smap_release_sock(psock, sock);
1725                 }
1726         }
1727         rcu_read_unlock();
1728
1729         sock_map_remove_complete(stab);
1730 }
1731
1732 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next_key)
1733 {
1734         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1735         u32 i = key ? *(u32 *)key : U32_MAX;
1736         u32 *next = (u32 *)next_key;
1737
1738         if (i >= stab->map.max_entries) {
1739                 *next = 0;
1740                 return 0;
1741         }
1742
1743         if (i == stab->map.max_entries - 1)
1744                 return -ENOENT;
1745
1746         *next = i + 1;
1747         return 0;
1748 }
1749
1750 struct sock  *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
1751 {
1752         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1753
1754         if (key >= map->max_entries)
1755                 return NULL;
1756
1757         return READ_ONCE(stab->sock_map[key]);
1758 }
1759
1760 static int sock_map_delete_elem(struct bpf_map *map, void *key)
1761 {
1762         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1763         struct smap_psock *psock;
1764         int k = *(u32 *)key;
1765         struct sock *sock;
1766
1767         if (k >= map->max_entries)
1768                 return -EINVAL;
1769
1770         sock = xchg(&stab->sock_map[k], NULL);
1771         if (!sock)
1772                 return -EINVAL;
1773
1774         psock = smap_psock_sk(sock);
1775         if (!psock)
1776                 goto out;
1777
1778         if (psock->bpf_parse)
1779                 smap_stop_sock(psock, sock);
1780         smap_list_map_remove(psock, &stab->sock_map[k]);
1781         smap_release_sock(psock, sock);
1782 out:
1783         return 0;
1784 }
1785
1786 /* Locking notes: Concurrent updates, deletes, and lookups are allowed and are
1787  * done inside rcu critical sections. This ensures on updates that the psock
1788  * will not be released via smap_release_sock() until concurrent updates/deletes
1789  * complete. All operations operate on sock_map using cmpxchg and xchg
1790  * operations to ensure we do not get stale references. Any reads into the
1791  * map must be done with READ_ONCE() because of this.
1792  *
1793  * A psock is destroyed via call_rcu and after any worker threads are cancelled
1794  * and syncd so we are certain all references from the update/lookup/delete
1795  * operations as well as references in the data path are no longer in use.
1796  *
1797  * Psocks may exist in multiple maps, but only a single set of parse/verdict
1798  * programs may be inherited from the maps it belongs to. A reference count
1799  * is kept with the total number of references to the psock from all maps. The
1800  * psock will not be released until this reaches zero. The psock and sock
1801  * user data data use the sk_callback_lock to protect critical data structures
1802  * from concurrent access. This allows us to avoid two updates from modifying
1803  * the user data in sock and the lock is required anyways for modifying
1804  * callbacks, we simply increase its scope slightly.
1805  *
1806  * Rules to follow,
1807  *  - psock must always be read inside RCU critical section
1808  *  - sk_user_data must only be modified inside sk_callback_lock and read
1809  *    inside RCU critical section.
1810  *  - psock->maps list must only be read & modified inside sk_callback_lock
1811  *  - sock_map must use READ_ONCE and (cmp)xchg operations
1812  *  - BPF verdict/parse programs must use READ_ONCE and xchg operations
1813  */
1814
1815 static int __sock_map_ctx_update_elem(struct bpf_map *map,
1816                                       struct bpf_sock_progs *progs,
1817                                       struct sock *sock,
1818                                       struct sock **map_link,
1819                                       void *key)
1820 {
1821         struct bpf_prog *verdict, *parse, *tx_msg;
1822         struct smap_psock_map_entry *e = NULL;
1823         struct smap_psock *psock;
1824         bool new = false;
1825         int err = 0;
1826
1827         /* 1. If sock map has BPF programs those will be inherited by the
1828          * sock being added. If the sock is already attached to BPF programs
1829          * this results in an error.
1830          */
1831         verdict = READ_ONCE(progs->bpf_verdict);
1832         parse = READ_ONCE(progs->bpf_parse);
1833         tx_msg = READ_ONCE(progs->bpf_tx_msg);
1834
1835         if (parse && verdict) {
1836                 /* bpf prog refcnt may be zero if a concurrent attach operation
1837                  * removes the program after the above READ_ONCE() but before
1838                  * we increment the refcnt. If this is the case abort with an
1839                  * error.
1840                  */
1841                 verdict = bpf_prog_inc_not_zero(verdict);
1842                 if (IS_ERR(verdict))
1843                         return PTR_ERR(verdict);
1844
1845                 parse = bpf_prog_inc_not_zero(parse);
1846                 if (IS_ERR(parse)) {
1847                         bpf_prog_put(verdict);
1848                         return PTR_ERR(parse);
1849                 }
1850         }
1851
1852         if (tx_msg) {
1853                 tx_msg = bpf_prog_inc_not_zero(tx_msg);
1854                 if (IS_ERR(tx_msg)) {
1855                         if (parse && verdict) {
1856                                 bpf_prog_put(parse);
1857                                 bpf_prog_put(verdict);
1858                         }
1859                         return PTR_ERR(tx_msg);
1860                 }
1861         }
1862
1863         psock = smap_psock_sk(sock);
1864
1865         /* 2. Do not allow inheriting programs if psock exists and has
1866          * already inherited programs. This would create confusion on
1867          * which parser/verdict program is running. If no psock exists
1868          * create one. Inside sk_callback_lock to ensure concurrent create
1869          * doesn't update user data.
1870          */
1871         if (psock) {
1872                 if (READ_ONCE(psock->bpf_parse) && parse) {
1873                         err = -EBUSY;
1874                         goto out_progs;
1875                 }
1876                 if (READ_ONCE(psock->bpf_tx_msg) && tx_msg) {
1877                         err = -EBUSY;
1878                         goto out_progs;
1879                 }
1880                 if (!refcount_inc_not_zero(&psock->refcnt)) {
1881                         err = -EAGAIN;
1882                         goto out_progs;
1883                 }
1884         } else {
1885                 psock = smap_init_psock(sock, map->numa_node);
1886                 if (IS_ERR(psock)) {
1887                         err = PTR_ERR(psock);
1888                         goto out_progs;
1889                 }
1890
1891                 set_bit(SMAP_TX_RUNNING, &psock->state);
1892                 new = true;
1893         }
1894
1895         if (map_link) {
1896                 e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
1897                 if (!e) {
1898                         err = -ENOMEM;
1899                         goto out_progs;
1900                 }
1901         }
1902
1903         /* 3. At this point we have a reference to a valid psock that is
1904          * running. Attach any BPF programs needed.
1905          */
1906         if (tx_msg)
1907                 bpf_tcp_msg_add(psock, sock, tx_msg);
1908         if (new) {
1909                 err = tcp_set_ulp_id(sock, TCP_ULP_BPF);
1910                 if (err)
1911                         goto out_free;
1912         }
1913
1914         if (parse && verdict && !psock->strp_enabled) {
1915                 err = smap_init_sock(psock, sock);
1916                 if (err)
1917                         goto out_free;
1918                 smap_init_progs(psock, verdict, parse);
1919                 write_lock_bh(&sock->sk_callback_lock);
1920                 smap_start_sock(psock, sock);
1921                 write_unlock_bh(&sock->sk_callback_lock);
1922         }
1923
1924         /* 4. Place psock in sockmap for use and stop any programs on
1925          * the old sock assuming its not the same sock we are replacing
1926          * it with. Because we can only have a single set of programs if
1927          * old_sock has a strp we can stop it.
1928          */
1929         if (map_link) {
1930                 e->entry = map_link;
1931                 spin_lock_bh(&psock->maps_lock);
1932                 list_add_tail(&e->list, &psock->maps);
1933                 spin_unlock_bh(&psock->maps_lock);
1934         }
1935         return err;
1936 out_free:
1937         smap_release_sock(psock, sock);
1938 out_progs:
1939         if (parse && verdict) {
1940                 bpf_prog_put(parse);
1941                 bpf_prog_put(verdict);
1942         }
1943         if (tx_msg)
1944                 bpf_prog_put(tx_msg);
1945         kfree(e);
1946         return err;
1947 }
1948
1949 static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
1950                                     struct bpf_map *map,
1951                                     void *key, u64 flags)
1952 {
1953         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1954         struct bpf_sock_progs *progs = &stab->progs;
1955         struct sock *osock, *sock;
1956         u32 i = *(u32 *)key;
1957         int err;
1958
1959         if (unlikely(flags > BPF_EXIST))
1960                 return -EINVAL;
1961
1962         if (unlikely(i >= stab->map.max_entries))
1963                 return -E2BIG;
1964
1965         sock = READ_ONCE(stab->sock_map[i]);
1966         if (flags == BPF_EXIST && !sock)
1967                 return -ENOENT;
1968         else if (flags == BPF_NOEXIST && sock)
1969                 return -EEXIST;
1970
1971         sock = skops->sk;
1972         err = __sock_map_ctx_update_elem(map, progs, sock, &stab->sock_map[i],
1973                                          key);
1974         if (err)
1975                 goto out;
1976
1977         osock = xchg(&stab->sock_map[i], sock);
1978         if (osock) {
1979                 struct smap_psock *opsock = smap_psock_sk(osock);
1980
1981                 smap_list_map_remove(opsock, &stab->sock_map[i]);
1982                 smap_release_sock(opsock, osock);
1983         }
1984 out:
1985         return err;
1986 }
1987
1988 int sock_map_prog(struct bpf_map *map, struct bpf_prog *prog, u32 type)
1989 {
1990         struct bpf_sock_progs *progs;
1991         struct bpf_prog *orig;
1992
1993         if (map->map_type == BPF_MAP_TYPE_SOCKMAP) {
1994                 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1995
1996                 progs = &stab->progs;
1997         } else if (map->map_type == BPF_MAP_TYPE_SOCKHASH) {
1998                 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
1999
2000                 progs = &htab->progs;
2001         } else {
2002                 return -EINVAL;
2003         }
2004
2005         switch (type) {
2006         case BPF_SK_MSG_VERDICT:
2007                 orig = xchg(&progs->bpf_tx_msg, prog);
2008                 break;
2009         case BPF_SK_SKB_STREAM_PARSER:
2010                 orig = xchg(&progs->bpf_parse, prog);
2011                 break;
2012         case BPF_SK_SKB_STREAM_VERDICT:
2013                 orig = xchg(&progs->bpf_verdict, prog);
2014                 break;
2015         default:
2016                 return -EOPNOTSUPP;
2017         }
2018
2019         if (orig)
2020                 bpf_prog_put(orig);
2021
2022         return 0;
2023 }
2024
2025 int sockmap_get_from_fd(const union bpf_attr *attr, int type,
2026                         struct bpf_prog *prog)
2027 {
2028         int ufd = attr->target_fd;
2029         struct bpf_map *map;
2030         struct fd f;
2031         int err;
2032
2033         f = fdget(ufd);
2034         map = __bpf_map_get(f);
2035         if (IS_ERR(map))
2036                 return PTR_ERR(map);
2037
2038         err = sock_map_prog(map, prog, attr->attach_type);
2039         fdput(f);
2040         return err;
2041 }
2042
2043 static void *sock_map_lookup(struct bpf_map *map, void *key)
2044 {
2045         return NULL;
2046 }
2047
2048 static int sock_map_update_elem(struct bpf_map *map,
2049                                 void *key, void *value, u64 flags)
2050 {
2051         struct bpf_sock_ops_kern skops;
2052         u32 fd = *(u32 *)value;
2053         struct socket *socket;
2054         int err;
2055
2056         socket = sockfd_lookup(fd, &err);
2057         if (!socket)
2058                 return err;
2059
2060         skops.sk = socket->sk;
2061         if (!skops.sk) {
2062                 fput(socket->file);
2063                 return -EINVAL;
2064         }
2065
2066         if (skops.sk->sk_type != SOCK_STREAM ||
2067             skops.sk->sk_protocol != IPPROTO_TCP) {
2068                 fput(socket->file);
2069                 return -EOPNOTSUPP;
2070         }
2071
2072         err = sock_map_ctx_update_elem(&skops, map, key, flags);
2073         fput(socket->file);
2074         return err;
2075 }
2076
2077 static void sock_map_release(struct bpf_map *map)
2078 {
2079         struct bpf_sock_progs *progs;
2080         struct bpf_prog *orig;
2081
2082         if (map->map_type == BPF_MAP_TYPE_SOCKMAP) {
2083                 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
2084
2085                 progs = &stab->progs;
2086         } else {
2087                 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2088
2089                 progs = &htab->progs;
2090         }
2091
2092         orig = xchg(&progs->bpf_parse, NULL);
2093         if (orig)
2094                 bpf_prog_put(orig);
2095         orig = xchg(&progs->bpf_verdict, NULL);
2096         if (orig)
2097                 bpf_prog_put(orig);
2098
2099         orig = xchg(&progs->bpf_tx_msg, NULL);
2100         if (orig)
2101                 bpf_prog_put(orig);
2102 }
2103
2104 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
2105 {
2106         struct bpf_htab *htab;
2107         int i, err;
2108         u64 cost;
2109
2110         if (!capable(CAP_NET_ADMIN))
2111                 return ERR_PTR(-EPERM);
2112
2113         /* check sanity of attributes */
2114         if (attr->max_entries == 0 || attr->value_size != 4 ||
2115             attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
2116                 return ERR_PTR(-EINVAL);
2117
2118         if (attr->key_size > MAX_BPF_STACK)
2119                 /* eBPF programs initialize keys on stack, so they cannot be
2120                  * larger than max stack size
2121                  */
2122                 return ERR_PTR(-E2BIG);
2123
2124         err = bpf_tcp_ulp_register();
2125         if (err && err != -EEXIST)
2126                 return ERR_PTR(err);
2127
2128         htab = kzalloc(sizeof(*htab), GFP_USER);
2129         if (!htab)
2130                 return ERR_PTR(-ENOMEM);
2131
2132         bpf_map_init_from_attr(&htab->map, attr);
2133
2134         htab->n_buckets = roundup_pow_of_two(htab->map.max_entries);
2135         htab->elem_size = sizeof(struct htab_elem) +
2136                           round_up(htab->map.key_size, 8);
2137         err = -EINVAL;
2138         if (htab->n_buckets == 0 ||
2139             htab->n_buckets > U32_MAX / sizeof(struct bucket))
2140                 goto free_htab;
2141
2142         cost = (u64) htab->n_buckets * sizeof(struct bucket) +
2143                (u64) htab->elem_size * htab->map.max_entries;
2144
2145         if (cost >= U32_MAX - PAGE_SIZE)
2146                 goto free_htab;
2147
2148         htab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
2149         err = bpf_map_precharge_memlock(htab->map.pages);
2150         if (err)
2151                 goto free_htab;
2152
2153         err = -ENOMEM;
2154         htab->buckets = bpf_map_area_alloc(
2155                                 htab->n_buckets * sizeof(struct bucket),
2156                                 htab->map.numa_node);
2157         if (!htab->buckets)
2158                 goto free_htab;
2159
2160         for (i = 0; i < htab->n_buckets; i++) {
2161                 INIT_HLIST_HEAD(&htab->buckets[i].head);
2162                 raw_spin_lock_init(&htab->buckets[i].lock);
2163         }
2164
2165         return &htab->map;
2166 free_htab:
2167         kfree(htab);
2168         return ERR_PTR(err);
2169 }
2170
2171 static void __bpf_htab_free(struct rcu_head *rcu)
2172 {
2173         struct bpf_htab *htab;
2174
2175         htab = container_of(rcu, struct bpf_htab, rcu);
2176         bpf_map_area_free(htab->buckets);
2177         kfree(htab);
2178 }
2179
2180 static void sock_hash_free(struct bpf_map *map)
2181 {
2182         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2183         int i;
2184
2185         synchronize_rcu();
2186
2187         /* At this point no update, lookup or delete operations can happen.
2188          * However, be aware we can still get a socket state event updates,
2189          * and data ready callabacks that reference the psock from sk_user_data
2190          * Also psock worker threads are still in-flight. So smap_release_sock
2191          * will only free the psock after cancel_sync on the worker threads
2192          * and a grace period expire to ensure psock is really safe to remove.
2193          */
2194         rcu_read_lock();
2195         for (i = 0; i < htab->n_buckets; i++) {
2196                 struct bucket *b = __select_bucket(htab, i);
2197                 struct hlist_head *head;
2198                 struct hlist_node *n;
2199                 struct htab_elem *l;
2200
2201                 raw_spin_lock_bh(&b->lock);
2202                 head = &b->head;
2203                 hlist_for_each_entry_safe(l, n, head, hash_node) {
2204                         struct sock *sock = l->sk;
2205                         struct smap_psock *psock;
2206
2207                         hlist_del_rcu(&l->hash_node);
2208                         psock = smap_psock_sk(sock);
2209                         /* This check handles a racing sock event that can get
2210                          * the sk_callback_lock before this case but after xchg
2211                          * causing the refcnt to hit zero and sock user data
2212                          * (psock) to be null and queued for garbage collection.
2213                          */
2214                         if (likely(psock)) {
2215                                 smap_list_hash_remove(psock, l);
2216                                 smap_release_sock(psock, sock);
2217                         }
2218                         free_htab_elem(htab, l);
2219                 }
2220                 raw_spin_unlock_bh(&b->lock);
2221         }
2222         rcu_read_unlock();
2223         call_rcu(&htab->rcu, __bpf_htab_free);
2224 }
2225
2226 static struct htab_elem *alloc_sock_hash_elem(struct bpf_htab *htab,
2227                                               void *key, u32 key_size, u32 hash,
2228                                               struct sock *sk,
2229                                               struct htab_elem *old_elem)
2230 {
2231         struct htab_elem *l_new;
2232
2233         if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
2234                 if (!old_elem) {
2235                         atomic_dec(&htab->count);
2236                         return ERR_PTR(-E2BIG);
2237                 }
2238         }
2239         l_new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
2240                              htab->map.numa_node);
2241         if (!l_new)
2242                 return ERR_PTR(-ENOMEM);
2243
2244         memcpy(l_new->key, key, key_size);
2245         l_new->sk = sk;
2246         l_new->hash = hash;
2247         return l_new;
2248 }
2249
2250 static inline u32 htab_map_hash(const void *key, u32 key_len)
2251 {
2252         return jhash(key, key_len, 0);
2253 }
2254
2255 static int sock_hash_get_next_key(struct bpf_map *map,
2256                                   void *key, void *next_key)
2257 {
2258         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2259         struct htab_elem *l, *next_l;
2260         struct hlist_head *h;
2261         u32 hash, key_size;
2262         int i = 0;
2263
2264         WARN_ON_ONCE(!rcu_read_lock_held());
2265
2266         key_size = map->key_size;
2267         if (!key)
2268                 goto find_first_elem;
2269         hash = htab_map_hash(key, key_size);
2270         h = select_bucket(htab, hash);
2271
2272         l = lookup_elem_raw(h, hash, key, key_size);
2273         if (!l)
2274                 goto find_first_elem;
2275         next_l = hlist_entry_safe(
2276                      rcu_dereference_raw(hlist_next_rcu(&l->hash_node)),
2277                      struct htab_elem, hash_node);
2278         if (next_l) {
2279                 memcpy(next_key, next_l->key, key_size);
2280                 return 0;
2281         }
2282
2283         /* no more elements in this hash list, go to the next bucket */
2284         i = hash & (htab->n_buckets - 1);
2285         i++;
2286
2287 find_first_elem:
2288         /* iterate over buckets */
2289         for (; i < htab->n_buckets; i++) {
2290                 h = select_bucket(htab, i);
2291
2292                 /* pick first element in the bucket */
2293                 next_l = hlist_entry_safe(
2294                                 rcu_dereference_raw(hlist_first_rcu(h)),
2295                                 struct htab_elem, hash_node);
2296                 if (next_l) {
2297                         /* if it's not empty, just return it */
2298                         memcpy(next_key, next_l->key, key_size);
2299                         return 0;
2300                 }
2301         }
2302
2303         /* iterated over all buckets and all elements */
2304         return -ENOENT;
2305 }
2306
2307 static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops,
2308                                      struct bpf_map *map,
2309                                      void *key, u64 map_flags)
2310 {
2311         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2312         struct bpf_sock_progs *progs = &htab->progs;
2313         struct htab_elem *l_new = NULL, *l_old;
2314         struct smap_psock_map_entry *e = NULL;
2315         struct hlist_head *head;
2316         struct smap_psock *psock;
2317         u32 key_size, hash;
2318         struct sock *sock;
2319         struct bucket *b;
2320         int err;
2321
2322         sock = skops->sk;
2323
2324         if (sock->sk_type != SOCK_STREAM ||
2325             sock->sk_protocol != IPPROTO_TCP)
2326                 return -EOPNOTSUPP;
2327
2328         if (unlikely(map_flags > BPF_EXIST))
2329                 return -EINVAL;
2330
2331         e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
2332         if (!e)
2333                 return -ENOMEM;
2334
2335         WARN_ON_ONCE(!rcu_read_lock_held());
2336         key_size = map->key_size;
2337         hash = htab_map_hash(key, key_size);
2338         b = __select_bucket(htab, hash);
2339         head = &b->head;
2340
2341         err = __sock_map_ctx_update_elem(map, progs, sock, NULL, key);
2342         if (err)
2343                 goto err;
2344
2345         /* bpf_map_update_elem() can be called in_irq() */
2346         raw_spin_lock_bh(&b->lock);
2347         l_old = lookup_elem_raw(head, hash, key, key_size);
2348         if (l_old && map_flags == BPF_NOEXIST) {
2349                 err = -EEXIST;
2350                 goto bucket_err;
2351         }
2352         if (!l_old && map_flags == BPF_EXIST) {
2353                 err = -ENOENT;
2354                 goto bucket_err;
2355         }
2356
2357         l_new = alloc_sock_hash_elem(htab, key, key_size, hash, sock, l_old);
2358         if (IS_ERR(l_new)) {
2359                 err = PTR_ERR(l_new);
2360                 goto bucket_err;
2361         }
2362
2363         psock = smap_psock_sk(sock);
2364         if (unlikely(!psock)) {
2365                 err = -EINVAL;
2366                 goto bucket_err;
2367         }
2368
2369         rcu_assign_pointer(e->hash_link, l_new);
2370         rcu_assign_pointer(e->htab,
2371                            container_of(map, struct bpf_htab, map));
2372         spin_lock_bh(&psock->maps_lock);
2373         list_add_tail(&e->list, &psock->maps);
2374         spin_unlock_bh(&psock->maps_lock);
2375
2376         /* add new element to the head of the list, so that
2377          * concurrent search will find it before old elem
2378          */
2379         hlist_add_head_rcu(&l_new->hash_node, head);
2380         if (l_old) {
2381                 psock = smap_psock_sk(l_old->sk);
2382
2383                 hlist_del_rcu(&l_old->hash_node);
2384                 smap_list_hash_remove(psock, l_old);
2385                 smap_release_sock(psock, l_old->sk);
2386                 free_htab_elem(htab, l_old);
2387         }
2388         raw_spin_unlock_bh(&b->lock);
2389         return 0;
2390 bucket_err:
2391         raw_spin_unlock_bh(&b->lock);
2392 err:
2393         kfree(e);
2394         psock = smap_psock_sk(sock);
2395         if (psock)
2396                 smap_release_sock(psock, sock);
2397         return err;
2398 }
2399
2400 static int sock_hash_update_elem(struct bpf_map *map,
2401                                 void *key, void *value, u64 flags)
2402 {
2403         struct bpf_sock_ops_kern skops;
2404         u32 fd = *(u32 *)value;
2405         struct socket *socket;
2406         int err;
2407
2408         socket = sockfd_lookup(fd, &err);
2409         if (!socket)
2410                 return err;
2411
2412         skops.sk = socket->sk;
2413         if (!skops.sk) {
2414                 fput(socket->file);
2415                 return -EINVAL;
2416         }
2417
2418         err = sock_hash_ctx_update_elem(&skops, map, key, flags);
2419         fput(socket->file);
2420         return err;
2421 }
2422
2423 static int sock_hash_delete_elem(struct bpf_map *map, void *key)
2424 {
2425         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2426         struct hlist_head *head;
2427         struct bucket *b;
2428         struct htab_elem *l;
2429         u32 hash, key_size;
2430         int ret = -ENOENT;
2431
2432         key_size = map->key_size;
2433         hash = htab_map_hash(key, key_size);
2434         b = __select_bucket(htab, hash);
2435         head = &b->head;
2436
2437         raw_spin_lock_bh(&b->lock);
2438         l = lookup_elem_raw(head, hash, key, key_size);
2439         if (l) {
2440                 struct sock *sock = l->sk;
2441                 struct smap_psock *psock;
2442
2443                 hlist_del_rcu(&l->hash_node);
2444                 psock = smap_psock_sk(sock);
2445                 /* This check handles a racing sock event that can get the
2446                  * sk_callback_lock before this case but after xchg happens
2447                  * causing the refcnt to hit zero and sock user data (psock)
2448                  * to be null and queued for garbage collection.
2449                  */
2450                 if (likely(psock)) {
2451                         smap_list_hash_remove(psock, l);
2452                         smap_release_sock(psock, sock);
2453                 }
2454                 free_htab_elem(htab, l);
2455                 ret = 0;
2456         }
2457         raw_spin_unlock_bh(&b->lock);
2458         return ret;
2459 }
2460
2461 struct sock  *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
2462 {
2463         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2464         struct hlist_head *head;
2465         struct htab_elem *l;
2466         u32 key_size, hash;
2467         struct bucket *b;
2468         struct sock *sk;
2469
2470         key_size = map->key_size;
2471         hash = htab_map_hash(key, key_size);
2472         b = __select_bucket(htab, hash);
2473         head = &b->head;
2474
2475         raw_spin_lock_bh(&b->lock);
2476         l = lookup_elem_raw(head, hash, key, key_size);
2477         sk = l ? l->sk : NULL;
2478         raw_spin_unlock_bh(&b->lock);
2479         return sk;
2480 }
2481
2482 const struct bpf_map_ops sock_map_ops = {
2483         .map_alloc = sock_map_alloc,
2484         .map_free = sock_map_free,
2485         .map_lookup_elem = sock_map_lookup,
2486         .map_get_next_key = sock_map_get_next_key,
2487         .map_update_elem = sock_map_update_elem,
2488         .map_delete_elem = sock_map_delete_elem,
2489         .map_release_uref = sock_map_release,
2490 };
2491
2492 const struct bpf_map_ops sock_hash_ops = {
2493         .map_alloc = sock_hash_alloc,
2494         .map_free = sock_hash_free,
2495         .map_lookup_elem = sock_map_lookup,
2496         .map_get_next_key = sock_hash_get_next_key,
2497         .map_update_elem = sock_hash_update_elem,
2498         .map_delete_elem = sock_hash_delete_elem,
2499         .map_release_uref = sock_map_release,
2500 };
2501
2502 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, bpf_sock,
2503            struct bpf_map *, map, void *, key, u64, flags)
2504 {
2505         WARN_ON_ONCE(!rcu_read_lock_held());
2506         return sock_map_ctx_update_elem(bpf_sock, map, key, flags);
2507 }
2508
2509 const struct bpf_func_proto bpf_sock_map_update_proto = {
2510         .func           = bpf_sock_map_update,
2511         .gpl_only       = false,
2512         .pkt_access     = true,
2513         .ret_type       = RET_INTEGER,
2514         .arg1_type      = ARG_PTR_TO_CTX,
2515         .arg2_type      = ARG_CONST_MAP_PTR,
2516         .arg3_type      = ARG_PTR_TO_MAP_KEY,
2517         .arg4_type      = ARG_ANYTHING,
2518 };
2519
2520 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, bpf_sock,
2521            struct bpf_map *, map, void *, key, u64, flags)
2522 {
2523         WARN_ON_ONCE(!rcu_read_lock_held());
2524         return sock_hash_ctx_update_elem(bpf_sock, map, key, flags);
2525 }
2526
2527 const struct bpf_func_proto bpf_sock_hash_update_proto = {
2528         .func           = bpf_sock_hash_update,
2529         .gpl_only       = false,
2530         .pkt_access     = true,
2531         .ret_type       = RET_INTEGER,
2532         .arg1_type      = ARG_PTR_TO_CTX,
2533         .arg2_type      = ARG_CONST_MAP_PTR,
2534         .arg3_type      = ARG_PTR_TO_MAP_KEY,
2535         .arg4_type      = ARG_ANYTHING,
2536 };