scsi: ibmvscsi: redo driver work thread to use enum action states
[linux-2.6-microblaze.git] / net / vmw_vsock / virtio_transport_common.c
1 /*
2  * common code for virtio vsock
3  *
4  * Copyright (C) 2013-2015 Red Hat, Inc.
5  * Author: Asias He <asias@redhat.com>
6  *         Stefan Hajnoczi <stefanha@redhat.com>
7  *
8  * This work is licensed under the terms of the GNU GPL, version 2.
9  */
10 #include <linux/spinlock.h>
11 #include <linux/module.h>
12 #include <linux/sched/signal.h>
13 #include <linux/ctype.h>
14 #include <linux/list.h>
15 #include <linux/virtio.h>
16 #include <linux/virtio_ids.h>
17 #include <linux/virtio_config.h>
18 #include <linux/virtio_vsock.h>
19 #include <uapi/linux/vsockmon.h>
20
21 #include <net/sock.h>
22 #include <net/af_vsock.h>
23
24 #define CREATE_TRACE_POINTS
25 #include <trace/events/vsock_virtio_transport_common.h>
26
27 /* How long to wait for graceful shutdown of a connection */
28 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
29
30 static const struct virtio_transport *virtio_transport_get_ops(void)
31 {
32         const struct vsock_transport *t = vsock_core_get_transport();
33
34         return container_of(t, struct virtio_transport, transport);
35 }
36
37 static struct virtio_vsock_pkt *
38 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
39                            size_t len,
40                            u32 src_cid,
41                            u32 src_port,
42                            u32 dst_cid,
43                            u32 dst_port)
44 {
45         struct virtio_vsock_pkt *pkt;
46         int err;
47
48         pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
49         if (!pkt)
50                 return NULL;
51
52         pkt->hdr.type           = cpu_to_le16(info->type);
53         pkt->hdr.op             = cpu_to_le16(info->op);
54         pkt->hdr.src_cid        = cpu_to_le64(src_cid);
55         pkt->hdr.dst_cid        = cpu_to_le64(dst_cid);
56         pkt->hdr.src_port       = cpu_to_le32(src_port);
57         pkt->hdr.dst_port       = cpu_to_le32(dst_port);
58         pkt->hdr.flags          = cpu_to_le32(info->flags);
59         pkt->len                = len;
60         pkt->hdr.len            = cpu_to_le32(len);
61         pkt->reply              = info->reply;
62         pkt->vsk                = info->vsk;
63
64         if (info->msg && len > 0) {
65                 pkt->buf = kmalloc(len, GFP_KERNEL);
66                 if (!pkt->buf)
67                         goto out_pkt;
68                 err = memcpy_from_msg(pkt->buf, info->msg, len);
69                 if (err)
70                         goto out;
71         }
72
73         trace_virtio_transport_alloc_pkt(src_cid, src_port,
74                                          dst_cid, dst_port,
75                                          len,
76                                          info->type,
77                                          info->op,
78                                          info->flags);
79
80         return pkt;
81
82 out:
83         kfree(pkt->buf);
84 out_pkt:
85         kfree(pkt);
86         return NULL;
87 }
88
89 /* Packet capture */
90 static struct sk_buff *virtio_transport_build_skb(void *opaque)
91 {
92         struct virtio_vsock_pkt *pkt = opaque;
93         struct af_vsockmon_hdr *hdr;
94         struct sk_buff *skb;
95
96         skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + pkt->len,
97                         GFP_ATOMIC);
98         if (!skb)
99                 return NULL;
100
101         hdr = skb_put(skb, sizeof(*hdr));
102
103         /* pkt->hdr is little-endian so no need to byteswap here */
104         hdr->src_cid = pkt->hdr.src_cid;
105         hdr->src_port = pkt->hdr.src_port;
106         hdr->dst_cid = pkt->hdr.dst_cid;
107         hdr->dst_port = pkt->hdr.dst_port;
108
109         hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
110         hdr->len = cpu_to_le16(sizeof(pkt->hdr));
111         memset(hdr->reserved, 0, sizeof(hdr->reserved));
112
113         switch (le16_to_cpu(pkt->hdr.op)) {
114         case VIRTIO_VSOCK_OP_REQUEST:
115         case VIRTIO_VSOCK_OP_RESPONSE:
116                 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
117                 break;
118         case VIRTIO_VSOCK_OP_RST:
119         case VIRTIO_VSOCK_OP_SHUTDOWN:
120                 hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
121                 break;
122         case VIRTIO_VSOCK_OP_RW:
123                 hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
124                 break;
125         case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
126         case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
127                 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
128                 break;
129         default:
130                 hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
131                 break;
132         }
133
134         skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr));
135
136         if (pkt->len) {
137                 skb_put_data(skb, pkt->buf, pkt->len);
138         }
139
140         return skb;
141 }
142
143 void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
144 {
145         vsock_deliver_tap(virtio_transport_build_skb, pkt);
146 }
147 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
148
149 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
150                                           struct virtio_vsock_pkt_info *info)
151 {
152         u32 src_cid, src_port, dst_cid, dst_port;
153         struct virtio_vsock_sock *vvs;
154         struct virtio_vsock_pkt *pkt;
155         u32 pkt_len = info->pkt_len;
156
157         src_cid = vm_sockets_get_local_cid();
158         src_port = vsk->local_addr.svm_port;
159         if (!info->remote_cid) {
160                 dst_cid = vsk->remote_addr.svm_cid;
161                 dst_port = vsk->remote_addr.svm_port;
162         } else {
163                 dst_cid = info->remote_cid;
164                 dst_port = info->remote_port;
165         }
166
167         vvs = vsk->trans;
168
169         /* we can send less than pkt_len bytes */
170         if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
171                 pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
172
173         /* virtio_transport_get_credit might return less than pkt_len credit */
174         pkt_len = virtio_transport_get_credit(vvs, pkt_len);
175
176         /* Do not send zero length OP_RW pkt */
177         if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
178                 return pkt_len;
179
180         pkt = virtio_transport_alloc_pkt(info, pkt_len,
181                                          src_cid, src_port,
182                                          dst_cid, dst_port);
183         if (!pkt) {
184                 virtio_transport_put_credit(vvs, pkt_len);
185                 return -ENOMEM;
186         }
187
188         virtio_transport_inc_tx_pkt(vvs, pkt);
189
190         return virtio_transport_get_ops()->send_pkt(pkt);
191 }
192
193 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
194                                         struct virtio_vsock_pkt *pkt)
195 {
196         vvs->rx_bytes += pkt->len;
197 }
198
199 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
200                                         struct virtio_vsock_pkt *pkt)
201 {
202         vvs->rx_bytes -= pkt->len;
203         vvs->fwd_cnt += pkt->len;
204 }
205
206 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
207 {
208         spin_lock_bh(&vvs->tx_lock);
209         pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
210         pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
211         spin_unlock_bh(&vvs->tx_lock);
212 }
213 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
214
215 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
216 {
217         u32 ret;
218
219         spin_lock_bh(&vvs->tx_lock);
220         ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
221         if (ret > credit)
222                 ret = credit;
223         vvs->tx_cnt += ret;
224         spin_unlock_bh(&vvs->tx_lock);
225
226         return ret;
227 }
228 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
229
230 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
231 {
232         spin_lock_bh(&vvs->tx_lock);
233         vvs->tx_cnt -= credit;
234         spin_unlock_bh(&vvs->tx_lock);
235 }
236 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
237
238 static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
239                                                int type,
240                                                struct virtio_vsock_hdr *hdr)
241 {
242         struct virtio_vsock_pkt_info info = {
243                 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
244                 .type = type,
245                 .vsk = vsk,
246         };
247
248         return virtio_transport_send_pkt_info(vsk, &info);
249 }
250
251 static ssize_t
252 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
253                                    struct msghdr *msg,
254                                    size_t len)
255 {
256         struct virtio_vsock_sock *vvs = vsk->trans;
257         struct virtio_vsock_pkt *pkt;
258         size_t bytes, total = 0;
259         int err = -EFAULT;
260
261         spin_lock_bh(&vvs->rx_lock);
262         while (total < len && !list_empty(&vvs->rx_queue)) {
263                 pkt = list_first_entry(&vvs->rx_queue,
264                                        struct virtio_vsock_pkt, list);
265
266                 bytes = len - total;
267                 if (bytes > pkt->len - pkt->off)
268                         bytes = pkt->len - pkt->off;
269
270                 /* sk_lock is held by caller so no one else can dequeue.
271                  * Unlock rx_lock since memcpy_to_msg() may sleep.
272                  */
273                 spin_unlock_bh(&vvs->rx_lock);
274
275                 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
276                 if (err)
277                         goto out;
278
279                 spin_lock_bh(&vvs->rx_lock);
280
281                 total += bytes;
282                 pkt->off += bytes;
283                 if (pkt->off == pkt->len) {
284                         virtio_transport_dec_rx_pkt(vvs, pkt);
285                         list_del(&pkt->list);
286                         virtio_transport_free_pkt(pkt);
287                 }
288         }
289         spin_unlock_bh(&vvs->rx_lock);
290
291         /* Send a credit pkt to peer */
292         virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
293                                             NULL);
294
295         return total;
296
297 out:
298         if (total)
299                 err = total;
300         return err;
301 }
302
303 ssize_t
304 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
305                                 struct msghdr *msg,
306                                 size_t len, int flags)
307 {
308         if (flags & MSG_PEEK)
309                 return -EOPNOTSUPP;
310
311         return virtio_transport_stream_do_dequeue(vsk, msg, len);
312 }
313 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
314
315 int
316 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
317                                struct msghdr *msg,
318                                size_t len, int flags)
319 {
320         return -EOPNOTSUPP;
321 }
322 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
323
324 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
325 {
326         struct virtio_vsock_sock *vvs = vsk->trans;
327         s64 bytes;
328
329         spin_lock_bh(&vvs->rx_lock);
330         bytes = vvs->rx_bytes;
331         spin_unlock_bh(&vvs->rx_lock);
332
333         return bytes;
334 }
335 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
336
337 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
338 {
339         struct virtio_vsock_sock *vvs = vsk->trans;
340         s64 bytes;
341
342         bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
343         if (bytes < 0)
344                 bytes = 0;
345
346         return bytes;
347 }
348
349 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
350 {
351         struct virtio_vsock_sock *vvs = vsk->trans;
352         s64 bytes;
353
354         spin_lock_bh(&vvs->tx_lock);
355         bytes = virtio_transport_has_space(vsk);
356         spin_unlock_bh(&vvs->tx_lock);
357
358         return bytes;
359 }
360 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
361
362 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
363                                     struct vsock_sock *psk)
364 {
365         struct virtio_vsock_sock *vvs;
366
367         vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
368         if (!vvs)
369                 return -ENOMEM;
370
371         vsk->trans = vvs;
372         vvs->vsk = vsk;
373         if (psk) {
374                 struct virtio_vsock_sock *ptrans = psk->trans;
375
376                 vvs->buf_size   = ptrans->buf_size;
377                 vvs->buf_size_min = ptrans->buf_size_min;
378                 vvs->buf_size_max = ptrans->buf_size_max;
379                 vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
380         } else {
381                 vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
382                 vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
383                 vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
384         }
385
386         vvs->buf_alloc = vvs->buf_size;
387
388         spin_lock_init(&vvs->rx_lock);
389         spin_lock_init(&vvs->tx_lock);
390         INIT_LIST_HEAD(&vvs->rx_queue);
391
392         return 0;
393 }
394 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
395
396 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
397 {
398         struct virtio_vsock_sock *vvs = vsk->trans;
399
400         return vvs->buf_size;
401 }
402 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
403
404 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
405 {
406         struct virtio_vsock_sock *vvs = vsk->trans;
407
408         return vvs->buf_size_min;
409 }
410 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
411
412 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
413 {
414         struct virtio_vsock_sock *vvs = vsk->trans;
415
416         return vvs->buf_size_max;
417 }
418 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
419
420 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
421 {
422         struct virtio_vsock_sock *vvs = vsk->trans;
423
424         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
425                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
426         if (val < vvs->buf_size_min)
427                 vvs->buf_size_min = val;
428         if (val > vvs->buf_size_max)
429                 vvs->buf_size_max = val;
430         vvs->buf_size = val;
431         vvs->buf_alloc = val;
432 }
433 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
434
435 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
436 {
437         struct virtio_vsock_sock *vvs = vsk->trans;
438
439         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
440                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
441         if (val > vvs->buf_size)
442                 vvs->buf_size = val;
443         vvs->buf_size_min = val;
444 }
445 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
446
447 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
448 {
449         struct virtio_vsock_sock *vvs = vsk->trans;
450
451         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
452                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
453         if (val < vvs->buf_size)
454                 vvs->buf_size = val;
455         vvs->buf_size_max = val;
456 }
457 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
458
459 int
460 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
461                                 size_t target,
462                                 bool *data_ready_now)
463 {
464         if (vsock_stream_has_data(vsk))
465                 *data_ready_now = true;
466         else
467                 *data_ready_now = false;
468
469         return 0;
470 }
471 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
472
473 int
474 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
475                                  size_t target,
476                                  bool *space_avail_now)
477 {
478         s64 free_space;
479
480         free_space = vsock_stream_has_space(vsk);
481         if (free_space > 0)
482                 *space_avail_now = true;
483         else if (free_space == 0)
484                 *space_avail_now = false;
485
486         return 0;
487 }
488 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
489
490 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
491         size_t target, struct vsock_transport_recv_notify_data *data)
492 {
493         return 0;
494 }
495 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
496
497 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
498         size_t target, struct vsock_transport_recv_notify_data *data)
499 {
500         return 0;
501 }
502 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
503
504 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
505         size_t target, struct vsock_transport_recv_notify_data *data)
506 {
507         return 0;
508 }
509 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
510
511 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
512         size_t target, ssize_t copied, bool data_read,
513         struct vsock_transport_recv_notify_data *data)
514 {
515         return 0;
516 }
517 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
518
519 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
520         struct vsock_transport_send_notify_data *data)
521 {
522         return 0;
523 }
524 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
525
526 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
527         struct vsock_transport_send_notify_data *data)
528 {
529         return 0;
530 }
531 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
532
533 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
534         struct vsock_transport_send_notify_data *data)
535 {
536         return 0;
537 }
538 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
539
540 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
541         ssize_t written, struct vsock_transport_send_notify_data *data)
542 {
543         return 0;
544 }
545 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
546
547 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
548 {
549         struct virtio_vsock_sock *vvs = vsk->trans;
550
551         return vvs->buf_size;
552 }
553 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
554
555 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
556 {
557         return true;
558 }
559 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
560
561 bool virtio_transport_stream_allow(u32 cid, u32 port)
562 {
563         return true;
564 }
565 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
566
567 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
568                                 struct sockaddr_vm *addr)
569 {
570         return -EOPNOTSUPP;
571 }
572 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
573
574 bool virtio_transport_dgram_allow(u32 cid, u32 port)
575 {
576         return false;
577 }
578 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
579
580 int virtio_transport_connect(struct vsock_sock *vsk)
581 {
582         struct virtio_vsock_pkt_info info = {
583                 .op = VIRTIO_VSOCK_OP_REQUEST,
584                 .type = VIRTIO_VSOCK_TYPE_STREAM,
585                 .vsk = vsk,
586         };
587
588         return virtio_transport_send_pkt_info(vsk, &info);
589 }
590 EXPORT_SYMBOL_GPL(virtio_transport_connect);
591
592 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
593 {
594         struct virtio_vsock_pkt_info info = {
595                 .op = VIRTIO_VSOCK_OP_SHUTDOWN,
596                 .type = VIRTIO_VSOCK_TYPE_STREAM,
597                 .flags = (mode & RCV_SHUTDOWN ?
598                           VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
599                          (mode & SEND_SHUTDOWN ?
600                           VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
601                 .vsk = vsk,
602         };
603
604         return virtio_transport_send_pkt_info(vsk, &info);
605 }
606 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
607
608 int
609 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
610                                struct sockaddr_vm *remote_addr,
611                                struct msghdr *msg,
612                                size_t dgram_len)
613 {
614         return -EOPNOTSUPP;
615 }
616 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
617
618 ssize_t
619 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
620                                 struct msghdr *msg,
621                                 size_t len)
622 {
623         struct virtio_vsock_pkt_info info = {
624                 .op = VIRTIO_VSOCK_OP_RW,
625                 .type = VIRTIO_VSOCK_TYPE_STREAM,
626                 .msg = msg,
627                 .pkt_len = len,
628                 .vsk = vsk,
629         };
630
631         return virtio_transport_send_pkt_info(vsk, &info);
632 }
633 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
634
635 void virtio_transport_destruct(struct vsock_sock *vsk)
636 {
637         struct virtio_vsock_sock *vvs = vsk->trans;
638
639         kfree(vvs);
640 }
641 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
642
643 static int virtio_transport_reset(struct vsock_sock *vsk,
644                                   struct virtio_vsock_pkt *pkt)
645 {
646         struct virtio_vsock_pkt_info info = {
647                 .op = VIRTIO_VSOCK_OP_RST,
648                 .type = VIRTIO_VSOCK_TYPE_STREAM,
649                 .reply = !!pkt,
650                 .vsk = vsk,
651         };
652
653         /* Send RST only if the original pkt is not a RST pkt */
654         if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
655                 return 0;
656
657         return virtio_transport_send_pkt_info(vsk, &info);
658 }
659
660 /* Normally packets are associated with a socket.  There may be no socket if an
661  * attempt was made to connect to a socket that does not exist.
662  */
663 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
664 {
665         const struct virtio_transport *t;
666         struct virtio_vsock_pkt *reply;
667         struct virtio_vsock_pkt_info info = {
668                 .op = VIRTIO_VSOCK_OP_RST,
669                 .type = le16_to_cpu(pkt->hdr.type),
670                 .reply = true,
671         };
672
673         /* Send RST only if the original pkt is not a RST pkt */
674         if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
675                 return 0;
676
677         reply = virtio_transport_alloc_pkt(&info, 0,
678                                            le64_to_cpu(pkt->hdr.dst_cid),
679                                            le32_to_cpu(pkt->hdr.dst_port),
680                                            le64_to_cpu(pkt->hdr.src_cid),
681                                            le32_to_cpu(pkt->hdr.src_port));
682         if (!reply)
683                 return -ENOMEM;
684
685         t = virtio_transport_get_ops();
686         if (!t) {
687                 virtio_transport_free_pkt(reply);
688                 return -ENOTCONN;
689         }
690
691         return t->send_pkt(reply);
692 }
693
694 static void virtio_transport_wait_close(struct sock *sk, long timeout)
695 {
696         if (timeout) {
697                 DEFINE_WAIT_FUNC(wait, woken_wake_function);
698
699                 add_wait_queue(sk_sleep(sk), &wait);
700
701                 do {
702                         if (sk_wait_event(sk, &timeout,
703                                           sock_flag(sk, SOCK_DONE), &wait))
704                                 break;
705                 } while (!signal_pending(current) && timeout);
706
707                 remove_wait_queue(sk_sleep(sk), &wait);
708         }
709 }
710
711 static void virtio_transport_do_close(struct vsock_sock *vsk,
712                                       bool cancel_timeout)
713 {
714         struct sock *sk = sk_vsock(vsk);
715
716         sock_set_flag(sk, SOCK_DONE);
717         vsk->peer_shutdown = SHUTDOWN_MASK;
718         if (vsock_stream_has_data(vsk) <= 0)
719                 sk->sk_state = TCP_CLOSING;
720         sk->sk_state_change(sk);
721
722         if (vsk->close_work_scheduled &&
723             (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
724                 vsk->close_work_scheduled = false;
725
726                 vsock_remove_sock(vsk);
727
728                 /* Release refcnt obtained when we scheduled the timeout */
729                 sock_put(sk);
730         }
731 }
732
733 static void virtio_transport_close_timeout(struct work_struct *work)
734 {
735         struct vsock_sock *vsk =
736                 container_of(work, struct vsock_sock, close_work.work);
737         struct sock *sk = sk_vsock(vsk);
738
739         sock_hold(sk);
740         lock_sock(sk);
741
742         if (!sock_flag(sk, SOCK_DONE)) {
743                 (void)virtio_transport_reset(vsk, NULL);
744
745                 virtio_transport_do_close(vsk, false);
746         }
747
748         vsk->close_work_scheduled = false;
749
750         release_sock(sk);
751         sock_put(sk);
752 }
753
754 /* User context, vsk->sk is locked */
755 static bool virtio_transport_close(struct vsock_sock *vsk)
756 {
757         struct sock *sk = &vsk->sk;
758
759         if (!(sk->sk_state == TCP_ESTABLISHED ||
760               sk->sk_state == TCP_CLOSING))
761                 return true;
762
763         /* Already received SHUTDOWN from peer, reply with RST */
764         if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
765                 (void)virtio_transport_reset(vsk, NULL);
766                 return true;
767         }
768
769         if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
770                 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
771
772         if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
773                 virtio_transport_wait_close(sk, sk->sk_lingertime);
774
775         if (sock_flag(sk, SOCK_DONE)) {
776                 return true;
777         }
778
779         sock_hold(sk);
780         INIT_DELAYED_WORK(&vsk->close_work,
781                           virtio_transport_close_timeout);
782         vsk->close_work_scheduled = true;
783         schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
784         return false;
785 }
786
787 void virtio_transport_release(struct vsock_sock *vsk)
788 {
789         struct sock *sk = &vsk->sk;
790         bool remove_sock = true;
791
792         lock_sock(sk);
793         if (sk->sk_type == SOCK_STREAM)
794                 remove_sock = virtio_transport_close(vsk);
795         release_sock(sk);
796
797         if (remove_sock)
798                 vsock_remove_sock(vsk);
799 }
800 EXPORT_SYMBOL_GPL(virtio_transport_release);
801
802 static int
803 virtio_transport_recv_connecting(struct sock *sk,
804                                  struct virtio_vsock_pkt *pkt)
805 {
806         struct vsock_sock *vsk = vsock_sk(sk);
807         int err;
808         int skerr;
809
810         switch (le16_to_cpu(pkt->hdr.op)) {
811         case VIRTIO_VSOCK_OP_RESPONSE:
812                 sk->sk_state = TCP_ESTABLISHED;
813                 sk->sk_socket->state = SS_CONNECTED;
814                 vsock_insert_connected(vsk);
815                 sk->sk_state_change(sk);
816                 break;
817         case VIRTIO_VSOCK_OP_INVALID:
818                 break;
819         case VIRTIO_VSOCK_OP_RST:
820                 skerr = ECONNRESET;
821                 err = 0;
822                 goto destroy;
823         default:
824                 skerr = EPROTO;
825                 err = -EINVAL;
826                 goto destroy;
827         }
828         return 0;
829
830 destroy:
831         virtio_transport_reset(vsk, pkt);
832         sk->sk_state = TCP_CLOSE;
833         sk->sk_err = skerr;
834         sk->sk_error_report(sk);
835         return err;
836 }
837
838 static int
839 virtio_transport_recv_connected(struct sock *sk,
840                                 struct virtio_vsock_pkt *pkt)
841 {
842         struct vsock_sock *vsk = vsock_sk(sk);
843         struct virtio_vsock_sock *vvs = vsk->trans;
844         int err = 0;
845
846         switch (le16_to_cpu(pkt->hdr.op)) {
847         case VIRTIO_VSOCK_OP_RW:
848                 pkt->len = le32_to_cpu(pkt->hdr.len);
849                 pkt->off = 0;
850
851                 spin_lock_bh(&vvs->rx_lock);
852                 virtio_transport_inc_rx_pkt(vvs, pkt);
853                 list_add_tail(&pkt->list, &vvs->rx_queue);
854                 spin_unlock_bh(&vvs->rx_lock);
855
856                 sk->sk_data_ready(sk);
857                 return err;
858         case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
859                 sk->sk_write_space(sk);
860                 break;
861         case VIRTIO_VSOCK_OP_SHUTDOWN:
862                 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
863                         vsk->peer_shutdown |= RCV_SHUTDOWN;
864                 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
865                         vsk->peer_shutdown |= SEND_SHUTDOWN;
866                 if (vsk->peer_shutdown == SHUTDOWN_MASK &&
867                     vsock_stream_has_data(vsk) <= 0)
868                         sk->sk_state = TCP_CLOSING;
869                 if (le32_to_cpu(pkt->hdr.flags))
870                         sk->sk_state_change(sk);
871                 break;
872         case VIRTIO_VSOCK_OP_RST:
873                 virtio_transport_do_close(vsk, true);
874                 break;
875         default:
876                 err = -EINVAL;
877                 break;
878         }
879
880         virtio_transport_free_pkt(pkt);
881         return err;
882 }
883
884 static void
885 virtio_transport_recv_disconnecting(struct sock *sk,
886                                     struct virtio_vsock_pkt *pkt)
887 {
888         struct vsock_sock *vsk = vsock_sk(sk);
889
890         if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
891                 virtio_transport_do_close(vsk, true);
892 }
893
894 static int
895 virtio_transport_send_response(struct vsock_sock *vsk,
896                                struct virtio_vsock_pkt *pkt)
897 {
898         struct virtio_vsock_pkt_info info = {
899                 .op = VIRTIO_VSOCK_OP_RESPONSE,
900                 .type = VIRTIO_VSOCK_TYPE_STREAM,
901                 .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
902                 .remote_port = le32_to_cpu(pkt->hdr.src_port),
903                 .reply = true,
904                 .vsk = vsk,
905         };
906
907         return virtio_transport_send_pkt_info(vsk, &info);
908 }
909
910 /* Handle server socket */
911 static int
912 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
913 {
914         struct vsock_sock *vsk = vsock_sk(sk);
915         struct vsock_sock *vchild;
916         struct sock *child;
917
918         if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
919                 virtio_transport_reset(vsk, pkt);
920                 return -EINVAL;
921         }
922
923         if (sk_acceptq_is_full(sk)) {
924                 virtio_transport_reset(vsk, pkt);
925                 return -ENOMEM;
926         }
927
928         child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
929                                sk->sk_type, 0);
930         if (!child) {
931                 virtio_transport_reset(vsk, pkt);
932                 return -ENOMEM;
933         }
934
935         sk->sk_ack_backlog++;
936
937         lock_sock_nested(child, SINGLE_DEPTH_NESTING);
938
939         child->sk_state = TCP_ESTABLISHED;
940
941         vchild = vsock_sk(child);
942         vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
943                         le32_to_cpu(pkt->hdr.dst_port));
944         vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
945                         le32_to_cpu(pkt->hdr.src_port));
946
947         vsock_insert_connected(vchild);
948         vsock_enqueue_accept(sk, child);
949         virtio_transport_send_response(vchild, pkt);
950
951         release_sock(child);
952
953         sk->sk_data_ready(sk);
954         return 0;
955 }
956
957 static bool virtio_transport_space_update(struct sock *sk,
958                                           struct virtio_vsock_pkt *pkt)
959 {
960         struct vsock_sock *vsk = vsock_sk(sk);
961         struct virtio_vsock_sock *vvs = vsk->trans;
962         bool space_available;
963
964         /* buf_alloc and fwd_cnt is always included in the hdr */
965         spin_lock_bh(&vvs->tx_lock);
966         vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
967         vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
968         space_available = virtio_transport_has_space(vsk);
969         spin_unlock_bh(&vvs->tx_lock);
970         return space_available;
971 }
972
973 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
974  * lock.
975  */
976 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
977 {
978         struct sockaddr_vm src, dst;
979         struct vsock_sock *vsk;
980         struct sock *sk;
981         bool space_available;
982
983         vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
984                         le32_to_cpu(pkt->hdr.src_port));
985         vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
986                         le32_to_cpu(pkt->hdr.dst_port));
987
988         trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
989                                         dst.svm_cid, dst.svm_port,
990                                         le32_to_cpu(pkt->hdr.len),
991                                         le16_to_cpu(pkt->hdr.type),
992                                         le16_to_cpu(pkt->hdr.op),
993                                         le32_to_cpu(pkt->hdr.flags),
994                                         le32_to_cpu(pkt->hdr.buf_alloc),
995                                         le32_to_cpu(pkt->hdr.fwd_cnt));
996
997         if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
998                 (void)virtio_transport_reset_no_sock(pkt);
999                 goto free_pkt;
1000         }
1001
1002         /* The socket must be in connected or bound table
1003          * otherwise send reset back
1004          */
1005         sk = vsock_find_connected_socket(&src, &dst);
1006         if (!sk) {
1007                 sk = vsock_find_bound_socket(&dst);
1008                 if (!sk) {
1009                         (void)virtio_transport_reset_no_sock(pkt);
1010                         goto free_pkt;
1011                 }
1012         }
1013
1014         vsk = vsock_sk(sk);
1015
1016         space_available = virtio_transport_space_update(sk, pkt);
1017
1018         lock_sock(sk);
1019
1020         /* Update CID in case it has changed after a transport reset event */
1021         vsk->local_addr.svm_cid = dst.svm_cid;
1022
1023         if (space_available)
1024                 sk->sk_write_space(sk);
1025
1026         switch (sk->sk_state) {
1027         case TCP_LISTEN:
1028                 virtio_transport_recv_listen(sk, pkt);
1029                 virtio_transport_free_pkt(pkt);
1030                 break;
1031         case TCP_SYN_SENT:
1032                 virtio_transport_recv_connecting(sk, pkt);
1033                 virtio_transport_free_pkt(pkt);
1034                 break;
1035         case TCP_ESTABLISHED:
1036                 virtio_transport_recv_connected(sk, pkt);
1037                 break;
1038         case TCP_CLOSING:
1039                 virtio_transport_recv_disconnecting(sk, pkt);
1040                 virtio_transport_free_pkt(pkt);
1041                 break;
1042         default:
1043                 virtio_transport_free_pkt(pkt);
1044                 break;
1045         }
1046         release_sock(sk);
1047
1048         /* Release refcnt obtained when we fetched this socket out of the
1049          * bound or connected list.
1050          */
1051         sock_put(sk);
1052         return;
1053
1054 free_pkt:
1055         virtio_transport_free_pkt(pkt);
1056 }
1057 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1058
1059 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
1060 {
1061         kfree(pkt->buf);
1062         kfree(pkt);
1063 }
1064 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
1065
1066 MODULE_LICENSE("GPL v2");
1067 MODULE_AUTHOR("Asias He");
1068 MODULE_DESCRIPTION("common code for virtio vsock");