tools build: Fix test-clang.cpp with Clang 8+
[linux-2.6-microblaze.git] / net / vmw_vsock / af_vsock.c
index 582a3e4..74db4cd 100644 (file)
@@ -126,19 +126,18 @@ static struct proto vsock_proto = {
  */
 #define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ)
 
-static const struct vsock_transport *transport;
+#define VSOCK_DEFAULT_BUFFER_SIZE     (1024 * 256)
+#define VSOCK_DEFAULT_BUFFER_MAX_SIZE (1024 * 256)
+#define VSOCK_DEFAULT_BUFFER_MIN_SIZE 128
+
+/* Transport used for host->guest communication */
+static const struct vsock_transport *transport_h2g;
+/* Transport used for guest->host communication */
+static const struct vsock_transport *transport_g2h;
+/* Transport used for DGRAM communication */
+static const struct vsock_transport *transport_dgram;
 static DEFINE_MUTEX(vsock_register_mutex);
 
-/**** EXPORTS ****/
-
-/* Get the ID of the local context.  This is transport dependent. */
-
-int vm_sockets_get_local_cid(void)
-{
-       return transport->get_local_cid();
-}
-EXPORT_SYMBOL_GPL(vm_sockets_get_local_cid);
-
 /**** UTILS ****/
 
 /* Each bound VSocket is stored in the bind hash table and each connected
@@ -188,7 +187,7 @@ static int vsock_auto_bind(struct vsock_sock *vsk)
        return __vsock_bind(sk, &local_addr);
 }
 
-static int __init vsock_init_tables(void)
+static void vsock_init_tables(void)
 {
        int i;
 
@@ -197,7 +196,6 @@ static int __init vsock_init_tables(void)
 
        for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++)
                INIT_LIST_HEAD(&vsock_connected_table[i]);
-       return 0;
 }
 
 static void __vsock_insert_bound(struct list_head *list,
@@ -230,9 +228,15 @@ static struct sock *__vsock_find_bound_socket(struct sockaddr_vm *addr)
 {
        struct vsock_sock *vsk;
 
-       list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table)
-               if (addr->svm_port == vsk->local_addr.svm_port)
+       list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table) {
+               if (vsock_addr_equals_addr(addr, &vsk->local_addr))
+                       return sk_vsock(vsk);
+
+               if (addr->svm_port == vsk->local_addr.svm_port &&
+                   (vsk->local_addr.svm_cid == VMADDR_CID_ANY ||
+                    addr->svm_cid == VMADDR_CID_ANY))
                        return sk_vsock(vsk);
+       }
 
        return NULL;
 }
@@ -382,6 +386,88 @@ void vsock_enqueue_accept(struct sock *listener, struct sock *connected)
 }
 EXPORT_SYMBOL_GPL(vsock_enqueue_accept);
 
+static void vsock_deassign_transport(struct vsock_sock *vsk)
+{
+       if (!vsk->transport)
+               return;
+
+       vsk->transport->destruct(vsk);
+       module_put(vsk->transport->module);
+       vsk->transport = NULL;
+}
+
+/* Assign a transport to a socket and call the .init transport callback.
+ *
+ * Note: for stream socket this must be called when vsk->remote_addr is set
+ * (e.g. during the connect() or when a connection request on a listener
+ * socket is received).
+ * The vsk->remote_addr is used to decide which transport to use:
+ *  - remote CID <= VMADDR_CID_HOST will use guest->host transport;
+ *  - remote CID == local_cid (guest->host transport) will use guest->host
+ *    transport for loopback (host->guest transports don't support loopback);
+ *  - remote CID > VMADDR_CID_HOST will use host->guest transport;
+ */
+int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
+{
+       const struct vsock_transport *new_transport;
+       struct sock *sk = sk_vsock(vsk);
+       unsigned int remote_cid = vsk->remote_addr.svm_cid;
+       int ret;
+
+       switch (sk->sk_type) {
+       case SOCK_DGRAM:
+               new_transport = transport_dgram;
+               break;
+       case SOCK_STREAM:
+               if (remote_cid <= VMADDR_CID_HOST ||
+                   (transport_g2h &&
+                    remote_cid == transport_g2h->get_local_cid()))
+                       new_transport = transport_g2h;
+               else
+                       new_transport = transport_h2g;
+               break;
+       default:
+               return -ESOCKTNOSUPPORT;
+       }
+
+       if (vsk->transport) {
+               if (vsk->transport == new_transport)
+                       return 0;
+
+               vsk->transport->release(vsk);
+               vsock_deassign_transport(vsk);
+       }
+
+       /* We increase the module refcnt to prevent the transport unloading
+        * while there are open sockets assigned to it.
+        */
+       if (!new_transport || !try_module_get(new_transport->module))
+               return -ENODEV;
+
+       ret = new_transport->init(vsk, psk);
+       if (ret) {
+               module_put(new_transport->module);
+               return ret;
+       }
+
+       vsk->transport = new_transport;
+
+       return 0;
+}
+EXPORT_SYMBOL_GPL(vsock_assign_transport);
+
+bool vsock_find_cid(unsigned int cid)
+{
+       if (transport_g2h && cid == transport_g2h->get_local_cid())
+               return true;
+
+       if (transport_h2g && cid == VMADDR_CID_HOST)
+               return true;
+
+       return false;
+}
+EXPORT_SYMBOL_GPL(vsock_find_cid);
+
 static struct sock *vsock_dequeue_accept(struct sock *listener)
 {
        struct vsock_sock *vlistener;
@@ -418,7 +504,12 @@ static bool vsock_is_pending(struct sock *sk)
 
 static int vsock_send_shutdown(struct sock *sk, int mode)
 {
-       return transport->shutdown(vsock_sk(sk), mode);
+       struct vsock_sock *vsk = vsock_sk(sk);
+
+       if (!vsk->transport)
+               return -ENODEV;
+
+       return vsk->transport->shutdown(vsk, mode);
 }
 
 static void vsock_pending_work(struct work_struct *work)
@@ -439,7 +530,7 @@ static void vsock_pending_work(struct work_struct *work)
        if (vsock_is_pending(sk)) {
                vsock_remove_pending(listener, sk);
 
-               listener->sk_ack_backlog--;
+               sk_acceptq_removed(listener);
        } else if (!vsk->rejected) {
                /* We are not on the pending list and accept() did not reject
                 * us, so we must have been accepted by our user process.  We
@@ -528,13 +619,12 @@ static int __vsock_bind_stream(struct vsock_sock *vsk,
 static int __vsock_bind_dgram(struct vsock_sock *vsk,
                              struct sockaddr_vm *addr)
 {
-       return transport->dgram_bind(vsk, addr);
+       return vsk->transport->dgram_bind(vsk, addr);
 }
 
 static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
 {
        struct vsock_sock *vsk = vsock_sk(sk);
-       u32 cid;
        int retval;
 
        /* First ensure this socket isn't already bound. */
@@ -544,10 +634,9 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
        /* Now bind to the provided address or select appropriate values if
         * none are provided (VMADDR_CID_ANY and VMADDR_PORT_ANY).  Note that
         * like AF_INET prevents binding to a non-local IP address (in most
-        * cases), we only allow binding to the local CID.
+        * cases), we only allow binding to a local CID.
         */
-       cid = transport->get_local_cid();
-       if (addr->svm_cid != cid && addr->svm_cid != VMADDR_CID_ANY)
+       if (addr->svm_cid != VMADDR_CID_ANY && !vsock_find_cid(addr->svm_cid))
                return -EADDRNOTAVAIL;
 
        switch (sk->sk_socket->type) {
@@ -571,12 +660,12 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
 
 static void vsock_connect_timeout(struct work_struct *work);
 
-struct sock *__vsock_create(struct net *net,
-                           struct socket *sock,
-                           struct sock *parent,
-                           gfp_t priority,
-                           unsigned short type,
-                           int kern)
+static struct sock *__vsock_create(struct net *net,
+                                  struct socket *sock,
+                                  struct sock *parent,
+                                  gfp_t priority,
+                                  unsigned short type,
+                                  int kern)
 {
        struct sock *sk;
        struct vsock_sock *psk;
@@ -620,28 +709,24 @@ struct sock *__vsock_create(struct net *net,
                vsk->trusted = psk->trusted;
                vsk->owner = get_cred(psk->owner);
                vsk->connect_timeout = psk->connect_timeout;
+               vsk->buffer_size = psk->buffer_size;
+               vsk->buffer_min_size = psk->buffer_min_size;
+               vsk->buffer_max_size = psk->buffer_max_size;
        } else {
                vsk->trusted = capable(CAP_NET_ADMIN);
                vsk->owner = get_current_cred();
                vsk->connect_timeout = VSOCK_DEFAULT_CONNECT_TIMEOUT;
+               vsk->buffer_size = VSOCK_DEFAULT_BUFFER_SIZE;
+               vsk->buffer_min_size = VSOCK_DEFAULT_BUFFER_MIN_SIZE;
+               vsk->buffer_max_size = VSOCK_DEFAULT_BUFFER_MAX_SIZE;
        }
 
-       if (transport->init(vsk, psk) < 0) {
-               sk_free(sk);
-               return NULL;
-       }
-
-       if (sock)
-               vsock_insert_unbound(vsk);
-
        return sk;
 }
-EXPORT_SYMBOL_GPL(__vsock_create);
 
 static void __vsock_release(struct sock *sk, int level)
 {
        if (sk) {
-               struct sk_buff *skb;
                struct sock *pending;
                struct vsock_sock *vsk;
 
@@ -651,7 +736,10 @@ static void __vsock_release(struct sock *sk, int level)
                /* The release call is supposed to use lock_sock_nested()
                 * rather than lock_sock(), if a sock lock should be acquired.
                 */
-               transport->release(vsk);
+               if (vsk->transport)
+                       vsk->transport->release(vsk);
+               else if (sk->sk_type == SOCK_STREAM)
+                       vsock_remove_sock(vsk);
 
                /* When "level" is SINGLE_DEPTH_NESTING, use the nested
                 * version to avoid the warning "possible recursive locking
@@ -662,8 +750,7 @@ static void __vsock_release(struct sock *sk, int level)
                sock_orphan(sk);
                sk->sk_shutdown = SHUTDOWN_MASK;
 
-               while ((skb = skb_dequeue(&sk->sk_receive_queue)))
-                       kfree_skb(skb);
+               skb_queue_purge(&sk->sk_receive_queue);
 
                /* Clean up any sockets that never were accepted. */
                while ((pending = vsock_dequeue_accept(sk)) != NULL) {
@@ -680,7 +767,7 @@ static void vsock_sk_destruct(struct sock *sk)
 {
        struct vsock_sock *vsk = vsock_sk(sk);
 
-       transport->destruct(vsk);
+       vsock_deassign_transport(vsk);
 
        /* When clearing these addresses, there's no need to set the family and
         * possibly register the address family with the kernel.
@@ -702,15 +789,22 @@ static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
        return err;
 }
 
+struct sock *vsock_create_connected(struct sock *parent)
+{
+       return __vsock_create(sock_net(parent), NULL, parent, GFP_KERNEL,
+                             parent->sk_type, 0);
+}
+EXPORT_SYMBOL_GPL(vsock_create_connected);
+
 s64 vsock_stream_has_data(struct vsock_sock *vsk)
 {
-       return transport->stream_has_data(vsk);
+       return vsk->transport->stream_has_data(vsk);
 }
 EXPORT_SYMBOL_GPL(vsock_stream_has_data);
 
 s64 vsock_stream_has_space(struct vsock_sock *vsk)
 {
-       return transport->stream_has_space(vsk);
+       return vsk->transport->stream_has_space(vsk);
 }
 EXPORT_SYMBOL_GPL(vsock_stream_has_space);
 
@@ -879,6 +973,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock,
                        mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND;
 
        } else if (sock->type == SOCK_STREAM) {
+               const struct vsock_transport *transport = vsk->transport;
                lock_sock(sk);
 
                /* Listening sockets that have connections in their accept
@@ -889,7 +984,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock,
                        mask |= EPOLLIN | EPOLLRDNORM;
 
                /* If there is something in the queue then we can read. */
-               if (transport->stream_is_active(vsk) &&
+               if (transport && transport->stream_is_active(vsk) &&
                    !(sk->sk_shutdown & RCV_SHUTDOWN)) {
                        bool data_ready_now = false;
                        int ret = transport->notify_poll_in(
@@ -954,6 +1049,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
        struct sock *sk;
        struct vsock_sock *vsk;
        struct sockaddr_vm *remote_addr;
+       const struct vsock_transport *transport;
 
        if (msg->msg_flags & MSG_OOB)
                return -EOPNOTSUPP;
@@ -962,6 +1058,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
        err = 0;
        sk = sock->sk;
        vsk = vsock_sk(sk);
+       transport = vsk->transport;
 
        lock_sock(sk);
 
@@ -1046,8 +1143,8 @@ static int vsock_dgram_connect(struct socket *sock,
        if (err)
                goto out;
 
-       if (!transport->dgram_allow(remote_addr->svm_cid,
-                                   remote_addr->svm_port)) {
+       if (!vsk->transport->dgram_allow(remote_addr->svm_cid,
+                                        remote_addr->svm_port)) {
                err = -EINVAL;
                goto out;
        }
@@ -1063,7 +1160,9 @@ out:
 static int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg,
                               size_t len, int flags)
 {
-       return transport->dgram_dequeue(vsock_sk(sock->sk), msg, len, flags);
+       struct vsock_sock *vsk = vsock_sk(sock->sk);
+
+       return vsk->transport->dgram_dequeue(vsk, msg, len, flags);
 }
 
 static const struct proto_ops vsock_dgram_ops = {
@@ -1089,6 +1188,8 @@ static const struct proto_ops vsock_dgram_ops = {
 
 static int vsock_transport_cancel_pkt(struct vsock_sock *vsk)
 {
+       const struct vsock_transport *transport = vsk->transport;
+
        if (!transport->cancel_pkt)
                return -EOPNOTSUPP;
 
@@ -1125,6 +1226,7 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
        int err;
        struct sock *sk;
        struct vsock_sock *vsk;
+       const struct vsock_transport *transport;
        struct sockaddr_vm *remote_addr;
        long timeout;
        DEFINE_WAIT(wait);
@@ -1159,19 +1261,26 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
                        goto out;
                }
 
+               /* Set the remote address that we are connecting to. */
+               memcpy(&vsk->remote_addr, remote_addr,
+                      sizeof(vsk->remote_addr));
+
+               err = vsock_assign_transport(vsk, NULL);
+               if (err)
+                       goto out;
+
+               transport = vsk->transport;
+
                /* The hypervisor and well-known contexts do not have socket
                 * endpoints.
                 */
-               if (!transport->stream_allow(remote_addr->svm_cid,
+               if (!transport ||
+                   !transport->stream_allow(remote_addr->svm_cid,
                                             remote_addr->svm_port)) {
                        err = -ENETUNREACH;
                        goto out;
                }
 
-               /* Set the remote address that we are connecting to. */
-               memcpy(&vsk->remote_addr, remote_addr,
-                      sizeof(vsk->remote_addr));
-
                err = vsock_auto_bind(vsk);
                if (err)
                        goto out;
@@ -1301,7 +1410,7 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags,
                err = -listener->sk_err;
 
        if (connected) {
-               listener->sk_ack_backlog--;
+               sk_acceptq_removed(listener);
 
                lock_sock_nested(connected, SINGLE_DEPTH_NESTING);
                vconnected = vsock_sk(connected);
@@ -1366,6 +1475,23 @@ out:
        return err;
 }
 
+static void vsock_update_buffer_size(struct vsock_sock *vsk,
+                                    const struct vsock_transport *transport,
+                                    u64 val)
+{
+       if (val > vsk->buffer_max_size)
+               val = vsk->buffer_max_size;
+
+       if (val < vsk->buffer_min_size)
+               val = vsk->buffer_min_size;
+
+       if (val != vsk->buffer_size &&
+           transport && transport->notify_buffer_size)
+               transport->notify_buffer_size(vsk, &val);
+
+       vsk->buffer_size = val;
+}
+
 static int vsock_stream_setsockopt(struct socket *sock,
                                   int level,
                                   int optname,
@@ -1375,6 +1501,7 @@ static int vsock_stream_setsockopt(struct socket *sock,
        int err;
        struct sock *sk;
        struct vsock_sock *vsk;
+       const struct vsock_transport *transport;
        u64 val;
 
        if (level != AF_VSOCK)
@@ -1395,23 +1522,26 @@ static int vsock_stream_setsockopt(struct socket *sock,
        err = 0;
        sk = sock->sk;
        vsk = vsock_sk(sk);
+       transport = vsk->transport;
 
        lock_sock(sk);
 
        switch (optname) {
        case SO_VM_SOCKETS_BUFFER_SIZE:
                COPY_IN(val);
-               transport->set_buffer_size(vsk, val);
+               vsock_update_buffer_size(vsk, transport, val);
                break;
 
        case SO_VM_SOCKETS_BUFFER_MAX_SIZE:
                COPY_IN(val);
-               transport->set_max_buffer_size(vsk, val);
+               vsk->buffer_max_size = val;
+               vsock_update_buffer_size(vsk, transport, vsk->buffer_size);
                break;
 
        case SO_VM_SOCKETS_BUFFER_MIN_SIZE:
                COPY_IN(val);
-               transport->set_min_buffer_size(vsk, val);
+               vsk->buffer_min_size = val;
+               vsock_update_buffer_size(vsk, transport, vsk->buffer_size);
                break;
 
        case SO_VM_SOCKETS_CONNECT_TIMEOUT: {
@@ -1478,17 +1608,17 @@ static int vsock_stream_getsockopt(struct socket *sock,
 
        switch (optname) {
        case SO_VM_SOCKETS_BUFFER_SIZE:
-               val = transport->get_buffer_size(vsk);
+               val = vsk->buffer_size;
                COPY_OUT(val);
                break;
 
        case SO_VM_SOCKETS_BUFFER_MAX_SIZE:
-               val = transport->get_max_buffer_size(vsk);
+               val = vsk->buffer_max_size;
                COPY_OUT(val);
                break;
 
        case SO_VM_SOCKETS_BUFFER_MIN_SIZE:
-               val = transport->get_min_buffer_size(vsk);
+               val = vsk->buffer_min_size;
                COPY_OUT(val);
                break;
 
@@ -1519,6 +1649,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
 {
        struct sock *sk;
        struct vsock_sock *vsk;
+       const struct vsock_transport *transport;
        ssize_t total_written;
        long timeout;
        int err;
@@ -1527,6 +1658,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
 
        sk = sock->sk;
        vsk = vsock_sk(sk);
+       transport = vsk->transport;
        total_written = 0;
        err = 0;
 
@@ -1548,7 +1680,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
                goto out;
        }
 
-       if (sk->sk_state != TCP_ESTABLISHED ||
+       if (!transport || sk->sk_state != TCP_ESTABLISHED ||
            !vsock_addr_bound(&vsk->local_addr)) {
                err = -ENOTCONN;
                goto out;
@@ -1658,6 +1790,7 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 {
        struct sock *sk;
        struct vsock_sock *vsk;
+       const struct vsock_transport *transport;
        int err;
        size_t target;
        ssize_t copied;
@@ -1668,11 +1801,12 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 
        sk = sock->sk;
        vsk = vsock_sk(sk);
+       transport = vsk->transport;
        err = 0;
 
        lock_sock(sk);
 
-       if (sk->sk_state != TCP_ESTABLISHED) {
+       if (!transport || sk->sk_state != TCP_ESTABLISHED) {
                /* Recvmsg is supposed to return 0 if a peer performs an
                 * orderly shutdown. Differentiate between that case and when a
                 * peer has not connected or a local shutdown occured with the
@@ -1846,6 +1980,10 @@ static const struct proto_ops vsock_stream_ops = {
 static int vsock_create(struct net *net, struct socket *sock,
                        int protocol, int kern)
 {
+       struct vsock_sock *vsk;
+       struct sock *sk;
+       int ret;
+
        if (!sock)
                return -EINVAL;
 
@@ -1865,7 +2003,23 @@ static int vsock_create(struct net *net, struct socket *sock,
 
        sock->state = SS_UNCONNECTED;
 
-       return __vsock_create(net, sock, NULL, GFP_KERNEL, 0, kern) ? 0 : -ENOMEM;
+       sk = __vsock_create(net, sock, NULL, GFP_KERNEL, 0, kern);
+       if (!sk)
+               return -ENOMEM;
+
+       vsk = vsock_sk(sk);
+
+       if (sock->type == SOCK_DGRAM) {
+               ret = vsock_assign_transport(vsk, NULL);
+               if (ret < 0) {
+                       sock_put(sk);
+                       return ret;
+               }
+       }
+
+       vsock_insert_unbound(vsk);
+
+       return 0;
 }
 
 static const struct net_proto_family vsock_family_ops = {
@@ -1878,11 +2032,20 @@ static long vsock_dev_do_ioctl(struct file *filp,
                               unsigned int cmd, void __user *ptr)
 {
        u32 __user *p = ptr;
+       u32 cid = VMADDR_CID_ANY;
        int retval = 0;
 
        switch (cmd) {
        case IOCTL_VM_SOCKETS_GET_LOCAL_CID:
-               if (put_user(transport->get_local_cid(), p) != 0)
+               /* To be compatible with the VMCI behavior, we prioritize the
+                * guest CID instead of well-know host CID (VMADDR_CID_HOST).
+                */
+               if (transport_g2h)
+                       cid = transport_g2h->get_local_cid();
+               else if (transport_h2g)
+                       cid = transport_h2g->get_local_cid();
+
+               if (put_user(cid, p) != 0)
                        retval = -EFAULT;
                break;
 
@@ -1922,24 +2085,13 @@ static struct miscdevice vsock_device = {
        .fops           = &vsock_device_ops,
 };
 
-int __vsock_core_init(const struct vsock_transport *t, struct module *owner)
+static int __init vsock_init(void)
 {
-       int err = mutex_lock_interruptible(&vsock_register_mutex);
+       int err = 0;
 
-       if (err)
-               return err;
-
-       if (transport) {
-               err = -EBUSY;
-               goto err_busy;
-       }
-
-       /* Transport must be the owner of the protocol so that it can't
-        * unload while there are open sockets.
-        */
-       vsock_proto.owner = owner;
-       transport = t;
+       vsock_init_tables();
 
+       vsock_proto.owner = THIS_MODULE;
        vsock_device.minor = MISC_DYNAMIC_MINOR;
        err = misc_register(&vsock_device);
        if (err) {
@@ -1960,7 +2112,6 @@ int __vsock_core_init(const struct vsock_transport *t, struct module *owner)
                goto err_unregister_proto;
        }
 
-       mutex_unlock(&vsock_register_mutex);
        return 0;
 
 err_unregister_proto:
@@ -1968,44 +2119,86 @@ err_unregister_proto:
 err_deregister_misc:
        misc_deregister(&vsock_device);
 err_reset_transport:
-       transport = NULL;
-err_busy:
-       mutex_unlock(&vsock_register_mutex);
        return err;
 }
-EXPORT_SYMBOL_GPL(__vsock_core_init);
 
-void vsock_core_exit(void)
+static void __exit vsock_exit(void)
 {
-       mutex_lock(&vsock_register_mutex);
-
        misc_deregister(&vsock_device);
        sock_unregister(AF_VSOCK);
        proto_unregister(&vsock_proto);
-
-       /* We do not want the assignment below re-ordered. */
-       mb();
-       transport = NULL;
-
-       mutex_unlock(&vsock_register_mutex);
 }
-EXPORT_SYMBOL_GPL(vsock_core_exit);
 
-const struct vsock_transport *vsock_core_get_transport(void)
+const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk)
 {
-       /* vsock_register_mutex not taken since only the transport uses this
-        * function and only while registered.
-        */
-       return transport;
+       return vsk->transport;
 }
 EXPORT_SYMBOL_GPL(vsock_core_get_transport);
 
-static void __exit vsock_exit(void)
+int vsock_core_register(const struct vsock_transport *t, int features)
+{
+       const struct vsock_transport *t_h2g, *t_g2h, *t_dgram;
+       int err = mutex_lock_interruptible(&vsock_register_mutex);
+
+       if (err)
+               return err;
+
+       t_h2g = transport_h2g;
+       t_g2h = transport_g2h;
+       t_dgram = transport_dgram;
+
+       if (features & VSOCK_TRANSPORT_F_H2G) {
+               if (t_h2g) {
+                       err = -EBUSY;
+                       goto err_busy;
+               }
+               t_h2g = t;
+       }
+
+       if (features & VSOCK_TRANSPORT_F_G2H) {
+               if (t_g2h) {
+                       err = -EBUSY;
+                       goto err_busy;
+               }
+               t_g2h = t;
+       }
+
+       if (features & VSOCK_TRANSPORT_F_DGRAM) {
+               if (t_dgram) {
+                       err = -EBUSY;
+                       goto err_busy;
+               }
+               t_dgram = t;
+       }
+
+       transport_h2g = t_h2g;
+       transport_g2h = t_g2h;
+       transport_dgram = t_dgram;
+
+err_busy:
+       mutex_unlock(&vsock_register_mutex);
+       return err;
+}
+EXPORT_SYMBOL_GPL(vsock_core_register);
+
+void vsock_core_unregister(const struct vsock_transport *t)
 {
-       /* Do nothing.  This function makes this module removable. */
+       mutex_lock(&vsock_register_mutex);
+
+       if (transport_h2g == t)
+               transport_h2g = NULL;
+
+       if (transport_g2h == t)
+               transport_g2h = NULL;
+
+       if (transport_dgram == t)
+               transport_dgram = NULL;
+
+       mutex_unlock(&vsock_register_mutex);
 }
+EXPORT_SYMBOL_GPL(vsock_core_unregister);
 
-module_init(vsock_init_tables);
+module_init(vsock_init);
 module_exit(vsock_exit);
 
 MODULE_AUTHOR("VMware, Inc.");