Merge branch 'work.vboxsf' of git://git.kernel.org/pub/scm/linux/kernel/git/viro/vfs
[linux-2.6-microblaze.git] / net / vmw_vsock / af_vsock.c
index 74db4cd..9c5b2a9 100644 (file)
@@ -136,6 +136,8 @@ static const struct vsock_transport *transport_h2g;
 static const struct vsock_transport *transport_g2h;
 /* Transport used for DGRAM communication */
 static const struct vsock_transport *transport_dgram;
+/* Transport used for local communication */
+static const struct vsock_transport *transport_local;
 static DEFINE_MUTEX(vsock_register_mutex);
 
 /**** UTILS ****/
@@ -386,6 +388,21 @@ void vsock_enqueue_accept(struct sock *listener, struct sock *connected)
 }
 EXPORT_SYMBOL_GPL(vsock_enqueue_accept);
 
+static bool vsock_use_local_transport(unsigned int remote_cid)
+{
+       if (!transport_local)
+               return false;
+
+       if (remote_cid == VMADDR_CID_LOCAL)
+               return true;
+
+       if (transport_g2h) {
+               return remote_cid == transport_g2h->get_local_cid();
+       } else {
+               return remote_cid == VMADDR_CID_HOST;
+       }
+}
+
 static void vsock_deassign_transport(struct vsock_sock *vsk)
 {
        if (!vsk->transport)
@@ -402,9 +419,9 @@ static void vsock_deassign_transport(struct vsock_sock *vsk)
  * (e.g. during the connect() or when a connection request on a listener
  * socket is received).
  * The vsk->remote_addr is used to decide which transport to use:
+ *  - remote CID == VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_HOST if
+ *    g2h is not loaded, will use local transport;
  *  - remote CID <= VMADDR_CID_HOST will use guest->host transport;
- *  - remote CID == local_cid (guest->host transport) will use guest->host
- *    transport for loopback (host->guest transports don't support loopback);
  *  - remote CID > VMADDR_CID_HOST will use host->guest transport;
  */
 int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
@@ -419,9 +436,9 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
                new_transport = transport_dgram;
                break;
        case SOCK_STREAM:
-               if (remote_cid <= VMADDR_CID_HOST ||
-                   (transport_g2h &&
-                    remote_cid == transport_g2h->get_local_cid()))
+               if (vsock_use_local_transport(remote_cid))
+                       new_transport = transport_local;
+               else if (remote_cid <= VMADDR_CID_HOST)
                        new_transport = transport_g2h;
                else
                        new_transport = transport_h2g;
@@ -464,6 +481,9 @@ bool vsock_find_cid(unsigned int cid)
        if (transport_h2g && cid == VMADDR_CID_HOST)
                return true;
 
+       if (transport_local && cid == VMADDR_CID_LOCAL)
+               return true;
+
        return false;
 }
 EXPORT_SYMBOL_GPL(vsock_find_cid);
@@ -2137,7 +2157,7 @@ EXPORT_SYMBOL_GPL(vsock_core_get_transport);
 
 int vsock_core_register(const struct vsock_transport *t, int features)
 {
-       const struct vsock_transport *t_h2g, *t_g2h, *t_dgram;
+       const struct vsock_transport *t_h2g, *t_g2h, *t_dgram, *t_local;
        int err = mutex_lock_interruptible(&vsock_register_mutex);
 
        if (err)
@@ -2146,6 +2166,7 @@ int vsock_core_register(const struct vsock_transport *t, int features)
        t_h2g = transport_h2g;
        t_g2h = transport_g2h;
        t_dgram = transport_dgram;
+       t_local = transport_local;
 
        if (features & VSOCK_TRANSPORT_F_H2G) {
                if (t_h2g) {
@@ -2171,9 +2192,18 @@ int vsock_core_register(const struct vsock_transport *t, int features)
                t_dgram = t;
        }
 
+       if (features & VSOCK_TRANSPORT_F_LOCAL) {
+               if (t_local) {
+                       err = -EBUSY;
+                       goto err_busy;
+               }
+               t_local = t;
+       }
+
        transport_h2g = t_h2g;
        transport_g2h = t_g2h;
        transport_dgram = t_dgram;
+       transport_local = t_local;
 
 err_busy:
        mutex_unlock(&vsock_register_mutex);
@@ -2194,6 +2224,9 @@ void vsock_core_unregister(const struct vsock_transport *t)
        if (transport_dgram == t)
                transport_dgram = NULL;
 
+       if (transport_local == t)
+               transport_local = NULL;
+
        mutex_unlock(&vsock_register_mutex);
 }
 EXPORT_SYMBOL_GPL(vsock_core_unregister);