Merge branch 'linux-4.18' of git://github.com/skeggsb/linux into drm-fixes
[linux-2.6-microblaze.git] / net / xdp / xsk.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* XDP sockets
3  *
4  * AF_XDP sockets allows a channel between XDP programs and userspace
5  * applications.
6  * Copyright(c) 2018 Intel Corporation.
7  *
8  * Author(s): Björn Töpel <bjorn.topel@intel.com>
9  *            Magnus Karlsson <magnus.karlsson@intel.com>
10  */
11
12 #define pr_fmt(fmt) "AF_XDP: %s: " fmt, __func__
13
14 #include <linux/if_xdp.h>
15 #include <linux/init.h>
16 #include <linux/sched/mm.h>
17 #include <linux/sched/signal.h>
18 #include <linux/sched/task.h>
19 #include <linux/socket.h>
20 #include <linux/file.h>
21 #include <linux/uaccess.h>
22 #include <linux/net.h>
23 #include <linux/netdevice.h>
24 #include <linux/rculist.h>
25 #include <net/xdp_sock.h>
26 #include <net/xdp.h>
27
28 #include "xsk_queue.h"
29 #include "xdp_umem.h"
30
31 #define TX_BATCH_SIZE 16
32
33 static struct xdp_sock *xdp_sk(struct sock *sk)
34 {
35         return (struct xdp_sock *)sk;
36 }
37
38 bool xsk_is_setup_for_bpf_map(struct xdp_sock *xs)
39 {
40         return READ_ONCE(xs->rx) &&  READ_ONCE(xs->umem) &&
41                 READ_ONCE(xs->umem->fq);
42 }
43
44 u64 *xsk_umem_peek_addr(struct xdp_umem *umem, u64 *addr)
45 {
46         return xskq_peek_addr(umem->fq, addr);
47 }
48 EXPORT_SYMBOL(xsk_umem_peek_addr);
49
50 void xsk_umem_discard_addr(struct xdp_umem *umem)
51 {
52         xskq_discard_addr(umem->fq);
53 }
54 EXPORT_SYMBOL(xsk_umem_discard_addr);
55
56 static int __xsk_rcv(struct xdp_sock *xs, struct xdp_buff *xdp, u32 len)
57 {
58         void *buffer;
59         u64 addr;
60         int err;
61
62         if (!xskq_peek_addr(xs->umem->fq, &addr) ||
63             len > xs->umem->chunk_size_nohr) {
64                 xs->rx_dropped++;
65                 return -ENOSPC;
66         }
67
68         addr += xs->umem->headroom;
69
70         buffer = xdp_umem_get_data(xs->umem, addr);
71         memcpy(buffer, xdp->data, len);
72         err = xskq_produce_batch_desc(xs->rx, addr, len);
73         if (!err) {
74                 xskq_discard_addr(xs->umem->fq);
75                 xdp_return_buff(xdp);
76                 return 0;
77         }
78
79         xs->rx_dropped++;
80         return err;
81 }
82
83 static int __xsk_rcv_zc(struct xdp_sock *xs, struct xdp_buff *xdp, u32 len)
84 {
85         int err = xskq_produce_batch_desc(xs->rx, (u64)xdp->handle, len);
86
87         if (err) {
88                 xdp_return_buff(xdp);
89                 xs->rx_dropped++;
90         }
91
92         return err;
93 }
94
95 int xsk_rcv(struct xdp_sock *xs, struct xdp_buff *xdp)
96 {
97         u32 len;
98
99         if (xs->dev != xdp->rxq->dev || xs->queue_id != xdp->rxq->queue_index)
100                 return -EINVAL;
101
102         len = xdp->data_end - xdp->data;
103
104         return (xdp->rxq->mem.type == MEM_TYPE_ZERO_COPY) ?
105                 __xsk_rcv_zc(xs, xdp, len) : __xsk_rcv(xs, xdp, len);
106 }
107
108 void xsk_flush(struct xdp_sock *xs)
109 {
110         xskq_produce_flush_desc(xs->rx);
111         xs->sk.sk_data_ready(&xs->sk);
112 }
113
114 int xsk_generic_rcv(struct xdp_sock *xs, struct xdp_buff *xdp)
115 {
116         u32 len = xdp->data_end - xdp->data;
117         void *buffer;
118         u64 addr;
119         int err;
120
121         if (xs->dev != xdp->rxq->dev || xs->queue_id != xdp->rxq->queue_index)
122                 return -EINVAL;
123
124         if (!xskq_peek_addr(xs->umem->fq, &addr) ||
125             len > xs->umem->chunk_size_nohr) {
126                 xs->rx_dropped++;
127                 return -ENOSPC;
128         }
129
130         addr += xs->umem->headroom;
131
132         buffer = xdp_umem_get_data(xs->umem, addr);
133         memcpy(buffer, xdp->data, len);
134         err = xskq_produce_batch_desc(xs->rx, addr, len);
135         if (!err) {
136                 xskq_discard_addr(xs->umem->fq);
137                 xsk_flush(xs);
138                 return 0;
139         }
140
141         xs->rx_dropped++;
142         return err;
143 }
144
145 void xsk_umem_complete_tx(struct xdp_umem *umem, u32 nb_entries)
146 {
147         xskq_produce_flush_addr_n(umem->cq, nb_entries);
148 }
149 EXPORT_SYMBOL(xsk_umem_complete_tx);
150
151 void xsk_umem_consume_tx_done(struct xdp_umem *umem)
152 {
153         struct xdp_sock *xs;
154
155         rcu_read_lock();
156         list_for_each_entry_rcu(xs, &umem->xsk_list, list) {
157                 xs->sk.sk_write_space(&xs->sk);
158         }
159         rcu_read_unlock();
160 }
161 EXPORT_SYMBOL(xsk_umem_consume_tx_done);
162
163 bool xsk_umem_consume_tx(struct xdp_umem *umem, dma_addr_t *dma, u32 *len)
164 {
165         struct xdp_desc desc;
166         struct xdp_sock *xs;
167
168         rcu_read_lock();
169         list_for_each_entry_rcu(xs, &umem->xsk_list, list) {
170                 if (!xskq_peek_desc(xs->tx, &desc))
171                         continue;
172
173                 if (xskq_produce_addr_lazy(umem->cq, desc.addr))
174                         goto out;
175
176                 *dma = xdp_umem_get_dma(umem, desc.addr);
177                 *len = desc.len;
178
179                 xskq_discard_desc(xs->tx);
180                 rcu_read_unlock();
181                 return true;
182         }
183
184 out:
185         rcu_read_unlock();
186         return false;
187 }
188 EXPORT_SYMBOL(xsk_umem_consume_tx);
189
190 static int xsk_zc_xmit(struct sock *sk)
191 {
192         struct xdp_sock *xs = xdp_sk(sk);
193         struct net_device *dev = xs->dev;
194
195         return dev->netdev_ops->ndo_xsk_async_xmit(dev, xs->queue_id);
196 }
197
198 static void xsk_destruct_skb(struct sk_buff *skb)
199 {
200         u64 addr = (u64)(long)skb_shinfo(skb)->destructor_arg;
201         struct xdp_sock *xs = xdp_sk(skb->sk);
202
203         WARN_ON_ONCE(xskq_produce_addr(xs->umem->cq, addr));
204
205         sock_wfree(skb);
206 }
207
208 static int xsk_generic_xmit(struct sock *sk, struct msghdr *m,
209                             size_t total_len)
210 {
211         u32 max_batch = TX_BATCH_SIZE;
212         struct xdp_sock *xs = xdp_sk(sk);
213         bool sent_frame = false;
214         struct xdp_desc desc;
215         struct sk_buff *skb;
216         int err = 0;
217
218         if (unlikely(!xs->tx))
219                 return -ENOBUFS;
220
221         mutex_lock(&xs->mutex);
222
223         while (xskq_peek_desc(xs->tx, &desc)) {
224                 char *buffer;
225                 u64 addr;
226                 u32 len;
227
228                 if (max_batch-- == 0) {
229                         err = -EAGAIN;
230                         goto out;
231                 }
232
233                 if (xskq_reserve_addr(xs->umem->cq)) {
234                         err = -EAGAIN;
235                         goto out;
236                 }
237
238                 len = desc.len;
239                 if (unlikely(len > xs->dev->mtu)) {
240                         err = -EMSGSIZE;
241                         goto out;
242                 }
243
244                 if (xs->queue_id >= xs->dev->real_num_tx_queues) {
245                         err = -ENXIO;
246                         goto out;
247                 }
248
249                 skb = sock_alloc_send_skb(sk, len, 1, &err);
250                 if (unlikely(!skb)) {
251                         err = -EAGAIN;
252                         goto out;
253                 }
254
255                 skb_put(skb, len);
256                 addr = desc.addr;
257                 buffer = xdp_umem_get_data(xs->umem, addr);
258                 err = skb_store_bits(skb, 0, buffer, len);
259                 if (unlikely(err)) {
260                         kfree_skb(skb);
261                         goto out;
262                 }
263
264                 skb->dev = xs->dev;
265                 skb->priority = sk->sk_priority;
266                 skb->mark = sk->sk_mark;
267                 skb_shinfo(skb)->destructor_arg = (void *)(long)addr;
268                 skb->destructor = xsk_destruct_skb;
269
270                 err = dev_direct_xmit(skb, xs->queue_id);
271                 /* Ignore NET_XMIT_CN as packet might have been sent */
272                 if (err == NET_XMIT_DROP || err == NETDEV_TX_BUSY) {
273                         err = -EAGAIN;
274                         /* SKB consumed by dev_direct_xmit() */
275                         goto out;
276                 }
277
278                 sent_frame = true;
279                 xskq_discard_desc(xs->tx);
280         }
281
282 out:
283         if (sent_frame)
284                 sk->sk_write_space(sk);
285
286         mutex_unlock(&xs->mutex);
287         return err;
288 }
289
290 static int xsk_sendmsg(struct socket *sock, struct msghdr *m, size_t total_len)
291 {
292         bool need_wait = !(m->msg_flags & MSG_DONTWAIT);
293         struct sock *sk = sock->sk;
294         struct xdp_sock *xs = xdp_sk(sk);
295
296         if (unlikely(!xs->dev))
297                 return -ENXIO;
298         if (unlikely(!(xs->dev->flags & IFF_UP)))
299                 return -ENETDOWN;
300         if (need_wait)
301                 return -EOPNOTSUPP;
302
303         return (xs->zc) ? xsk_zc_xmit(sk) : xsk_generic_xmit(sk, m, total_len);
304 }
305
306 static unsigned int xsk_poll(struct file *file, struct socket *sock,
307                              struct poll_table_struct *wait)
308 {
309         unsigned int mask = datagram_poll(file, sock, wait);
310         struct sock *sk = sock->sk;
311         struct xdp_sock *xs = xdp_sk(sk);
312
313         if (xs->rx && !xskq_empty_desc(xs->rx))
314                 mask |= POLLIN | POLLRDNORM;
315         if (xs->tx && !xskq_full_desc(xs->tx))
316                 mask |= POLLOUT | POLLWRNORM;
317
318         return mask;
319 }
320
321 static int xsk_init_queue(u32 entries, struct xsk_queue **queue,
322                           bool umem_queue)
323 {
324         struct xsk_queue *q;
325
326         if (entries == 0 || *queue || !is_power_of_2(entries))
327                 return -EINVAL;
328
329         q = xskq_create(entries, umem_queue);
330         if (!q)
331                 return -ENOMEM;
332
333         /* Make sure queue is ready before it can be seen by others */
334         smp_wmb();
335         *queue = q;
336         return 0;
337 }
338
339 static int xsk_release(struct socket *sock)
340 {
341         struct sock *sk = sock->sk;
342         struct xdp_sock *xs = xdp_sk(sk);
343         struct net *net;
344
345         if (!sk)
346                 return 0;
347
348         net = sock_net(sk);
349
350         local_bh_disable();
351         sock_prot_inuse_add(net, sk->sk_prot, -1);
352         local_bh_enable();
353
354         if (xs->dev) {
355                 /* Wait for driver to stop using the xdp socket. */
356                 synchronize_net();
357                 dev_put(xs->dev);
358                 xs->dev = NULL;
359         }
360
361         sock_orphan(sk);
362         sock->sk = NULL;
363
364         sk_refcnt_debug_release(sk);
365         sock_put(sk);
366
367         return 0;
368 }
369
370 static struct socket *xsk_lookup_xsk_from_fd(int fd)
371 {
372         struct socket *sock;
373         int err;
374
375         sock = sockfd_lookup(fd, &err);
376         if (!sock)
377                 return ERR_PTR(-ENOTSOCK);
378
379         if (sock->sk->sk_family != PF_XDP) {
380                 sockfd_put(sock);
381                 return ERR_PTR(-ENOPROTOOPT);
382         }
383
384         return sock;
385 }
386
387 static int xsk_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
388 {
389         struct sockaddr_xdp *sxdp = (struct sockaddr_xdp *)addr;
390         struct sock *sk = sock->sk;
391         struct xdp_sock *xs = xdp_sk(sk);
392         struct net_device *dev;
393         u32 flags, qid;
394         int err = 0;
395
396         if (addr_len < sizeof(struct sockaddr_xdp))
397                 return -EINVAL;
398         if (sxdp->sxdp_family != AF_XDP)
399                 return -EINVAL;
400
401         mutex_lock(&xs->mutex);
402         if (xs->dev) {
403                 err = -EBUSY;
404                 goto out_release;
405         }
406
407         dev = dev_get_by_index(sock_net(sk), sxdp->sxdp_ifindex);
408         if (!dev) {
409                 err = -ENODEV;
410                 goto out_release;
411         }
412
413         if (!xs->rx && !xs->tx) {
414                 err = -EINVAL;
415                 goto out_unlock;
416         }
417
418         qid = sxdp->sxdp_queue_id;
419
420         if ((xs->rx && qid >= dev->real_num_rx_queues) ||
421             (xs->tx && qid >= dev->real_num_tx_queues)) {
422                 err = -EINVAL;
423                 goto out_unlock;
424         }
425
426         flags = sxdp->sxdp_flags;
427
428         if (flags & XDP_SHARED_UMEM) {
429                 struct xdp_sock *umem_xs;
430                 struct socket *sock;
431
432                 if ((flags & XDP_COPY) || (flags & XDP_ZEROCOPY)) {
433                         /* Cannot specify flags for shared sockets. */
434                         err = -EINVAL;
435                         goto out_unlock;
436                 }
437
438                 if (xs->umem) {
439                         /* We have already our own. */
440                         err = -EINVAL;
441                         goto out_unlock;
442                 }
443
444                 sock = xsk_lookup_xsk_from_fd(sxdp->sxdp_shared_umem_fd);
445                 if (IS_ERR(sock)) {
446                         err = PTR_ERR(sock);
447                         goto out_unlock;
448                 }
449
450                 umem_xs = xdp_sk(sock->sk);
451                 if (!umem_xs->umem) {
452                         /* No umem to inherit. */
453                         err = -EBADF;
454                         sockfd_put(sock);
455                         goto out_unlock;
456                 } else if (umem_xs->dev != dev || umem_xs->queue_id != qid) {
457                         err = -EINVAL;
458                         sockfd_put(sock);
459                         goto out_unlock;
460                 }
461
462                 xdp_get_umem(umem_xs->umem);
463                 xs->umem = umem_xs->umem;
464                 sockfd_put(sock);
465         } else if (!xs->umem || !xdp_umem_validate_queues(xs->umem)) {
466                 err = -EINVAL;
467                 goto out_unlock;
468         } else {
469                 /* This xsk has its own umem. */
470                 xskq_set_umem(xs->umem->fq, &xs->umem->props);
471                 xskq_set_umem(xs->umem->cq, &xs->umem->props);
472
473                 err = xdp_umem_assign_dev(xs->umem, dev, qid, flags);
474                 if (err)
475                         goto out_unlock;
476         }
477
478         xs->dev = dev;
479         xs->zc = xs->umem->zc;
480         xs->queue_id = qid;
481         xskq_set_umem(xs->rx, &xs->umem->props);
482         xskq_set_umem(xs->tx, &xs->umem->props);
483         xdp_add_sk_umem(xs->umem, xs);
484
485 out_unlock:
486         if (err)
487                 dev_put(dev);
488 out_release:
489         mutex_unlock(&xs->mutex);
490         return err;
491 }
492
493 static int xsk_setsockopt(struct socket *sock, int level, int optname,
494                           char __user *optval, unsigned int optlen)
495 {
496         struct sock *sk = sock->sk;
497         struct xdp_sock *xs = xdp_sk(sk);
498         int err;
499
500         if (level != SOL_XDP)
501                 return -ENOPROTOOPT;
502
503         switch (optname) {
504         case XDP_RX_RING:
505         case XDP_TX_RING:
506         {
507                 struct xsk_queue **q;
508                 int entries;
509
510                 if (optlen < sizeof(entries))
511                         return -EINVAL;
512                 if (copy_from_user(&entries, optval, sizeof(entries)))
513                         return -EFAULT;
514
515                 mutex_lock(&xs->mutex);
516                 q = (optname == XDP_TX_RING) ? &xs->tx : &xs->rx;
517                 err = xsk_init_queue(entries, q, false);
518                 mutex_unlock(&xs->mutex);
519                 return err;
520         }
521         case XDP_UMEM_REG:
522         {
523                 struct xdp_umem_reg mr;
524                 struct xdp_umem *umem;
525
526                 if (copy_from_user(&mr, optval, sizeof(mr)))
527                         return -EFAULT;
528
529                 mutex_lock(&xs->mutex);
530                 if (xs->umem) {
531                         mutex_unlock(&xs->mutex);
532                         return -EBUSY;
533                 }
534
535                 umem = xdp_umem_create(&mr);
536                 if (IS_ERR(umem)) {
537                         mutex_unlock(&xs->mutex);
538                         return PTR_ERR(umem);
539                 }
540
541                 /* Make sure umem is ready before it can be seen by others */
542                 smp_wmb();
543                 xs->umem = umem;
544                 mutex_unlock(&xs->mutex);
545                 return 0;
546         }
547         case XDP_UMEM_FILL_RING:
548         case XDP_UMEM_COMPLETION_RING:
549         {
550                 struct xsk_queue **q;
551                 int entries;
552
553                 if (copy_from_user(&entries, optval, sizeof(entries)))
554                         return -EFAULT;
555
556                 mutex_lock(&xs->mutex);
557                 if (!xs->umem) {
558                         mutex_unlock(&xs->mutex);
559                         return -EINVAL;
560                 }
561
562                 q = (optname == XDP_UMEM_FILL_RING) ? &xs->umem->fq :
563                         &xs->umem->cq;
564                 err = xsk_init_queue(entries, q, true);
565                 mutex_unlock(&xs->mutex);
566                 return err;
567         }
568         default:
569                 break;
570         }
571
572         return -ENOPROTOOPT;
573 }
574
575 static int xsk_getsockopt(struct socket *sock, int level, int optname,
576                           char __user *optval, int __user *optlen)
577 {
578         struct sock *sk = sock->sk;
579         struct xdp_sock *xs = xdp_sk(sk);
580         int len;
581
582         if (level != SOL_XDP)
583                 return -ENOPROTOOPT;
584
585         if (get_user(len, optlen))
586                 return -EFAULT;
587         if (len < 0)
588                 return -EINVAL;
589
590         switch (optname) {
591         case XDP_STATISTICS:
592         {
593                 struct xdp_statistics stats;
594
595                 if (len < sizeof(stats))
596                         return -EINVAL;
597
598                 mutex_lock(&xs->mutex);
599                 stats.rx_dropped = xs->rx_dropped;
600                 stats.rx_invalid_descs = xskq_nb_invalid_descs(xs->rx);
601                 stats.tx_invalid_descs = xskq_nb_invalid_descs(xs->tx);
602                 mutex_unlock(&xs->mutex);
603
604                 if (copy_to_user(optval, &stats, sizeof(stats)))
605                         return -EFAULT;
606                 if (put_user(sizeof(stats), optlen))
607                         return -EFAULT;
608
609                 return 0;
610         }
611         case XDP_MMAP_OFFSETS:
612         {
613                 struct xdp_mmap_offsets off;
614
615                 if (len < sizeof(off))
616                         return -EINVAL;
617
618                 off.rx.producer = offsetof(struct xdp_rxtx_ring, ptrs.producer);
619                 off.rx.consumer = offsetof(struct xdp_rxtx_ring, ptrs.consumer);
620                 off.rx.desc     = offsetof(struct xdp_rxtx_ring, desc);
621                 off.tx.producer = offsetof(struct xdp_rxtx_ring, ptrs.producer);
622                 off.tx.consumer = offsetof(struct xdp_rxtx_ring, ptrs.consumer);
623                 off.tx.desc     = offsetof(struct xdp_rxtx_ring, desc);
624
625                 off.fr.producer = offsetof(struct xdp_umem_ring, ptrs.producer);
626                 off.fr.consumer = offsetof(struct xdp_umem_ring, ptrs.consumer);
627                 off.fr.desc     = offsetof(struct xdp_umem_ring, desc);
628                 off.cr.producer = offsetof(struct xdp_umem_ring, ptrs.producer);
629                 off.cr.consumer = offsetof(struct xdp_umem_ring, ptrs.consumer);
630                 off.cr.desc     = offsetof(struct xdp_umem_ring, desc);
631
632                 len = sizeof(off);
633                 if (copy_to_user(optval, &off, len))
634                         return -EFAULT;
635                 if (put_user(len, optlen))
636                         return -EFAULT;
637
638                 return 0;
639         }
640         default:
641                 break;
642         }
643
644         return -EOPNOTSUPP;
645 }
646
647 static int xsk_mmap(struct file *file, struct socket *sock,
648                     struct vm_area_struct *vma)
649 {
650         loff_t offset = (loff_t)vma->vm_pgoff << PAGE_SHIFT;
651         unsigned long size = vma->vm_end - vma->vm_start;
652         struct xdp_sock *xs = xdp_sk(sock->sk);
653         struct xsk_queue *q = NULL;
654         struct xdp_umem *umem;
655         unsigned long pfn;
656         struct page *qpg;
657
658         if (offset == XDP_PGOFF_RX_RING) {
659                 q = READ_ONCE(xs->rx);
660         } else if (offset == XDP_PGOFF_TX_RING) {
661                 q = READ_ONCE(xs->tx);
662         } else {
663                 umem = READ_ONCE(xs->umem);
664                 if (!umem)
665                         return -EINVAL;
666
667                 if (offset == XDP_UMEM_PGOFF_FILL_RING)
668                         q = READ_ONCE(umem->fq);
669                 else if (offset == XDP_UMEM_PGOFF_COMPLETION_RING)
670                         q = READ_ONCE(umem->cq);
671         }
672
673         if (!q)
674                 return -EINVAL;
675
676         qpg = virt_to_head_page(q->ring);
677         if (size > (PAGE_SIZE << compound_order(qpg)))
678                 return -EINVAL;
679
680         pfn = virt_to_phys(q->ring) >> PAGE_SHIFT;
681         return remap_pfn_range(vma, vma->vm_start, pfn,
682                                size, vma->vm_page_prot);
683 }
684
685 static struct proto xsk_proto = {
686         .name =         "XDP",
687         .owner =        THIS_MODULE,
688         .obj_size =     sizeof(struct xdp_sock),
689 };
690
691 static const struct proto_ops xsk_proto_ops = {
692         .family         = PF_XDP,
693         .owner          = THIS_MODULE,
694         .release        = xsk_release,
695         .bind           = xsk_bind,
696         .connect        = sock_no_connect,
697         .socketpair     = sock_no_socketpair,
698         .accept         = sock_no_accept,
699         .getname        = sock_no_getname,
700         .poll           = xsk_poll,
701         .ioctl          = sock_no_ioctl,
702         .listen         = sock_no_listen,
703         .shutdown       = sock_no_shutdown,
704         .setsockopt     = xsk_setsockopt,
705         .getsockopt     = xsk_getsockopt,
706         .sendmsg        = xsk_sendmsg,
707         .recvmsg        = sock_no_recvmsg,
708         .mmap           = xsk_mmap,
709         .sendpage       = sock_no_sendpage,
710 };
711
712 static void xsk_destruct(struct sock *sk)
713 {
714         struct xdp_sock *xs = xdp_sk(sk);
715
716         if (!sock_flag(sk, SOCK_DEAD))
717                 return;
718
719         xskq_destroy(xs->rx);
720         xskq_destroy(xs->tx);
721         xdp_del_sk_umem(xs->umem, xs);
722         xdp_put_umem(xs->umem);
723
724         sk_refcnt_debug_dec(sk);
725 }
726
727 static int xsk_create(struct net *net, struct socket *sock, int protocol,
728                       int kern)
729 {
730         struct sock *sk;
731         struct xdp_sock *xs;
732
733         if (!ns_capable(net->user_ns, CAP_NET_RAW))
734                 return -EPERM;
735         if (sock->type != SOCK_RAW)
736                 return -ESOCKTNOSUPPORT;
737
738         if (protocol)
739                 return -EPROTONOSUPPORT;
740
741         sock->state = SS_UNCONNECTED;
742
743         sk = sk_alloc(net, PF_XDP, GFP_KERNEL, &xsk_proto, kern);
744         if (!sk)
745                 return -ENOBUFS;
746
747         sock->ops = &xsk_proto_ops;
748
749         sock_init_data(sock, sk);
750
751         sk->sk_family = PF_XDP;
752
753         sk->sk_destruct = xsk_destruct;
754         sk_refcnt_debug_inc(sk);
755
756         xs = xdp_sk(sk);
757         mutex_init(&xs->mutex);
758
759         local_bh_disable();
760         sock_prot_inuse_add(net, &xsk_proto, 1);
761         local_bh_enable();
762
763         return 0;
764 }
765
766 static const struct net_proto_family xsk_family_ops = {
767         .family = PF_XDP,
768         .create = xsk_create,
769         .owner  = THIS_MODULE,
770 };
771
772 static int __init xsk_init(void)
773 {
774         int err;
775
776         err = proto_register(&xsk_proto, 0 /* no slab */);
777         if (err)
778                 goto out;
779
780         err = sock_register(&xsk_family_ops);
781         if (err)
782                 goto out_proto;
783
784         return 0;
785
786 out_proto:
787         proto_unregister(&xsk_proto);
788 out:
789         return err;
790 }
791
792 fs_initcall(xsk_init);