Merge tag 'nfs-for-5.5-1' of git://git.linux-nfs.org/projects/trondmy/linux-nfs
[linux-2.6-microblaze.git] / net / vmw_vsock / af_vsock.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * VMware vSockets Driver
4  *
5  * Copyright (C) 2007-2013 VMware, Inc. All rights reserved.
6  */
7
8 /* Implementation notes:
9  *
10  * - There are two kinds of sockets: those created by user action (such as
11  * calling socket(2)) and those created by incoming connection request packets.
12  *
13  * - There are two "global" tables, one for bound sockets (sockets that have
14  * specified an address that they are responsible for) and one for connected
15  * sockets (sockets that have established a connection with another socket).
16  * These tables are "global" in that all sockets on the system are placed
17  * within them. - Note, though, that the bound table contains an extra entry
18  * for a list of unbound sockets and SOCK_DGRAM sockets will always remain in
19  * that list. The bound table is used solely for lookup of sockets when packets
20  * are received and that's not necessary for SOCK_DGRAM sockets since we create
21  * a datagram handle for each and need not perform a lookup.  Keeping SOCK_DGRAM
22  * sockets out of the bound hash buckets will reduce the chance of collisions
23  * when looking for SOCK_STREAM sockets and prevents us from having to check the
24  * socket type in the hash table lookups.
25  *
26  * - Sockets created by user action will either be "client" sockets that
27  * initiate a connection or "server" sockets that listen for connections; we do
28  * not support simultaneous connects (two "client" sockets connecting).
29  *
30  * - "Server" sockets are referred to as listener sockets throughout this
31  * implementation because they are in the TCP_LISTEN state.  When a
32  * connection request is received (the second kind of socket mentioned above),
33  * we create a new socket and refer to it as a pending socket.  These pending
34  * sockets are placed on the pending connection list of the listener socket.
35  * When future packets are received for the address the listener socket is
36  * bound to, we check if the source of the packet is from one that has an
37  * existing pending connection.  If it does, we process the packet for the
38  * pending socket.  When that socket reaches the connected state, it is removed
39  * from the listener socket's pending list and enqueued in the listener
40  * socket's accept queue.  Callers of accept(2) will accept connected sockets
41  * from the listener socket's accept queue.  If the socket cannot be accepted
42  * for some reason then it is marked rejected.  Once the connection is
43  * accepted, it is owned by the user process and the responsibility for cleanup
44  * falls with that user process.
45  *
46  * - It is possible that these pending sockets will never reach the connected
47  * state; in fact, we may never receive another packet after the connection
48  * request.  Because of this, we must schedule a cleanup function to run in the
49  * future, after some amount of time passes where a connection should have been
50  * established.  This function ensures that the socket is off all lists so it
51  * cannot be retrieved, then drops all references to the socket so it is cleaned
52  * up (sock_put() -> sk_free() -> our sk_destruct implementation).  Note this
53  * function will also cleanup rejected sockets, those that reach the connected
54  * state but leave it before they have been accepted.
55  *
56  * - Lock ordering for pending or accept queue sockets is:
57  *
58  *     lock_sock(listener);
59  *     lock_sock_nested(pending, SINGLE_DEPTH_NESTING);
60  *
61  * Using explicit nested locking keeps lockdep happy since normally only one
62  * lock of a given class may be taken at a time.
63  *
64  * - Sockets created by user action will be cleaned up when the user process
65  * calls close(2), causing our release implementation to be called. Our release
66  * implementation will perform some cleanup then drop the last reference so our
67  * sk_destruct implementation is invoked.  Our sk_destruct implementation will
68  * perform additional cleanup that's common for both types of sockets.
69  *
70  * - A socket's reference count is what ensures that the structure won't be
71  * freed.  Each entry in a list (such as the "global" bound and connected tables
72  * and the listener socket's pending list and connected queue) ensures a
73  * reference.  When we defer work until process context and pass a socket as our
74  * argument, we must ensure the reference count is increased to ensure the
75  * socket isn't freed before the function is run; the deferred function will
76  * then drop the reference.
77  *
78  * - sk->sk_state uses the TCP state constants because they are widely used by
79  * other address families and exposed to userspace tools like ss(8):
80  *
81  *   TCP_CLOSE - unconnected
82  *   TCP_SYN_SENT - connecting
83  *   TCP_ESTABLISHED - connected
84  *   TCP_CLOSING - disconnecting
85  *   TCP_LISTEN - listening
86  */
87
88 #include <linux/types.h>
89 #include <linux/bitops.h>
90 #include <linux/cred.h>
91 #include <linux/init.h>
92 #include <linux/io.h>
93 #include <linux/kernel.h>
94 #include <linux/sched/signal.h>
95 #include <linux/kmod.h>
96 #include <linux/list.h>
97 #include <linux/miscdevice.h>
98 #include <linux/module.h>
99 #include <linux/mutex.h>
100 #include <linux/net.h>
101 #include <linux/poll.h>
102 #include <linux/random.h>
103 #include <linux/skbuff.h>
104 #include <linux/smp.h>
105 #include <linux/socket.h>
106 #include <linux/stddef.h>
107 #include <linux/unistd.h>
108 #include <linux/wait.h>
109 #include <linux/workqueue.h>
110 #include <net/sock.h>
111 #include <net/af_vsock.h>
112
113 static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr);
114 static void vsock_sk_destruct(struct sock *sk);
115 static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb);
116
117 /* Protocol family. */
118 static struct proto vsock_proto = {
119         .name = "AF_VSOCK",
120         .owner = THIS_MODULE,
121         .obj_size = sizeof(struct vsock_sock),
122 };
123
124 /* The default peer timeout indicates how long we will wait for a peer response
125  * to a control message.
126  */
127 #define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ)
128
129 #define VSOCK_DEFAULT_BUFFER_SIZE     (1024 * 256)
130 #define VSOCK_DEFAULT_BUFFER_MAX_SIZE (1024 * 256)
131 #define VSOCK_DEFAULT_BUFFER_MIN_SIZE 128
132
133 /* Transport used for host->guest communication */
134 static const struct vsock_transport *transport_h2g;
135 /* Transport used for guest->host communication */
136 static const struct vsock_transport *transport_g2h;
137 /* Transport used for DGRAM communication */
138 static const struct vsock_transport *transport_dgram;
139 static DEFINE_MUTEX(vsock_register_mutex);
140
141 /**** UTILS ****/
142
143 /* Each bound VSocket is stored in the bind hash table and each connected
144  * VSocket is stored in the connected hash table.
145  *
146  * Unbound sockets are all put on the same list attached to the end of the hash
147  * table (vsock_unbound_sockets).  Bound sockets are added to the hash table in
148  * the bucket that their local address hashes to (vsock_bound_sockets(addr)
149  * represents the list that addr hashes to).
150  *
151  * Specifically, we initialize the vsock_bind_table array to a size of
152  * VSOCK_HASH_SIZE + 1 so that vsock_bind_table[0] through
153  * vsock_bind_table[VSOCK_HASH_SIZE - 1] are for bound sockets and
154  * vsock_bind_table[VSOCK_HASH_SIZE] is for unbound sockets.  The hash function
155  * mods with VSOCK_HASH_SIZE to ensure this.
156  */
157 #define MAX_PORT_RETRIES        24
158
159 #define VSOCK_HASH(addr)        ((addr)->svm_port % VSOCK_HASH_SIZE)
160 #define vsock_bound_sockets(addr) (&vsock_bind_table[VSOCK_HASH(addr)])
161 #define vsock_unbound_sockets     (&vsock_bind_table[VSOCK_HASH_SIZE])
162
163 /* XXX This can probably be implemented in a better way. */
164 #define VSOCK_CONN_HASH(src, dst)                               \
165         (((src)->svm_cid ^ (dst)->svm_port) % VSOCK_HASH_SIZE)
166 #define vsock_connected_sockets(src, dst)               \
167         (&vsock_connected_table[VSOCK_CONN_HASH(src, dst)])
168 #define vsock_connected_sockets_vsk(vsk)                                \
169         vsock_connected_sockets(&(vsk)->remote_addr, &(vsk)->local_addr)
170
171 struct list_head vsock_bind_table[VSOCK_HASH_SIZE + 1];
172 EXPORT_SYMBOL_GPL(vsock_bind_table);
173 struct list_head vsock_connected_table[VSOCK_HASH_SIZE];
174 EXPORT_SYMBOL_GPL(vsock_connected_table);
175 DEFINE_SPINLOCK(vsock_table_lock);
176 EXPORT_SYMBOL_GPL(vsock_table_lock);
177
178 /* Autobind this socket to the local address if necessary. */
179 static int vsock_auto_bind(struct vsock_sock *vsk)
180 {
181         struct sock *sk = sk_vsock(vsk);
182         struct sockaddr_vm local_addr;
183
184         if (vsock_addr_bound(&vsk->local_addr))
185                 return 0;
186         vsock_addr_init(&local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
187         return __vsock_bind(sk, &local_addr);
188 }
189
190 static void vsock_init_tables(void)
191 {
192         int i;
193
194         for (i = 0; i < ARRAY_SIZE(vsock_bind_table); i++)
195                 INIT_LIST_HEAD(&vsock_bind_table[i]);
196
197         for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++)
198                 INIT_LIST_HEAD(&vsock_connected_table[i]);
199 }
200
201 static void __vsock_insert_bound(struct list_head *list,
202                                  struct vsock_sock *vsk)
203 {
204         sock_hold(&vsk->sk);
205         list_add(&vsk->bound_table, list);
206 }
207
208 static void __vsock_insert_connected(struct list_head *list,
209                                      struct vsock_sock *vsk)
210 {
211         sock_hold(&vsk->sk);
212         list_add(&vsk->connected_table, list);
213 }
214
215 static void __vsock_remove_bound(struct vsock_sock *vsk)
216 {
217         list_del_init(&vsk->bound_table);
218         sock_put(&vsk->sk);
219 }
220
221 static void __vsock_remove_connected(struct vsock_sock *vsk)
222 {
223         list_del_init(&vsk->connected_table);
224         sock_put(&vsk->sk);
225 }
226
227 static struct sock *__vsock_find_bound_socket(struct sockaddr_vm *addr)
228 {
229         struct vsock_sock *vsk;
230
231         list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table) {
232                 if (vsock_addr_equals_addr(addr, &vsk->local_addr))
233                         return sk_vsock(vsk);
234
235                 if (addr->svm_port == vsk->local_addr.svm_port &&
236                     (vsk->local_addr.svm_cid == VMADDR_CID_ANY ||
237                      addr->svm_cid == VMADDR_CID_ANY))
238                         return sk_vsock(vsk);
239         }
240
241         return NULL;
242 }
243
244 static struct sock *__vsock_find_connected_socket(struct sockaddr_vm *src,
245                                                   struct sockaddr_vm *dst)
246 {
247         struct vsock_sock *vsk;
248
249         list_for_each_entry(vsk, vsock_connected_sockets(src, dst),
250                             connected_table) {
251                 if (vsock_addr_equals_addr(src, &vsk->remote_addr) &&
252                     dst->svm_port == vsk->local_addr.svm_port) {
253                         return sk_vsock(vsk);
254                 }
255         }
256
257         return NULL;
258 }
259
260 static void vsock_insert_unbound(struct vsock_sock *vsk)
261 {
262         spin_lock_bh(&vsock_table_lock);
263         __vsock_insert_bound(vsock_unbound_sockets, vsk);
264         spin_unlock_bh(&vsock_table_lock);
265 }
266
267 void vsock_insert_connected(struct vsock_sock *vsk)
268 {
269         struct list_head *list = vsock_connected_sockets(
270                 &vsk->remote_addr, &vsk->local_addr);
271
272         spin_lock_bh(&vsock_table_lock);
273         __vsock_insert_connected(list, vsk);
274         spin_unlock_bh(&vsock_table_lock);
275 }
276 EXPORT_SYMBOL_GPL(vsock_insert_connected);
277
278 void vsock_remove_bound(struct vsock_sock *vsk)
279 {
280         spin_lock_bh(&vsock_table_lock);
281         if (__vsock_in_bound_table(vsk))
282                 __vsock_remove_bound(vsk);
283         spin_unlock_bh(&vsock_table_lock);
284 }
285 EXPORT_SYMBOL_GPL(vsock_remove_bound);
286
287 void vsock_remove_connected(struct vsock_sock *vsk)
288 {
289         spin_lock_bh(&vsock_table_lock);
290         if (__vsock_in_connected_table(vsk))
291                 __vsock_remove_connected(vsk);
292         spin_unlock_bh(&vsock_table_lock);
293 }
294 EXPORT_SYMBOL_GPL(vsock_remove_connected);
295
296 struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr)
297 {
298         struct sock *sk;
299
300         spin_lock_bh(&vsock_table_lock);
301         sk = __vsock_find_bound_socket(addr);
302         if (sk)
303                 sock_hold(sk);
304
305         spin_unlock_bh(&vsock_table_lock);
306
307         return sk;
308 }
309 EXPORT_SYMBOL_GPL(vsock_find_bound_socket);
310
311 struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
312                                          struct sockaddr_vm *dst)
313 {
314         struct sock *sk;
315
316         spin_lock_bh(&vsock_table_lock);
317         sk = __vsock_find_connected_socket(src, dst);
318         if (sk)
319                 sock_hold(sk);
320
321         spin_unlock_bh(&vsock_table_lock);
322
323         return sk;
324 }
325 EXPORT_SYMBOL_GPL(vsock_find_connected_socket);
326
327 void vsock_remove_sock(struct vsock_sock *vsk)
328 {
329         vsock_remove_bound(vsk);
330         vsock_remove_connected(vsk);
331 }
332 EXPORT_SYMBOL_GPL(vsock_remove_sock);
333
334 void vsock_for_each_connected_socket(void (*fn)(struct sock *sk))
335 {
336         int i;
337
338         spin_lock_bh(&vsock_table_lock);
339
340         for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++) {
341                 struct vsock_sock *vsk;
342                 list_for_each_entry(vsk, &vsock_connected_table[i],
343                                     connected_table)
344                         fn(sk_vsock(vsk));
345         }
346
347         spin_unlock_bh(&vsock_table_lock);
348 }
349 EXPORT_SYMBOL_GPL(vsock_for_each_connected_socket);
350
351 void vsock_add_pending(struct sock *listener, struct sock *pending)
352 {
353         struct vsock_sock *vlistener;
354         struct vsock_sock *vpending;
355
356         vlistener = vsock_sk(listener);
357         vpending = vsock_sk(pending);
358
359         sock_hold(pending);
360         sock_hold(listener);
361         list_add_tail(&vpending->pending_links, &vlistener->pending_links);
362 }
363 EXPORT_SYMBOL_GPL(vsock_add_pending);
364
365 void vsock_remove_pending(struct sock *listener, struct sock *pending)
366 {
367         struct vsock_sock *vpending = vsock_sk(pending);
368
369         list_del_init(&vpending->pending_links);
370         sock_put(listener);
371         sock_put(pending);
372 }
373 EXPORT_SYMBOL_GPL(vsock_remove_pending);
374
375 void vsock_enqueue_accept(struct sock *listener, struct sock *connected)
376 {
377         struct vsock_sock *vlistener;
378         struct vsock_sock *vconnected;
379
380         vlistener = vsock_sk(listener);
381         vconnected = vsock_sk(connected);
382
383         sock_hold(connected);
384         sock_hold(listener);
385         list_add_tail(&vconnected->accept_queue, &vlistener->accept_queue);
386 }
387 EXPORT_SYMBOL_GPL(vsock_enqueue_accept);
388
389 static void vsock_deassign_transport(struct vsock_sock *vsk)
390 {
391         if (!vsk->transport)
392                 return;
393
394         vsk->transport->destruct(vsk);
395         module_put(vsk->transport->module);
396         vsk->transport = NULL;
397 }
398
399 /* Assign a transport to a socket and call the .init transport callback.
400  *
401  * Note: for stream socket this must be called when vsk->remote_addr is set
402  * (e.g. during the connect() or when a connection request on a listener
403  * socket is received).
404  * The vsk->remote_addr is used to decide which transport to use:
405  *  - remote CID <= VMADDR_CID_HOST will use guest->host transport;
406  *  - remote CID == local_cid (guest->host transport) will use guest->host
407  *    transport for loopback (host->guest transports don't support loopback);
408  *  - remote CID > VMADDR_CID_HOST will use host->guest transport;
409  */
410 int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
411 {
412         const struct vsock_transport *new_transport;
413         struct sock *sk = sk_vsock(vsk);
414         unsigned int remote_cid = vsk->remote_addr.svm_cid;
415         int ret;
416
417         switch (sk->sk_type) {
418         case SOCK_DGRAM:
419                 new_transport = transport_dgram;
420                 break;
421         case SOCK_STREAM:
422                 if (remote_cid <= VMADDR_CID_HOST ||
423                     (transport_g2h &&
424                      remote_cid == transport_g2h->get_local_cid()))
425                         new_transport = transport_g2h;
426                 else
427                         new_transport = transport_h2g;
428                 break;
429         default:
430                 return -ESOCKTNOSUPPORT;
431         }
432
433         if (vsk->transport) {
434                 if (vsk->transport == new_transport)
435                         return 0;
436
437                 vsk->transport->release(vsk);
438                 vsock_deassign_transport(vsk);
439         }
440
441         /* We increase the module refcnt to prevent the transport unloading
442          * while there are open sockets assigned to it.
443          */
444         if (!new_transport || !try_module_get(new_transport->module))
445                 return -ENODEV;
446
447         ret = new_transport->init(vsk, psk);
448         if (ret) {
449                 module_put(new_transport->module);
450                 return ret;
451         }
452
453         vsk->transport = new_transport;
454
455         return 0;
456 }
457 EXPORT_SYMBOL_GPL(vsock_assign_transport);
458
459 bool vsock_find_cid(unsigned int cid)
460 {
461         if (transport_g2h && cid == transport_g2h->get_local_cid())
462                 return true;
463
464         if (transport_h2g && cid == VMADDR_CID_HOST)
465                 return true;
466
467         return false;
468 }
469 EXPORT_SYMBOL_GPL(vsock_find_cid);
470
471 static struct sock *vsock_dequeue_accept(struct sock *listener)
472 {
473         struct vsock_sock *vlistener;
474         struct vsock_sock *vconnected;
475
476         vlistener = vsock_sk(listener);
477
478         if (list_empty(&vlistener->accept_queue))
479                 return NULL;
480
481         vconnected = list_entry(vlistener->accept_queue.next,
482                                 struct vsock_sock, accept_queue);
483
484         list_del_init(&vconnected->accept_queue);
485         sock_put(listener);
486         /* The caller will need a reference on the connected socket so we let
487          * it call sock_put().
488          */
489
490         return sk_vsock(vconnected);
491 }
492
493 static bool vsock_is_accept_queue_empty(struct sock *sk)
494 {
495         struct vsock_sock *vsk = vsock_sk(sk);
496         return list_empty(&vsk->accept_queue);
497 }
498
499 static bool vsock_is_pending(struct sock *sk)
500 {
501         struct vsock_sock *vsk = vsock_sk(sk);
502         return !list_empty(&vsk->pending_links);
503 }
504
505 static int vsock_send_shutdown(struct sock *sk, int mode)
506 {
507         struct vsock_sock *vsk = vsock_sk(sk);
508
509         if (!vsk->transport)
510                 return -ENODEV;
511
512         return vsk->transport->shutdown(vsk, mode);
513 }
514
515 static void vsock_pending_work(struct work_struct *work)
516 {
517         struct sock *sk;
518         struct sock *listener;
519         struct vsock_sock *vsk;
520         bool cleanup;
521
522         vsk = container_of(work, struct vsock_sock, pending_work.work);
523         sk = sk_vsock(vsk);
524         listener = vsk->listener;
525         cleanup = true;
526
527         lock_sock(listener);
528         lock_sock_nested(sk, SINGLE_DEPTH_NESTING);
529
530         if (vsock_is_pending(sk)) {
531                 vsock_remove_pending(listener, sk);
532
533                 sk_acceptq_removed(listener);
534         } else if (!vsk->rejected) {
535                 /* We are not on the pending list and accept() did not reject
536                  * us, so we must have been accepted by our user process.  We
537                  * just need to drop our references to the sockets and be on
538                  * our way.
539                  */
540                 cleanup = false;
541                 goto out;
542         }
543
544         /* We need to remove ourself from the global connected sockets list so
545          * incoming packets can't find this socket, and to reduce the reference
546          * count.
547          */
548         vsock_remove_connected(vsk);
549
550         sk->sk_state = TCP_CLOSE;
551
552 out:
553         release_sock(sk);
554         release_sock(listener);
555         if (cleanup)
556                 sock_put(sk);
557
558         sock_put(sk);
559         sock_put(listener);
560 }
561
562 /**** SOCKET OPERATIONS ****/
563
564 static int __vsock_bind_stream(struct vsock_sock *vsk,
565                                struct sockaddr_vm *addr)
566 {
567         static u32 port;
568         struct sockaddr_vm new_addr;
569
570         if (!port)
571                 port = LAST_RESERVED_PORT + 1 +
572                         prandom_u32_max(U32_MAX - LAST_RESERVED_PORT);
573
574         vsock_addr_init(&new_addr, addr->svm_cid, addr->svm_port);
575
576         if (addr->svm_port == VMADDR_PORT_ANY) {
577                 bool found = false;
578                 unsigned int i;
579
580                 for (i = 0; i < MAX_PORT_RETRIES; i++) {
581                         if (port <= LAST_RESERVED_PORT)
582                                 port = LAST_RESERVED_PORT + 1;
583
584                         new_addr.svm_port = port++;
585
586                         if (!__vsock_find_bound_socket(&new_addr)) {
587                                 found = true;
588                                 break;
589                         }
590                 }
591
592                 if (!found)
593                         return -EADDRNOTAVAIL;
594         } else {
595                 /* If port is in reserved range, ensure caller
596                  * has necessary privileges.
597                  */
598                 if (addr->svm_port <= LAST_RESERVED_PORT &&
599                     !capable(CAP_NET_BIND_SERVICE)) {
600                         return -EACCES;
601                 }
602
603                 if (__vsock_find_bound_socket(&new_addr))
604                         return -EADDRINUSE;
605         }
606
607         vsock_addr_init(&vsk->local_addr, new_addr.svm_cid, new_addr.svm_port);
608
609         /* Remove stream sockets from the unbound list and add them to the hash
610          * table for easy lookup by its address.  The unbound list is simply an
611          * extra entry at the end of the hash table, a trick used by AF_UNIX.
612          */
613         __vsock_remove_bound(vsk);
614         __vsock_insert_bound(vsock_bound_sockets(&vsk->local_addr), vsk);
615
616         return 0;
617 }
618
619 static int __vsock_bind_dgram(struct vsock_sock *vsk,
620                               struct sockaddr_vm *addr)
621 {
622         return vsk->transport->dgram_bind(vsk, addr);
623 }
624
625 static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
626 {
627         struct vsock_sock *vsk = vsock_sk(sk);
628         int retval;
629
630         /* First ensure this socket isn't already bound. */
631         if (vsock_addr_bound(&vsk->local_addr))
632                 return -EINVAL;
633
634         /* Now bind to the provided address or select appropriate values if
635          * none are provided (VMADDR_CID_ANY and VMADDR_PORT_ANY).  Note that
636          * like AF_INET prevents binding to a non-local IP address (in most
637          * cases), we only allow binding to a local CID.
638          */
639         if (addr->svm_cid != VMADDR_CID_ANY && !vsock_find_cid(addr->svm_cid))
640                 return -EADDRNOTAVAIL;
641
642         switch (sk->sk_socket->type) {
643         case SOCK_STREAM:
644                 spin_lock_bh(&vsock_table_lock);
645                 retval = __vsock_bind_stream(vsk, addr);
646                 spin_unlock_bh(&vsock_table_lock);
647                 break;
648
649         case SOCK_DGRAM:
650                 retval = __vsock_bind_dgram(vsk, addr);
651                 break;
652
653         default:
654                 retval = -EINVAL;
655                 break;
656         }
657
658         return retval;
659 }
660
661 static void vsock_connect_timeout(struct work_struct *work);
662
663 static struct sock *__vsock_create(struct net *net,
664                                    struct socket *sock,
665                                    struct sock *parent,
666                                    gfp_t priority,
667                                    unsigned short type,
668                                    int kern)
669 {
670         struct sock *sk;
671         struct vsock_sock *psk;
672         struct vsock_sock *vsk;
673
674         sk = sk_alloc(net, AF_VSOCK, priority, &vsock_proto, kern);
675         if (!sk)
676                 return NULL;
677
678         sock_init_data(sock, sk);
679
680         /* sk->sk_type is normally set in sock_init_data, but only if sock is
681          * non-NULL. We make sure that our sockets always have a type by
682          * setting it here if needed.
683          */
684         if (!sock)
685                 sk->sk_type = type;
686
687         vsk = vsock_sk(sk);
688         vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
689         vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
690
691         sk->sk_destruct = vsock_sk_destruct;
692         sk->sk_backlog_rcv = vsock_queue_rcv_skb;
693         sock_reset_flag(sk, SOCK_DONE);
694
695         INIT_LIST_HEAD(&vsk->bound_table);
696         INIT_LIST_HEAD(&vsk->connected_table);
697         vsk->listener = NULL;
698         INIT_LIST_HEAD(&vsk->pending_links);
699         INIT_LIST_HEAD(&vsk->accept_queue);
700         vsk->rejected = false;
701         vsk->sent_request = false;
702         vsk->ignore_connecting_rst = false;
703         vsk->peer_shutdown = 0;
704         INIT_DELAYED_WORK(&vsk->connect_work, vsock_connect_timeout);
705         INIT_DELAYED_WORK(&vsk->pending_work, vsock_pending_work);
706
707         psk = parent ? vsock_sk(parent) : NULL;
708         if (parent) {
709                 vsk->trusted = psk->trusted;
710                 vsk->owner = get_cred(psk->owner);
711                 vsk->connect_timeout = psk->connect_timeout;
712                 vsk->buffer_size = psk->buffer_size;
713                 vsk->buffer_min_size = psk->buffer_min_size;
714                 vsk->buffer_max_size = psk->buffer_max_size;
715         } else {
716                 vsk->trusted = capable(CAP_NET_ADMIN);
717                 vsk->owner = get_current_cred();
718                 vsk->connect_timeout = VSOCK_DEFAULT_CONNECT_TIMEOUT;
719                 vsk->buffer_size = VSOCK_DEFAULT_BUFFER_SIZE;
720                 vsk->buffer_min_size = VSOCK_DEFAULT_BUFFER_MIN_SIZE;
721                 vsk->buffer_max_size = VSOCK_DEFAULT_BUFFER_MAX_SIZE;
722         }
723
724         return sk;
725 }
726
727 static void __vsock_release(struct sock *sk, int level)
728 {
729         if (sk) {
730                 struct sock *pending;
731                 struct vsock_sock *vsk;
732
733                 vsk = vsock_sk(sk);
734                 pending = NULL; /* Compiler warning. */
735
736                 /* The release call is supposed to use lock_sock_nested()
737                  * rather than lock_sock(), if a sock lock should be acquired.
738                  */
739                 if (vsk->transport)
740                         vsk->transport->release(vsk);
741                 else if (sk->sk_type == SOCK_STREAM)
742                         vsock_remove_sock(vsk);
743
744                 /* When "level" is SINGLE_DEPTH_NESTING, use the nested
745                  * version to avoid the warning "possible recursive locking
746                  * detected". When "level" is 0, lock_sock_nested(sk, level)
747                  * is the same as lock_sock(sk).
748                  */
749                 lock_sock_nested(sk, level);
750                 sock_orphan(sk);
751                 sk->sk_shutdown = SHUTDOWN_MASK;
752
753                 skb_queue_purge(&sk->sk_receive_queue);
754
755                 /* Clean up any sockets that never were accepted. */
756                 while ((pending = vsock_dequeue_accept(sk)) != NULL) {
757                         __vsock_release(pending, SINGLE_DEPTH_NESTING);
758                         sock_put(pending);
759                 }
760
761                 release_sock(sk);
762                 sock_put(sk);
763         }
764 }
765
766 static void vsock_sk_destruct(struct sock *sk)
767 {
768         struct vsock_sock *vsk = vsock_sk(sk);
769
770         vsock_deassign_transport(vsk);
771
772         /* When clearing these addresses, there's no need to set the family and
773          * possibly register the address family with the kernel.
774          */
775         vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
776         vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
777
778         put_cred(vsk->owner);
779 }
780
781 static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
782 {
783         int err;
784
785         err = sock_queue_rcv_skb(sk, skb);
786         if (err)
787                 kfree_skb(skb);
788
789         return err;
790 }
791
792 struct sock *vsock_create_connected(struct sock *parent)
793 {
794         return __vsock_create(sock_net(parent), NULL, parent, GFP_KERNEL,
795                               parent->sk_type, 0);
796 }
797 EXPORT_SYMBOL_GPL(vsock_create_connected);
798
799 s64 vsock_stream_has_data(struct vsock_sock *vsk)
800 {
801         return vsk->transport->stream_has_data(vsk);
802 }
803 EXPORT_SYMBOL_GPL(vsock_stream_has_data);
804
805 s64 vsock_stream_has_space(struct vsock_sock *vsk)
806 {
807         return vsk->transport->stream_has_space(vsk);
808 }
809 EXPORT_SYMBOL_GPL(vsock_stream_has_space);
810
811 static int vsock_release(struct socket *sock)
812 {
813         __vsock_release(sock->sk, 0);
814         sock->sk = NULL;
815         sock->state = SS_FREE;
816
817         return 0;
818 }
819
820 static int
821 vsock_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
822 {
823         int err;
824         struct sock *sk;
825         struct sockaddr_vm *vm_addr;
826
827         sk = sock->sk;
828
829         if (vsock_addr_cast(addr, addr_len, &vm_addr) != 0)
830                 return -EINVAL;
831
832         lock_sock(sk);
833         err = __vsock_bind(sk, vm_addr);
834         release_sock(sk);
835
836         return err;
837 }
838
839 static int vsock_getname(struct socket *sock,
840                          struct sockaddr *addr, int peer)
841 {
842         int err;
843         struct sock *sk;
844         struct vsock_sock *vsk;
845         struct sockaddr_vm *vm_addr;
846
847         sk = sock->sk;
848         vsk = vsock_sk(sk);
849         err = 0;
850
851         lock_sock(sk);
852
853         if (peer) {
854                 if (sock->state != SS_CONNECTED) {
855                         err = -ENOTCONN;
856                         goto out;
857                 }
858                 vm_addr = &vsk->remote_addr;
859         } else {
860                 vm_addr = &vsk->local_addr;
861         }
862
863         if (!vm_addr) {
864                 err = -EINVAL;
865                 goto out;
866         }
867
868         /* sys_getsockname() and sys_getpeername() pass us a
869          * MAX_SOCK_ADDR-sized buffer and don't set addr_len.  Unfortunately
870          * that macro is defined in socket.c instead of .h, so we hardcode its
871          * value here.
872          */
873         BUILD_BUG_ON(sizeof(*vm_addr) > 128);
874         memcpy(addr, vm_addr, sizeof(*vm_addr));
875         err = sizeof(*vm_addr);
876
877 out:
878         release_sock(sk);
879         return err;
880 }
881
882 static int vsock_shutdown(struct socket *sock, int mode)
883 {
884         int err;
885         struct sock *sk;
886
887         /* User level uses SHUT_RD (0) and SHUT_WR (1), but the kernel uses
888          * RCV_SHUTDOWN (1) and SEND_SHUTDOWN (2), so we must increment mode
889          * here like the other address families do.  Note also that the
890          * increment makes SHUT_RDWR (2) into RCV_SHUTDOWN | SEND_SHUTDOWN (3),
891          * which is what we want.
892          */
893         mode++;
894
895         if ((mode & ~SHUTDOWN_MASK) || !mode)
896                 return -EINVAL;
897
898         /* If this is a STREAM socket and it is not connected then bail out
899          * immediately.  If it is a DGRAM socket then we must first kick the
900          * socket so that it wakes up from any sleeping calls, for example
901          * recv(), and then afterwards return the error.
902          */
903
904         sk = sock->sk;
905         if (sock->state == SS_UNCONNECTED) {
906                 err = -ENOTCONN;
907                 if (sk->sk_type == SOCK_STREAM)
908                         return err;
909         } else {
910                 sock->state = SS_DISCONNECTING;
911                 err = 0;
912         }
913
914         /* Receive and send shutdowns are treated alike. */
915         mode = mode & (RCV_SHUTDOWN | SEND_SHUTDOWN);
916         if (mode) {
917                 lock_sock(sk);
918                 sk->sk_shutdown |= mode;
919                 sk->sk_state_change(sk);
920                 release_sock(sk);
921
922                 if (sk->sk_type == SOCK_STREAM) {
923                         sock_reset_flag(sk, SOCK_DONE);
924                         vsock_send_shutdown(sk, mode);
925                 }
926         }
927
928         return err;
929 }
930
931 static __poll_t vsock_poll(struct file *file, struct socket *sock,
932                                poll_table *wait)
933 {
934         struct sock *sk;
935         __poll_t mask;
936         struct vsock_sock *vsk;
937
938         sk = sock->sk;
939         vsk = vsock_sk(sk);
940
941         poll_wait(file, sk_sleep(sk), wait);
942         mask = 0;
943
944         if (sk->sk_err)
945                 /* Signify that there has been an error on this socket. */
946                 mask |= EPOLLERR;
947
948         /* INET sockets treat local write shutdown and peer write shutdown as a
949          * case of EPOLLHUP set.
950          */
951         if ((sk->sk_shutdown == SHUTDOWN_MASK) ||
952             ((sk->sk_shutdown & SEND_SHUTDOWN) &&
953              (vsk->peer_shutdown & SEND_SHUTDOWN))) {
954                 mask |= EPOLLHUP;
955         }
956
957         if (sk->sk_shutdown & RCV_SHUTDOWN ||
958             vsk->peer_shutdown & SEND_SHUTDOWN) {
959                 mask |= EPOLLRDHUP;
960         }
961
962         if (sock->type == SOCK_DGRAM) {
963                 /* For datagram sockets we can read if there is something in
964                  * the queue and write as long as the socket isn't shutdown for
965                  * sending.
966                  */
967                 if (!skb_queue_empty_lockless(&sk->sk_receive_queue) ||
968                     (sk->sk_shutdown & RCV_SHUTDOWN)) {
969                         mask |= EPOLLIN | EPOLLRDNORM;
970                 }
971
972                 if (!(sk->sk_shutdown & SEND_SHUTDOWN))
973                         mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND;
974
975         } else if (sock->type == SOCK_STREAM) {
976                 const struct vsock_transport *transport = vsk->transport;
977                 lock_sock(sk);
978
979                 /* Listening sockets that have connections in their accept
980                  * queue can be read.
981                  */
982                 if (sk->sk_state == TCP_LISTEN
983                     && !vsock_is_accept_queue_empty(sk))
984                         mask |= EPOLLIN | EPOLLRDNORM;
985
986                 /* If there is something in the queue then we can read. */
987                 if (transport && transport->stream_is_active(vsk) &&
988                     !(sk->sk_shutdown & RCV_SHUTDOWN)) {
989                         bool data_ready_now = false;
990                         int ret = transport->notify_poll_in(
991                                         vsk, 1, &data_ready_now);
992                         if (ret < 0) {
993                                 mask |= EPOLLERR;
994                         } else {
995                                 if (data_ready_now)
996                                         mask |= EPOLLIN | EPOLLRDNORM;
997
998                         }
999                 }
1000
1001                 /* Sockets whose connections have been closed, reset, or
1002                  * terminated should also be considered read, and we check the
1003                  * shutdown flag for that.
1004                  */
1005                 if (sk->sk_shutdown & RCV_SHUTDOWN ||
1006                     vsk->peer_shutdown & SEND_SHUTDOWN) {
1007                         mask |= EPOLLIN | EPOLLRDNORM;
1008                 }
1009
1010                 /* Connected sockets that can produce data can be written. */
1011                 if (sk->sk_state == TCP_ESTABLISHED) {
1012                         if (!(sk->sk_shutdown & SEND_SHUTDOWN)) {
1013                                 bool space_avail_now = false;
1014                                 int ret = transport->notify_poll_out(
1015                                                 vsk, 1, &space_avail_now);
1016                                 if (ret < 0) {
1017                                         mask |= EPOLLERR;
1018                                 } else {
1019                                         if (space_avail_now)
1020                                                 /* Remove EPOLLWRBAND since INET
1021                                                  * sockets are not setting it.
1022                                                  */
1023                                                 mask |= EPOLLOUT | EPOLLWRNORM;
1024
1025                                 }
1026                         }
1027                 }
1028
1029                 /* Simulate INET socket poll behaviors, which sets
1030                  * EPOLLOUT|EPOLLWRNORM when peer is closed and nothing to read,
1031                  * but local send is not shutdown.
1032                  */
1033                 if (sk->sk_state == TCP_CLOSE || sk->sk_state == TCP_CLOSING) {
1034                         if (!(sk->sk_shutdown & SEND_SHUTDOWN))
1035                                 mask |= EPOLLOUT | EPOLLWRNORM;
1036
1037                 }
1038
1039                 release_sock(sk);
1040         }
1041
1042         return mask;
1043 }
1044
1045 static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
1046                                size_t len)
1047 {
1048         int err;
1049         struct sock *sk;
1050         struct vsock_sock *vsk;
1051         struct sockaddr_vm *remote_addr;
1052         const struct vsock_transport *transport;
1053
1054         if (msg->msg_flags & MSG_OOB)
1055                 return -EOPNOTSUPP;
1056
1057         /* For now, MSG_DONTWAIT is always assumed... */
1058         err = 0;
1059         sk = sock->sk;
1060         vsk = vsock_sk(sk);
1061         transport = vsk->transport;
1062
1063         lock_sock(sk);
1064
1065         err = vsock_auto_bind(vsk);
1066         if (err)
1067                 goto out;
1068
1069
1070         /* If the provided message contains an address, use that.  Otherwise
1071          * fall back on the socket's remote handle (if it has been connected).
1072          */
1073         if (msg->msg_name &&
1074             vsock_addr_cast(msg->msg_name, msg->msg_namelen,
1075                             &remote_addr) == 0) {
1076                 /* Ensure this address is of the right type and is a valid
1077                  * destination.
1078                  */
1079
1080                 if (remote_addr->svm_cid == VMADDR_CID_ANY)
1081                         remote_addr->svm_cid = transport->get_local_cid();
1082
1083                 if (!vsock_addr_bound(remote_addr)) {
1084                         err = -EINVAL;
1085                         goto out;
1086                 }
1087         } else if (sock->state == SS_CONNECTED) {
1088                 remote_addr = &vsk->remote_addr;
1089
1090                 if (remote_addr->svm_cid == VMADDR_CID_ANY)
1091                         remote_addr->svm_cid = transport->get_local_cid();
1092
1093                 /* XXX Should connect() or this function ensure remote_addr is
1094                  * bound?
1095                  */
1096                 if (!vsock_addr_bound(&vsk->remote_addr)) {
1097                         err = -EINVAL;
1098                         goto out;
1099                 }
1100         } else {
1101                 err = -EINVAL;
1102                 goto out;
1103         }
1104
1105         if (!transport->dgram_allow(remote_addr->svm_cid,
1106                                     remote_addr->svm_port)) {
1107                 err = -EINVAL;
1108                 goto out;
1109         }
1110
1111         err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
1112
1113 out:
1114         release_sock(sk);
1115         return err;
1116 }
1117
1118 static int vsock_dgram_connect(struct socket *sock,
1119                                struct sockaddr *addr, int addr_len, int flags)
1120 {
1121         int err;
1122         struct sock *sk;
1123         struct vsock_sock *vsk;
1124         struct sockaddr_vm *remote_addr;
1125
1126         sk = sock->sk;
1127         vsk = vsock_sk(sk);
1128
1129         err = vsock_addr_cast(addr, addr_len, &remote_addr);
1130         if (err == -EAFNOSUPPORT && remote_addr->svm_family == AF_UNSPEC) {
1131                 lock_sock(sk);
1132                 vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY,
1133                                 VMADDR_PORT_ANY);
1134                 sock->state = SS_UNCONNECTED;
1135                 release_sock(sk);
1136                 return 0;
1137         } else if (err != 0)
1138                 return -EINVAL;
1139
1140         lock_sock(sk);
1141
1142         err = vsock_auto_bind(vsk);
1143         if (err)
1144                 goto out;
1145
1146         if (!vsk->transport->dgram_allow(remote_addr->svm_cid,
1147                                          remote_addr->svm_port)) {
1148                 err = -EINVAL;
1149                 goto out;
1150         }
1151
1152         memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
1153         sock->state = SS_CONNECTED;
1154
1155 out:
1156         release_sock(sk);
1157         return err;
1158 }
1159
1160 static int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg,
1161                                size_t len, int flags)
1162 {
1163         struct vsock_sock *vsk = vsock_sk(sock->sk);
1164
1165         return vsk->transport->dgram_dequeue(vsk, msg, len, flags);
1166 }
1167
1168 static const struct proto_ops vsock_dgram_ops = {
1169         .family = PF_VSOCK,
1170         .owner = THIS_MODULE,
1171         .release = vsock_release,
1172         .bind = vsock_bind,
1173         .connect = vsock_dgram_connect,
1174         .socketpair = sock_no_socketpair,
1175         .accept = sock_no_accept,
1176         .getname = vsock_getname,
1177         .poll = vsock_poll,
1178         .ioctl = sock_no_ioctl,
1179         .listen = sock_no_listen,
1180         .shutdown = vsock_shutdown,
1181         .setsockopt = sock_no_setsockopt,
1182         .getsockopt = sock_no_getsockopt,
1183         .sendmsg = vsock_dgram_sendmsg,
1184         .recvmsg = vsock_dgram_recvmsg,
1185         .mmap = sock_no_mmap,
1186         .sendpage = sock_no_sendpage,
1187 };
1188
1189 static int vsock_transport_cancel_pkt(struct vsock_sock *vsk)
1190 {
1191         const struct vsock_transport *transport = vsk->transport;
1192
1193         if (!transport->cancel_pkt)
1194                 return -EOPNOTSUPP;
1195
1196         return transport->cancel_pkt(vsk);
1197 }
1198
1199 static void vsock_connect_timeout(struct work_struct *work)
1200 {
1201         struct sock *sk;
1202         struct vsock_sock *vsk;
1203         int cancel = 0;
1204
1205         vsk = container_of(work, struct vsock_sock, connect_work.work);
1206         sk = sk_vsock(vsk);
1207
1208         lock_sock(sk);
1209         if (sk->sk_state == TCP_SYN_SENT &&
1210             (sk->sk_shutdown != SHUTDOWN_MASK)) {
1211                 sk->sk_state = TCP_CLOSE;
1212                 sk->sk_err = ETIMEDOUT;
1213                 sk->sk_error_report(sk);
1214                 cancel = 1;
1215         }
1216         release_sock(sk);
1217         if (cancel)
1218                 vsock_transport_cancel_pkt(vsk);
1219
1220         sock_put(sk);
1221 }
1222
1223 static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
1224                                 int addr_len, int flags)
1225 {
1226         int err;
1227         struct sock *sk;
1228         struct vsock_sock *vsk;
1229         const struct vsock_transport *transport;
1230         struct sockaddr_vm *remote_addr;
1231         long timeout;
1232         DEFINE_WAIT(wait);
1233
1234         err = 0;
1235         sk = sock->sk;
1236         vsk = vsock_sk(sk);
1237
1238         lock_sock(sk);
1239
1240         /* XXX AF_UNSPEC should make us disconnect like AF_INET. */
1241         switch (sock->state) {
1242         case SS_CONNECTED:
1243                 err = -EISCONN;
1244                 goto out;
1245         case SS_DISCONNECTING:
1246                 err = -EINVAL;
1247                 goto out;
1248         case SS_CONNECTING:
1249                 /* This continues on so we can move sock into the SS_CONNECTED
1250                  * state once the connection has completed (at which point err
1251                  * will be set to zero also).  Otherwise, we will either wait
1252                  * for the connection or return -EALREADY should this be a
1253                  * non-blocking call.
1254                  */
1255                 err = -EALREADY;
1256                 break;
1257         default:
1258                 if ((sk->sk_state == TCP_LISTEN) ||
1259                     vsock_addr_cast(addr, addr_len, &remote_addr) != 0) {
1260                         err = -EINVAL;
1261                         goto out;
1262                 }
1263
1264                 /* Set the remote address that we are connecting to. */
1265                 memcpy(&vsk->remote_addr, remote_addr,
1266                        sizeof(vsk->remote_addr));
1267
1268                 err = vsock_assign_transport(vsk, NULL);
1269                 if (err)
1270                         goto out;
1271
1272                 transport = vsk->transport;
1273
1274                 /* The hypervisor and well-known contexts do not have socket
1275                  * endpoints.
1276                  */
1277                 if (!transport ||
1278                     !transport->stream_allow(remote_addr->svm_cid,
1279                                              remote_addr->svm_port)) {
1280                         err = -ENETUNREACH;
1281                         goto out;
1282                 }
1283
1284                 err = vsock_auto_bind(vsk);
1285                 if (err)
1286                         goto out;
1287
1288                 sk->sk_state = TCP_SYN_SENT;
1289
1290                 err = transport->connect(vsk);
1291                 if (err < 0)
1292                         goto out;
1293
1294                 /* Mark sock as connecting and set the error code to in
1295                  * progress in case this is a non-blocking connect.
1296                  */
1297                 sock->state = SS_CONNECTING;
1298                 err = -EINPROGRESS;
1299         }
1300
1301         /* The receive path will handle all communication until we are able to
1302          * enter the connected state.  Here we wait for the connection to be
1303          * completed or a notification of an error.
1304          */
1305         timeout = vsk->connect_timeout;
1306         prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
1307
1308         while (sk->sk_state != TCP_ESTABLISHED && sk->sk_err == 0) {
1309                 if (flags & O_NONBLOCK) {
1310                         /* If we're not going to block, we schedule a timeout
1311                          * function to generate a timeout on the connection
1312                          * attempt, in case the peer doesn't respond in a
1313                          * timely manner. We hold on to the socket until the
1314                          * timeout fires.
1315                          */
1316                         sock_hold(sk);
1317                         schedule_delayed_work(&vsk->connect_work, timeout);
1318
1319                         /* Skip ahead to preserve error code set above. */
1320                         goto out_wait;
1321                 }
1322
1323                 release_sock(sk);
1324                 timeout = schedule_timeout(timeout);
1325                 lock_sock(sk);
1326
1327                 if (signal_pending(current)) {
1328                         err = sock_intr_errno(timeout);
1329                         sk->sk_state = TCP_CLOSE;
1330                         sock->state = SS_UNCONNECTED;
1331                         vsock_transport_cancel_pkt(vsk);
1332                         goto out_wait;
1333                 } else if (timeout == 0) {
1334                         err = -ETIMEDOUT;
1335                         sk->sk_state = TCP_CLOSE;
1336                         sock->state = SS_UNCONNECTED;
1337                         vsock_transport_cancel_pkt(vsk);
1338                         goto out_wait;
1339                 }
1340
1341                 prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
1342         }
1343
1344         if (sk->sk_err) {
1345                 err = -sk->sk_err;
1346                 sk->sk_state = TCP_CLOSE;
1347                 sock->state = SS_UNCONNECTED;
1348         } else {
1349                 err = 0;
1350         }
1351
1352 out_wait:
1353         finish_wait(sk_sleep(sk), &wait);
1354 out:
1355         release_sock(sk);
1356         return err;
1357 }
1358
1359 static int vsock_accept(struct socket *sock, struct socket *newsock, int flags,
1360                         bool kern)
1361 {
1362         struct sock *listener;
1363         int err;
1364         struct sock *connected;
1365         struct vsock_sock *vconnected;
1366         long timeout;
1367         DEFINE_WAIT(wait);
1368
1369         err = 0;
1370         listener = sock->sk;
1371
1372         lock_sock(listener);
1373
1374         if (sock->type != SOCK_STREAM) {
1375                 err = -EOPNOTSUPP;
1376                 goto out;
1377         }
1378
1379         if (listener->sk_state != TCP_LISTEN) {
1380                 err = -EINVAL;
1381                 goto out;
1382         }
1383
1384         /* Wait for children sockets to appear; these are the new sockets
1385          * created upon connection establishment.
1386          */
1387         timeout = sock_sndtimeo(listener, flags & O_NONBLOCK);
1388         prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE);
1389
1390         while ((connected = vsock_dequeue_accept(listener)) == NULL &&
1391                listener->sk_err == 0) {
1392                 release_sock(listener);
1393                 timeout = schedule_timeout(timeout);
1394                 finish_wait(sk_sleep(listener), &wait);
1395                 lock_sock(listener);
1396
1397                 if (signal_pending(current)) {
1398                         err = sock_intr_errno(timeout);
1399                         goto out;
1400                 } else if (timeout == 0) {
1401                         err = -EAGAIN;
1402                         goto out;
1403                 }
1404
1405                 prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE);
1406         }
1407         finish_wait(sk_sleep(listener), &wait);
1408
1409         if (listener->sk_err)
1410                 err = -listener->sk_err;
1411
1412         if (connected) {
1413                 sk_acceptq_removed(listener);
1414
1415                 lock_sock_nested(connected, SINGLE_DEPTH_NESTING);
1416                 vconnected = vsock_sk(connected);
1417
1418                 /* If the listener socket has received an error, then we should
1419                  * reject this socket and return.  Note that we simply mark the
1420                  * socket rejected, drop our reference, and let the cleanup
1421                  * function handle the cleanup; the fact that we found it in
1422                  * the listener's accept queue guarantees that the cleanup
1423                  * function hasn't run yet.
1424                  */
1425                 if (err) {
1426                         vconnected->rejected = true;
1427                 } else {
1428                         newsock->state = SS_CONNECTED;
1429                         sock_graft(connected, newsock);
1430                 }
1431
1432                 release_sock(connected);
1433                 sock_put(connected);
1434         }
1435
1436 out:
1437         release_sock(listener);
1438         return err;
1439 }
1440
1441 static int vsock_listen(struct socket *sock, int backlog)
1442 {
1443         int err;
1444         struct sock *sk;
1445         struct vsock_sock *vsk;
1446
1447         sk = sock->sk;
1448
1449         lock_sock(sk);
1450
1451         if (sock->type != SOCK_STREAM) {
1452                 err = -EOPNOTSUPP;
1453                 goto out;
1454         }
1455
1456         if (sock->state != SS_UNCONNECTED) {
1457                 err = -EINVAL;
1458                 goto out;
1459         }
1460
1461         vsk = vsock_sk(sk);
1462
1463         if (!vsock_addr_bound(&vsk->local_addr)) {
1464                 err = -EINVAL;
1465                 goto out;
1466         }
1467
1468         sk->sk_max_ack_backlog = backlog;
1469         sk->sk_state = TCP_LISTEN;
1470
1471         err = 0;
1472
1473 out:
1474         release_sock(sk);
1475         return err;
1476 }
1477
1478 static void vsock_update_buffer_size(struct vsock_sock *vsk,
1479                                      const struct vsock_transport *transport,
1480                                      u64 val)
1481 {
1482         if (val > vsk->buffer_max_size)
1483                 val = vsk->buffer_max_size;
1484
1485         if (val < vsk->buffer_min_size)
1486                 val = vsk->buffer_min_size;
1487
1488         if (val != vsk->buffer_size &&
1489             transport && transport->notify_buffer_size)
1490                 transport->notify_buffer_size(vsk, &val);
1491
1492         vsk->buffer_size = val;
1493 }
1494
1495 static int vsock_stream_setsockopt(struct socket *sock,
1496                                    int level,
1497                                    int optname,
1498                                    char __user *optval,
1499                                    unsigned int optlen)
1500 {
1501         int err;
1502         struct sock *sk;
1503         struct vsock_sock *vsk;
1504         const struct vsock_transport *transport;
1505         u64 val;
1506
1507         if (level != AF_VSOCK)
1508                 return -ENOPROTOOPT;
1509
1510 #define COPY_IN(_v)                                       \
1511         do {                                              \
1512                 if (optlen < sizeof(_v)) {                \
1513                         err = -EINVAL;                    \
1514                         goto exit;                        \
1515                 }                                         \
1516                 if (copy_from_user(&_v, optval, sizeof(_v)) != 0) {     \
1517                         err = -EFAULT;                                  \
1518                         goto exit;                                      \
1519                 }                                                       \
1520         } while (0)
1521
1522         err = 0;
1523         sk = sock->sk;
1524         vsk = vsock_sk(sk);
1525         transport = vsk->transport;
1526
1527         lock_sock(sk);
1528
1529         switch (optname) {
1530         case SO_VM_SOCKETS_BUFFER_SIZE:
1531                 COPY_IN(val);
1532                 vsock_update_buffer_size(vsk, transport, val);
1533                 break;
1534
1535         case SO_VM_SOCKETS_BUFFER_MAX_SIZE:
1536                 COPY_IN(val);
1537                 vsk->buffer_max_size = val;
1538                 vsock_update_buffer_size(vsk, transport, vsk->buffer_size);
1539                 break;
1540
1541         case SO_VM_SOCKETS_BUFFER_MIN_SIZE:
1542                 COPY_IN(val);
1543                 vsk->buffer_min_size = val;
1544                 vsock_update_buffer_size(vsk, transport, vsk->buffer_size);
1545                 break;
1546
1547         case SO_VM_SOCKETS_CONNECT_TIMEOUT: {
1548                 struct __kernel_old_timeval tv;
1549                 COPY_IN(tv);
1550                 if (tv.tv_sec >= 0 && tv.tv_usec < USEC_PER_SEC &&
1551                     tv.tv_sec < (MAX_SCHEDULE_TIMEOUT / HZ - 1)) {
1552                         vsk->connect_timeout = tv.tv_sec * HZ +
1553                             DIV_ROUND_UP(tv.tv_usec, (1000000 / HZ));
1554                         if (vsk->connect_timeout == 0)
1555                                 vsk->connect_timeout =
1556                                     VSOCK_DEFAULT_CONNECT_TIMEOUT;
1557
1558                 } else {
1559                         err = -ERANGE;
1560                 }
1561                 break;
1562         }
1563
1564         default:
1565                 err = -ENOPROTOOPT;
1566                 break;
1567         }
1568
1569 #undef COPY_IN
1570
1571 exit:
1572         release_sock(sk);
1573         return err;
1574 }
1575
1576 static int vsock_stream_getsockopt(struct socket *sock,
1577                                    int level, int optname,
1578                                    char __user *optval,
1579                                    int __user *optlen)
1580 {
1581         int err;
1582         int len;
1583         struct sock *sk;
1584         struct vsock_sock *vsk;
1585         u64 val;
1586
1587         if (level != AF_VSOCK)
1588                 return -ENOPROTOOPT;
1589
1590         err = get_user(len, optlen);
1591         if (err != 0)
1592                 return err;
1593
1594 #define COPY_OUT(_v)                            \
1595         do {                                    \
1596                 if (len < sizeof(_v))           \
1597                         return -EINVAL;         \
1598                                                 \
1599                 len = sizeof(_v);               \
1600                 if (copy_to_user(optval, &_v, len) != 0)        \
1601                         return -EFAULT;                         \
1602                                                                 \
1603         } while (0)
1604
1605         err = 0;
1606         sk = sock->sk;
1607         vsk = vsock_sk(sk);
1608
1609         switch (optname) {
1610         case SO_VM_SOCKETS_BUFFER_SIZE:
1611                 val = vsk->buffer_size;
1612                 COPY_OUT(val);
1613                 break;
1614
1615         case SO_VM_SOCKETS_BUFFER_MAX_SIZE:
1616                 val = vsk->buffer_max_size;
1617                 COPY_OUT(val);
1618                 break;
1619
1620         case SO_VM_SOCKETS_BUFFER_MIN_SIZE:
1621                 val = vsk->buffer_min_size;
1622                 COPY_OUT(val);
1623                 break;
1624
1625         case SO_VM_SOCKETS_CONNECT_TIMEOUT: {
1626                 struct __kernel_old_timeval tv;
1627                 tv.tv_sec = vsk->connect_timeout / HZ;
1628                 tv.tv_usec =
1629                     (vsk->connect_timeout -
1630                      tv.tv_sec * HZ) * (1000000 / HZ);
1631                 COPY_OUT(tv);
1632                 break;
1633         }
1634         default:
1635                 return -ENOPROTOOPT;
1636         }
1637
1638         err = put_user(len, optlen);
1639         if (err != 0)
1640                 return -EFAULT;
1641
1642 #undef COPY_OUT
1643
1644         return 0;
1645 }
1646
1647 static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
1648                                 size_t len)
1649 {
1650         struct sock *sk;
1651         struct vsock_sock *vsk;
1652         const struct vsock_transport *transport;
1653         ssize_t total_written;
1654         long timeout;
1655         int err;
1656         struct vsock_transport_send_notify_data send_data;
1657         DEFINE_WAIT_FUNC(wait, woken_wake_function);
1658
1659         sk = sock->sk;
1660         vsk = vsock_sk(sk);
1661         transport = vsk->transport;
1662         total_written = 0;
1663         err = 0;
1664
1665         if (msg->msg_flags & MSG_OOB)
1666                 return -EOPNOTSUPP;
1667
1668         lock_sock(sk);
1669
1670         /* Callers should not provide a destination with stream sockets. */
1671         if (msg->msg_namelen) {
1672                 err = sk->sk_state == TCP_ESTABLISHED ? -EISCONN : -EOPNOTSUPP;
1673                 goto out;
1674         }
1675
1676         /* Send data only if both sides are not shutdown in the direction. */
1677         if (sk->sk_shutdown & SEND_SHUTDOWN ||
1678             vsk->peer_shutdown & RCV_SHUTDOWN) {
1679                 err = -EPIPE;
1680                 goto out;
1681         }
1682
1683         if (!transport || sk->sk_state != TCP_ESTABLISHED ||
1684             !vsock_addr_bound(&vsk->local_addr)) {
1685                 err = -ENOTCONN;
1686                 goto out;
1687         }
1688
1689         if (!vsock_addr_bound(&vsk->remote_addr)) {
1690                 err = -EDESTADDRREQ;
1691                 goto out;
1692         }
1693
1694         /* Wait for room in the produce queue to enqueue our user's data. */
1695         timeout = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
1696
1697         err = transport->notify_send_init(vsk, &send_data);
1698         if (err < 0)
1699                 goto out;
1700
1701         while (total_written < len) {
1702                 ssize_t written;
1703
1704                 add_wait_queue(sk_sleep(sk), &wait);
1705                 while (vsock_stream_has_space(vsk) == 0 &&
1706                        sk->sk_err == 0 &&
1707                        !(sk->sk_shutdown & SEND_SHUTDOWN) &&
1708                        !(vsk->peer_shutdown & RCV_SHUTDOWN)) {
1709
1710                         /* Don't wait for non-blocking sockets. */
1711                         if (timeout == 0) {
1712                                 err = -EAGAIN;
1713                                 remove_wait_queue(sk_sleep(sk), &wait);
1714                                 goto out_err;
1715                         }
1716
1717                         err = transport->notify_send_pre_block(vsk, &send_data);
1718                         if (err < 0) {
1719                                 remove_wait_queue(sk_sleep(sk), &wait);
1720                                 goto out_err;
1721                         }
1722
1723                         release_sock(sk);
1724                         timeout = wait_woken(&wait, TASK_INTERRUPTIBLE, timeout);
1725                         lock_sock(sk);
1726                         if (signal_pending(current)) {
1727                                 err = sock_intr_errno(timeout);
1728                                 remove_wait_queue(sk_sleep(sk), &wait);
1729                                 goto out_err;
1730                         } else if (timeout == 0) {
1731                                 err = -EAGAIN;
1732                                 remove_wait_queue(sk_sleep(sk), &wait);
1733                                 goto out_err;
1734                         }
1735                 }
1736                 remove_wait_queue(sk_sleep(sk), &wait);
1737
1738                 /* These checks occur both as part of and after the loop
1739                  * conditional since we need to check before and after
1740                  * sleeping.
1741                  */
1742                 if (sk->sk_err) {
1743                         err = -sk->sk_err;
1744                         goto out_err;
1745                 } else if ((sk->sk_shutdown & SEND_SHUTDOWN) ||
1746                            (vsk->peer_shutdown & RCV_SHUTDOWN)) {
1747                         err = -EPIPE;
1748                         goto out_err;
1749                 }
1750
1751                 err = transport->notify_send_pre_enqueue(vsk, &send_data);
1752                 if (err < 0)
1753                         goto out_err;
1754
1755                 /* Note that enqueue will only write as many bytes as are free
1756                  * in the produce queue, so we don't need to ensure len is
1757                  * smaller than the queue size.  It is the caller's
1758                  * responsibility to check how many bytes we were able to send.
1759                  */
1760
1761                 written = transport->stream_enqueue(
1762                                 vsk, msg,
1763                                 len - total_written);
1764                 if (written < 0) {
1765                         err = -ENOMEM;
1766                         goto out_err;
1767                 }
1768
1769                 total_written += written;
1770
1771                 err = transport->notify_send_post_enqueue(
1772                                 vsk, written, &send_data);
1773                 if (err < 0)
1774                         goto out_err;
1775
1776         }
1777
1778 out_err:
1779         if (total_written > 0)
1780                 err = total_written;
1781 out:
1782         release_sock(sk);
1783         return err;
1784 }
1785
1786
1787 static int
1788 vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
1789                      int flags)
1790 {
1791         struct sock *sk;
1792         struct vsock_sock *vsk;
1793         const struct vsock_transport *transport;
1794         int err;
1795         size_t target;
1796         ssize_t copied;
1797         long timeout;
1798         struct vsock_transport_recv_notify_data recv_data;
1799
1800         DEFINE_WAIT(wait);
1801
1802         sk = sock->sk;
1803         vsk = vsock_sk(sk);
1804         transport = vsk->transport;
1805         err = 0;
1806
1807         lock_sock(sk);
1808
1809         if (!transport || sk->sk_state != TCP_ESTABLISHED) {
1810                 /* Recvmsg is supposed to return 0 if a peer performs an
1811                  * orderly shutdown. Differentiate between that case and when a
1812                  * peer has not connected or a local shutdown occured with the
1813                  * SOCK_DONE flag.
1814                  */
1815                 if (sock_flag(sk, SOCK_DONE))
1816                         err = 0;
1817                 else
1818                         err = -ENOTCONN;
1819
1820                 goto out;
1821         }
1822
1823         if (flags & MSG_OOB) {
1824                 err = -EOPNOTSUPP;
1825                 goto out;
1826         }
1827
1828         /* We don't check peer_shutdown flag here since peer may actually shut
1829          * down, but there can be data in the queue that a local socket can
1830          * receive.
1831          */
1832         if (sk->sk_shutdown & RCV_SHUTDOWN) {
1833                 err = 0;
1834                 goto out;
1835         }
1836
1837         /* It is valid on Linux to pass in a zero-length receive buffer.  This
1838          * is not an error.  We may as well bail out now.
1839          */
1840         if (!len) {
1841                 err = 0;
1842                 goto out;
1843         }
1844
1845         /* We must not copy less than target bytes into the user's buffer
1846          * before returning successfully, so we wait for the consume queue to
1847          * have that much data to consume before dequeueing.  Note that this
1848          * makes it impossible to handle cases where target is greater than the
1849          * queue size.
1850          */
1851         target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
1852         if (target >= transport->stream_rcvhiwat(vsk)) {
1853                 err = -ENOMEM;
1854                 goto out;
1855         }
1856         timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1857         copied = 0;
1858
1859         err = transport->notify_recv_init(vsk, target, &recv_data);
1860         if (err < 0)
1861                 goto out;
1862
1863
1864         while (1) {
1865                 s64 ready;
1866
1867                 prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
1868                 ready = vsock_stream_has_data(vsk);
1869
1870                 if (ready == 0) {
1871                         if (sk->sk_err != 0 ||
1872                             (sk->sk_shutdown & RCV_SHUTDOWN) ||
1873                             (vsk->peer_shutdown & SEND_SHUTDOWN)) {
1874                                 finish_wait(sk_sleep(sk), &wait);
1875                                 break;
1876                         }
1877                         /* Don't wait for non-blocking sockets. */
1878                         if (timeout == 0) {
1879                                 err = -EAGAIN;
1880                                 finish_wait(sk_sleep(sk), &wait);
1881                                 break;
1882                         }
1883
1884                         err = transport->notify_recv_pre_block(
1885                                         vsk, target, &recv_data);
1886                         if (err < 0) {
1887                                 finish_wait(sk_sleep(sk), &wait);
1888                                 break;
1889                         }
1890                         release_sock(sk);
1891                         timeout = schedule_timeout(timeout);
1892                         lock_sock(sk);
1893
1894                         if (signal_pending(current)) {
1895                                 err = sock_intr_errno(timeout);
1896                                 finish_wait(sk_sleep(sk), &wait);
1897                                 break;
1898                         } else if (timeout == 0) {
1899                                 err = -EAGAIN;
1900                                 finish_wait(sk_sleep(sk), &wait);
1901                                 break;
1902                         }
1903                 } else {
1904                         ssize_t read;
1905
1906                         finish_wait(sk_sleep(sk), &wait);
1907
1908                         if (ready < 0) {
1909                                 /* Invalid queue pair content. XXX This should
1910                                 * be changed to a connection reset in a later
1911                                 * change.
1912                                 */
1913
1914                                 err = -ENOMEM;
1915                                 goto out;
1916                         }
1917
1918                         err = transport->notify_recv_pre_dequeue(
1919                                         vsk, target, &recv_data);
1920                         if (err < 0)
1921                                 break;
1922
1923                         read = transport->stream_dequeue(
1924                                         vsk, msg,
1925                                         len - copied, flags);
1926                         if (read < 0) {
1927                                 err = -ENOMEM;
1928                                 break;
1929                         }
1930
1931                         copied += read;
1932
1933                         err = transport->notify_recv_post_dequeue(
1934                                         vsk, target, read,
1935                                         !(flags & MSG_PEEK), &recv_data);
1936                         if (err < 0)
1937                                 goto out;
1938
1939                         if (read >= target || flags & MSG_PEEK)
1940                                 break;
1941
1942                         target -= read;
1943                 }
1944         }
1945
1946         if (sk->sk_err)
1947                 err = -sk->sk_err;
1948         else if (sk->sk_shutdown & RCV_SHUTDOWN)
1949                 err = 0;
1950
1951         if (copied > 0)
1952                 err = copied;
1953
1954 out:
1955         release_sock(sk);
1956         return err;
1957 }
1958
1959 static const struct proto_ops vsock_stream_ops = {
1960         .family = PF_VSOCK,
1961         .owner = THIS_MODULE,
1962         .release = vsock_release,
1963         .bind = vsock_bind,
1964         .connect = vsock_stream_connect,
1965         .socketpair = sock_no_socketpair,
1966         .accept = vsock_accept,
1967         .getname = vsock_getname,
1968         .poll = vsock_poll,
1969         .ioctl = sock_no_ioctl,
1970         .listen = vsock_listen,
1971         .shutdown = vsock_shutdown,
1972         .setsockopt = vsock_stream_setsockopt,
1973         .getsockopt = vsock_stream_getsockopt,
1974         .sendmsg = vsock_stream_sendmsg,
1975         .recvmsg = vsock_stream_recvmsg,
1976         .mmap = sock_no_mmap,
1977         .sendpage = sock_no_sendpage,
1978 };
1979
1980 static int vsock_create(struct net *net, struct socket *sock,
1981                         int protocol, int kern)
1982 {
1983         struct vsock_sock *vsk;
1984         struct sock *sk;
1985         int ret;
1986
1987         if (!sock)
1988                 return -EINVAL;
1989
1990         if (protocol && protocol != PF_VSOCK)
1991                 return -EPROTONOSUPPORT;
1992
1993         switch (sock->type) {
1994         case SOCK_DGRAM:
1995                 sock->ops = &vsock_dgram_ops;
1996                 break;
1997         case SOCK_STREAM:
1998                 sock->ops = &vsock_stream_ops;
1999                 break;
2000         default:
2001                 return -ESOCKTNOSUPPORT;
2002         }
2003
2004         sock->state = SS_UNCONNECTED;
2005
2006         sk = __vsock_create(net, sock, NULL, GFP_KERNEL, 0, kern);
2007         if (!sk)
2008                 return -ENOMEM;
2009
2010         vsk = vsock_sk(sk);
2011
2012         if (sock->type == SOCK_DGRAM) {
2013                 ret = vsock_assign_transport(vsk, NULL);
2014                 if (ret < 0) {
2015                         sock_put(sk);
2016                         return ret;
2017                 }
2018         }
2019
2020         vsock_insert_unbound(vsk);
2021
2022         return 0;
2023 }
2024
2025 static const struct net_proto_family vsock_family_ops = {
2026         .family = AF_VSOCK,
2027         .create = vsock_create,
2028         .owner = THIS_MODULE,
2029 };
2030
2031 static long vsock_dev_do_ioctl(struct file *filp,
2032                                unsigned int cmd, void __user *ptr)
2033 {
2034         u32 __user *p = ptr;
2035         u32 cid = VMADDR_CID_ANY;
2036         int retval = 0;
2037
2038         switch (cmd) {
2039         case IOCTL_VM_SOCKETS_GET_LOCAL_CID:
2040                 /* To be compatible with the VMCI behavior, we prioritize the
2041                  * guest CID instead of well-know host CID (VMADDR_CID_HOST).
2042                  */
2043                 if (transport_g2h)
2044                         cid = transport_g2h->get_local_cid();
2045                 else if (transport_h2g)
2046                         cid = transport_h2g->get_local_cid();
2047
2048                 if (put_user(cid, p) != 0)
2049                         retval = -EFAULT;
2050                 break;
2051
2052         default:
2053                 pr_err("Unknown ioctl %d\n", cmd);
2054                 retval = -EINVAL;
2055         }
2056
2057         return retval;
2058 }
2059
2060 static long vsock_dev_ioctl(struct file *filp,
2061                             unsigned int cmd, unsigned long arg)
2062 {
2063         return vsock_dev_do_ioctl(filp, cmd, (void __user *)arg);
2064 }
2065
2066 #ifdef CONFIG_COMPAT
2067 static long vsock_dev_compat_ioctl(struct file *filp,
2068                                    unsigned int cmd, unsigned long arg)
2069 {
2070         return vsock_dev_do_ioctl(filp, cmd, compat_ptr(arg));
2071 }
2072 #endif
2073
2074 static const struct file_operations vsock_device_ops = {
2075         .owner          = THIS_MODULE,
2076         .unlocked_ioctl = vsock_dev_ioctl,
2077 #ifdef CONFIG_COMPAT
2078         .compat_ioctl   = vsock_dev_compat_ioctl,
2079 #endif
2080         .open           = nonseekable_open,
2081 };
2082
2083 static struct miscdevice vsock_device = {
2084         .name           = "vsock",
2085         .fops           = &vsock_device_ops,
2086 };
2087
2088 static int __init vsock_init(void)
2089 {
2090         int err = 0;
2091
2092         vsock_init_tables();
2093
2094         vsock_proto.owner = THIS_MODULE;
2095         vsock_device.minor = MISC_DYNAMIC_MINOR;
2096         err = misc_register(&vsock_device);
2097         if (err) {
2098                 pr_err("Failed to register misc device\n");
2099                 goto err_reset_transport;
2100         }
2101
2102         err = proto_register(&vsock_proto, 1);  /* we want our slab */
2103         if (err) {
2104                 pr_err("Cannot register vsock protocol\n");
2105                 goto err_deregister_misc;
2106         }
2107
2108         err = sock_register(&vsock_family_ops);
2109         if (err) {
2110                 pr_err("could not register af_vsock (%d) address family: %d\n",
2111                        AF_VSOCK, err);
2112                 goto err_unregister_proto;
2113         }
2114
2115         return 0;
2116
2117 err_unregister_proto:
2118         proto_unregister(&vsock_proto);
2119 err_deregister_misc:
2120         misc_deregister(&vsock_device);
2121 err_reset_transport:
2122         return err;
2123 }
2124
2125 static void __exit vsock_exit(void)
2126 {
2127         misc_deregister(&vsock_device);
2128         sock_unregister(AF_VSOCK);
2129         proto_unregister(&vsock_proto);
2130 }
2131
2132 const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk)
2133 {
2134         return vsk->transport;
2135 }
2136 EXPORT_SYMBOL_GPL(vsock_core_get_transport);
2137
2138 int vsock_core_register(const struct vsock_transport *t, int features)
2139 {
2140         const struct vsock_transport *t_h2g, *t_g2h, *t_dgram;
2141         int err = mutex_lock_interruptible(&vsock_register_mutex);
2142
2143         if (err)
2144                 return err;
2145
2146         t_h2g = transport_h2g;
2147         t_g2h = transport_g2h;
2148         t_dgram = transport_dgram;
2149
2150         if (features & VSOCK_TRANSPORT_F_H2G) {
2151                 if (t_h2g) {
2152                         err = -EBUSY;
2153                         goto err_busy;
2154                 }
2155                 t_h2g = t;
2156         }
2157
2158         if (features & VSOCK_TRANSPORT_F_G2H) {
2159                 if (t_g2h) {
2160                         err = -EBUSY;
2161                         goto err_busy;
2162                 }
2163                 t_g2h = t;
2164         }
2165
2166         if (features & VSOCK_TRANSPORT_F_DGRAM) {
2167                 if (t_dgram) {
2168                         err = -EBUSY;
2169                         goto err_busy;
2170                 }
2171                 t_dgram = t;
2172         }
2173
2174         transport_h2g = t_h2g;
2175         transport_g2h = t_g2h;
2176         transport_dgram = t_dgram;
2177
2178 err_busy:
2179         mutex_unlock(&vsock_register_mutex);
2180         return err;
2181 }
2182 EXPORT_SYMBOL_GPL(vsock_core_register);
2183
2184 void vsock_core_unregister(const struct vsock_transport *t)
2185 {
2186         mutex_lock(&vsock_register_mutex);
2187
2188         if (transport_h2g == t)
2189                 transport_h2g = NULL;
2190
2191         if (transport_g2h == t)
2192                 transport_g2h = NULL;
2193
2194         if (transport_dgram == t)
2195                 transport_dgram = NULL;
2196
2197         mutex_unlock(&vsock_register_mutex);
2198 }
2199 EXPORT_SYMBOL_GPL(vsock_core_unregister);
2200
2201 module_init(vsock_init);
2202 module_exit(vsock_exit);
2203
2204 MODULE_AUTHOR("VMware, Inc.");
2205 MODULE_DESCRIPTION("VMware Virtual Socket Family");
2206 MODULE_VERSION("1.0.2.0-k");
2207 MODULE_LICENSE("GPL v2");