Merge tag 'for-6.8-rc4-tag' of git://git.kernel.org/pub/scm/linux/kernel/git/kdave...
[linux-2.6-microblaze.git] / drivers / vhost / vhost.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright (C) 2009 Red Hat, Inc.
3  * Copyright (C) 2006 Rusty Russell IBM Corporation
4  *
5  * Author: Michael S. Tsirkin <mst@redhat.com>
6  *
7  * Inspiration, some code, and most witty comments come from
8  * Documentation/virtual/lguest/lguest.c, by Rusty Russell
9  *
10  * Generic code for virtio server in host kernel.
11  */
12
13 #include <linux/eventfd.h>
14 #include <linux/vhost.h>
15 #include <linux/uio.h>
16 #include <linux/mm.h>
17 #include <linux/miscdevice.h>
18 #include <linux/mutex.h>
19 #include <linux/poll.h>
20 #include <linux/file.h>
21 #include <linux/highmem.h>
22 #include <linux/slab.h>
23 #include <linux/vmalloc.h>
24 #include <linux/kthread.h>
25 #include <linux/module.h>
26 #include <linux/sort.h>
27 #include <linux/sched/mm.h>
28 #include <linux/sched/signal.h>
29 #include <linux/sched/vhost_task.h>
30 #include <linux/interval_tree_generic.h>
31 #include <linux/nospec.h>
32 #include <linux/kcov.h>
33
34 #include "vhost.h"
35
36 static ushort max_mem_regions = 64;
37 module_param(max_mem_regions, ushort, 0444);
38 MODULE_PARM_DESC(max_mem_regions,
39         "Maximum number of memory regions in memory map. (default: 64)");
40 static int max_iotlb_entries = 2048;
41 module_param(max_iotlb_entries, int, 0444);
42 MODULE_PARM_DESC(max_iotlb_entries,
43         "Maximum number of iotlb entries. (default: 2048)");
44
45 enum {
46         VHOST_MEMORY_F_LOG = 0x1,
47 };
48
49 #define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num])
50 #define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num])
51
52 #ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY
53 static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
54 {
55         vq->user_be = !virtio_legacy_is_little_endian();
56 }
57
58 static void vhost_enable_cross_endian_big(struct vhost_virtqueue *vq)
59 {
60         vq->user_be = true;
61 }
62
63 static void vhost_enable_cross_endian_little(struct vhost_virtqueue *vq)
64 {
65         vq->user_be = false;
66 }
67
68 static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp)
69 {
70         struct vhost_vring_state s;
71
72         if (vq->private_data)
73                 return -EBUSY;
74
75         if (copy_from_user(&s, argp, sizeof(s)))
76                 return -EFAULT;
77
78         if (s.num != VHOST_VRING_LITTLE_ENDIAN &&
79             s.num != VHOST_VRING_BIG_ENDIAN)
80                 return -EINVAL;
81
82         if (s.num == VHOST_VRING_BIG_ENDIAN)
83                 vhost_enable_cross_endian_big(vq);
84         else
85                 vhost_enable_cross_endian_little(vq);
86
87         return 0;
88 }
89
90 static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx,
91                                    int __user *argp)
92 {
93         struct vhost_vring_state s = {
94                 .index = idx,
95                 .num = vq->user_be
96         };
97
98         if (copy_to_user(argp, &s, sizeof(s)))
99                 return -EFAULT;
100
101         return 0;
102 }
103
104 static void vhost_init_is_le(struct vhost_virtqueue *vq)
105 {
106         /* Note for legacy virtio: user_be is initialized at reset time
107          * according to the host endianness. If userspace does not set an
108          * explicit endianness, the default behavior is native endian, as
109          * expected by legacy virtio.
110          */
111         vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1) || !vq->user_be;
112 }
113 #else
114 static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
115 {
116 }
117
118 static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp)
119 {
120         return -ENOIOCTLCMD;
121 }
122
123 static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx,
124                                    int __user *argp)
125 {
126         return -ENOIOCTLCMD;
127 }
128
129 static void vhost_init_is_le(struct vhost_virtqueue *vq)
130 {
131         vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1)
132                 || virtio_legacy_is_little_endian();
133 }
134 #endif /* CONFIG_VHOST_CROSS_ENDIAN_LEGACY */
135
136 static void vhost_reset_is_le(struct vhost_virtqueue *vq)
137 {
138         vhost_init_is_le(vq);
139 }
140
141 struct vhost_flush_struct {
142         struct vhost_work work;
143         struct completion wait_event;
144 };
145
146 static void vhost_flush_work(struct vhost_work *work)
147 {
148         struct vhost_flush_struct *s;
149
150         s = container_of(work, struct vhost_flush_struct, work);
151         complete(&s->wait_event);
152 }
153
154 static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh,
155                             poll_table *pt)
156 {
157         struct vhost_poll *poll;
158
159         poll = container_of(pt, struct vhost_poll, table);
160         poll->wqh = wqh;
161         add_wait_queue(wqh, &poll->wait);
162 }
163
164 static int vhost_poll_wakeup(wait_queue_entry_t *wait, unsigned mode, int sync,
165                              void *key)
166 {
167         struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait);
168         struct vhost_work *work = &poll->work;
169
170         if (!(key_to_poll(key) & poll->mask))
171                 return 0;
172
173         if (!poll->dev->use_worker)
174                 work->fn(work);
175         else
176                 vhost_poll_queue(poll);
177
178         return 0;
179 }
180
181 void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn)
182 {
183         clear_bit(VHOST_WORK_QUEUED, &work->flags);
184         work->fn = fn;
185 }
186 EXPORT_SYMBOL_GPL(vhost_work_init);
187
188 /* Init poll structure */
189 void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn,
190                      __poll_t mask, struct vhost_dev *dev,
191                      struct vhost_virtqueue *vq)
192 {
193         init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup);
194         init_poll_funcptr(&poll->table, vhost_poll_func);
195         poll->mask = mask;
196         poll->dev = dev;
197         poll->wqh = NULL;
198         poll->vq = vq;
199
200         vhost_work_init(&poll->work, fn);
201 }
202 EXPORT_SYMBOL_GPL(vhost_poll_init);
203
204 /* Start polling a file. We add ourselves to file's wait queue. The caller must
205  * keep a reference to a file until after vhost_poll_stop is called. */
206 int vhost_poll_start(struct vhost_poll *poll, struct file *file)
207 {
208         __poll_t mask;
209
210         if (poll->wqh)
211                 return 0;
212
213         mask = vfs_poll(file, &poll->table);
214         if (mask)
215                 vhost_poll_wakeup(&poll->wait, 0, 0, poll_to_key(mask));
216         if (mask & EPOLLERR) {
217                 vhost_poll_stop(poll);
218                 return -EINVAL;
219         }
220
221         return 0;
222 }
223 EXPORT_SYMBOL_GPL(vhost_poll_start);
224
225 /* Stop polling a file. After this function returns, it becomes safe to drop the
226  * file reference. You must also flush afterwards. */
227 void vhost_poll_stop(struct vhost_poll *poll)
228 {
229         if (poll->wqh) {
230                 remove_wait_queue(poll->wqh, &poll->wait);
231                 poll->wqh = NULL;
232         }
233 }
234 EXPORT_SYMBOL_GPL(vhost_poll_stop);
235
236 static void vhost_worker_queue(struct vhost_worker *worker,
237                                struct vhost_work *work)
238 {
239         if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
240                 /* We can only add the work to the list after we're
241                  * sure it was not in the list.
242                  * test_and_set_bit() implies a memory barrier.
243                  */
244                 llist_add(&work->node, &worker->work_list);
245                 vhost_task_wake(worker->vtsk);
246         }
247 }
248
249 bool vhost_vq_work_queue(struct vhost_virtqueue *vq, struct vhost_work *work)
250 {
251         struct vhost_worker *worker;
252         bool queued = false;
253
254         rcu_read_lock();
255         worker = rcu_dereference(vq->worker);
256         if (worker) {
257                 queued = true;
258                 vhost_worker_queue(worker, work);
259         }
260         rcu_read_unlock();
261
262         return queued;
263 }
264 EXPORT_SYMBOL_GPL(vhost_vq_work_queue);
265
266 void vhost_vq_flush(struct vhost_virtqueue *vq)
267 {
268         struct vhost_flush_struct flush;
269
270         init_completion(&flush.wait_event);
271         vhost_work_init(&flush.work, vhost_flush_work);
272
273         if (vhost_vq_work_queue(vq, &flush.work))
274                 wait_for_completion(&flush.wait_event);
275 }
276 EXPORT_SYMBOL_GPL(vhost_vq_flush);
277
278 /**
279  * vhost_worker_flush - flush a worker
280  * @worker: worker to flush
281  *
282  * This does not use RCU to protect the worker, so the device or worker
283  * mutex must be held.
284  */
285 static void vhost_worker_flush(struct vhost_worker *worker)
286 {
287         struct vhost_flush_struct flush;
288
289         init_completion(&flush.wait_event);
290         vhost_work_init(&flush.work, vhost_flush_work);
291
292         vhost_worker_queue(worker, &flush.work);
293         wait_for_completion(&flush.wait_event);
294 }
295
296 void vhost_dev_flush(struct vhost_dev *dev)
297 {
298         struct vhost_worker *worker;
299         unsigned long i;
300
301         xa_for_each(&dev->worker_xa, i, worker) {
302                 mutex_lock(&worker->mutex);
303                 if (!worker->attachment_cnt) {
304                         mutex_unlock(&worker->mutex);
305                         continue;
306                 }
307                 vhost_worker_flush(worker);
308                 mutex_unlock(&worker->mutex);
309         }
310 }
311 EXPORT_SYMBOL_GPL(vhost_dev_flush);
312
313 /* A lockless hint for busy polling code to exit the loop */
314 bool vhost_vq_has_work(struct vhost_virtqueue *vq)
315 {
316         struct vhost_worker *worker;
317         bool has_work = false;
318
319         rcu_read_lock();
320         worker = rcu_dereference(vq->worker);
321         if (worker && !llist_empty(&worker->work_list))
322                 has_work = true;
323         rcu_read_unlock();
324
325         return has_work;
326 }
327 EXPORT_SYMBOL_GPL(vhost_vq_has_work);
328
329 void vhost_poll_queue(struct vhost_poll *poll)
330 {
331         vhost_vq_work_queue(poll->vq, &poll->work);
332 }
333 EXPORT_SYMBOL_GPL(vhost_poll_queue);
334
335 static void __vhost_vq_meta_reset(struct vhost_virtqueue *vq)
336 {
337         int j;
338
339         for (j = 0; j < VHOST_NUM_ADDRS; j++)
340                 vq->meta_iotlb[j] = NULL;
341 }
342
343 static void vhost_vq_meta_reset(struct vhost_dev *d)
344 {
345         int i;
346
347         for (i = 0; i < d->nvqs; ++i)
348                 __vhost_vq_meta_reset(d->vqs[i]);
349 }
350
351 static void vhost_vring_call_reset(struct vhost_vring_call *call_ctx)
352 {
353         call_ctx->ctx = NULL;
354         memset(&call_ctx->producer, 0x0, sizeof(struct irq_bypass_producer));
355 }
356
357 bool vhost_vq_is_setup(struct vhost_virtqueue *vq)
358 {
359         return vq->avail && vq->desc && vq->used && vhost_vq_access_ok(vq);
360 }
361 EXPORT_SYMBOL_GPL(vhost_vq_is_setup);
362
363 static void vhost_vq_reset(struct vhost_dev *dev,
364                            struct vhost_virtqueue *vq)
365 {
366         vq->num = 1;
367         vq->desc = NULL;
368         vq->avail = NULL;
369         vq->used = NULL;
370         vq->last_avail_idx = 0;
371         vq->avail_idx = 0;
372         vq->last_used_idx = 0;
373         vq->signalled_used = 0;
374         vq->signalled_used_valid = false;
375         vq->used_flags = 0;
376         vq->log_used = false;
377         vq->log_addr = -1ull;
378         vq->private_data = NULL;
379         vq->acked_features = 0;
380         vq->acked_backend_features = 0;
381         vq->log_base = NULL;
382         vq->error_ctx = NULL;
383         vq->kick = NULL;
384         vq->log_ctx = NULL;
385         vhost_disable_cross_endian(vq);
386         vhost_reset_is_le(vq);
387         vq->busyloop_timeout = 0;
388         vq->umem = NULL;
389         vq->iotlb = NULL;
390         rcu_assign_pointer(vq->worker, NULL);
391         vhost_vring_call_reset(&vq->call_ctx);
392         __vhost_vq_meta_reset(vq);
393 }
394
395 static bool vhost_worker(void *data)
396 {
397         struct vhost_worker *worker = data;
398         struct vhost_work *work, *work_next;
399         struct llist_node *node;
400
401         node = llist_del_all(&worker->work_list);
402         if (node) {
403                 __set_current_state(TASK_RUNNING);
404
405                 node = llist_reverse_order(node);
406                 /* make sure flag is seen after deletion */
407                 smp_wmb();
408                 llist_for_each_entry_safe(work, work_next, node, node) {
409                         clear_bit(VHOST_WORK_QUEUED, &work->flags);
410                         kcov_remote_start_common(worker->kcov_handle);
411                         work->fn(work);
412                         kcov_remote_stop();
413                         cond_resched();
414                 }
415         }
416
417         return !!node;
418 }
419
420 static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq)
421 {
422         kfree(vq->indirect);
423         vq->indirect = NULL;
424         kfree(vq->log);
425         vq->log = NULL;
426         kfree(vq->heads);
427         vq->heads = NULL;
428 }
429
430 /* Helper to allocate iovec buffers for all vqs. */
431 static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
432 {
433         struct vhost_virtqueue *vq;
434         int i;
435
436         for (i = 0; i < dev->nvqs; ++i) {
437                 vq = dev->vqs[i];
438                 vq->indirect = kmalloc_array(UIO_MAXIOV,
439                                              sizeof(*vq->indirect),
440                                              GFP_KERNEL);
441                 vq->log = kmalloc_array(dev->iov_limit, sizeof(*vq->log),
442                                         GFP_KERNEL);
443                 vq->heads = kmalloc_array(dev->iov_limit, sizeof(*vq->heads),
444                                           GFP_KERNEL);
445                 if (!vq->indirect || !vq->log || !vq->heads)
446                         goto err_nomem;
447         }
448         return 0;
449
450 err_nomem:
451         for (; i >= 0; --i)
452                 vhost_vq_free_iovecs(dev->vqs[i]);
453         return -ENOMEM;
454 }
455
456 static void vhost_dev_free_iovecs(struct vhost_dev *dev)
457 {
458         int i;
459
460         for (i = 0; i < dev->nvqs; ++i)
461                 vhost_vq_free_iovecs(dev->vqs[i]);
462 }
463
464 bool vhost_exceeds_weight(struct vhost_virtqueue *vq,
465                           int pkts, int total_len)
466 {
467         struct vhost_dev *dev = vq->dev;
468
469         if ((dev->byte_weight && total_len >= dev->byte_weight) ||
470             pkts >= dev->weight) {
471                 vhost_poll_queue(&vq->poll);
472                 return true;
473         }
474
475         return false;
476 }
477 EXPORT_SYMBOL_GPL(vhost_exceeds_weight);
478
479 static size_t vhost_get_avail_size(struct vhost_virtqueue *vq,
480                                    unsigned int num)
481 {
482         size_t event __maybe_unused =
483                vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
484
485         return size_add(struct_size(vq->avail, ring, num), event);
486 }
487
488 static size_t vhost_get_used_size(struct vhost_virtqueue *vq,
489                                   unsigned int num)
490 {
491         size_t event __maybe_unused =
492                vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
493
494         return size_add(struct_size(vq->used, ring, num), event);
495 }
496
497 static size_t vhost_get_desc_size(struct vhost_virtqueue *vq,
498                                   unsigned int num)
499 {
500         return sizeof(*vq->desc) * num;
501 }
502
503 void vhost_dev_init(struct vhost_dev *dev,
504                     struct vhost_virtqueue **vqs, int nvqs,
505                     int iov_limit, int weight, int byte_weight,
506                     bool use_worker,
507                     int (*msg_handler)(struct vhost_dev *dev, u32 asid,
508                                        struct vhost_iotlb_msg *msg))
509 {
510         struct vhost_virtqueue *vq;
511         int i;
512
513         dev->vqs = vqs;
514         dev->nvqs = nvqs;
515         mutex_init(&dev->mutex);
516         dev->log_ctx = NULL;
517         dev->umem = NULL;
518         dev->iotlb = NULL;
519         dev->mm = NULL;
520         dev->iov_limit = iov_limit;
521         dev->weight = weight;
522         dev->byte_weight = byte_weight;
523         dev->use_worker = use_worker;
524         dev->msg_handler = msg_handler;
525         init_waitqueue_head(&dev->wait);
526         INIT_LIST_HEAD(&dev->read_list);
527         INIT_LIST_HEAD(&dev->pending_list);
528         spin_lock_init(&dev->iotlb_lock);
529         xa_init_flags(&dev->worker_xa, XA_FLAGS_ALLOC);
530
531         for (i = 0; i < dev->nvqs; ++i) {
532                 vq = dev->vqs[i];
533                 vq->log = NULL;
534                 vq->indirect = NULL;
535                 vq->heads = NULL;
536                 vq->dev = dev;
537                 mutex_init(&vq->mutex);
538                 vhost_vq_reset(dev, vq);
539                 if (vq->handle_kick)
540                         vhost_poll_init(&vq->poll, vq->handle_kick,
541                                         EPOLLIN, dev, vq);
542         }
543 }
544 EXPORT_SYMBOL_GPL(vhost_dev_init);
545
546 /* Caller should have device mutex */
547 long vhost_dev_check_owner(struct vhost_dev *dev)
548 {
549         /* Are you the owner? If not, I don't think you mean to do that */
550         return dev->mm == current->mm ? 0 : -EPERM;
551 }
552 EXPORT_SYMBOL_GPL(vhost_dev_check_owner);
553
554 /* Caller should have device mutex */
555 bool vhost_dev_has_owner(struct vhost_dev *dev)
556 {
557         return dev->mm;
558 }
559 EXPORT_SYMBOL_GPL(vhost_dev_has_owner);
560
561 static void vhost_attach_mm(struct vhost_dev *dev)
562 {
563         /* No owner, become one */
564         if (dev->use_worker) {
565                 dev->mm = get_task_mm(current);
566         } else {
567                 /* vDPA device does not use worker thead, so there's
568                  * no need to hold the address space for mm. This help
569                  * to avoid deadlock in the case of mmap() which may
570                  * held the refcnt of the file and depends on release
571                  * method to remove vma.
572                  */
573                 dev->mm = current->mm;
574                 mmgrab(dev->mm);
575         }
576 }
577
578 static void vhost_detach_mm(struct vhost_dev *dev)
579 {
580         if (!dev->mm)
581                 return;
582
583         if (dev->use_worker)
584                 mmput(dev->mm);
585         else
586                 mmdrop(dev->mm);
587
588         dev->mm = NULL;
589 }
590
591 static void vhost_worker_destroy(struct vhost_dev *dev,
592                                  struct vhost_worker *worker)
593 {
594         if (!worker)
595                 return;
596
597         WARN_ON(!llist_empty(&worker->work_list));
598         xa_erase(&dev->worker_xa, worker->id);
599         vhost_task_stop(worker->vtsk);
600         kfree(worker);
601 }
602
603 static void vhost_workers_free(struct vhost_dev *dev)
604 {
605         struct vhost_worker *worker;
606         unsigned long i;
607
608         if (!dev->use_worker)
609                 return;
610
611         for (i = 0; i < dev->nvqs; i++)
612                 rcu_assign_pointer(dev->vqs[i]->worker, NULL);
613         /*
614          * Free the default worker we created and cleanup workers userspace
615          * created but couldn't clean up (it forgot or crashed).
616          */
617         xa_for_each(&dev->worker_xa, i, worker)
618                 vhost_worker_destroy(dev, worker);
619         xa_destroy(&dev->worker_xa);
620 }
621
622 static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
623 {
624         struct vhost_worker *worker;
625         struct vhost_task *vtsk;
626         char name[TASK_COMM_LEN];
627         int ret;
628         u32 id;
629
630         worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT);
631         if (!worker)
632                 return NULL;
633
634         snprintf(name, sizeof(name), "vhost-%d", current->pid);
635
636         vtsk = vhost_task_create(vhost_worker, worker, name);
637         if (!vtsk)
638                 goto free_worker;
639
640         mutex_init(&worker->mutex);
641         init_llist_head(&worker->work_list);
642         worker->kcov_handle = kcov_common_handle();
643         worker->vtsk = vtsk;
644
645         vhost_task_start(vtsk);
646
647         ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL);
648         if (ret < 0)
649                 goto stop_worker;
650         worker->id = id;
651
652         return worker;
653
654 stop_worker:
655         vhost_task_stop(vtsk);
656 free_worker:
657         kfree(worker);
658         return NULL;
659 }
660
661 /* Caller must have device mutex */
662 static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
663                                      struct vhost_worker *worker)
664 {
665         struct vhost_worker *old_worker;
666
667         old_worker = rcu_dereference_check(vq->worker,
668                                            lockdep_is_held(&vq->dev->mutex));
669
670         mutex_lock(&worker->mutex);
671         worker->attachment_cnt++;
672         mutex_unlock(&worker->mutex);
673         rcu_assign_pointer(vq->worker, worker);
674
675         if (!old_worker)
676                 return;
677         /*
678          * Take the worker mutex to make sure we see the work queued from
679          * device wide flushes which doesn't use RCU for execution.
680          */
681         mutex_lock(&old_worker->mutex);
682         old_worker->attachment_cnt--;
683         /*
684          * We don't want to call synchronize_rcu for every vq during setup
685          * because it will slow down VM startup. If we haven't done
686          * VHOST_SET_VRING_KICK and not done the driver specific
687          * SET_ENDPOINT/RUNNUNG then we can skip the sync since there will
688          * not be any works queued for scsi and net.
689          */
690         mutex_lock(&vq->mutex);
691         if (!vhost_vq_get_backend(vq) && !vq->kick) {
692                 mutex_unlock(&vq->mutex);
693                 mutex_unlock(&old_worker->mutex);
694                 /*
695                  * vsock can queue anytime after VHOST_VSOCK_SET_GUEST_CID.
696                  * Warn if it adds support for multiple workers but forgets to
697                  * handle the early queueing case.
698                  */
699                 WARN_ON(!old_worker->attachment_cnt &&
700                         !llist_empty(&old_worker->work_list));
701                 return;
702         }
703         mutex_unlock(&vq->mutex);
704
705         /* Make sure new vq queue/flush/poll calls see the new worker */
706         synchronize_rcu();
707         /* Make sure whatever was queued gets run */
708         vhost_worker_flush(old_worker);
709         mutex_unlock(&old_worker->mutex);
710 }
711
712  /* Caller must have device mutex */
713 static int vhost_vq_attach_worker(struct vhost_virtqueue *vq,
714                                   struct vhost_vring_worker *info)
715 {
716         unsigned long index = info->worker_id;
717         struct vhost_dev *dev = vq->dev;
718         struct vhost_worker *worker;
719
720         if (!dev->use_worker)
721                 return -EINVAL;
722
723         worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT);
724         if (!worker || worker->id != info->worker_id)
725                 return -ENODEV;
726
727         __vhost_vq_attach_worker(vq, worker);
728         return 0;
729 }
730
731 /* Caller must have device mutex */
732 static int vhost_new_worker(struct vhost_dev *dev,
733                             struct vhost_worker_state *info)
734 {
735         struct vhost_worker *worker;
736
737         worker = vhost_worker_create(dev);
738         if (!worker)
739                 return -ENOMEM;
740
741         info->worker_id = worker->id;
742         return 0;
743 }
744
745 /* Caller must have device mutex */
746 static int vhost_free_worker(struct vhost_dev *dev,
747                              struct vhost_worker_state *info)
748 {
749         unsigned long index = info->worker_id;
750         struct vhost_worker *worker;
751
752         worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT);
753         if (!worker || worker->id != info->worker_id)
754                 return -ENODEV;
755
756         mutex_lock(&worker->mutex);
757         if (worker->attachment_cnt) {
758                 mutex_unlock(&worker->mutex);
759                 return -EBUSY;
760         }
761         mutex_unlock(&worker->mutex);
762
763         vhost_worker_destroy(dev, worker);
764         return 0;
765 }
766
767 static int vhost_get_vq_from_user(struct vhost_dev *dev, void __user *argp,
768                                   struct vhost_virtqueue **vq, u32 *id)
769 {
770         u32 __user *idxp = argp;
771         u32 idx;
772         long r;
773
774         r = get_user(idx, idxp);
775         if (r < 0)
776                 return r;
777
778         if (idx >= dev->nvqs)
779                 return -ENOBUFS;
780
781         idx = array_index_nospec(idx, dev->nvqs);
782
783         *vq = dev->vqs[idx];
784         *id = idx;
785         return 0;
786 }
787
788 /* Caller must have device mutex */
789 long vhost_worker_ioctl(struct vhost_dev *dev, unsigned int ioctl,
790                         void __user *argp)
791 {
792         struct vhost_vring_worker ring_worker;
793         struct vhost_worker_state state;
794         struct vhost_worker *worker;
795         struct vhost_virtqueue *vq;
796         long ret;
797         u32 idx;
798
799         if (!dev->use_worker)
800                 return -EINVAL;
801
802         if (!vhost_dev_has_owner(dev))
803                 return -EINVAL;
804
805         ret = vhost_dev_check_owner(dev);
806         if (ret)
807                 return ret;
808
809         switch (ioctl) {
810         /* dev worker ioctls */
811         case VHOST_NEW_WORKER:
812                 ret = vhost_new_worker(dev, &state);
813                 if (!ret && copy_to_user(argp, &state, sizeof(state)))
814                         ret = -EFAULT;
815                 return ret;
816         case VHOST_FREE_WORKER:
817                 if (copy_from_user(&state, argp, sizeof(state)))
818                         return -EFAULT;
819                 return vhost_free_worker(dev, &state);
820         /* vring worker ioctls */
821         case VHOST_ATTACH_VRING_WORKER:
822         case VHOST_GET_VRING_WORKER:
823                 break;
824         default:
825                 return -ENOIOCTLCMD;
826         }
827
828         ret = vhost_get_vq_from_user(dev, argp, &vq, &idx);
829         if (ret)
830                 return ret;
831
832         switch (ioctl) {
833         case VHOST_ATTACH_VRING_WORKER:
834                 if (copy_from_user(&ring_worker, argp, sizeof(ring_worker))) {
835                         ret = -EFAULT;
836                         break;
837                 }
838
839                 ret = vhost_vq_attach_worker(vq, &ring_worker);
840                 break;
841         case VHOST_GET_VRING_WORKER:
842                 worker = rcu_dereference_check(vq->worker,
843                                                lockdep_is_held(&dev->mutex));
844                 if (!worker) {
845                         ret = -EINVAL;
846                         break;
847                 }
848
849                 ring_worker.index = idx;
850                 ring_worker.worker_id = worker->id;
851
852                 if (copy_to_user(argp, &ring_worker, sizeof(ring_worker)))
853                         ret = -EFAULT;
854                 break;
855         default:
856                 ret = -ENOIOCTLCMD;
857                 break;
858         }
859
860         return ret;
861 }
862 EXPORT_SYMBOL_GPL(vhost_worker_ioctl);
863
864 /* Caller should have device mutex */
865 long vhost_dev_set_owner(struct vhost_dev *dev)
866 {
867         struct vhost_worker *worker;
868         int err, i;
869
870         /* Is there an owner already? */
871         if (vhost_dev_has_owner(dev)) {
872                 err = -EBUSY;
873                 goto err_mm;
874         }
875
876         vhost_attach_mm(dev);
877
878         err = vhost_dev_alloc_iovecs(dev);
879         if (err)
880                 goto err_iovecs;
881
882         if (dev->use_worker) {
883                 /*
884                  * This should be done last, because vsock can queue work
885                  * before VHOST_SET_OWNER so it simplifies the failure path
886                  * below since we don't have to worry about vsock queueing
887                  * while we free the worker.
888                  */
889                 worker = vhost_worker_create(dev);
890                 if (!worker) {
891                         err = -ENOMEM;
892                         goto err_worker;
893                 }
894
895                 for (i = 0; i < dev->nvqs; i++)
896                         __vhost_vq_attach_worker(dev->vqs[i], worker);
897         }
898
899         return 0;
900
901 err_worker:
902         vhost_dev_free_iovecs(dev);
903 err_iovecs:
904         vhost_detach_mm(dev);
905 err_mm:
906         return err;
907 }
908 EXPORT_SYMBOL_GPL(vhost_dev_set_owner);
909
910 static struct vhost_iotlb *iotlb_alloc(void)
911 {
912         return vhost_iotlb_alloc(max_iotlb_entries,
913                                  VHOST_IOTLB_FLAG_RETIRE);
914 }
915
916 struct vhost_iotlb *vhost_dev_reset_owner_prepare(void)
917 {
918         return iotlb_alloc();
919 }
920 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
921
922 /* Caller should have device mutex */
923 void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_iotlb *umem)
924 {
925         int i;
926
927         vhost_dev_cleanup(dev);
928
929         dev->umem = umem;
930         /* We don't need VQ locks below since vhost_dev_cleanup makes sure
931          * VQs aren't running.
932          */
933         for (i = 0; i < dev->nvqs; ++i)
934                 dev->vqs[i]->umem = umem;
935 }
936 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
937
938 void vhost_dev_stop(struct vhost_dev *dev)
939 {
940         int i;
941
942         for (i = 0; i < dev->nvqs; ++i) {
943                 if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick)
944                         vhost_poll_stop(&dev->vqs[i]->poll);
945         }
946
947         vhost_dev_flush(dev);
948 }
949 EXPORT_SYMBOL_GPL(vhost_dev_stop);
950
951 void vhost_clear_msg(struct vhost_dev *dev)
952 {
953         struct vhost_msg_node *node, *n;
954
955         spin_lock(&dev->iotlb_lock);
956
957         list_for_each_entry_safe(node, n, &dev->read_list, node) {
958                 list_del(&node->node);
959                 kfree(node);
960         }
961
962         list_for_each_entry_safe(node, n, &dev->pending_list, node) {
963                 list_del(&node->node);
964                 kfree(node);
965         }
966
967         spin_unlock(&dev->iotlb_lock);
968 }
969 EXPORT_SYMBOL_GPL(vhost_clear_msg);
970
971 void vhost_dev_cleanup(struct vhost_dev *dev)
972 {
973         int i;
974
975         for (i = 0; i < dev->nvqs; ++i) {
976                 if (dev->vqs[i]->error_ctx)
977                         eventfd_ctx_put(dev->vqs[i]->error_ctx);
978                 if (dev->vqs[i]->kick)
979                         fput(dev->vqs[i]->kick);
980                 if (dev->vqs[i]->call_ctx.ctx)
981                         eventfd_ctx_put(dev->vqs[i]->call_ctx.ctx);
982                 vhost_vq_reset(dev, dev->vqs[i]);
983         }
984         vhost_dev_free_iovecs(dev);
985         if (dev->log_ctx)
986                 eventfd_ctx_put(dev->log_ctx);
987         dev->log_ctx = NULL;
988         /* No one will access memory at this point */
989         vhost_iotlb_free(dev->umem);
990         dev->umem = NULL;
991         vhost_iotlb_free(dev->iotlb);
992         dev->iotlb = NULL;
993         vhost_clear_msg(dev);
994         wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
995         vhost_workers_free(dev);
996         vhost_detach_mm(dev);
997 }
998 EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
999
1000 static bool log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
1001 {
1002         u64 a = addr / VHOST_PAGE_SIZE / 8;
1003
1004         /* Make sure 64 bit math will not overflow. */
1005         if (a > ULONG_MAX - (unsigned long)log_base ||
1006             a + (unsigned long)log_base > ULONG_MAX)
1007                 return false;
1008
1009         return access_ok(log_base + a,
1010                          (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
1011 }
1012
1013 /* Make sure 64 bit math will not overflow. */
1014 static bool vhost_overflow(u64 uaddr, u64 size)
1015 {
1016         if (uaddr > ULONG_MAX || size > ULONG_MAX)
1017                 return true;
1018
1019         if (!size)
1020                 return false;
1021
1022         return uaddr > ULONG_MAX - size + 1;
1023 }
1024
1025 /* Caller should have vq mutex and device mutex. */
1026 static bool vq_memory_access_ok(void __user *log_base, struct vhost_iotlb *umem,
1027                                 int log_all)
1028 {
1029         struct vhost_iotlb_map *map;
1030
1031         if (!umem)
1032                 return false;
1033
1034         list_for_each_entry(map, &umem->list, link) {
1035                 unsigned long a = map->addr;
1036
1037                 if (vhost_overflow(map->addr, map->size))
1038                         return false;
1039
1040
1041                 if (!access_ok((void __user *)a, map->size))
1042                         return false;
1043                 else if (log_all && !log_access_ok(log_base,
1044                                                    map->start,
1045                                                    map->size))
1046                         return false;
1047         }
1048         return true;
1049 }
1050
1051 static inline void __user *vhost_vq_meta_fetch(struct vhost_virtqueue *vq,
1052                                                u64 addr, unsigned int size,
1053                                                int type)
1054 {
1055         const struct vhost_iotlb_map *map = vq->meta_iotlb[type];
1056
1057         if (!map)
1058                 return NULL;
1059
1060         return (void __user *)(uintptr_t)(map->addr + addr - map->start);
1061 }
1062
1063 /* Can we switch to this memory table? */
1064 /* Caller should have device mutex but not vq mutex */
1065 static bool memory_access_ok(struct vhost_dev *d, struct vhost_iotlb *umem,
1066                              int log_all)
1067 {
1068         int i;
1069
1070         for (i = 0; i < d->nvqs; ++i) {
1071                 bool ok;
1072                 bool log;
1073
1074                 mutex_lock(&d->vqs[i]->mutex);
1075                 log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL);
1076                 /* If ring is inactive, will check when it's enabled. */
1077                 if (d->vqs[i]->private_data)
1078                         ok = vq_memory_access_ok(d->vqs[i]->log_base,
1079                                                  umem, log);
1080                 else
1081                         ok = true;
1082                 mutex_unlock(&d->vqs[i]->mutex);
1083                 if (!ok)
1084                         return false;
1085         }
1086         return true;
1087 }
1088
1089 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
1090                           struct iovec iov[], int iov_size, int access);
1091
1092 static int vhost_copy_to_user(struct vhost_virtqueue *vq, void __user *to,
1093                               const void *from, unsigned size)
1094 {
1095         int ret;
1096
1097         if (!vq->iotlb)
1098                 return __copy_to_user(to, from, size);
1099         else {
1100                 /* This function should be called after iotlb
1101                  * prefetch, which means we're sure that all vq
1102                  * could be access through iotlb. So -EAGAIN should
1103                  * not happen in this case.
1104                  */
1105                 struct iov_iter t;
1106                 void __user *uaddr = vhost_vq_meta_fetch(vq,
1107                                      (u64)(uintptr_t)to, size,
1108                                      VHOST_ADDR_USED);
1109
1110                 if (uaddr)
1111                         return __copy_to_user(uaddr, from, size);
1112
1113                 ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov,
1114                                      ARRAY_SIZE(vq->iotlb_iov),
1115                                      VHOST_ACCESS_WO);
1116                 if (ret < 0)
1117                         goto out;
1118                 iov_iter_init(&t, ITER_DEST, vq->iotlb_iov, ret, size);
1119                 ret = copy_to_iter(from, size, &t);
1120                 if (ret == size)
1121                         ret = 0;
1122         }
1123 out:
1124         return ret;
1125 }
1126
1127 static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to,
1128                                 void __user *from, unsigned size)
1129 {
1130         int ret;
1131
1132         if (!vq->iotlb)
1133                 return __copy_from_user(to, from, size);
1134         else {
1135                 /* This function should be called after iotlb
1136                  * prefetch, which means we're sure that vq
1137                  * could be access through iotlb. So -EAGAIN should
1138                  * not happen in this case.
1139                  */
1140                 void __user *uaddr = vhost_vq_meta_fetch(vq,
1141                                      (u64)(uintptr_t)from, size,
1142                                      VHOST_ADDR_DESC);
1143                 struct iov_iter f;
1144
1145                 if (uaddr)
1146                         return __copy_from_user(to, uaddr, size);
1147
1148                 ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov,
1149                                      ARRAY_SIZE(vq->iotlb_iov),
1150                                      VHOST_ACCESS_RO);
1151                 if (ret < 0) {
1152                         vq_err(vq, "IOTLB translation failure: uaddr "
1153                                "%p size 0x%llx\n", from,
1154                                (unsigned long long) size);
1155                         goto out;
1156                 }
1157                 iov_iter_init(&f, ITER_SOURCE, vq->iotlb_iov, ret, size);
1158                 ret = copy_from_iter(to, size, &f);
1159                 if (ret == size)
1160                         ret = 0;
1161         }
1162
1163 out:
1164         return ret;
1165 }
1166
1167 static void __user *__vhost_get_user_slow(struct vhost_virtqueue *vq,
1168                                           void __user *addr, unsigned int size,
1169                                           int type)
1170 {
1171         int ret;
1172
1173         ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov,
1174                              ARRAY_SIZE(vq->iotlb_iov),
1175                              VHOST_ACCESS_RO);
1176         if (ret < 0) {
1177                 vq_err(vq, "IOTLB translation failure: uaddr "
1178                         "%p size 0x%llx\n", addr,
1179                         (unsigned long long) size);
1180                 return NULL;
1181         }
1182
1183         if (ret != 1 || vq->iotlb_iov[0].iov_len != size) {
1184                 vq_err(vq, "Non atomic userspace memory access: uaddr "
1185                         "%p size 0x%llx\n", addr,
1186                         (unsigned long long) size);
1187                 return NULL;
1188         }
1189
1190         return vq->iotlb_iov[0].iov_base;
1191 }
1192
1193 /* This function should be called after iotlb
1194  * prefetch, which means we're sure that vq
1195  * could be access through iotlb. So -EAGAIN should
1196  * not happen in this case.
1197  */
1198 static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq,
1199                                             void __user *addr, unsigned int size,
1200                                             int type)
1201 {
1202         void __user *uaddr = vhost_vq_meta_fetch(vq,
1203                              (u64)(uintptr_t)addr, size, type);
1204         if (uaddr)
1205                 return uaddr;
1206
1207         return __vhost_get_user_slow(vq, addr, size, type);
1208 }
1209
1210 #define vhost_put_user(vq, x, ptr)              \
1211 ({ \
1212         int ret; \
1213         if (!vq->iotlb) { \
1214                 ret = __put_user(x, ptr); \
1215         } else { \
1216                 __typeof__(ptr) to = \
1217                         (__typeof__(ptr)) __vhost_get_user(vq, ptr,     \
1218                                           sizeof(*ptr), VHOST_ADDR_USED); \
1219                 if (to != NULL) \
1220                         ret = __put_user(x, to); \
1221                 else \
1222                         ret = -EFAULT;  \
1223         } \
1224         ret; \
1225 })
1226
1227 static inline int vhost_put_avail_event(struct vhost_virtqueue *vq)
1228 {
1229         return vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
1230                               vhost_avail_event(vq));
1231 }
1232
1233 static inline int vhost_put_used(struct vhost_virtqueue *vq,
1234                                  struct vring_used_elem *head, int idx,
1235                                  int count)
1236 {
1237         return vhost_copy_to_user(vq, vq->used->ring + idx, head,
1238                                   count * sizeof(*head));
1239 }
1240
1241 static inline int vhost_put_used_flags(struct vhost_virtqueue *vq)
1242
1243 {
1244         return vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
1245                               &vq->used->flags);
1246 }
1247
1248 static inline int vhost_put_used_idx(struct vhost_virtqueue *vq)
1249
1250 {
1251         return vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
1252                               &vq->used->idx);
1253 }
1254
1255 #define vhost_get_user(vq, x, ptr, type)                \
1256 ({ \
1257         int ret; \
1258         if (!vq->iotlb) { \
1259                 ret = __get_user(x, ptr); \
1260         } else { \
1261                 __typeof__(ptr) from = \
1262                         (__typeof__(ptr)) __vhost_get_user(vq, ptr, \
1263                                                            sizeof(*ptr), \
1264                                                            type); \
1265                 if (from != NULL) \
1266                         ret = __get_user(x, from); \
1267                 else \
1268                         ret = -EFAULT; \
1269         } \
1270         ret; \
1271 })
1272
1273 #define vhost_get_avail(vq, x, ptr) \
1274         vhost_get_user(vq, x, ptr, VHOST_ADDR_AVAIL)
1275
1276 #define vhost_get_used(vq, x, ptr) \
1277         vhost_get_user(vq, x, ptr, VHOST_ADDR_USED)
1278
1279 static void vhost_dev_lock_vqs(struct vhost_dev *d)
1280 {
1281         int i = 0;
1282         for (i = 0; i < d->nvqs; ++i)
1283                 mutex_lock_nested(&d->vqs[i]->mutex, i);
1284 }
1285
1286 static void vhost_dev_unlock_vqs(struct vhost_dev *d)
1287 {
1288         int i = 0;
1289         for (i = 0; i < d->nvqs; ++i)
1290                 mutex_unlock(&d->vqs[i]->mutex);
1291 }
1292
1293 static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq,
1294                                       __virtio16 *idx)
1295 {
1296         return vhost_get_avail(vq, *idx, &vq->avail->idx);
1297 }
1298
1299 static inline int vhost_get_avail_head(struct vhost_virtqueue *vq,
1300                                        __virtio16 *head, int idx)
1301 {
1302         return vhost_get_avail(vq, *head,
1303                                &vq->avail->ring[idx & (vq->num - 1)]);
1304 }
1305
1306 static inline int vhost_get_avail_flags(struct vhost_virtqueue *vq,
1307                                         __virtio16 *flags)
1308 {
1309         return vhost_get_avail(vq, *flags, &vq->avail->flags);
1310 }
1311
1312 static inline int vhost_get_used_event(struct vhost_virtqueue *vq,
1313                                        __virtio16 *event)
1314 {
1315         return vhost_get_avail(vq, *event, vhost_used_event(vq));
1316 }
1317
1318 static inline int vhost_get_used_idx(struct vhost_virtqueue *vq,
1319                                      __virtio16 *idx)
1320 {
1321         return vhost_get_used(vq, *idx, &vq->used->idx);
1322 }
1323
1324 static inline int vhost_get_desc(struct vhost_virtqueue *vq,
1325                                  struct vring_desc *desc, int idx)
1326 {
1327         return vhost_copy_from_user(vq, desc, vq->desc + idx, sizeof(*desc));
1328 }
1329
1330 static void vhost_iotlb_notify_vq(struct vhost_dev *d,
1331                                   struct vhost_iotlb_msg *msg)
1332 {
1333         struct vhost_msg_node *node, *n;
1334
1335         spin_lock(&d->iotlb_lock);
1336
1337         list_for_each_entry_safe(node, n, &d->pending_list, node) {
1338                 struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb;
1339                 if (msg->iova <= vq_msg->iova &&
1340                     msg->iova + msg->size - 1 >= vq_msg->iova &&
1341                     vq_msg->type == VHOST_IOTLB_MISS) {
1342                         vhost_poll_queue(&node->vq->poll);
1343                         list_del(&node->node);
1344                         kfree(node);
1345                 }
1346         }
1347
1348         spin_unlock(&d->iotlb_lock);
1349 }
1350
1351 static bool umem_access_ok(u64 uaddr, u64 size, int access)
1352 {
1353         unsigned long a = uaddr;
1354
1355         /* Make sure 64 bit math will not overflow. */
1356         if (vhost_overflow(uaddr, size))
1357                 return false;
1358
1359         if ((access & VHOST_ACCESS_RO) &&
1360             !access_ok((void __user *)a, size))
1361                 return false;
1362         if ((access & VHOST_ACCESS_WO) &&
1363             !access_ok((void __user *)a, size))
1364                 return false;
1365         return true;
1366 }
1367
1368 static int vhost_process_iotlb_msg(struct vhost_dev *dev, u32 asid,
1369                                    struct vhost_iotlb_msg *msg)
1370 {
1371         int ret = 0;
1372
1373         if (asid != 0)
1374                 return -EINVAL;
1375
1376         mutex_lock(&dev->mutex);
1377         vhost_dev_lock_vqs(dev);
1378         switch (msg->type) {
1379         case VHOST_IOTLB_UPDATE:
1380                 if (!dev->iotlb) {
1381                         ret = -EFAULT;
1382                         break;
1383                 }
1384                 if (!umem_access_ok(msg->uaddr, msg->size, msg->perm)) {
1385                         ret = -EFAULT;
1386                         break;
1387                 }
1388                 vhost_vq_meta_reset(dev);
1389                 if (vhost_iotlb_add_range(dev->iotlb, msg->iova,
1390                                           msg->iova + msg->size - 1,
1391                                           msg->uaddr, msg->perm)) {
1392                         ret = -ENOMEM;
1393                         break;
1394                 }
1395                 vhost_iotlb_notify_vq(dev, msg);
1396                 break;
1397         case VHOST_IOTLB_INVALIDATE:
1398                 if (!dev->iotlb) {
1399                         ret = -EFAULT;
1400                         break;
1401                 }
1402                 vhost_vq_meta_reset(dev);
1403                 vhost_iotlb_del_range(dev->iotlb, msg->iova,
1404                                       msg->iova + msg->size - 1);
1405                 break;
1406         default:
1407                 ret = -EINVAL;
1408                 break;
1409         }
1410
1411         vhost_dev_unlock_vqs(dev);
1412         mutex_unlock(&dev->mutex);
1413
1414         return ret;
1415 }
1416 ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
1417                              struct iov_iter *from)
1418 {
1419         struct vhost_iotlb_msg msg;
1420         size_t offset;
1421         int type, ret;
1422         u32 asid = 0;
1423
1424         ret = copy_from_iter(&type, sizeof(type), from);
1425         if (ret != sizeof(type)) {
1426                 ret = -EINVAL;
1427                 goto done;
1428         }
1429
1430         switch (type) {
1431         case VHOST_IOTLB_MSG:
1432                 /* There maybe a hole after type for V1 message type,
1433                  * so skip it here.
1434                  */
1435                 offset = offsetof(struct vhost_msg, iotlb) - sizeof(int);
1436                 break;
1437         case VHOST_IOTLB_MSG_V2:
1438                 if (vhost_backend_has_feature(dev->vqs[0],
1439                                               VHOST_BACKEND_F_IOTLB_ASID)) {
1440                         ret = copy_from_iter(&asid, sizeof(asid), from);
1441                         if (ret != sizeof(asid)) {
1442                                 ret = -EINVAL;
1443                                 goto done;
1444                         }
1445                         offset = 0;
1446                 } else
1447                         offset = sizeof(__u32);
1448                 break;
1449         default:
1450                 ret = -EINVAL;
1451                 goto done;
1452         }
1453
1454         iov_iter_advance(from, offset);
1455         ret = copy_from_iter(&msg, sizeof(msg), from);
1456         if (ret != sizeof(msg)) {
1457                 ret = -EINVAL;
1458                 goto done;
1459         }
1460
1461         if (msg.type == VHOST_IOTLB_UPDATE && msg.size == 0) {
1462                 ret = -EINVAL;
1463                 goto done;
1464         }
1465
1466         if (dev->msg_handler)
1467                 ret = dev->msg_handler(dev, asid, &msg);
1468         else
1469                 ret = vhost_process_iotlb_msg(dev, asid, &msg);
1470         if (ret) {
1471                 ret = -EFAULT;
1472                 goto done;
1473         }
1474
1475         ret = (type == VHOST_IOTLB_MSG) ? sizeof(struct vhost_msg) :
1476               sizeof(struct vhost_msg_v2);
1477 done:
1478         return ret;
1479 }
1480 EXPORT_SYMBOL(vhost_chr_write_iter);
1481
1482 __poll_t vhost_chr_poll(struct file *file, struct vhost_dev *dev,
1483                             poll_table *wait)
1484 {
1485         __poll_t mask = 0;
1486
1487         poll_wait(file, &dev->wait, wait);
1488
1489         if (!list_empty(&dev->read_list))
1490                 mask |= EPOLLIN | EPOLLRDNORM;
1491
1492         return mask;
1493 }
1494 EXPORT_SYMBOL(vhost_chr_poll);
1495
1496 ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
1497                             int noblock)
1498 {
1499         DEFINE_WAIT(wait);
1500         struct vhost_msg_node *node;
1501         ssize_t ret = 0;
1502         unsigned size = sizeof(struct vhost_msg);
1503
1504         if (iov_iter_count(to) < size)
1505                 return 0;
1506
1507         while (1) {
1508                 if (!noblock)
1509                         prepare_to_wait(&dev->wait, &wait,
1510                                         TASK_INTERRUPTIBLE);
1511
1512                 node = vhost_dequeue_msg(dev, &dev->read_list);
1513                 if (node)
1514                         break;
1515                 if (noblock) {
1516                         ret = -EAGAIN;
1517                         break;
1518                 }
1519                 if (signal_pending(current)) {
1520                         ret = -ERESTARTSYS;
1521                         break;
1522                 }
1523                 if (!dev->iotlb) {
1524                         ret = -EBADFD;
1525                         break;
1526                 }
1527
1528                 schedule();
1529         }
1530
1531         if (!noblock)
1532                 finish_wait(&dev->wait, &wait);
1533
1534         if (node) {
1535                 struct vhost_iotlb_msg *msg;
1536                 void *start = &node->msg;
1537
1538                 switch (node->msg.type) {
1539                 case VHOST_IOTLB_MSG:
1540                         size = sizeof(node->msg);
1541                         msg = &node->msg.iotlb;
1542                         break;
1543                 case VHOST_IOTLB_MSG_V2:
1544                         size = sizeof(node->msg_v2);
1545                         msg = &node->msg_v2.iotlb;
1546                         break;
1547                 default:
1548                         BUG();
1549                         break;
1550                 }
1551
1552                 ret = copy_to_iter(start, size, to);
1553                 if (ret != size || msg->type != VHOST_IOTLB_MISS) {
1554                         kfree(node);
1555                         return ret;
1556                 }
1557                 vhost_enqueue_msg(dev, &dev->pending_list, node);
1558         }
1559
1560         return ret;
1561 }
1562 EXPORT_SYMBOL_GPL(vhost_chr_read_iter);
1563
1564 static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
1565 {
1566         struct vhost_dev *dev = vq->dev;
1567         struct vhost_msg_node *node;
1568         struct vhost_iotlb_msg *msg;
1569         bool v2 = vhost_backend_has_feature(vq, VHOST_BACKEND_F_IOTLB_MSG_V2);
1570
1571         node = vhost_new_msg(vq, v2 ? VHOST_IOTLB_MSG_V2 : VHOST_IOTLB_MSG);
1572         if (!node)
1573                 return -ENOMEM;
1574
1575         if (v2) {
1576                 node->msg_v2.type = VHOST_IOTLB_MSG_V2;
1577                 msg = &node->msg_v2.iotlb;
1578         } else {
1579                 msg = &node->msg.iotlb;
1580         }
1581
1582         msg->type = VHOST_IOTLB_MISS;
1583         msg->iova = iova;
1584         msg->perm = access;
1585
1586         vhost_enqueue_msg(dev, &dev->read_list, node);
1587
1588         return 0;
1589 }
1590
1591 static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
1592                          vring_desc_t __user *desc,
1593                          vring_avail_t __user *avail,
1594                          vring_used_t __user *used)
1595
1596 {
1597         /* If an IOTLB device is present, the vring addresses are
1598          * GIOVAs. Access validation occurs at prefetch time. */
1599         if (vq->iotlb)
1600                 return true;
1601
1602         return access_ok(desc, vhost_get_desc_size(vq, num)) &&
1603                access_ok(avail, vhost_get_avail_size(vq, num)) &&
1604                access_ok(used, vhost_get_used_size(vq, num));
1605 }
1606
1607 static void vhost_vq_meta_update(struct vhost_virtqueue *vq,
1608                                  const struct vhost_iotlb_map *map,
1609                                  int type)
1610 {
1611         int access = (type == VHOST_ADDR_USED) ?
1612                      VHOST_ACCESS_WO : VHOST_ACCESS_RO;
1613
1614         if (likely(map->perm & access))
1615                 vq->meta_iotlb[type] = map;
1616 }
1617
1618 static bool iotlb_access_ok(struct vhost_virtqueue *vq,
1619                             int access, u64 addr, u64 len, int type)
1620 {
1621         const struct vhost_iotlb_map *map;
1622         struct vhost_iotlb *umem = vq->iotlb;
1623         u64 s = 0, size, orig_addr = addr, last = addr + len - 1;
1624
1625         if (vhost_vq_meta_fetch(vq, addr, len, type))
1626                 return true;
1627
1628         while (len > s) {
1629                 map = vhost_iotlb_itree_first(umem, addr, last);
1630                 if (map == NULL || map->start > addr) {
1631                         vhost_iotlb_miss(vq, addr, access);
1632                         return false;
1633                 } else if (!(map->perm & access)) {
1634                         /* Report the possible access violation by
1635                          * request another translation from userspace.
1636                          */
1637                         return false;
1638                 }
1639
1640                 size = map->size - addr + map->start;
1641
1642                 if (orig_addr == addr && size >= len)
1643                         vhost_vq_meta_update(vq, map, type);
1644
1645                 s += size;
1646                 addr += size;
1647         }
1648
1649         return true;
1650 }
1651
1652 int vq_meta_prefetch(struct vhost_virtqueue *vq)
1653 {
1654         unsigned int num = vq->num;
1655
1656         if (!vq->iotlb)
1657                 return 1;
1658
1659         return iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->desc,
1660                                vhost_get_desc_size(vq, num), VHOST_ADDR_DESC) &&
1661                iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->avail,
1662                                vhost_get_avail_size(vq, num),
1663                                VHOST_ADDR_AVAIL) &&
1664                iotlb_access_ok(vq, VHOST_MAP_WO, (u64)(uintptr_t)vq->used,
1665                                vhost_get_used_size(vq, num), VHOST_ADDR_USED);
1666 }
1667 EXPORT_SYMBOL_GPL(vq_meta_prefetch);
1668
1669 /* Can we log writes? */
1670 /* Caller should have device mutex but not vq mutex */
1671 bool vhost_log_access_ok(struct vhost_dev *dev)
1672 {
1673         return memory_access_ok(dev, dev->umem, 1);
1674 }
1675 EXPORT_SYMBOL_GPL(vhost_log_access_ok);
1676
1677 static bool vq_log_used_access_ok(struct vhost_virtqueue *vq,
1678                                   void __user *log_base,
1679                                   bool log_used,
1680                                   u64 log_addr)
1681 {
1682         /* If an IOTLB device is present, log_addr is a GIOVA that
1683          * will never be logged by log_used(). */
1684         if (vq->iotlb)
1685                 return true;
1686
1687         return !log_used || log_access_ok(log_base, log_addr,
1688                                           vhost_get_used_size(vq, vq->num));
1689 }
1690
1691 /* Verify access for write logging. */
1692 /* Caller should have vq mutex and device mutex */
1693 static bool vq_log_access_ok(struct vhost_virtqueue *vq,
1694                              void __user *log_base)
1695 {
1696         return vq_memory_access_ok(log_base, vq->umem,
1697                                    vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
1698                 vq_log_used_access_ok(vq, log_base, vq->log_used, vq->log_addr);
1699 }
1700
1701 /* Can we start vq? */
1702 /* Caller should have vq mutex and device mutex */
1703 bool vhost_vq_access_ok(struct vhost_virtqueue *vq)
1704 {
1705         if (!vq_log_access_ok(vq, vq->log_base))
1706                 return false;
1707
1708         return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used);
1709 }
1710 EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
1711
1712 static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
1713 {
1714         struct vhost_memory mem, *newmem;
1715         struct vhost_memory_region *region;
1716         struct vhost_iotlb *newumem, *oldumem;
1717         unsigned long size = offsetof(struct vhost_memory, regions);
1718         int i;
1719
1720         if (copy_from_user(&mem, m, size))
1721                 return -EFAULT;
1722         if (mem.padding)
1723                 return -EOPNOTSUPP;
1724         if (mem.nregions > max_mem_regions)
1725                 return -E2BIG;
1726         newmem = kvzalloc(struct_size(newmem, regions, mem.nregions),
1727                         GFP_KERNEL);
1728         if (!newmem)
1729                 return -ENOMEM;
1730
1731         memcpy(newmem, &mem, size);
1732         if (copy_from_user(newmem->regions, m->regions,
1733                            flex_array_size(newmem, regions, mem.nregions))) {
1734                 kvfree(newmem);
1735                 return -EFAULT;
1736         }
1737
1738         newumem = iotlb_alloc();
1739         if (!newumem) {
1740                 kvfree(newmem);
1741                 return -ENOMEM;
1742         }
1743
1744         for (region = newmem->regions;
1745              region < newmem->regions + mem.nregions;
1746              region++) {
1747                 if (vhost_iotlb_add_range(newumem,
1748                                           region->guest_phys_addr,
1749                                           region->guest_phys_addr +
1750                                           region->memory_size - 1,
1751                                           region->userspace_addr,
1752                                           VHOST_MAP_RW))
1753                         goto err;
1754         }
1755
1756         if (!memory_access_ok(d, newumem, 0))
1757                 goto err;
1758
1759         oldumem = d->umem;
1760         d->umem = newumem;
1761
1762         /* All memory accesses are done under some VQ mutex. */
1763         for (i = 0; i < d->nvqs; ++i) {
1764                 mutex_lock(&d->vqs[i]->mutex);
1765                 d->vqs[i]->umem = newumem;
1766                 mutex_unlock(&d->vqs[i]->mutex);
1767         }
1768
1769         kvfree(newmem);
1770         vhost_iotlb_free(oldumem);
1771         return 0;
1772
1773 err:
1774         vhost_iotlb_free(newumem);
1775         kvfree(newmem);
1776         return -EFAULT;
1777 }
1778
1779 static long vhost_vring_set_num(struct vhost_dev *d,
1780                                 struct vhost_virtqueue *vq,
1781                                 void __user *argp)
1782 {
1783         struct vhost_vring_state s;
1784
1785         /* Resizing ring with an active backend?
1786          * You don't want to do that. */
1787         if (vq->private_data)
1788                 return -EBUSY;
1789
1790         if (copy_from_user(&s, argp, sizeof s))
1791                 return -EFAULT;
1792
1793         if (!s.num || s.num > 0xffff || (s.num & (s.num - 1)))
1794                 return -EINVAL;
1795         vq->num = s.num;
1796
1797         return 0;
1798 }
1799
1800 static long vhost_vring_set_addr(struct vhost_dev *d,
1801                                  struct vhost_virtqueue *vq,
1802                                  void __user *argp)
1803 {
1804         struct vhost_vring_addr a;
1805
1806         if (copy_from_user(&a, argp, sizeof a))
1807                 return -EFAULT;
1808         if (a.flags & ~(0x1 << VHOST_VRING_F_LOG))
1809                 return -EOPNOTSUPP;
1810
1811         /* For 32bit, verify that the top 32bits of the user
1812            data are set to zero. */
1813         if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr ||
1814             (u64)(unsigned long)a.used_user_addr != a.used_user_addr ||
1815             (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr)
1816                 return -EFAULT;
1817
1818         /* Make sure it's safe to cast pointers to vring types. */
1819         BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE);
1820         BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE);
1821         if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) ||
1822             (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) ||
1823             (a.log_guest_addr & (VRING_USED_ALIGN_SIZE - 1)))
1824                 return -EINVAL;
1825
1826         /* We only verify access here if backend is configured.
1827          * If it is not, we don't as size might not have been setup.
1828          * We will verify when backend is configured. */
1829         if (vq->private_data) {
1830                 if (!vq_access_ok(vq, vq->num,
1831                         (void __user *)(unsigned long)a.desc_user_addr,
1832                         (void __user *)(unsigned long)a.avail_user_addr,
1833                         (void __user *)(unsigned long)a.used_user_addr))
1834                         return -EINVAL;
1835
1836                 /* Also validate log access for used ring if enabled. */
1837                 if (!vq_log_used_access_ok(vq, vq->log_base,
1838                                 a.flags & (0x1 << VHOST_VRING_F_LOG),
1839                                 a.log_guest_addr))
1840                         return -EINVAL;
1841         }
1842
1843         vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
1844         vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
1845         vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
1846         vq->log_addr = a.log_guest_addr;
1847         vq->used = (void __user *)(unsigned long)a.used_user_addr;
1848
1849         return 0;
1850 }
1851
1852 static long vhost_vring_set_num_addr(struct vhost_dev *d,
1853                                      struct vhost_virtqueue *vq,
1854                                      unsigned int ioctl,
1855                                      void __user *argp)
1856 {
1857         long r;
1858
1859         mutex_lock(&vq->mutex);
1860
1861         switch (ioctl) {
1862         case VHOST_SET_VRING_NUM:
1863                 r = vhost_vring_set_num(d, vq, argp);
1864                 break;
1865         case VHOST_SET_VRING_ADDR:
1866                 r = vhost_vring_set_addr(d, vq, argp);
1867                 break;
1868         default:
1869                 BUG();
1870         }
1871
1872         mutex_unlock(&vq->mutex);
1873
1874         return r;
1875 }
1876 long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
1877 {
1878         struct file *eventfp, *filep = NULL;
1879         bool pollstart = false, pollstop = false;
1880         struct eventfd_ctx *ctx = NULL;
1881         struct vhost_virtqueue *vq;
1882         struct vhost_vring_state s;
1883         struct vhost_vring_file f;
1884         u32 idx;
1885         long r;
1886
1887         r = vhost_get_vq_from_user(d, argp, &vq, &idx);
1888         if (r < 0)
1889                 return r;
1890
1891         if (ioctl == VHOST_SET_VRING_NUM ||
1892             ioctl == VHOST_SET_VRING_ADDR) {
1893                 return vhost_vring_set_num_addr(d, vq, ioctl, argp);
1894         }
1895
1896         mutex_lock(&vq->mutex);
1897
1898         switch (ioctl) {
1899         case VHOST_SET_VRING_BASE:
1900                 /* Moving base with an active backend?
1901                  * You don't want to do that. */
1902                 if (vq->private_data) {
1903                         r = -EBUSY;
1904                         break;
1905                 }
1906                 if (copy_from_user(&s, argp, sizeof s)) {
1907                         r = -EFAULT;
1908                         break;
1909                 }
1910                 if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) {
1911                         vq->last_avail_idx = s.num & 0xffff;
1912                         vq->last_used_idx = (s.num >> 16) & 0xffff;
1913                 } else {
1914                         if (s.num > 0xffff) {
1915                                 r = -EINVAL;
1916                                 break;
1917                         }
1918                         vq->last_avail_idx = s.num;
1919                 }
1920                 /* Forget the cached index value. */
1921                 vq->avail_idx = vq->last_avail_idx;
1922                 break;
1923         case VHOST_GET_VRING_BASE:
1924                 s.index = idx;
1925                 if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED))
1926                         s.num = (u32)vq->last_avail_idx | ((u32)vq->last_used_idx << 16);
1927                 else
1928                         s.num = vq->last_avail_idx;
1929                 if (copy_to_user(argp, &s, sizeof s))
1930                         r = -EFAULT;
1931                 break;
1932         case VHOST_SET_VRING_KICK:
1933                 if (copy_from_user(&f, argp, sizeof f)) {
1934                         r = -EFAULT;
1935                         break;
1936                 }
1937                 eventfp = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_fget(f.fd);
1938                 if (IS_ERR(eventfp)) {
1939                         r = PTR_ERR(eventfp);
1940                         break;
1941                 }
1942                 if (eventfp != vq->kick) {
1943                         pollstop = (filep = vq->kick) != NULL;
1944                         pollstart = (vq->kick = eventfp) != NULL;
1945                 } else
1946                         filep = eventfp;
1947                 break;
1948         case VHOST_SET_VRING_CALL:
1949                 if (copy_from_user(&f, argp, sizeof f)) {
1950                         r = -EFAULT;
1951                         break;
1952                 }
1953                 ctx = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(f.fd);
1954                 if (IS_ERR(ctx)) {
1955                         r = PTR_ERR(ctx);
1956                         break;
1957                 }
1958
1959                 swap(ctx, vq->call_ctx.ctx);
1960                 break;
1961         case VHOST_SET_VRING_ERR:
1962                 if (copy_from_user(&f, argp, sizeof f)) {
1963                         r = -EFAULT;
1964                         break;
1965                 }
1966                 ctx = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(f.fd);
1967                 if (IS_ERR(ctx)) {
1968                         r = PTR_ERR(ctx);
1969                         break;
1970                 }
1971                 swap(ctx, vq->error_ctx);
1972                 break;
1973         case VHOST_SET_VRING_ENDIAN:
1974                 r = vhost_set_vring_endian(vq, argp);
1975                 break;
1976         case VHOST_GET_VRING_ENDIAN:
1977                 r = vhost_get_vring_endian(vq, idx, argp);
1978                 break;
1979         case VHOST_SET_VRING_BUSYLOOP_TIMEOUT:
1980                 if (copy_from_user(&s, argp, sizeof(s))) {
1981                         r = -EFAULT;
1982                         break;
1983                 }
1984                 vq->busyloop_timeout = s.num;
1985                 break;
1986         case VHOST_GET_VRING_BUSYLOOP_TIMEOUT:
1987                 s.index = idx;
1988                 s.num = vq->busyloop_timeout;
1989                 if (copy_to_user(argp, &s, sizeof(s)))
1990                         r = -EFAULT;
1991                 break;
1992         default:
1993                 r = -ENOIOCTLCMD;
1994         }
1995
1996         if (pollstop && vq->handle_kick)
1997                 vhost_poll_stop(&vq->poll);
1998
1999         if (!IS_ERR_OR_NULL(ctx))
2000                 eventfd_ctx_put(ctx);
2001         if (filep)
2002                 fput(filep);
2003
2004         if (pollstart && vq->handle_kick)
2005                 r = vhost_poll_start(&vq->poll, vq->kick);
2006
2007         mutex_unlock(&vq->mutex);
2008
2009         if (pollstop && vq->handle_kick)
2010                 vhost_dev_flush(vq->poll.dev);
2011         return r;
2012 }
2013 EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
2014
2015 int vhost_init_device_iotlb(struct vhost_dev *d)
2016 {
2017         struct vhost_iotlb *niotlb, *oiotlb;
2018         int i;
2019
2020         niotlb = iotlb_alloc();
2021         if (!niotlb)
2022                 return -ENOMEM;
2023
2024         oiotlb = d->iotlb;
2025         d->iotlb = niotlb;
2026
2027         for (i = 0; i < d->nvqs; ++i) {
2028                 struct vhost_virtqueue *vq = d->vqs[i];
2029
2030                 mutex_lock(&vq->mutex);
2031                 vq->iotlb = niotlb;
2032                 __vhost_vq_meta_reset(vq);
2033                 mutex_unlock(&vq->mutex);
2034         }
2035
2036         vhost_iotlb_free(oiotlb);
2037
2038         return 0;
2039 }
2040 EXPORT_SYMBOL_GPL(vhost_init_device_iotlb);
2041
2042 /* Caller must have device mutex */
2043 long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
2044 {
2045         struct eventfd_ctx *ctx;
2046         u64 p;
2047         long r;
2048         int i, fd;
2049
2050         /* If you are not the owner, you can become one */
2051         if (ioctl == VHOST_SET_OWNER) {
2052                 r = vhost_dev_set_owner(d);
2053                 goto done;
2054         }
2055
2056         /* You must be the owner to do anything else */
2057         r = vhost_dev_check_owner(d);
2058         if (r)
2059                 goto done;
2060
2061         switch (ioctl) {
2062         case VHOST_SET_MEM_TABLE:
2063                 r = vhost_set_memory(d, argp);
2064                 break;
2065         case VHOST_SET_LOG_BASE:
2066                 if (copy_from_user(&p, argp, sizeof p)) {
2067                         r = -EFAULT;
2068                         break;
2069                 }
2070                 if ((u64)(unsigned long)p != p) {
2071                         r = -EFAULT;
2072                         break;
2073                 }
2074                 for (i = 0; i < d->nvqs; ++i) {
2075                         struct vhost_virtqueue *vq;
2076                         void __user *base = (void __user *)(unsigned long)p;
2077                         vq = d->vqs[i];
2078                         mutex_lock(&vq->mutex);
2079                         /* If ring is inactive, will check when it's enabled. */
2080                         if (vq->private_data && !vq_log_access_ok(vq, base))
2081                                 r = -EFAULT;
2082                         else
2083                                 vq->log_base = base;
2084                         mutex_unlock(&vq->mutex);
2085                 }
2086                 break;
2087         case VHOST_SET_LOG_FD:
2088                 r = get_user(fd, (int __user *)argp);
2089                 if (r < 0)
2090                         break;
2091                 ctx = fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(fd);
2092                 if (IS_ERR(ctx)) {
2093                         r = PTR_ERR(ctx);
2094                         break;
2095                 }
2096                 swap(ctx, d->log_ctx);
2097                 for (i = 0; i < d->nvqs; ++i) {
2098                         mutex_lock(&d->vqs[i]->mutex);
2099                         d->vqs[i]->log_ctx = d->log_ctx;
2100                         mutex_unlock(&d->vqs[i]->mutex);
2101                 }
2102                 if (ctx)
2103                         eventfd_ctx_put(ctx);
2104                 break;
2105         default:
2106                 r = -ENOIOCTLCMD;
2107                 break;
2108         }
2109 done:
2110         return r;
2111 }
2112 EXPORT_SYMBOL_GPL(vhost_dev_ioctl);
2113
2114 /* TODO: This is really inefficient.  We need something like get_user()
2115  * (instruction directly accesses the data, with an exception table entry
2116  * returning -EFAULT). See Documentation/arch/x86/exception-tables.rst.
2117  */
2118 static int set_bit_to_user(int nr, void __user *addr)
2119 {
2120         unsigned long log = (unsigned long)addr;
2121         struct page *page;
2122         void *base;
2123         int bit = nr + (log % PAGE_SIZE) * 8;
2124         int r;
2125
2126         r = pin_user_pages_fast(log, 1, FOLL_WRITE, &page);
2127         if (r < 0)
2128                 return r;
2129         BUG_ON(r != 1);
2130         base = kmap_atomic(page);
2131         set_bit(bit, base);
2132         kunmap_atomic(base);
2133         unpin_user_pages_dirty_lock(&page, 1, true);
2134         return 0;
2135 }
2136
2137 static int log_write(void __user *log_base,
2138                      u64 write_address, u64 write_length)
2139 {
2140         u64 write_page = write_address / VHOST_PAGE_SIZE;
2141         int r;
2142
2143         if (!write_length)
2144                 return 0;
2145         write_length += write_address % VHOST_PAGE_SIZE;
2146         for (;;) {
2147                 u64 base = (u64)(unsigned long)log_base;
2148                 u64 log = base + write_page / 8;
2149                 int bit = write_page % 8;
2150                 if ((u64)(unsigned long)log != log)
2151                         return -EFAULT;
2152                 r = set_bit_to_user(bit, (void __user *)(unsigned long)log);
2153                 if (r < 0)
2154                         return r;
2155                 if (write_length <= VHOST_PAGE_SIZE)
2156                         break;
2157                 write_length -= VHOST_PAGE_SIZE;
2158                 write_page += 1;
2159         }
2160         return r;
2161 }
2162
2163 static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len)
2164 {
2165         struct vhost_iotlb *umem = vq->umem;
2166         struct vhost_iotlb_map *u;
2167         u64 start, end, l, min;
2168         int r;
2169         bool hit = false;
2170
2171         while (len) {
2172                 min = len;
2173                 /* More than one GPAs can be mapped into a single HVA. So
2174                  * iterate all possible umems here to be safe.
2175                  */
2176                 list_for_each_entry(u, &umem->list, link) {
2177                         if (u->addr > hva - 1 + len ||
2178                             u->addr - 1 + u->size < hva)
2179                                 continue;
2180                         start = max(u->addr, hva);
2181                         end = min(u->addr - 1 + u->size, hva - 1 + len);
2182                         l = end - start + 1;
2183                         r = log_write(vq->log_base,
2184                                       u->start + start - u->addr,
2185                                       l);
2186                         if (r < 0)
2187                                 return r;
2188                         hit = true;
2189                         min = min(l, min);
2190                 }
2191
2192                 if (!hit)
2193                         return -EFAULT;
2194
2195                 len -= min;
2196                 hva += min;
2197         }
2198
2199         return 0;
2200 }
2201
2202 static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len)
2203 {
2204         struct iovec *iov = vq->log_iov;
2205         int i, ret;
2206
2207         if (!vq->iotlb)
2208                 return log_write(vq->log_base, vq->log_addr + used_offset, len);
2209
2210         ret = translate_desc(vq, (uintptr_t)vq->used + used_offset,
2211                              len, iov, 64, VHOST_ACCESS_WO);
2212         if (ret < 0)
2213                 return ret;
2214
2215         for (i = 0; i < ret; i++) {
2216                 ret = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
2217                                     iov[i].iov_len);
2218                 if (ret)
2219                         return ret;
2220         }
2221
2222         return 0;
2223 }
2224
2225 int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
2226                     unsigned int log_num, u64 len, struct iovec *iov, int count)
2227 {
2228         int i, r;
2229
2230         /* Make sure data written is seen before log. */
2231         smp_wmb();
2232
2233         if (vq->iotlb) {
2234                 for (i = 0; i < count; i++) {
2235                         r = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
2236                                           iov[i].iov_len);
2237                         if (r < 0)
2238                                 return r;
2239                 }
2240                 return 0;
2241         }
2242
2243         for (i = 0; i < log_num; ++i) {
2244                 u64 l = min(log[i].len, len);
2245                 r = log_write(vq->log_base, log[i].addr, l);
2246                 if (r < 0)
2247                         return r;
2248                 len -= l;
2249                 if (!len) {
2250                         if (vq->log_ctx)
2251                                 eventfd_signal(vq->log_ctx);
2252                         return 0;
2253                 }
2254         }
2255         /* Length written exceeds what we have stored. This is a bug. */
2256         BUG();
2257         return 0;
2258 }
2259 EXPORT_SYMBOL_GPL(vhost_log_write);
2260
2261 static int vhost_update_used_flags(struct vhost_virtqueue *vq)
2262 {
2263         void __user *used;
2264         if (vhost_put_used_flags(vq))
2265                 return -EFAULT;
2266         if (unlikely(vq->log_used)) {
2267                 /* Make sure the flag is seen before log. */
2268                 smp_wmb();
2269                 /* Log used flag write. */
2270                 used = &vq->used->flags;
2271                 log_used(vq, (used - (void __user *)vq->used),
2272                          sizeof vq->used->flags);
2273                 if (vq->log_ctx)
2274                         eventfd_signal(vq->log_ctx);
2275         }
2276         return 0;
2277 }
2278
2279 static int vhost_update_avail_event(struct vhost_virtqueue *vq)
2280 {
2281         if (vhost_put_avail_event(vq))
2282                 return -EFAULT;
2283         if (unlikely(vq->log_used)) {
2284                 void __user *used;
2285                 /* Make sure the event is seen before log. */
2286                 smp_wmb();
2287                 /* Log avail event write */
2288                 used = vhost_avail_event(vq);
2289                 log_used(vq, (used - (void __user *)vq->used),
2290                          sizeof *vhost_avail_event(vq));
2291                 if (vq->log_ctx)
2292                         eventfd_signal(vq->log_ctx);
2293         }
2294         return 0;
2295 }
2296
2297 int vhost_vq_init_access(struct vhost_virtqueue *vq)
2298 {
2299         __virtio16 last_used_idx;
2300         int r;
2301         bool is_le = vq->is_le;
2302
2303         if (!vq->private_data)
2304                 return 0;
2305
2306         vhost_init_is_le(vq);
2307
2308         r = vhost_update_used_flags(vq);
2309         if (r)
2310                 goto err;
2311         vq->signalled_used_valid = false;
2312         if (!vq->iotlb &&
2313             !access_ok(&vq->used->idx, sizeof vq->used->idx)) {
2314                 r = -EFAULT;
2315                 goto err;
2316         }
2317         r = vhost_get_used_idx(vq, &last_used_idx);
2318         if (r) {
2319                 vq_err(vq, "Can't access used idx at %p\n",
2320                        &vq->used->idx);
2321                 goto err;
2322         }
2323         vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx);
2324         return 0;
2325
2326 err:
2327         vq->is_le = is_le;
2328         return r;
2329 }
2330 EXPORT_SYMBOL_GPL(vhost_vq_init_access);
2331
2332 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
2333                           struct iovec iov[], int iov_size, int access)
2334 {
2335         const struct vhost_iotlb_map *map;
2336         struct vhost_dev *dev = vq->dev;
2337         struct vhost_iotlb *umem = dev->iotlb ? dev->iotlb : dev->umem;
2338         struct iovec *_iov;
2339         u64 s = 0, last = addr + len - 1;
2340         int ret = 0;
2341
2342         while ((u64)len > s) {
2343                 u64 size;
2344                 if (unlikely(ret >= iov_size)) {
2345                         ret = -ENOBUFS;
2346                         break;
2347                 }
2348
2349                 map = vhost_iotlb_itree_first(umem, addr, last);
2350                 if (map == NULL || map->start > addr) {
2351                         if (umem != dev->iotlb) {
2352                                 ret = -EFAULT;
2353                                 break;
2354                         }
2355                         ret = -EAGAIN;
2356                         break;
2357                 } else if (!(map->perm & access)) {
2358                         ret = -EPERM;
2359                         break;
2360                 }
2361
2362                 _iov = iov + ret;
2363                 size = map->size - addr + map->start;
2364                 _iov->iov_len = min((u64)len - s, size);
2365                 _iov->iov_base = (void __user *)(unsigned long)
2366                                  (map->addr + addr - map->start);
2367                 s += size;
2368                 addr += size;
2369                 ++ret;
2370         }
2371
2372         if (ret == -EAGAIN)
2373                 vhost_iotlb_miss(vq, addr, access);
2374         return ret;
2375 }
2376
2377 /* Each buffer in the virtqueues is actually a chain of descriptors.  This
2378  * function returns the next descriptor in the chain,
2379  * or -1U if we're at the end. */
2380 static unsigned next_desc(struct vhost_virtqueue *vq, struct vring_desc *desc)
2381 {
2382         unsigned int next;
2383
2384         /* If this descriptor says it doesn't chain, we're done. */
2385         if (!(desc->flags & cpu_to_vhost16(vq, VRING_DESC_F_NEXT)))
2386                 return -1U;
2387
2388         /* Check they're not leading us off end of descriptors. */
2389         next = vhost16_to_cpu(vq, READ_ONCE(desc->next));
2390         return next;
2391 }
2392
2393 static int get_indirect(struct vhost_virtqueue *vq,
2394                         struct iovec iov[], unsigned int iov_size,
2395                         unsigned int *out_num, unsigned int *in_num,
2396                         struct vhost_log *log, unsigned int *log_num,
2397                         struct vring_desc *indirect)
2398 {
2399         struct vring_desc desc;
2400         unsigned int i = 0, count, found = 0;
2401         u32 len = vhost32_to_cpu(vq, indirect->len);
2402         struct iov_iter from;
2403         int ret, access;
2404
2405         /* Sanity check */
2406         if (unlikely(len % sizeof desc)) {
2407                 vq_err(vq, "Invalid length in indirect descriptor: "
2408                        "len 0x%llx not multiple of 0x%zx\n",
2409                        (unsigned long long)len,
2410                        sizeof desc);
2411                 return -EINVAL;
2412         }
2413
2414         ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect,
2415                              UIO_MAXIOV, VHOST_ACCESS_RO);
2416         if (unlikely(ret < 0)) {
2417                 if (ret != -EAGAIN)
2418                         vq_err(vq, "Translation failure %d in indirect.\n", ret);
2419                 return ret;
2420         }
2421         iov_iter_init(&from, ITER_SOURCE, vq->indirect, ret, len);
2422         count = len / sizeof desc;
2423         /* Buffers are chained via a 16 bit next field, so
2424          * we can have at most 2^16 of these. */
2425         if (unlikely(count > USHRT_MAX + 1)) {
2426                 vq_err(vq, "Indirect buffer length too big: %d\n",
2427                        indirect->len);
2428                 return -E2BIG;
2429         }
2430
2431         do {
2432                 unsigned iov_count = *in_num + *out_num;
2433                 if (unlikely(++found > count)) {
2434                         vq_err(vq, "Loop detected: last one at %u "
2435                                "indirect size %u\n",
2436                                i, count);
2437                         return -EINVAL;
2438                 }
2439                 if (unlikely(!copy_from_iter_full(&desc, sizeof(desc), &from))) {
2440                         vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n",
2441                                i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc);
2442                         return -EINVAL;
2443                 }
2444                 if (unlikely(desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT))) {
2445                         vq_err(vq, "Nested indirect descriptor: idx %d, %zx\n",
2446                                i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc);
2447                         return -EINVAL;
2448                 }
2449
2450                 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
2451                         access = VHOST_ACCESS_WO;
2452                 else
2453                         access = VHOST_ACCESS_RO;
2454
2455                 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
2456                                      vhost32_to_cpu(vq, desc.len), iov + iov_count,
2457                                      iov_size - iov_count, access);
2458                 if (unlikely(ret < 0)) {
2459                         if (ret != -EAGAIN)
2460                                 vq_err(vq, "Translation failure %d indirect idx %d\n",
2461                                         ret, i);
2462                         return ret;
2463                 }
2464                 /* If this is an input descriptor, increment that count. */
2465                 if (access == VHOST_ACCESS_WO) {
2466                         *in_num += ret;
2467                         if (unlikely(log && ret)) {
2468                                 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
2469                                 log[*log_num].len = vhost32_to_cpu(vq, desc.len);
2470                                 ++*log_num;
2471                         }
2472                 } else {
2473                         /* If it's an output descriptor, they're all supposed
2474                          * to come before any input descriptors. */
2475                         if (unlikely(*in_num)) {
2476                                 vq_err(vq, "Indirect descriptor "
2477                                        "has out after in: idx %d\n", i);
2478                                 return -EINVAL;
2479                         }
2480                         *out_num += ret;
2481                 }
2482         } while ((i = next_desc(vq, &desc)) != -1);
2483         return 0;
2484 }
2485
2486 /* This looks in the virtqueue and for the first available buffer, and converts
2487  * it to an iovec for convenient access.  Since descriptors consist of some
2488  * number of output then some number of input descriptors, it's actually two
2489  * iovecs, but we pack them into one and note how many of each there were.
2490  *
2491  * This function returns the descriptor number found, or vq->num (which is
2492  * never a valid descriptor number) if none was found.  A negative code is
2493  * returned on error. */
2494 int vhost_get_vq_desc(struct vhost_virtqueue *vq,
2495                       struct iovec iov[], unsigned int iov_size,
2496                       unsigned int *out_num, unsigned int *in_num,
2497                       struct vhost_log *log, unsigned int *log_num)
2498 {
2499         struct vring_desc desc;
2500         unsigned int i, head, found = 0;
2501         u16 last_avail_idx;
2502         __virtio16 avail_idx;
2503         __virtio16 ring_head;
2504         int ret, access;
2505
2506         /* Check it isn't doing very strange things with descriptor numbers. */
2507         last_avail_idx = vq->last_avail_idx;
2508
2509         if (vq->avail_idx == vq->last_avail_idx) {
2510                 if (unlikely(vhost_get_avail_idx(vq, &avail_idx))) {
2511                         vq_err(vq, "Failed to access avail idx at %p\n",
2512                                 &vq->avail->idx);
2513                         return -EFAULT;
2514                 }
2515                 vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2516
2517                 if (unlikely((u16)(vq->avail_idx - last_avail_idx) > vq->num)) {
2518                         vq_err(vq, "Guest moved used index from %u to %u",
2519                                 last_avail_idx, vq->avail_idx);
2520                         return -EFAULT;
2521                 }
2522
2523                 /* If there's nothing new since last we looked, return
2524                  * invalid.
2525                  */
2526                 if (vq->avail_idx == last_avail_idx)
2527                         return vq->num;
2528
2529                 /* Only get avail ring entries after they have been
2530                  * exposed by guest.
2531                  */
2532                 smp_rmb();
2533         }
2534
2535         /* Grab the next descriptor number they're advertising, and increment
2536          * the index we've seen. */
2537         if (unlikely(vhost_get_avail_head(vq, &ring_head, last_avail_idx))) {
2538                 vq_err(vq, "Failed to read head: idx %d address %p\n",
2539                        last_avail_idx,
2540                        &vq->avail->ring[last_avail_idx % vq->num]);
2541                 return -EFAULT;
2542         }
2543
2544         head = vhost16_to_cpu(vq, ring_head);
2545
2546         /* If their number is silly, that's an error. */
2547         if (unlikely(head >= vq->num)) {
2548                 vq_err(vq, "Guest says index %u > %u is available",
2549                        head, vq->num);
2550                 return -EINVAL;
2551         }
2552
2553         /* When we start there are none of either input nor output. */
2554         *out_num = *in_num = 0;
2555         if (unlikely(log))
2556                 *log_num = 0;
2557
2558         i = head;
2559         do {
2560                 unsigned iov_count = *in_num + *out_num;
2561                 if (unlikely(i >= vq->num)) {
2562                         vq_err(vq, "Desc index is %u > %u, head = %u",
2563                                i, vq->num, head);
2564                         return -EINVAL;
2565                 }
2566                 if (unlikely(++found > vq->num)) {
2567                         vq_err(vq, "Loop detected: last one at %u "
2568                                "vq size %u head %u\n",
2569                                i, vq->num, head);
2570                         return -EINVAL;
2571                 }
2572                 ret = vhost_get_desc(vq, &desc, i);
2573                 if (unlikely(ret)) {
2574                         vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
2575                                i, vq->desc + i);
2576                         return -EFAULT;
2577                 }
2578                 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT)) {
2579                         ret = get_indirect(vq, iov, iov_size,
2580                                            out_num, in_num,
2581                                            log, log_num, &desc);
2582                         if (unlikely(ret < 0)) {
2583                                 if (ret != -EAGAIN)
2584                                         vq_err(vq, "Failure detected "
2585                                                 "in indirect descriptor at idx %d\n", i);
2586                                 return ret;
2587                         }
2588                         continue;
2589                 }
2590
2591                 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
2592                         access = VHOST_ACCESS_WO;
2593                 else
2594                         access = VHOST_ACCESS_RO;
2595                 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
2596                                      vhost32_to_cpu(vq, desc.len), iov + iov_count,
2597                                      iov_size - iov_count, access);
2598                 if (unlikely(ret < 0)) {
2599                         if (ret != -EAGAIN)
2600                                 vq_err(vq, "Translation failure %d descriptor idx %d\n",
2601                                         ret, i);
2602                         return ret;
2603                 }
2604                 if (access == VHOST_ACCESS_WO) {
2605                         /* If this is an input descriptor,
2606                          * increment that count. */
2607                         *in_num += ret;
2608                         if (unlikely(log && ret)) {
2609                                 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
2610                                 log[*log_num].len = vhost32_to_cpu(vq, desc.len);
2611                                 ++*log_num;
2612                         }
2613                 } else {
2614                         /* If it's an output descriptor, they're all supposed
2615                          * to come before any input descriptors. */
2616                         if (unlikely(*in_num)) {
2617                                 vq_err(vq, "Descriptor has out after in: "
2618                                        "idx %d\n", i);
2619                                 return -EINVAL;
2620                         }
2621                         *out_num += ret;
2622                 }
2623         } while ((i = next_desc(vq, &desc)) != -1);
2624
2625         /* On success, increment avail index. */
2626         vq->last_avail_idx++;
2627
2628         /* Assume notifications from guest are disabled at this point,
2629          * if they aren't we would need to update avail_event index. */
2630         BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY));
2631         return head;
2632 }
2633 EXPORT_SYMBOL_GPL(vhost_get_vq_desc);
2634
2635 /* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */
2636 void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n)
2637 {
2638         vq->last_avail_idx -= n;
2639 }
2640 EXPORT_SYMBOL_GPL(vhost_discard_vq_desc);
2641
2642 /* After we've used one of their buffers, we tell them about it.  We'll then
2643  * want to notify the guest, using eventfd. */
2644 int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
2645 {
2646         struct vring_used_elem heads = {
2647                 cpu_to_vhost32(vq, head),
2648                 cpu_to_vhost32(vq, len)
2649         };
2650
2651         return vhost_add_used_n(vq, &heads, 1);
2652 }
2653 EXPORT_SYMBOL_GPL(vhost_add_used);
2654
2655 static int __vhost_add_used_n(struct vhost_virtqueue *vq,
2656                             struct vring_used_elem *heads,
2657                             unsigned count)
2658 {
2659         vring_used_elem_t __user *used;
2660         u16 old, new;
2661         int start;
2662
2663         start = vq->last_used_idx & (vq->num - 1);
2664         used = vq->used->ring + start;
2665         if (vhost_put_used(vq, heads, start, count)) {
2666                 vq_err(vq, "Failed to write used");
2667                 return -EFAULT;
2668         }
2669         if (unlikely(vq->log_used)) {
2670                 /* Make sure data is seen before log. */
2671                 smp_wmb();
2672                 /* Log used ring entry write. */
2673                 log_used(vq, ((void __user *)used - (void __user *)vq->used),
2674                          count * sizeof *used);
2675         }
2676         old = vq->last_used_idx;
2677         new = (vq->last_used_idx += count);
2678         /* If the driver never bothers to signal in a very long while,
2679          * used index might wrap around. If that happens, invalidate
2680          * signalled_used index we stored. TODO: make sure driver
2681          * signals at least once in 2^16 and remove this. */
2682         if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old)))
2683                 vq->signalled_used_valid = false;
2684         return 0;
2685 }
2686
2687 /* After we've used one of their buffers, we tell them about it.  We'll then
2688  * want to notify the guest, using eventfd. */
2689 int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
2690                      unsigned count)
2691 {
2692         int start, n, r;
2693
2694         start = vq->last_used_idx & (vq->num - 1);
2695         n = vq->num - start;
2696         if (n < count) {
2697                 r = __vhost_add_used_n(vq, heads, n);
2698                 if (r < 0)
2699                         return r;
2700                 heads += n;
2701                 count -= n;
2702         }
2703         r = __vhost_add_used_n(vq, heads, count);
2704
2705         /* Make sure buffer is written before we update index. */
2706         smp_wmb();
2707         if (vhost_put_used_idx(vq)) {
2708                 vq_err(vq, "Failed to increment used idx");
2709                 return -EFAULT;
2710         }
2711         if (unlikely(vq->log_used)) {
2712                 /* Make sure used idx is seen before log. */
2713                 smp_wmb();
2714                 /* Log used index update. */
2715                 log_used(vq, offsetof(struct vring_used, idx),
2716                          sizeof vq->used->idx);
2717                 if (vq->log_ctx)
2718                         eventfd_signal(vq->log_ctx);
2719         }
2720         return r;
2721 }
2722 EXPORT_SYMBOL_GPL(vhost_add_used_n);
2723
2724 static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2725 {
2726         __u16 old, new;
2727         __virtio16 event;
2728         bool v;
2729         /* Flush out used index updates. This is paired
2730          * with the barrier that the Guest executes when enabling
2731          * interrupts. */
2732         smp_mb();
2733
2734         if (vhost_has_feature(vq, VIRTIO_F_NOTIFY_ON_EMPTY) &&
2735             unlikely(vq->avail_idx == vq->last_avail_idx))
2736                 return true;
2737
2738         if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2739                 __virtio16 flags;
2740                 if (vhost_get_avail_flags(vq, &flags)) {
2741                         vq_err(vq, "Failed to get flags");
2742                         return true;
2743                 }
2744                 return !(flags & cpu_to_vhost16(vq, VRING_AVAIL_F_NO_INTERRUPT));
2745         }
2746         old = vq->signalled_used;
2747         v = vq->signalled_used_valid;
2748         new = vq->signalled_used = vq->last_used_idx;
2749         vq->signalled_used_valid = true;
2750
2751         if (unlikely(!v))
2752                 return true;
2753
2754         if (vhost_get_used_event(vq, &event)) {
2755                 vq_err(vq, "Failed to get used event idx");
2756                 return true;
2757         }
2758         return vring_need_event(vhost16_to_cpu(vq, event), new, old);
2759 }
2760
2761 /* This actually signals the guest, using eventfd. */
2762 void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2763 {
2764         /* Signal the Guest tell them we used something up. */
2765         if (vq->call_ctx.ctx && vhost_notify(dev, vq))
2766                 eventfd_signal(vq->call_ctx.ctx);
2767 }
2768 EXPORT_SYMBOL_GPL(vhost_signal);
2769
2770 /* And here's the combo meal deal.  Supersize me! */
2771 void vhost_add_used_and_signal(struct vhost_dev *dev,
2772                                struct vhost_virtqueue *vq,
2773                                unsigned int head, int len)
2774 {
2775         vhost_add_used(vq, head, len);
2776         vhost_signal(dev, vq);
2777 }
2778 EXPORT_SYMBOL_GPL(vhost_add_used_and_signal);
2779
2780 /* multi-buffer version of vhost_add_used_and_signal */
2781 void vhost_add_used_and_signal_n(struct vhost_dev *dev,
2782                                  struct vhost_virtqueue *vq,
2783                                  struct vring_used_elem *heads, unsigned count)
2784 {
2785         vhost_add_used_n(vq, heads, count);
2786         vhost_signal(dev, vq);
2787 }
2788 EXPORT_SYMBOL_GPL(vhost_add_used_and_signal_n);
2789
2790 /* return true if we're sure that avaiable ring is empty */
2791 bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2792 {
2793         __virtio16 avail_idx;
2794         int r;
2795
2796         if (vq->avail_idx != vq->last_avail_idx)
2797                 return false;
2798
2799         r = vhost_get_avail_idx(vq, &avail_idx);
2800         if (unlikely(r))
2801                 return false;
2802         vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2803
2804         return vq->avail_idx == vq->last_avail_idx;
2805 }
2806 EXPORT_SYMBOL_GPL(vhost_vq_avail_empty);
2807
2808 /* OK, now we need to know about added descriptors. */
2809 bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2810 {
2811         __virtio16 avail_idx;
2812         int r;
2813
2814         if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY))
2815                 return false;
2816         vq->used_flags &= ~VRING_USED_F_NO_NOTIFY;
2817         if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2818                 r = vhost_update_used_flags(vq);
2819                 if (r) {
2820                         vq_err(vq, "Failed to enable notification at %p: %d\n",
2821                                &vq->used->flags, r);
2822                         return false;
2823                 }
2824         } else {
2825                 r = vhost_update_avail_event(vq);
2826                 if (r) {
2827                         vq_err(vq, "Failed to update avail event index at %p: %d\n",
2828                                vhost_avail_event(vq), r);
2829                         return false;
2830                 }
2831         }
2832         /* They could have slipped one in as we were doing that: make
2833          * sure it's written, then check again. */
2834         smp_mb();
2835         r = vhost_get_avail_idx(vq, &avail_idx);
2836         if (r) {
2837                 vq_err(vq, "Failed to check avail idx at %p: %d\n",
2838                        &vq->avail->idx, r);
2839                 return false;
2840         }
2841         vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2842
2843         return vq->avail_idx != vq->last_avail_idx;
2844 }
2845 EXPORT_SYMBOL_GPL(vhost_enable_notify);
2846
2847 /* We don't need to be notified again. */
2848 void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2849 {
2850         int r;
2851
2852         if (vq->used_flags & VRING_USED_F_NO_NOTIFY)
2853                 return;
2854         vq->used_flags |= VRING_USED_F_NO_NOTIFY;
2855         if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2856                 r = vhost_update_used_flags(vq);
2857                 if (r)
2858                         vq_err(vq, "Failed to disable notification at %p: %d\n",
2859                                &vq->used->flags, r);
2860         }
2861 }
2862 EXPORT_SYMBOL_GPL(vhost_disable_notify);
2863
2864 /* Create a new message. */
2865 struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type)
2866 {
2867         /* Make sure all padding within the structure is initialized. */
2868         struct vhost_msg_node *node = kzalloc(sizeof(*node), GFP_KERNEL);
2869         if (!node)
2870                 return NULL;
2871
2872         node->vq = vq;
2873         node->msg.type = type;
2874         return node;
2875 }
2876 EXPORT_SYMBOL_GPL(vhost_new_msg);
2877
2878 void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head,
2879                        struct vhost_msg_node *node)
2880 {
2881         spin_lock(&dev->iotlb_lock);
2882         list_add_tail(&node->node, head);
2883         spin_unlock(&dev->iotlb_lock);
2884
2885         wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
2886 }
2887 EXPORT_SYMBOL_GPL(vhost_enqueue_msg);
2888
2889 struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
2890                                          struct list_head *head)
2891 {
2892         struct vhost_msg_node *node = NULL;
2893
2894         spin_lock(&dev->iotlb_lock);
2895         if (!list_empty(head)) {
2896                 node = list_first_entry(head, struct vhost_msg_node,
2897                                         node);
2898                 list_del(&node->node);
2899         }
2900         spin_unlock(&dev->iotlb_lock);
2901
2902         return node;
2903 }
2904 EXPORT_SYMBOL_GPL(vhost_dequeue_msg);
2905
2906 void vhost_set_backend_features(struct vhost_dev *dev, u64 features)
2907 {
2908         struct vhost_virtqueue *vq;
2909         int i;
2910
2911         mutex_lock(&dev->mutex);
2912         for (i = 0; i < dev->nvqs; ++i) {
2913                 vq = dev->vqs[i];
2914                 mutex_lock(&vq->mutex);
2915                 vq->acked_backend_features = features;
2916                 mutex_unlock(&vq->mutex);
2917         }
2918         mutex_unlock(&dev->mutex);
2919 }
2920 EXPORT_SYMBOL_GPL(vhost_set_backend_features);
2921
2922 static int __init vhost_init(void)
2923 {
2924         return 0;
2925 }
2926
2927 static void __exit vhost_exit(void)
2928 {
2929 }
2930
2931 module_init(vhost_init);
2932 module_exit(vhost_exit);
2933
2934 MODULE_VERSION("0.0.1");
2935 MODULE_LICENSE("GPL v2");
2936 MODULE_AUTHOR("Michael S. Tsirkin");
2937 MODULE_DESCRIPTION("Host kernel accelerator for virtio");