Merge tag 'x86-urgent-2022-08-06' of git://git.kernel.org/pub/scm/linux/kernel/git...
[linux-2.6-microblaze.git] / drivers / vfio / vfio_iommu_type1.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * VFIO: IOMMU DMA mapping support for Type1 IOMMU
4  *
5  * Copyright (C) 2012 Red Hat, Inc.  All rights reserved.
6  *     Author: Alex Williamson <alex.williamson@redhat.com>
7  *
8  * Derived from original vfio:
9  * Copyright 2010 Cisco Systems, Inc.  All rights reserved.
10  * Author: Tom Lyon, pugs@cisco.com
11  *
12  * We arbitrarily define a Type1 IOMMU as one matching the below code.
13  * It could be called the x86 IOMMU as it's designed for AMD-Vi & Intel
14  * VT-d, but that makes it harder to re-use as theoretically anyone
15  * implementing a similar IOMMU could make use of this.  We expect the
16  * IOMMU to support the IOMMU API and have few to no restrictions around
17  * the IOVA range that can be mapped.  The Type1 IOMMU is currently
18  * optimized for relatively static mappings of a userspace process with
19  * userspace pages pinned into memory.  We also assume devices and IOMMU
20  * domains are PCI based as the IOMMU API is still centered around a
21  * device/bus interface rather than a group interface.
22  */
23
24 #include <linux/compat.h>
25 #include <linux/device.h>
26 #include <linux/fs.h>
27 #include <linux/highmem.h>
28 #include <linux/iommu.h>
29 #include <linux/module.h>
30 #include <linux/mm.h>
31 #include <linux/kthread.h>
32 #include <linux/rbtree.h>
33 #include <linux/sched/signal.h>
34 #include <linux/sched/mm.h>
35 #include <linux/slab.h>
36 #include <linux/uaccess.h>
37 #include <linux/vfio.h>
38 #include <linux/workqueue.h>
39 #include <linux/notifier.h>
40 #include <linux/dma-iommu.h>
41 #include <linux/irqdomain.h>
42 #include "vfio.h"
43
44 #define DRIVER_VERSION  "0.2"
45 #define DRIVER_AUTHOR   "Alex Williamson <alex.williamson@redhat.com>"
46 #define DRIVER_DESC     "Type1 IOMMU driver for VFIO"
47
48 static bool allow_unsafe_interrupts;
49 module_param_named(allow_unsafe_interrupts,
50                    allow_unsafe_interrupts, bool, S_IRUGO | S_IWUSR);
51 MODULE_PARM_DESC(allow_unsafe_interrupts,
52                  "Enable VFIO IOMMU support for on platforms without interrupt remapping support.");
53
54 static bool disable_hugepages;
55 module_param_named(disable_hugepages,
56                    disable_hugepages, bool, S_IRUGO | S_IWUSR);
57 MODULE_PARM_DESC(disable_hugepages,
58                  "Disable VFIO IOMMU support for IOMMU hugepages.");
59
60 static unsigned int dma_entry_limit __read_mostly = U16_MAX;
61 module_param_named(dma_entry_limit, dma_entry_limit, uint, 0644);
62 MODULE_PARM_DESC(dma_entry_limit,
63                  "Maximum number of user DMA mappings per container (65535).");
64
65 struct vfio_iommu {
66         struct list_head        domain_list;
67         struct list_head        iova_list;
68         struct mutex            lock;
69         struct rb_root          dma_list;
70         struct list_head        device_list;
71         struct mutex            device_list_lock;
72         unsigned int            dma_avail;
73         unsigned int            vaddr_invalid_count;
74         uint64_t                pgsize_bitmap;
75         uint64_t                num_non_pinned_groups;
76         wait_queue_head_t       vaddr_wait;
77         bool                    v2;
78         bool                    nesting;
79         bool                    dirty_page_tracking;
80         bool                    container_open;
81         struct list_head        emulated_iommu_groups;
82 };
83
84 struct vfio_domain {
85         struct iommu_domain     *domain;
86         struct list_head        next;
87         struct list_head        group_list;
88         bool                    fgsp : 1;       /* Fine-grained super pages */
89         bool                    enforce_cache_coherency : 1;
90 };
91
92 struct vfio_dma {
93         struct rb_node          node;
94         dma_addr_t              iova;           /* Device address */
95         unsigned long           vaddr;          /* Process virtual addr */
96         size_t                  size;           /* Map size (bytes) */
97         int                     prot;           /* IOMMU_READ/WRITE */
98         bool                    iommu_mapped;
99         bool                    lock_cap;       /* capable(CAP_IPC_LOCK) */
100         bool                    vaddr_invalid;
101         struct task_struct      *task;
102         struct rb_root          pfn_list;       /* Ex-user pinned pfn list */
103         unsigned long           *bitmap;
104 };
105
106 struct vfio_batch {
107         struct page             **pages;        /* for pin_user_pages_remote */
108         struct page             *fallback_page; /* if pages alloc fails */
109         int                     capacity;       /* length of pages array */
110         int                     size;           /* of batch currently */
111         int                     offset;         /* of next entry in pages */
112 };
113
114 struct vfio_iommu_group {
115         struct iommu_group      *iommu_group;
116         struct list_head        next;
117         bool                    pinned_page_dirty_scope;
118 };
119
120 struct vfio_iova {
121         struct list_head        list;
122         dma_addr_t              start;
123         dma_addr_t              end;
124 };
125
126 /*
127  * Guest RAM pinning working set or DMA target
128  */
129 struct vfio_pfn {
130         struct rb_node          node;
131         dma_addr_t              iova;           /* Device address */
132         unsigned long           pfn;            /* Host pfn */
133         unsigned int            ref_count;
134 };
135
136 struct vfio_regions {
137         struct list_head list;
138         dma_addr_t iova;
139         phys_addr_t phys;
140         size_t len;
141 };
142
143 #define DIRTY_BITMAP_BYTES(n)   (ALIGN(n, BITS_PER_TYPE(u64)) / BITS_PER_BYTE)
144
145 /*
146  * Input argument of number of bits to bitmap_set() is unsigned integer, which
147  * further casts to signed integer for unaligned multi-bit operation,
148  * __bitmap_set().
149  * Then maximum bitmap size supported is 2^31 bits divided by 2^3 bits/byte,
150  * that is 2^28 (256 MB) which maps to 2^31 * 2^12 = 2^43 (8TB) on 4K page
151  * system.
152  */
153 #define DIRTY_BITMAP_PAGES_MAX   ((u64)INT_MAX)
154 #define DIRTY_BITMAP_SIZE_MAX    DIRTY_BITMAP_BYTES(DIRTY_BITMAP_PAGES_MAX)
155
156 #define WAITED 1
157
158 static int put_pfn(unsigned long pfn, int prot);
159
160 static struct vfio_iommu_group*
161 vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
162                             struct iommu_group *iommu_group);
163
164 /*
165  * This code handles mapping and unmapping of user data buffers
166  * into DMA'ble space using the IOMMU
167  */
168
169 static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
170                                       dma_addr_t start, size_t size)
171 {
172         struct rb_node *node = iommu->dma_list.rb_node;
173
174         while (node) {
175                 struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
176
177                 if (start + size <= dma->iova)
178                         node = node->rb_left;
179                 else if (start >= dma->iova + dma->size)
180                         node = node->rb_right;
181                 else
182                         return dma;
183         }
184
185         return NULL;
186 }
187
188 static struct rb_node *vfio_find_dma_first_node(struct vfio_iommu *iommu,
189                                                 dma_addr_t start, u64 size)
190 {
191         struct rb_node *res = NULL;
192         struct rb_node *node = iommu->dma_list.rb_node;
193         struct vfio_dma *dma_res = NULL;
194
195         while (node) {
196                 struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
197
198                 if (start < dma->iova + dma->size) {
199                         res = node;
200                         dma_res = dma;
201                         if (start >= dma->iova)
202                                 break;
203                         node = node->rb_left;
204                 } else {
205                         node = node->rb_right;
206                 }
207         }
208         if (res && size && dma_res->iova >= start + size)
209                 res = NULL;
210         return res;
211 }
212
213 static void vfio_link_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
214 {
215         struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
216         struct vfio_dma *dma;
217
218         while (*link) {
219                 parent = *link;
220                 dma = rb_entry(parent, struct vfio_dma, node);
221
222                 if (new->iova + new->size <= dma->iova)
223                         link = &(*link)->rb_left;
224                 else
225                         link = &(*link)->rb_right;
226         }
227
228         rb_link_node(&new->node, parent, link);
229         rb_insert_color(&new->node, &iommu->dma_list);
230 }
231
232 static void vfio_unlink_dma(struct vfio_iommu *iommu, struct vfio_dma *old)
233 {
234         rb_erase(&old->node, &iommu->dma_list);
235 }
236
237
238 static int vfio_dma_bitmap_alloc(struct vfio_dma *dma, size_t pgsize)
239 {
240         uint64_t npages = dma->size / pgsize;
241
242         if (npages > DIRTY_BITMAP_PAGES_MAX)
243                 return -EINVAL;
244
245         /*
246          * Allocate extra 64 bits that are used to calculate shift required for
247          * bitmap_shift_left() to manipulate and club unaligned number of pages
248          * in adjacent vfio_dma ranges.
249          */
250         dma->bitmap = kvzalloc(DIRTY_BITMAP_BYTES(npages) + sizeof(u64),
251                                GFP_KERNEL);
252         if (!dma->bitmap)
253                 return -ENOMEM;
254
255         return 0;
256 }
257
258 static void vfio_dma_bitmap_free(struct vfio_dma *dma)
259 {
260         kvfree(dma->bitmap);
261         dma->bitmap = NULL;
262 }
263
264 static void vfio_dma_populate_bitmap(struct vfio_dma *dma, size_t pgsize)
265 {
266         struct rb_node *p;
267         unsigned long pgshift = __ffs(pgsize);
268
269         for (p = rb_first(&dma->pfn_list); p; p = rb_next(p)) {
270                 struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn, node);
271
272                 bitmap_set(dma->bitmap, (vpfn->iova - dma->iova) >> pgshift, 1);
273         }
274 }
275
276 static void vfio_iommu_populate_bitmap_full(struct vfio_iommu *iommu)
277 {
278         struct rb_node *n;
279         unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
280
281         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
282                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
283
284                 bitmap_set(dma->bitmap, 0, dma->size >> pgshift);
285         }
286 }
287
288 static int vfio_dma_bitmap_alloc_all(struct vfio_iommu *iommu, size_t pgsize)
289 {
290         struct rb_node *n;
291
292         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
293                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
294                 int ret;
295
296                 ret = vfio_dma_bitmap_alloc(dma, pgsize);
297                 if (ret) {
298                         struct rb_node *p;
299
300                         for (p = rb_prev(n); p; p = rb_prev(p)) {
301                                 struct vfio_dma *dma = rb_entry(n,
302                                                         struct vfio_dma, node);
303
304                                 vfio_dma_bitmap_free(dma);
305                         }
306                         return ret;
307                 }
308                 vfio_dma_populate_bitmap(dma, pgsize);
309         }
310         return 0;
311 }
312
313 static void vfio_dma_bitmap_free_all(struct vfio_iommu *iommu)
314 {
315         struct rb_node *n;
316
317         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
318                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
319
320                 vfio_dma_bitmap_free(dma);
321         }
322 }
323
324 /*
325  * Helper Functions for host iova-pfn list
326  */
327 static struct vfio_pfn *vfio_find_vpfn(struct vfio_dma *dma, dma_addr_t iova)
328 {
329         struct vfio_pfn *vpfn;
330         struct rb_node *node = dma->pfn_list.rb_node;
331
332         while (node) {
333                 vpfn = rb_entry(node, struct vfio_pfn, node);
334
335                 if (iova < vpfn->iova)
336                         node = node->rb_left;
337                 else if (iova > vpfn->iova)
338                         node = node->rb_right;
339                 else
340                         return vpfn;
341         }
342         return NULL;
343 }
344
345 static void vfio_link_pfn(struct vfio_dma *dma,
346                           struct vfio_pfn *new)
347 {
348         struct rb_node **link, *parent = NULL;
349         struct vfio_pfn *vpfn;
350
351         link = &dma->pfn_list.rb_node;
352         while (*link) {
353                 parent = *link;
354                 vpfn = rb_entry(parent, struct vfio_pfn, node);
355
356                 if (new->iova < vpfn->iova)
357                         link = &(*link)->rb_left;
358                 else
359                         link = &(*link)->rb_right;
360         }
361
362         rb_link_node(&new->node, parent, link);
363         rb_insert_color(&new->node, &dma->pfn_list);
364 }
365
366 static void vfio_unlink_pfn(struct vfio_dma *dma, struct vfio_pfn *old)
367 {
368         rb_erase(&old->node, &dma->pfn_list);
369 }
370
371 static int vfio_add_to_pfn_list(struct vfio_dma *dma, dma_addr_t iova,
372                                 unsigned long pfn)
373 {
374         struct vfio_pfn *vpfn;
375
376         vpfn = kzalloc(sizeof(*vpfn), GFP_KERNEL);
377         if (!vpfn)
378                 return -ENOMEM;
379
380         vpfn->iova = iova;
381         vpfn->pfn = pfn;
382         vpfn->ref_count = 1;
383         vfio_link_pfn(dma, vpfn);
384         return 0;
385 }
386
387 static void vfio_remove_from_pfn_list(struct vfio_dma *dma,
388                                       struct vfio_pfn *vpfn)
389 {
390         vfio_unlink_pfn(dma, vpfn);
391         kfree(vpfn);
392 }
393
394 static struct vfio_pfn *vfio_iova_get_vfio_pfn(struct vfio_dma *dma,
395                                                unsigned long iova)
396 {
397         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
398
399         if (vpfn)
400                 vpfn->ref_count++;
401         return vpfn;
402 }
403
404 static int vfio_iova_put_vfio_pfn(struct vfio_dma *dma, struct vfio_pfn *vpfn)
405 {
406         int ret = 0;
407
408         vpfn->ref_count--;
409         if (!vpfn->ref_count) {
410                 ret = put_pfn(vpfn->pfn, dma->prot);
411                 vfio_remove_from_pfn_list(dma, vpfn);
412         }
413         return ret;
414 }
415
416 static int vfio_lock_acct(struct vfio_dma *dma, long npage, bool async)
417 {
418         struct mm_struct *mm;
419         int ret;
420
421         if (!npage)
422                 return 0;
423
424         mm = async ? get_task_mm(dma->task) : dma->task->mm;
425         if (!mm)
426                 return -ESRCH; /* process exited */
427
428         ret = mmap_write_lock_killable(mm);
429         if (!ret) {
430                 ret = __account_locked_vm(mm, abs(npage), npage > 0, dma->task,
431                                           dma->lock_cap);
432                 mmap_write_unlock(mm);
433         }
434
435         if (async)
436                 mmput(mm);
437
438         return ret;
439 }
440
441 /*
442  * Some mappings aren't backed by a struct page, for example an mmap'd
443  * MMIO range for our own or another device.  These use a different
444  * pfn conversion and shouldn't be tracked as locked pages.
445  * For compound pages, any driver that sets the reserved bit in head
446  * page needs to set the reserved bit in all subpages to be safe.
447  */
448 static bool is_invalid_reserved_pfn(unsigned long pfn)
449 {
450         if (pfn_valid(pfn))
451                 return PageReserved(pfn_to_page(pfn));
452
453         return true;
454 }
455
456 static int put_pfn(unsigned long pfn, int prot)
457 {
458         if (!is_invalid_reserved_pfn(pfn)) {
459                 struct page *page = pfn_to_page(pfn);
460
461                 unpin_user_pages_dirty_lock(&page, 1, prot & IOMMU_WRITE);
462                 return 1;
463         }
464         return 0;
465 }
466
467 #define VFIO_BATCH_MAX_CAPACITY (PAGE_SIZE / sizeof(struct page *))
468
469 static void vfio_batch_init(struct vfio_batch *batch)
470 {
471         batch->size = 0;
472         batch->offset = 0;
473
474         if (unlikely(disable_hugepages))
475                 goto fallback;
476
477         batch->pages = (struct page **) __get_free_page(GFP_KERNEL);
478         if (!batch->pages)
479                 goto fallback;
480
481         batch->capacity = VFIO_BATCH_MAX_CAPACITY;
482         return;
483
484 fallback:
485         batch->pages = &batch->fallback_page;
486         batch->capacity = 1;
487 }
488
489 static void vfio_batch_unpin(struct vfio_batch *batch, struct vfio_dma *dma)
490 {
491         while (batch->size) {
492                 unsigned long pfn = page_to_pfn(batch->pages[batch->offset]);
493
494                 put_pfn(pfn, dma->prot);
495                 batch->offset++;
496                 batch->size--;
497         }
498 }
499
500 static void vfio_batch_fini(struct vfio_batch *batch)
501 {
502         if (batch->capacity == VFIO_BATCH_MAX_CAPACITY)
503                 free_page((unsigned long)batch->pages);
504 }
505
506 static int follow_fault_pfn(struct vm_area_struct *vma, struct mm_struct *mm,
507                             unsigned long vaddr, unsigned long *pfn,
508                             bool write_fault)
509 {
510         pte_t *ptep;
511         spinlock_t *ptl;
512         int ret;
513
514         ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
515         if (ret) {
516                 bool unlocked = false;
517
518                 ret = fixup_user_fault(mm, vaddr,
519                                        FAULT_FLAG_REMOTE |
520                                        (write_fault ?  FAULT_FLAG_WRITE : 0),
521                                        &unlocked);
522                 if (unlocked)
523                         return -EAGAIN;
524
525                 if (ret)
526                         return ret;
527
528                 ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
529                 if (ret)
530                         return ret;
531         }
532
533         if (write_fault && !pte_write(*ptep))
534                 ret = -EFAULT;
535         else
536                 *pfn = pte_pfn(*ptep);
537
538         pte_unmap_unlock(ptep, ptl);
539         return ret;
540 }
541
542 /*
543  * Returns the positive number of pfns successfully obtained or a negative
544  * error code.
545  */
546 static int vaddr_get_pfns(struct mm_struct *mm, unsigned long vaddr,
547                           long npages, int prot, unsigned long *pfn,
548                           struct page **pages)
549 {
550         struct vm_area_struct *vma;
551         unsigned int flags = 0;
552         int ret;
553
554         if (prot & IOMMU_WRITE)
555                 flags |= FOLL_WRITE;
556
557         mmap_read_lock(mm);
558         ret = pin_user_pages_remote(mm, vaddr, npages, flags | FOLL_LONGTERM,
559                                     pages, NULL, NULL);
560         if (ret > 0) {
561                 *pfn = page_to_pfn(pages[0]);
562                 goto done;
563         }
564
565         vaddr = untagged_addr(vaddr);
566
567 retry:
568         vma = vma_lookup(mm, vaddr);
569
570         if (vma && vma->vm_flags & VM_PFNMAP) {
571                 ret = follow_fault_pfn(vma, mm, vaddr, pfn, prot & IOMMU_WRITE);
572                 if (ret == -EAGAIN)
573                         goto retry;
574
575                 if (!ret) {
576                         if (is_invalid_reserved_pfn(*pfn))
577                                 ret = 1;
578                         else
579                                 ret = -EFAULT;
580                 }
581         }
582 done:
583         mmap_read_unlock(mm);
584         return ret;
585 }
586
587 static int vfio_wait(struct vfio_iommu *iommu)
588 {
589         DEFINE_WAIT(wait);
590
591         prepare_to_wait(&iommu->vaddr_wait, &wait, TASK_KILLABLE);
592         mutex_unlock(&iommu->lock);
593         schedule();
594         mutex_lock(&iommu->lock);
595         finish_wait(&iommu->vaddr_wait, &wait);
596         if (kthread_should_stop() || !iommu->container_open ||
597             fatal_signal_pending(current)) {
598                 return -EFAULT;
599         }
600         return WAITED;
601 }
602
603 /*
604  * Find dma struct and wait for its vaddr to be valid.  iommu lock is dropped
605  * if the task waits, but is re-locked on return.  Return result in *dma_p.
606  * Return 0 on success with no waiting, WAITED on success if waited, and -errno
607  * on error.
608  */
609 static int vfio_find_dma_valid(struct vfio_iommu *iommu, dma_addr_t start,
610                                size_t size, struct vfio_dma **dma_p)
611 {
612         int ret = 0;
613
614         do {
615                 *dma_p = vfio_find_dma(iommu, start, size);
616                 if (!*dma_p)
617                         return -EINVAL;
618                 else if (!(*dma_p)->vaddr_invalid)
619                         return ret;
620                 else
621                         ret = vfio_wait(iommu);
622         } while (ret == WAITED);
623
624         return ret;
625 }
626
627 /*
628  * Wait for all vaddr in the dma_list to become valid.  iommu lock is dropped
629  * if the task waits, but is re-locked on return.  Return 0 on success with no
630  * waiting, WAITED on success if waited, and -errno on error.
631  */
632 static int vfio_wait_all_valid(struct vfio_iommu *iommu)
633 {
634         int ret = 0;
635
636         while (iommu->vaddr_invalid_count && ret >= 0)
637                 ret = vfio_wait(iommu);
638
639         return ret;
640 }
641
642 /*
643  * Attempt to pin pages.  We really don't want to track all the pfns and
644  * the iommu can only map chunks of consecutive pfns anyway, so get the
645  * first page and all consecutive pages with the same locking.
646  */
647 static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
648                                   long npage, unsigned long *pfn_base,
649                                   unsigned long limit, struct vfio_batch *batch)
650 {
651         unsigned long pfn;
652         struct mm_struct *mm = current->mm;
653         long ret, pinned = 0, lock_acct = 0;
654         bool rsvd;
655         dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
656
657         /* This code path is only user initiated */
658         if (!mm)
659                 return -ENODEV;
660
661         if (batch->size) {
662                 /* Leftover pages in batch from an earlier call. */
663                 *pfn_base = page_to_pfn(batch->pages[batch->offset]);
664                 pfn = *pfn_base;
665                 rsvd = is_invalid_reserved_pfn(*pfn_base);
666         } else {
667                 *pfn_base = 0;
668         }
669
670         while (npage) {
671                 if (!batch->size) {
672                         /* Empty batch, so refill it. */
673                         long req_pages = min_t(long, npage, batch->capacity);
674
675                         ret = vaddr_get_pfns(mm, vaddr, req_pages, dma->prot,
676                                              &pfn, batch->pages);
677                         if (ret < 0)
678                                 goto unpin_out;
679
680                         batch->size = ret;
681                         batch->offset = 0;
682
683                         if (!*pfn_base) {
684                                 *pfn_base = pfn;
685                                 rsvd = is_invalid_reserved_pfn(*pfn_base);
686                         }
687                 }
688
689                 /*
690                  * pfn is preset for the first iteration of this inner loop and
691                  * updated at the end to handle a VM_PFNMAP pfn.  In that case,
692                  * batch->pages isn't valid (there's no struct page), so allow
693                  * batch->pages to be touched only when there's more than one
694                  * pfn to check, which guarantees the pfns are from a
695                  * !VM_PFNMAP vma.
696                  */
697                 while (true) {
698                         if (pfn != *pfn_base + pinned ||
699                             rsvd != is_invalid_reserved_pfn(pfn))
700                                 goto out;
701
702                         /*
703                          * Reserved pages aren't counted against the user,
704                          * externally pinned pages are already counted against
705                          * the user.
706                          */
707                         if (!rsvd && !vfio_find_vpfn(dma, iova)) {
708                                 if (!dma->lock_cap &&
709                                     mm->locked_vm + lock_acct + 1 > limit) {
710                                         pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
711                                                 __func__, limit << PAGE_SHIFT);
712                                         ret = -ENOMEM;
713                                         goto unpin_out;
714                                 }
715                                 lock_acct++;
716                         }
717
718                         pinned++;
719                         npage--;
720                         vaddr += PAGE_SIZE;
721                         iova += PAGE_SIZE;
722                         batch->offset++;
723                         batch->size--;
724
725                         if (!batch->size)
726                                 break;
727
728                         pfn = page_to_pfn(batch->pages[batch->offset]);
729                 }
730
731                 if (unlikely(disable_hugepages))
732                         break;
733         }
734
735 out:
736         ret = vfio_lock_acct(dma, lock_acct, false);
737
738 unpin_out:
739         if (batch->size == 1 && !batch->offset) {
740                 /* May be a VM_PFNMAP pfn, which the batch can't remember. */
741                 put_pfn(pfn, dma->prot);
742                 batch->size = 0;
743         }
744
745         if (ret < 0) {
746                 if (pinned && !rsvd) {
747                         for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
748                                 put_pfn(pfn, dma->prot);
749                 }
750                 vfio_batch_unpin(batch, dma);
751
752                 return ret;
753         }
754
755         return pinned;
756 }
757
758 static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
759                                     unsigned long pfn, long npage,
760                                     bool do_accounting)
761 {
762         long unlocked = 0, locked = 0;
763         long i;
764
765         for (i = 0; i < npage; i++, iova += PAGE_SIZE) {
766                 if (put_pfn(pfn++, dma->prot)) {
767                         unlocked++;
768                         if (vfio_find_vpfn(dma, iova))
769                                 locked++;
770                 }
771         }
772
773         if (do_accounting)
774                 vfio_lock_acct(dma, locked - unlocked, true);
775
776         return unlocked;
777 }
778
779 static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
780                                   unsigned long *pfn_base, bool do_accounting)
781 {
782         struct page *pages[1];
783         struct mm_struct *mm;
784         int ret;
785
786         mm = get_task_mm(dma->task);
787         if (!mm)
788                 return -ENODEV;
789
790         ret = vaddr_get_pfns(mm, vaddr, 1, dma->prot, pfn_base, pages);
791         if (ret != 1)
792                 goto out;
793
794         ret = 0;
795
796         if (do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
797                 ret = vfio_lock_acct(dma, 1, true);
798                 if (ret) {
799                         put_pfn(*pfn_base, dma->prot);
800                         if (ret == -ENOMEM)
801                                 pr_warn("%s: Task %s (%d) RLIMIT_MEMLOCK "
802                                         "(%ld) exceeded\n", __func__,
803                                         dma->task->comm, task_pid_nr(dma->task),
804                                         task_rlimit(dma->task, RLIMIT_MEMLOCK));
805                 }
806         }
807
808 out:
809         mmput(mm);
810         return ret;
811 }
812
813 static int vfio_unpin_page_external(struct vfio_dma *dma, dma_addr_t iova,
814                                     bool do_accounting)
815 {
816         int unlocked;
817         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
818
819         if (!vpfn)
820                 return 0;
821
822         unlocked = vfio_iova_put_vfio_pfn(dma, vpfn);
823
824         if (do_accounting)
825                 vfio_lock_acct(dma, -unlocked, true);
826
827         return unlocked;
828 }
829
830 static int vfio_iommu_type1_pin_pages(void *iommu_data,
831                                       struct iommu_group *iommu_group,
832                                       dma_addr_t user_iova,
833                                       int npage, int prot,
834                                       struct page **pages)
835 {
836         struct vfio_iommu *iommu = iommu_data;
837         struct vfio_iommu_group *group;
838         int i, j, ret;
839         unsigned long remote_vaddr;
840         struct vfio_dma *dma;
841         bool do_accounting;
842         dma_addr_t iova;
843
844         if (!iommu || !pages)
845                 return -EINVAL;
846
847         /* Supported for v2 version only */
848         if (!iommu->v2)
849                 return -EACCES;
850
851         mutex_lock(&iommu->lock);
852
853         /*
854          * Wait for all necessary vaddr's to be valid so they can be used in
855          * the main loop without dropping the lock, to avoid racing vs unmap.
856          */
857 again:
858         if (iommu->vaddr_invalid_count) {
859                 for (i = 0; i < npage; i++) {
860                         iova = user_iova + PAGE_SIZE * i;
861                         ret = vfio_find_dma_valid(iommu, iova, PAGE_SIZE, &dma);
862                         if (ret < 0)
863                                 goto pin_done;
864                         if (ret == WAITED)
865                                 goto again;
866                 }
867         }
868
869         /* Fail if no dma_umap notifier is registered */
870         if (list_empty(&iommu->device_list)) {
871                 ret = -EINVAL;
872                 goto pin_done;
873         }
874
875         /*
876          * If iommu capable domain exist in the container then all pages are
877          * already pinned and accounted. Accounting should be done if there is no
878          * iommu capable domain in the container.
879          */
880         do_accounting = list_empty(&iommu->domain_list);
881
882         for (i = 0; i < npage; i++) {
883                 unsigned long phys_pfn;
884                 struct vfio_pfn *vpfn;
885
886                 iova = user_iova + PAGE_SIZE * i;
887                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
888                 if (!dma) {
889                         ret = -EINVAL;
890                         goto pin_unwind;
891                 }
892
893                 if ((dma->prot & prot) != prot) {
894                         ret = -EPERM;
895                         goto pin_unwind;
896                 }
897
898                 vpfn = vfio_iova_get_vfio_pfn(dma, iova);
899                 if (vpfn) {
900                         pages[i] = pfn_to_page(vpfn->pfn);
901                         continue;
902                 }
903
904                 remote_vaddr = dma->vaddr + (iova - dma->iova);
905                 ret = vfio_pin_page_external(dma, remote_vaddr, &phys_pfn,
906                                              do_accounting);
907                 if (ret)
908                         goto pin_unwind;
909
910                 ret = vfio_add_to_pfn_list(dma, iova, phys_pfn);
911                 if (ret) {
912                         if (put_pfn(phys_pfn, dma->prot) && do_accounting)
913                                 vfio_lock_acct(dma, -1, true);
914                         goto pin_unwind;
915                 }
916
917                 pages[i] = pfn_to_page(phys_pfn);
918
919                 if (iommu->dirty_page_tracking) {
920                         unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
921
922                         /*
923                          * Bitmap populated with the smallest supported page
924                          * size
925                          */
926                         bitmap_set(dma->bitmap,
927                                    (iova - dma->iova) >> pgshift, 1);
928                 }
929         }
930         ret = i;
931
932         group = vfio_iommu_find_iommu_group(iommu, iommu_group);
933         if (!group->pinned_page_dirty_scope) {
934                 group->pinned_page_dirty_scope = true;
935                 iommu->num_non_pinned_groups--;
936         }
937
938         goto pin_done;
939
940 pin_unwind:
941         pages[i] = NULL;
942         for (j = 0; j < i; j++) {
943                 dma_addr_t iova;
944
945                 iova = user_iova + PAGE_SIZE * j;
946                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
947                 vfio_unpin_page_external(dma, iova, do_accounting);
948                 pages[j] = NULL;
949         }
950 pin_done:
951         mutex_unlock(&iommu->lock);
952         return ret;
953 }
954
955 static void vfio_iommu_type1_unpin_pages(void *iommu_data,
956                                          dma_addr_t user_iova, int npage)
957 {
958         struct vfio_iommu *iommu = iommu_data;
959         bool do_accounting;
960         int i;
961
962         /* Supported for v2 version only */
963         if (WARN_ON(!iommu->v2))
964                 return;
965
966         mutex_lock(&iommu->lock);
967
968         do_accounting = list_empty(&iommu->domain_list);
969         for (i = 0; i < npage; i++) {
970                 dma_addr_t iova = user_iova + PAGE_SIZE * i;
971                 struct vfio_dma *dma;
972
973                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
974                 if (!dma)
975                         break;
976
977                 vfio_unpin_page_external(dma, iova, do_accounting);
978         }
979
980         mutex_unlock(&iommu->lock);
981
982         WARN_ON(i != npage);
983 }
984
985 static long vfio_sync_unpin(struct vfio_dma *dma, struct vfio_domain *domain,
986                             struct list_head *regions,
987                             struct iommu_iotlb_gather *iotlb_gather)
988 {
989         long unlocked = 0;
990         struct vfio_regions *entry, *next;
991
992         iommu_iotlb_sync(domain->domain, iotlb_gather);
993
994         list_for_each_entry_safe(entry, next, regions, list) {
995                 unlocked += vfio_unpin_pages_remote(dma,
996                                                     entry->iova,
997                                                     entry->phys >> PAGE_SHIFT,
998                                                     entry->len >> PAGE_SHIFT,
999                                                     false);
1000                 list_del(&entry->list);
1001                 kfree(entry);
1002         }
1003
1004         cond_resched();
1005
1006         return unlocked;
1007 }
1008
1009 /*
1010  * Generally, VFIO needs to unpin remote pages after each IOTLB flush.
1011  * Therefore, when using IOTLB flush sync interface, VFIO need to keep track
1012  * of these regions (currently using a list).
1013  *
1014  * This value specifies maximum number of regions for each IOTLB flush sync.
1015  */
1016 #define VFIO_IOMMU_TLB_SYNC_MAX         512
1017
1018 static size_t unmap_unpin_fast(struct vfio_domain *domain,
1019                                struct vfio_dma *dma, dma_addr_t *iova,
1020                                size_t len, phys_addr_t phys, long *unlocked,
1021                                struct list_head *unmapped_list,
1022                                int *unmapped_cnt,
1023                                struct iommu_iotlb_gather *iotlb_gather)
1024 {
1025         size_t unmapped = 0;
1026         struct vfio_regions *entry = kzalloc(sizeof(*entry), GFP_KERNEL);
1027
1028         if (entry) {
1029                 unmapped = iommu_unmap_fast(domain->domain, *iova, len,
1030                                             iotlb_gather);
1031
1032                 if (!unmapped) {
1033                         kfree(entry);
1034                 } else {
1035                         entry->iova = *iova;
1036                         entry->phys = phys;
1037                         entry->len  = unmapped;
1038                         list_add_tail(&entry->list, unmapped_list);
1039
1040                         *iova += unmapped;
1041                         (*unmapped_cnt)++;
1042                 }
1043         }
1044
1045         /*
1046          * Sync if the number of fast-unmap regions hits the limit
1047          * or in case of errors.
1048          */
1049         if (*unmapped_cnt >= VFIO_IOMMU_TLB_SYNC_MAX || !unmapped) {
1050                 *unlocked += vfio_sync_unpin(dma, domain, unmapped_list,
1051                                              iotlb_gather);
1052                 *unmapped_cnt = 0;
1053         }
1054
1055         return unmapped;
1056 }
1057
1058 static size_t unmap_unpin_slow(struct vfio_domain *domain,
1059                                struct vfio_dma *dma, dma_addr_t *iova,
1060                                size_t len, phys_addr_t phys,
1061                                long *unlocked)
1062 {
1063         size_t unmapped = iommu_unmap(domain->domain, *iova, len);
1064
1065         if (unmapped) {
1066                 *unlocked += vfio_unpin_pages_remote(dma, *iova,
1067                                                      phys >> PAGE_SHIFT,
1068                                                      unmapped >> PAGE_SHIFT,
1069                                                      false);
1070                 *iova += unmapped;
1071                 cond_resched();
1072         }
1073         return unmapped;
1074 }
1075
1076 static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
1077                              bool do_accounting)
1078 {
1079         dma_addr_t iova = dma->iova, end = dma->iova + dma->size;
1080         struct vfio_domain *domain, *d;
1081         LIST_HEAD(unmapped_region_list);
1082         struct iommu_iotlb_gather iotlb_gather;
1083         int unmapped_region_cnt = 0;
1084         long unlocked = 0;
1085
1086         if (!dma->size)
1087                 return 0;
1088
1089         if (list_empty(&iommu->domain_list))
1090                 return 0;
1091
1092         /*
1093          * We use the IOMMU to track the physical addresses, otherwise we'd
1094          * need a much more complicated tracking system.  Unfortunately that
1095          * means we need to use one of the iommu domains to figure out the
1096          * pfns to unpin.  The rest need to be unmapped in advance so we have
1097          * no iommu translations remaining when the pages are unpinned.
1098          */
1099         domain = d = list_first_entry(&iommu->domain_list,
1100                                       struct vfio_domain, next);
1101
1102         list_for_each_entry_continue(d, &iommu->domain_list, next) {
1103                 iommu_unmap(d->domain, dma->iova, dma->size);
1104                 cond_resched();
1105         }
1106
1107         iommu_iotlb_gather_init(&iotlb_gather);
1108         while (iova < end) {
1109                 size_t unmapped, len;
1110                 phys_addr_t phys, next;
1111
1112                 phys = iommu_iova_to_phys(domain->domain, iova);
1113                 if (WARN_ON(!phys)) {
1114                         iova += PAGE_SIZE;
1115                         continue;
1116                 }
1117
1118                 /*
1119                  * To optimize for fewer iommu_unmap() calls, each of which
1120                  * may require hardware cache flushing, try to find the
1121                  * largest contiguous physical memory chunk to unmap.
1122                  */
1123                 for (len = PAGE_SIZE;
1124                      !domain->fgsp && iova + len < end; len += PAGE_SIZE) {
1125                         next = iommu_iova_to_phys(domain->domain, iova + len);
1126                         if (next != phys + len)
1127                                 break;
1128                 }
1129
1130                 /*
1131                  * First, try to use fast unmap/unpin. In case of failure,
1132                  * switch to slow unmap/unpin path.
1133                  */
1134                 unmapped = unmap_unpin_fast(domain, dma, &iova, len, phys,
1135                                             &unlocked, &unmapped_region_list,
1136                                             &unmapped_region_cnt,
1137                                             &iotlb_gather);
1138                 if (!unmapped) {
1139                         unmapped = unmap_unpin_slow(domain, dma, &iova, len,
1140                                                     phys, &unlocked);
1141                         if (WARN_ON(!unmapped))
1142                                 break;
1143                 }
1144         }
1145
1146         dma->iommu_mapped = false;
1147
1148         if (unmapped_region_cnt) {
1149                 unlocked += vfio_sync_unpin(dma, domain, &unmapped_region_list,
1150                                             &iotlb_gather);
1151         }
1152
1153         if (do_accounting) {
1154                 vfio_lock_acct(dma, -unlocked, true);
1155                 return 0;
1156         }
1157         return unlocked;
1158 }
1159
1160 static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
1161 {
1162         WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list));
1163         vfio_unmap_unpin(iommu, dma, true);
1164         vfio_unlink_dma(iommu, dma);
1165         put_task_struct(dma->task);
1166         vfio_dma_bitmap_free(dma);
1167         if (dma->vaddr_invalid) {
1168                 iommu->vaddr_invalid_count--;
1169                 wake_up_all(&iommu->vaddr_wait);
1170         }
1171         kfree(dma);
1172         iommu->dma_avail++;
1173 }
1174
1175 static void vfio_update_pgsize_bitmap(struct vfio_iommu *iommu)
1176 {
1177         struct vfio_domain *domain;
1178
1179         iommu->pgsize_bitmap = ULONG_MAX;
1180
1181         list_for_each_entry(domain, &iommu->domain_list, next)
1182                 iommu->pgsize_bitmap &= domain->domain->pgsize_bitmap;
1183
1184         /*
1185          * In case the IOMMU supports page sizes smaller than PAGE_SIZE
1186          * we pretend PAGE_SIZE is supported and hide sub-PAGE_SIZE sizes.
1187          * That way the user will be able to map/unmap buffers whose size/
1188          * start address is aligned with PAGE_SIZE. Pinning code uses that
1189          * granularity while iommu driver can use the sub-PAGE_SIZE size
1190          * to map the buffer.
1191          */
1192         if (iommu->pgsize_bitmap & ~PAGE_MASK) {
1193                 iommu->pgsize_bitmap &= PAGE_MASK;
1194                 iommu->pgsize_bitmap |= PAGE_SIZE;
1195         }
1196 }
1197
1198 static int update_user_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
1199                               struct vfio_dma *dma, dma_addr_t base_iova,
1200                               size_t pgsize)
1201 {
1202         unsigned long pgshift = __ffs(pgsize);
1203         unsigned long nbits = dma->size >> pgshift;
1204         unsigned long bit_offset = (dma->iova - base_iova) >> pgshift;
1205         unsigned long copy_offset = bit_offset / BITS_PER_LONG;
1206         unsigned long shift = bit_offset % BITS_PER_LONG;
1207         unsigned long leftover;
1208
1209         /*
1210          * mark all pages dirty if any IOMMU capable device is not able
1211          * to report dirty pages and all pages are pinned and mapped.
1212          */
1213         if (iommu->num_non_pinned_groups && dma->iommu_mapped)
1214                 bitmap_set(dma->bitmap, 0, nbits);
1215
1216         if (shift) {
1217                 bitmap_shift_left(dma->bitmap, dma->bitmap, shift,
1218                                   nbits + shift);
1219
1220                 if (copy_from_user(&leftover,
1221                                    (void __user *)(bitmap + copy_offset),
1222                                    sizeof(leftover)))
1223                         return -EFAULT;
1224
1225                 bitmap_or(dma->bitmap, dma->bitmap, &leftover, shift);
1226         }
1227
1228         if (copy_to_user((void __user *)(bitmap + copy_offset), dma->bitmap,
1229                          DIRTY_BITMAP_BYTES(nbits + shift)))
1230                 return -EFAULT;
1231
1232         return 0;
1233 }
1234
1235 static int vfio_iova_dirty_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
1236                                   dma_addr_t iova, size_t size, size_t pgsize)
1237 {
1238         struct vfio_dma *dma;
1239         struct rb_node *n;
1240         unsigned long pgshift = __ffs(pgsize);
1241         int ret;
1242
1243         /*
1244          * GET_BITMAP request must fully cover vfio_dma mappings.  Multiple
1245          * vfio_dma mappings may be clubbed by specifying large ranges, but
1246          * there must not be any previous mappings bisected by the range.
1247          * An error will be returned if these conditions are not met.
1248          */
1249         dma = vfio_find_dma(iommu, iova, 1);
1250         if (dma && dma->iova != iova)
1251                 return -EINVAL;
1252
1253         dma = vfio_find_dma(iommu, iova + size - 1, 0);
1254         if (dma && dma->iova + dma->size != iova + size)
1255                 return -EINVAL;
1256
1257         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
1258                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1259
1260                 if (dma->iova < iova)
1261                         continue;
1262
1263                 if (dma->iova > iova + size - 1)
1264                         break;
1265
1266                 ret = update_user_bitmap(bitmap, iommu, dma, iova, pgsize);
1267                 if (ret)
1268                         return ret;
1269
1270                 /*
1271                  * Re-populate bitmap to include all pinned pages which are
1272                  * considered as dirty but exclude pages which are unpinned and
1273                  * pages which are marked dirty by vfio_dma_rw()
1274                  */
1275                 bitmap_clear(dma->bitmap, 0, dma->size >> pgshift);
1276                 vfio_dma_populate_bitmap(dma, pgsize);
1277         }
1278         return 0;
1279 }
1280
1281 static int verify_bitmap_size(uint64_t npages, uint64_t bitmap_size)
1282 {
1283         if (!npages || !bitmap_size || (bitmap_size > DIRTY_BITMAP_SIZE_MAX) ||
1284             (bitmap_size < DIRTY_BITMAP_BYTES(npages)))
1285                 return -EINVAL;
1286
1287         return 0;
1288 }
1289
1290 /*
1291  * Notify VFIO drivers using vfio_register_emulated_iommu_dev() to invalidate
1292  * and unmap iovas within the range we're about to unmap. Drivers MUST unpin
1293  * pages in response to an invalidation.
1294  */
1295 static void vfio_notify_dma_unmap(struct vfio_iommu *iommu,
1296                                   struct vfio_dma *dma)
1297 {
1298         struct vfio_device *device;
1299
1300         if (list_empty(&iommu->device_list))
1301                 return;
1302
1303         /*
1304          * The device is expected to call vfio_unpin_pages() for any IOVA it has
1305          * pinned within the range. Since vfio_unpin_pages() will eventually
1306          * call back down to this code and try to obtain the iommu->lock we must
1307          * drop it.
1308          */
1309         mutex_lock(&iommu->device_list_lock);
1310         mutex_unlock(&iommu->lock);
1311
1312         list_for_each_entry(device, &iommu->device_list, iommu_entry)
1313                 device->ops->dma_unmap(device, dma->iova, dma->size);
1314
1315         mutex_unlock(&iommu->device_list_lock);
1316         mutex_lock(&iommu->lock);
1317 }
1318
1319 static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
1320                              struct vfio_iommu_type1_dma_unmap *unmap,
1321                              struct vfio_bitmap *bitmap)
1322 {
1323         struct vfio_dma *dma, *dma_last = NULL;
1324         size_t unmapped = 0, pgsize;
1325         int ret = -EINVAL, retries = 0;
1326         unsigned long pgshift;
1327         dma_addr_t iova = unmap->iova;
1328         u64 size = unmap->size;
1329         bool unmap_all = unmap->flags & VFIO_DMA_UNMAP_FLAG_ALL;
1330         bool invalidate_vaddr = unmap->flags & VFIO_DMA_UNMAP_FLAG_VADDR;
1331         struct rb_node *n, *first_n;
1332
1333         mutex_lock(&iommu->lock);
1334
1335         pgshift = __ffs(iommu->pgsize_bitmap);
1336         pgsize = (size_t)1 << pgshift;
1337
1338         if (iova & (pgsize - 1))
1339                 goto unlock;
1340
1341         if (unmap_all) {
1342                 if (iova || size)
1343                         goto unlock;
1344                 size = U64_MAX;
1345         } else if (!size || size & (pgsize - 1) ||
1346                    iova + size - 1 < iova || size > SIZE_MAX) {
1347                 goto unlock;
1348         }
1349
1350         /* When dirty tracking is enabled, allow only min supported pgsize */
1351         if ((unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
1352             (!iommu->dirty_page_tracking || (bitmap->pgsize != pgsize))) {
1353                 goto unlock;
1354         }
1355
1356         WARN_ON((pgsize - 1) & PAGE_MASK);
1357 again:
1358         /*
1359          * vfio-iommu-type1 (v1) - User mappings were coalesced together to
1360          * avoid tracking individual mappings.  This means that the granularity
1361          * of the original mapping was lost and the user was allowed to attempt
1362          * to unmap any range.  Depending on the contiguousness of physical
1363          * memory and page sizes supported by the IOMMU, arbitrary unmaps may
1364          * or may not have worked.  We only guaranteed unmap granularity
1365          * matching the original mapping; even though it was untracked here,
1366          * the original mappings are reflected in IOMMU mappings.  This
1367          * resulted in a couple unusual behaviors.  First, if a range is not
1368          * able to be unmapped, ex. a set of 4k pages that was mapped as a
1369          * 2M hugepage into the IOMMU, the unmap ioctl returns success but with
1370          * a zero sized unmap.  Also, if an unmap request overlaps the first
1371          * address of a hugepage, the IOMMU will unmap the entire hugepage.
1372          * This also returns success and the returned unmap size reflects the
1373          * actual size unmapped.
1374          *
1375          * We attempt to maintain compatibility with this "v1" interface, but
1376          * we take control out of the hands of the IOMMU.  Therefore, an unmap
1377          * request offset from the beginning of the original mapping will
1378          * return success with zero sized unmap.  And an unmap request covering
1379          * the first iova of mapping will unmap the entire range.
1380          *
1381          * The v2 version of this interface intends to be more deterministic.
1382          * Unmap requests must fully cover previous mappings.  Multiple
1383          * mappings may still be unmaped by specifying large ranges, but there
1384          * must not be any previous mappings bisected by the range.  An error
1385          * will be returned if these conditions are not met.  The v2 interface
1386          * will only return success and a size of zero if there were no
1387          * mappings within the range.
1388          */
1389         if (iommu->v2 && !unmap_all) {
1390                 dma = vfio_find_dma(iommu, iova, 1);
1391                 if (dma && dma->iova != iova)
1392                         goto unlock;
1393
1394                 dma = vfio_find_dma(iommu, iova + size - 1, 0);
1395                 if (dma && dma->iova + dma->size != iova + size)
1396                         goto unlock;
1397         }
1398
1399         ret = 0;
1400         n = first_n = vfio_find_dma_first_node(iommu, iova, size);
1401
1402         while (n) {
1403                 dma = rb_entry(n, struct vfio_dma, node);
1404                 if (dma->iova >= iova + size)
1405                         break;
1406
1407                 if (!iommu->v2 && iova > dma->iova)
1408                         break;
1409
1410                 if (invalidate_vaddr) {
1411                         if (dma->vaddr_invalid) {
1412                                 struct rb_node *last_n = n;
1413
1414                                 for (n = first_n; n != last_n; n = rb_next(n)) {
1415                                         dma = rb_entry(n,
1416                                                        struct vfio_dma, node);
1417                                         dma->vaddr_invalid = false;
1418                                         iommu->vaddr_invalid_count--;
1419                                 }
1420                                 ret = -EINVAL;
1421                                 unmapped = 0;
1422                                 break;
1423                         }
1424                         dma->vaddr_invalid = true;
1425                         iommu->vaddr_invalid_count++;
1426                         unmapped += dma->size;
1427                         n = rb_next(n);
1428                         continue;
1429                 }
1430
1431                 if (!RB_EMPTY_ROOT(&dma->pfn_list)) {
1432                         if (dma_last == dma) {
1433                                 BUG_ON(++retries > 10);
1434                         } else {
1435                                 dma_last = dma;
1436                                 retries = 0;
1437                         }
1438
1439                         vfio_notify_dma_unmap(iommu, dma);
1440                         goto again;
1441                 }
1442
1443                 if (unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
1444                         ret = update_user_bitmap(bitmap->data, iommu, dma,
1445                                                  iova, pgsize);
1446                         if (ret)
1447                                 break;
1448                 }
1449
1450                 unmapped += dma->size;
1451                 n = rb_next(n);
1452                 vfio_remove_dma(iommu, dma);
1453         }
1454
1455 unlock:
1456         mutex_unlock(&iommu->lock);
1457
1458         /* Report how much was unmapped */
1459         unmap->size = unmapped;
1460
1461         return ret;
1462 }
1463
1464 static int vfio_iommu_map(struct vfio_iommu *iommu, dma_addr_t iova,
1465                           unsigned long pfn, long npage, int prot)
1466 {
1467         struct vfio_domain *d;
1468         int ret;
1469
1470         list_for_each_entry(d, &iommu->domain_list, next) {
1471                 ret = iommu_map(d->domain, iova, (phys_addr_t)pfn << PAGE_SHIFT,
1472                                 npage << PAGE_SHIFT, prot | IOMMU_CACHE);
1473                 if (ret)
1474                         goto unwind;
1475
1476                 cond_resched();
1477         }
1478
1479         return 0;
1480
1481 unwind:
1482         list_for_each_entry_continue_reverse(d, &iommu->domain_list, next) {
1483                 iommu_unmap(d->domain, iova, npage << PAGE_SHIFT);
1484                 cond_resched();
1485         }
1486
1487         return ret;
1488 }
1489
1490 static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
1491                             size_t map_size)
1492 {
1493         dma_addr_t iova = dma->iova;
1494         unsigned long vaddr = dma->vaddr;
1495         struct vfio_batch batch;
1496         size_t size = map_size;
1497         long npage;
1498         unsigned long pfn, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1499         int ret = 0;
1500
1501         vfio_batch_init(&batch);
1502
1503         while (size) {
1504                 /* Pin a contiguous chunk of memory */
1505                 npage = vfio_pin_pages_remote(dma, vaddr + dma->size,
1506                                               size >> PAGE_SHIFT, &pfn, limit,
1507                                               &batch);
1508                 if (npage <= 0) {
1509                         WARN_ON(!npage);
1510                         ret = (int)npage;
1511                         break;
1512                 }
1513
1514                 /* Map it! */
1515                 ret = vfio_iommu_map(iommu, iova + dma->size, pfn, npage,
1516                                      dma->prot);
1517                 if (ret) {
1518                         vfio_unpin_pages_remote(dma, iova + dma->size, pfn,
1519                                                 npage, true);
1520                         vfio_batch_unpin(&batch, dma);
1521                         break;
1522                 }
1523
1524                 size -= npage << PAGE_SHIFT;
1525                 dma->size += npage << PAGE_SHIFT;
1526         }
1527
1528         vfio_batch_fini(&batch);
1529         dma->iommu_mapped = true;
1530
1531         if (ret)
1532                 vfio_remove_dma(iommu, dma);
1533
1534         return ret;
1535 }
1536
1537 /*
1538  * Check dma map request is within a valid iova range
1539  */
1540 static bool vfio_iommu_iova_dma_valid(struct vfio_iommu *iommu,
1541                                       dma_addr_t start, dma_addr_t end)
1542 {
1543         struct list_head *iova = &iommu->iova_list;
1544         struct vfio_iova *node;
1545
1546         list_for_each_entry(node, iova, list) {
1547                 if (start >= node->start && end <= node->end)
1548                         return true;
1549         }
1550
1551         /*
1552          * Check for list_empty() as well since a container with
1553          * a single mdev device will have an empty list.
1554          */
1555         return list_empty(iova);
1556 }
1557
1558 static int vfio_dma_do_map(struct vfio_iommu *iommu,
1559                            struct vfio_iommu_type1_dma_map *map)
1560 {
1561         bool set_vaddr = map->flags & VFIO_DMA_MAP_FLAG_VADDR;
1562         dma_addr_t iova = map->iova;
1563         unsigned long vaddr = map->vaddr;
1564         size_t size = map->size;
1565         int ret = 0, prot = 0;
1566         size_t pgsize;
1567         struct vfio_dma *dma;
1568
1569         /* Verify that none of our __u64 fields overflow */
1570         if (map->size != size || map->vaddr != vaddr || map->iova != iova)
1571                 return -EINVAL;
1572
1573         /* READ/WRITE from device perspective */
1574         if (map->flags & VFIO_DMA_MAP_FLAG_WRITE)
1575                 prot |= IOMMU_WRITE;
1576         if (map->flags & VFIO_DMA_MAP_FLAG_READ)
1577                 prot |= IOMMU_READ;
1578
1579         if ((prot && set_vaddr) || (!prot && !set_vaddr))
1580                 return -EINVAL;
1581
1582         mutex_lock(&iommu->lock);
1583
1584         pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
1585
1586         WARN_ON((pgsize - 1) & PAGE_MASK);
1587
1588         if (!size || (size | iova | vaddr) & (pgsize - 1)) {
1589                 ret = -EINVAL;
1590                 goto out_unlock;
1591         }
1592
1593         /* Don't allow IOVA or virtual address wrap */
1594         if (iova + size - 1 < iova || vaddr + size - 1 < vaddr) {
1595                 ret = -EINVAL;
1596                 goto out_unlock;
1597         }
1598
1599         dma = vfio_find_dma(iommu, iova, size);
1600         if (set_vaddr) {
1601                 if (!dma) {
1602                         ret = -ENOENT;
1603                 } else if (!dma->vaddr_invalid || dma->iova != iova ||
1604                            dma->size != size) {
1605                         ret = -EINVAL;
1606                 } else {
1607                         dma->vaddr = vaddr;
1608                         dma->vaddr_invalid = false;
1609                         iommu->vaddr_invalid_count--;
1610                         wake_up_all(&iommu->vaddr_wait);
1611                 }
1612                 goto out_unlock;
1613         } else if (dma) {
1614                 ret = -EEXIST;
1615                 goto out_unlock;
1616         }
1617
1618         if (!iommu->dma_avail) {
1619                 ret = -ENOSPC;
1620                 goto out_unlock;
1621         }
1622
1623         if (!vfio_iommu_iova_dma_valid(iommu, iova, iova + size - 1)) {
1624                 ret = -EINVAL;
1625                 goto out_unlock;
1626         }
1627
1628         dma = kzalloc(sizeof(*dma), GFP_KERNEL);
1629         if (!dma) {
1630                 ret = -ENOMEM;
1631                 goto out_unlock;
1632         }
1633
1634         iommu->dma_avail--;
1635         dma->iova = iova;
1636         dma->vaddr = vaddr;
1637         dma->prot = prot;
1638
1639         /*
1640          * We need to be able to both add to a task's locked memory and test
1641          * against the locked memory limit and we need to be able to do both
1642          * outside of this call path as pinning can be asynchronous via the
1643          * external interfaces for mdev devices.  RLIMIT_MEMLOCK requires a
1644          * task_struct and VM locked pages requires an mm_struct, however
1645          * holding an indefinite mm reference is not recommended, therefore we
1646          * only hold a reference to a task.  We could hold a reference to
1647          * current, however QEMU uses this call path through vCPU threads,
1648          * which can be killed resulting in a NULL mm and failure in the unmap
1649          * path when called via a different thread.  Avoid this problem by
1650          * using the group_leader as threads within the same group require
1651          * both CLONE_THREAD and CLONE_VM and will therefore use the same
1652          * mm_struct.
1653          *
1654          * Previously we also used the task for testing CAP_IPC_LOCK at the
1655          * time of pinning and accounting, however has_capability() makes use
1656          * of real_cred, a copy-on-write field, so we can't guarantee that it
1657          * matches group_leader, or in fact that it might not change by the
1658          * time it's evaluated.  If a process were to call MAP_DMA with
1659          * CAP_IPC_LOCK but later drop it, it doesn't make sense that they
1660          * possibly see different results for an iommu_mapped vfio_dma vs
1661          * externally mapped.  Therefore track CAP_IPC_LOCK in vfio_dma at the
1662          * time of calling MAP_DMA.
1663          */
1664         get_task_struct(current->group_leader);
1665         dma->task = current->group_leader;
1666         dma->lock_cap = capable(CAP_IPC_LOCK);
1667
1668         dma->pfn_list = RB_ROOT;
1669
1670         /* Insert zero-sized and grow as we map chunks of it */
1671         vfio_link_dma(iommu, dma);
1672
1673         /* Don't pin and map if container doesn't contain IOMMU capable domain*/
1674         if (list_empty(&iommu->domain_list))
1675                 dma->size = size;
1676         else
1677                 ret = vfio_pin_map_dma(iommu, dma, size);
1678
1679         if (!ret && iommu->dirty_page_tracking) {
1680                 ret = vfio_dma_bitmap_alloc(dma, pgsize);
1681                 if (ret)
1682                         vfio_remove_dma(iommu, dma);
1683         }
1684
1685 out_unlock:
1686         mutex_unlock(&iommu->lock);
1687         return ret;
1688 }
1689
1690 static int vfio_iommu_replay(struct vfio_iommu *iommu,
1691                              struct vfio_domain *domain)
1692 {
1693         struct vfio_batch batch;
1694         struct vfio_domain *d = NULL;
1695         struct rb_node *n;
1696         unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1697         int ret;
1698
1699         ret = vfio_wait_all_valid(iommu);
1700         if (ret < 0)
1701                 return ret;
1702
1703         /* Arbitrarily pick the first domain in the list for lookups */
1704         if (!list_empty(&iommu->domain_list))
1705                 d = list_first_entry(&iommu->domain_list,
1706                                      struct vfio_domain, next);
1707
1708         vfio_batch_init(&batch);
1709
1710         n = rb_first(&iommu->dma_list);
1711
1712         for (; n; n = rb_next(n)) {
1713                 struct vfio_dma *dma;
1714                 dma_addr_t iova;
1715
1716                 dma = rb_entry(n, struct vfio_dma, node);
1717                 iova = dma->iova;
1718
1719                 while (iova < dma->iova + dma->size) {
1720                         phys_addr_t phys;
1721                         size_t size;
1722
1723                         if (dma->iommu_mapped) {
1724                                 phys_addr_t p;
1725                                 dma_addr_t i;
1726
1727                                 if (WARN_ON(!d)) { /* mapped w/o a domain?! */
1728                                         ret = -EINVAL;
1729                                         goto unwind;
1730                                 }
1731
1732                                 phys = iommu_iova_to_phys(d->domain, iova);
1733
1734                                 if (WARN_ON(!phys)) {
1735                                         iova += PAGE_SIZE;
1736                                         continue;
1737                                 }
1738
1739                                 size = PAGE_SIZE;
1740                                 p = phys + size;
1741                                 i = iova + size;
1742                                 while (i < dma->iova + dma->size &&
1743                                        p == iommu_iova_to_phys(d->domain, i)) {
1744                                         size += PAGE_SIZE;
1745                                         p += PAGE_SIZE;
1746                                         i += PAGE_SIZE;
1747                                 }
1748                         } else {
1749                                 unsigned long pfn;
1750                                 unsigned long vaddr = dma->vaddr +
1751                                                      (iova - dma->iova);
1752                                 size_t n = dma->iova + dma->size - iova;
1753                                 long npage;
1754
1755                                 npage = vfio_pin_pages_remote(dma, vaddr,
1756                                                               n >> PAGE_SHIFT,
1757                                                               &pfn, limit,
1758                                                               &batch);
1759                                 if (npage <= 0) {
1760                                         WARN_ON(!npage);
1761                                         ret = (int)npage;
1762                                         goto unwind;
1763                                 }
1764
1765                                 phys = pfn << PAGE_SHIFT;
1766                                 size = npage << PAGE_SHIFT;
1767                         }
1768
1769                         ret = iommu_map(domain->domain, iova, phys,
1770                                         size, dma->prot | IOMMU_CACHE);
1771                         if (ret) {
1772                                 if (!dma->iommu_mapped) {
1773                                         vfio_unpin_pages_remote(dma, iova,
1774                                                         phys >> PAGE_SHIFT,
1775                                                         size >> PAGE_SHIFT,
1776                                                         true);
1777                                         vfio_batch_unpin(&batch, dma);
1778                                 }
1779                                 goto unwind;
1780                         }
1781
1782                         iova += size;
1783                 }
1784         }
1785
1786         /* All dmas are now mapped, defer to second tree walk for unwind */
1787         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
1788                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1789
1790                 dma->iommu_mapped = true;
1791         }
1792
1793         vfio_batch_fini(&batch);
1794         return 0;
1795
1796 unwind:
1797         for (; n; n = rb_prev(n)) {
1798                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1799                 dma_addr_t iova;
1800
1801                 if (dma->iommu_mapped) {
1802                         iommu_unmap(domain->domain, dma->iova, dma->size);
1803                         continue;
1804                 }
1805
1806                 iova = dma->iova;
1807                 while (iova < dma->iova + dma->size) {
1808                         phys_addr_t phys, p;
1809                         size_t size;
1810                         dma_addr_t i;
1811
1812                         phys = iommu_iova_to_phys(domain->domain, iova);
1813                         if (!phys) {
1814                                 iova += PAGE_SIZE;
1815                                 continue;
1816                         }
1817
1818                         size = PAGE_SIZE;
1819                         p = phys + size;
1820                         i = iova + size;
1821                         while (i < dma->iova + dma->size &&
1822                                p == iommu_iova_to_phys(domain->domain, i)) {
1823                                 size += PAGE_SIZE;
1824                                 p += PAGE_SIZE;
1825                                 i += PAGE_SIZE;
1826                         }
1827
1828                         iommu_unmap(domain->domain, iova, size);
1829                         vfio_unpin_pages_remote(dma, iova, phys >> PAGE_SHIFT,
1830                                                 size >> PAGE_SHIFT, true);
1831                 }
1832         }
1833
1834         vfio_batch_fini(&batch);
1835         return ret;
1836 }
1837
1838 /*
1839  * We change our unmap behavior slightly depending on whether the IOMMU
1840  * supports fine-grained superpages.  IOMMUs like AMD-Vi will use a superpage
1841  * for practically any contiguous power-of-two mapping we give it.  This means
1842  * we don't need to look for contiguous chunks ourselves to make unmapping
1843  * more efficient.  On IOMMUs with coarse-grained super pages, like Intel VT-d
1844  * with discrete 2M/1G/512G/1T superpages, identifying contiguous chunks
1845  * significantly boosts non-hugetlbfs mappings and doesn't seem to hurt when
1846  * hugetlbfs is in use.
1847  */
1848 static void vfio_test_domain_fgsp(struct vfio_domain *domain)
1849 {
1850         struct page *pages;
1851         int ret, order = get_order(PAGE_SIZE * 2);
1852
1853         pages = alloc_pages(GFP_KERNEL | __GFP_ZERO, order);
1854         if (!pages)
1855                 return;
1856
1857         ret = iommu_map(domain->domain, 0, page_to_phys(pages), PAGE_SIZE * 2,
1858                         IOMMU_READ | IOMMU_WRITE | IOMMU_CACHE);
1859         if (!ret) {
1860                 size_t unmapped = iommu_unmap(domain->domain, 0, PAGE_SIZE);
1861
1862                 if (unmapped == PAGE_SIZE)
1863                         iommu_unmap(domain->domain, PAGE_SIZE, PAGE_SIZE);
1864                 else
1865                         domain->fgsp = true;
1866         }
1867
1868         __free_pages(pages, order);
1869 }
1870
1871 static struct vfio_iommu_group *find_iommu_group(struct vfio_domain *domain,
1872                                                  struct iommu_group *iommu_group)
1873 {
1874         struct vfio_iommu_group *g;
1875
1876         list_for_each_entry(g, &domain->group_list, next) {
1877                 if (g->iommu_group == iommu_group)
1878                         return g;
1879         }
1880
1881         return NULL;
1882 }
1883
1884 static struct vfio_iommu_group*
1885 vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
1886                             struct iommu_group *iommu_group)
1887 {
1888         struct vfio_iommu_group *group;
1889         struct vfio_domain *domain;
1890
1891         list_for_each_entry(domain, &iommu->domain_list, next) {
1892                 group = find_iommu_group(domain, iommu_group);
1893                 if (group)
1894                         return group;
1895         }
1896
1897         list_for_each_entry(group, &iommu->emulated_iommu_groups, next)
1898                 if (group->iommu_group == iommu_group)
1899                         return group;
1900         return NULL;
1901 }
1902
1903 static bool vfio_iommu_has_sw_msi(struct list_head *group_resv_regions,
1904                                   phys_addr_t *base)
1905 {
1906         struct iommu_resv_region *region;
1907         bool ret = false;
1908
1909         list_for_each_entry(region, group_resv_regions, list) {
1910                 /*
1911                  * The presence of any 'real' MSI regions should take
1912                  * precedence over the software-managed one if the
1913                  * IOMMU driver happens to advertise both types.
1914                  */
1915                 if (region->type == IOMMU_RESV_MSI) {
1916                         ret = false;
1917                         break;
1918                 }
1919
1920                 if (region->type == IOMMU_RESV_SW_MSI) {
1921                         *base = region->start;
1922                         ret = true;
1923                 }
1924         }
1925
1926         return ret;
1927 }
1928
1929 /*
1930  * This is a helper function to insert an address range to iova list.
1931  * The list is initially created with a single entry corresponding to
1932  * the IOMMU domain geometry to which the device group is attached.
1933  * The list aperture gets modified when a new domain is added to the
1934  * container if the new aperture doesn't conflict with the current one
1935  * or with any existing dma mappings. The list is also modified to
1936  * exclude any reserved regions associated with the device group.
1937  */
1938 static int vfio_iommu_iova_insert(struct list_head *head,
1939                                   dma_addr_t start, dma_addr_t end)
1940 {
1941         struct vfio_iova *region;
1942
1943         region = kmalloc(sizeof(*region), GFP_KERNEL);
1944         if (!region)
1945                 return -ENOMEM;
1946
1947         INIT_LIST_HEAD(&region->list);
1948         region->start = start;
1949         region->end = end;
1950
1951         list_add_tail(&region->list, head);
1952         return 0;
1953 }
1954
1955 /*
1956  * Check the new iommu aperture conflicts with existing aper or with any
1957  * existing dma mappings.
1958  */
1959 static bool vfio_iommu_aper_conflict(struct vfio_iommu *iommu,
1960                                      dma_addr_t start, dma_addr_t end)
1961 {
1962         struct vfio_iova *first, *last;
1963         struct list_head *iova = &iommu->iova_list;
1964
1965         if (list_empty(iova))
1966                 return false;
1967
1968         /* Disjoint sets, return conflict */
1969         first = list_first_entry(iova, struct vfio_iova, list);
1970         last = list_last_entry(iova, struct vfio_iova, list);
1971         if (start > last->end || end < first->start)
1972                 return true;
1973
1974         /* Check for any existing dma mappings below the new start */
1975         if (start > first->start) {
1976                 if (vfio_find_dma(iommu, first->start, start - first->start))
1977                         return true;
1978         }
1979
1980         /* Check for any existing dma mappings beyond the new end */
1981         if (end < last->end) {
1982                 if (vfio_find_dma(iommu, end + 1, last->end - end))
1983                         return true;
1984         }
1985
1986         return false;
1987 }
1988
1989 /*
1990  * Resize iommu iova aperture window. This is called only if the new
1991  * aperture has no conflict with existing aperture and dma mappings.
1992  */
1993 static int vfio_iommu_aper_resize(struct list_head *iova,
1994                                   dma_addr_t start, dma_addr_t end)
1995 {
1996         struct vfio_iova *node, *next;
1997
1998         if (list_empty(iova))
1999                 return vfio_iommu_iova_insert(iova, start, end);
2000
2001         /* Adjust iova list start */
2002         list_for_each_entry_safe(node, next, iova, list) {
2003                 if (start < node->start)
2004                         break;
2005                 if (start >= node->start && start < node->end) {
2006                         node->start = start;
2007                         break;
2008                 }
2009                 /* Delete nodes before new start */
2010                 list_del(&node->list);
2011                 kfree(node);
2012         }
2013
2014         /* Adjust iova list end */
2015         list_for_each_entry_safe(node, next, iova, list) {
2016                 if (end > node->end)
2017                         continue;
2018                 if (end > node->start && end <= node->end) {
2019                         node->end = end;
2020                         continue;
2021                 }
2022                 /* Delete nodes after new end */
2023                 list_del(&node->list);
2024                 kfree(node);
2025         }
2026
2027         return 0;
2028 }
2029
2030 /*
2031  * Check reserved region conflicts with existing dma mappings
2032  */
2033 static bool vfio_iommu_resv_conflict(struct vfio_iommu *iommu,
2034                                      struct list_head *resv_regions)
2035 {
2036         struct iommu_resv_region *region;
2037
2038         /* Check for conflict with existing dma mappings */
2039         list_for_each_entry(region, resv_regions, list) {
2040                 if (region->type == IOMMU_RESV_DIRECT_RELAXABLE)
2041                         continue;
2042
2043                 if (vfio_find_dma(iommu, region->start, region->length))
2044                         return true;
2045         }
2046
2047         return false;
2048 }
2049
2050 /*
2051  * Check iova region overlap with  reserved regions and
2052  * exclude them from the iommu iova range
2053  */
2054 static int vfio_iommu_resv_exclude(struct list_head *iova,
2055                                    struct list_head *resv_regions)
2056 {
2057         struct iommu_resv_region *resv;
2058         struct vfio_iova *n, *next;
2059
2060         list_for_each_entry(resv, resv_regions, list) {
2061                 phys_addr_t start, end;
2062
2063                 if (resv->type == IOMMU_RESV_DIRECT_RELAXABLE)
2064                         continue;
2065
2066                 start = resv->start;
2067                 end = resv->start + resv->length - 1;
2068
2069                 list_for_each_entry_safe(n, next, iova, list) {
2070                         int ret = 0;
2071
2072                         /* No overlap */
2073                         if (start > n->end || end < n->start)
2074                                 continue;
2075                         /*
2076                          * Insert a new node if current node overlaps with the
2077                          * reserve region to exclude that from valid iova range.
2078                          * Note that, new node is inserted before the current
2079                          * node and finally the current node is deleted keeping
2080                          * the list updated and sorted.
2081                          */
2082                         if (start > n->start)
2083                                 ret = vfio_iommu_iova_insert(&n->list, n->start,
2084                                                              start - 1);
2085                         if (!ret && end < n->end)
2086                                 ret = vfio_iommu_iova_insert(&n->list, end + 1,
2087                                                              n->end);
2088                         if (ret)
2089                                 return ret;
2090
2091                         list_del(&n->list);
2092                         kfree(n);
2093                 }
2094         }
2095
2096         if (list_empty(iova))
2097                 return -EINVAL;
2098
2099         return 0;
2100 }
2101
2102 static void vfio_iommu_resv_free(struct list_head *resv_regions)
2103 {
2104         struct iommu_resv_region *n, *next;
2105
2106         list_for_each_entry_safe(n, next, resv_regions, list) {
2107                 list_del(&n->list);
2108                 kfree(n);
2109         }
2110 }
2111
2112 static void vfio_iommu_iova_free(struct list_head *iova)
2113 {
2114         struct vfio_iova *n, *next;
2115
2116         list_for_each_entry_safe(n, next, iova, list) {
2117                 list_del(&n->list);
2118                 kfree(n);
2119         }
2120 }
2121
2122 static int vfio_iommu_iova_get_copy(struct vfio_iommu *iommu,
2123                                     struct list_head *iova_copy)
2124 {
2125         struct list_head *iova = &iommu->iova_list;
2126         struct vfio_iova *n;
2127         int ret;
2128
2129         list_for_each_entry(n, iova, list) {
2130                 ret = vfio_iommu_iova_insert(iova_copy, n->start, n->end);
2131                 if (ret)
2132                         goto out_free;
2133         }
2134
2135         return 0;
2136
2137 out_free:
2138         vfio_iommu_iova_free(iova_copy);
2139         return ret;
2140 }
2141
2142 static void vfio_iommu_iova_insert_copy(struct vfio_iommu *iommu,
2143                                         struct list_head *iova_copy)
2144 {
2145         struct list_head *iova = &iommu->iova_list;
2146
2147         vfio_iommu_iova_free(iova);
2148
2149         list_splice_tail(iova_copy, iova);
2150 }
2151
2152 /* Redundantly walks non-present capabilities to simplify caller */
2153 static int vfio_iommu_device_capable(struct device *dev, void *data)
2154 {
2155         return device_iommu_capable(dev, (enum iommu_cap)data);
2156 }
2157
2158 static int vfio_iommu_domain_alloc(struct device *dev, void *data)
2159 {
2160         struct iommu_domain **domain = data;
2161
2162         *domain = iommu_domain_alloc(dev->bus);
2163         return 1; /* Don't iterate */
2164 }
2165
2166 static int vfio_iommu_type1_attach_group(void *iommu_data,
2167                 struct iommu_group *iommu_group, enum vfio_group_type type)
2168 {
2169         struct vfio_iommu *iommu = iommu_data;
2170         struct vfio_iommu_group *group;
2171         struct vfio_domain *domain, *d;
2172         bool resv_msi, msi_remap;
2173         phys_addr_t resv_msi_base = 0;
2174         struct iommu_domain_geometry *geo;
2175         LIST_HEAD(iova_copy);
2176         LIST_HEAD(group_resv_regions);
2177         int ret = -EINVAL;
2178
2179         mutex_lock(&iommu->lock);
2180
2181         /* Check for duplicates */
2182         if (vfio_iommu_find_iommu_group(iommu, iommu_group))
2183                 goto out_unlock;
2184
2185         ret = -ENOMEM;
2186         group = kzalloc(sizeof(*group), GFP_KERNEL);
2187         if (!group)
2188                 goto out_unlock;
2189         group->iommu_group = iommu_group;
2190
2191         if (type == VFIO_EMULATED_IOMMU) {
2192                 list_add(&group->next, &iommu->emulated_iommu_groups);
2193                 /*
2194                  * An emulated IOMMU group cannot dirty memory directly, it can
2195                  * only use interfaces that provide dirty tracking.
2196                  * The iommu scope can only be promoted with the addition of a
2197                  * dirty tracking group.
2198                  */
2199                 group->pinned_page_dirty_scope = true;
2200                 ret = 0;
2201                 goto out_unlock;
2202         }
2203
2204         ret = -ENOMEM;
2205         domain = kzalloc(sizeof(*domain), GFP_KERNEL);
2206         if (!domain)
2207                 goto out_free_group;
2208
2209         /*
2210          * Going via the iommu_group iterator avoids races, and trivially gives
2211          * us a representative device for the IOMMU API call. We don't actually
2212          * want to iterate beyond the first device (if any).
2213          */
2214         ret = -EIO;
2215         iommu_group_for_each_dev(iommu_group, &domain->domain,
2216                                  vfio_iommu_domain_alloc);
2217         if (!domain->domain)
2218                 goto out_free_domain;
2219
2220         if (iommu->nesting) {
2221                 ret = iommu_enable_nesting(domain->domain);
2222                 if (ret)
2223                         goto out_domain;
2224         }
2225
2226         ret = iommu_attach_group(domain->domain, group->iommu_group);
2227         if (ret)
2228                 goto out_domain;
2229
2230         /* Get aperture info */
2231         geo = &domain->domain->geometry;
2232         if (vfio_iommu_aper_conflict(iommu, geo->aperture_start,
2233                                      geo->aperture_end)) {
2234                 ret = -EINVAL;
2235                 goto out_detach;
2236         }
2237
2238         ret = iommu_get_group_resv_regions(iommu_group, &group_resv_regions);
2239         if (ret)
2240                 goto out_detach;
2241
2242         if (vfio_iommu_resv_conflict(iommu, &group_resv_regions)) {
2243                 ret = -EINVAL;
2244                 goto out_detach;
2245         }
2246
2247         /*
2248          * We don't want to work on the original iova list as the list
2249          * gets modified and in case of failure we have to retain the
2250          * original list. Get a copy here.
2251          */
2252         ret = vfio_iommu_iova_get_copy(iommu, &iova_copy);
2253         if (ret)
2254                 goto out_detach;
2255
2256         ret = vfio_iommu_aper_resize(&iova_copy, geo->aperture_start,
2257                                      geo->aperture_end);
2258         if (ret)
2259                 goto out_detach;
2260
2261         ret = vfio_iommu_resv_exclude(&iova_copy, &group_resv_regions);
2262         if (ret)
2263                 goto out_detach;
2264
2265         resv_msi = vfio_iommu_has_sw_msi(&group_resv_regions, &resv_msi_base);
2266
2267         INIT_LIST_HEAD(&domain->group_list);
2268         list_add(&group->next, &domain->group_list);
2269
2270         msi_remap = irq_domain_check_msi_remap() ||
2271                     iommu_group_for_each_dev(iommu_group, (void *)IOMMU_CAP_INTR_REMAP,
2272                                              vfio_iommu_device_capable);
2273
2274         if (!allow_unsafe_interrupts && !msi_remap) {
2275                 pr_warn("%s: No interrupt remapping support.  Use the module param \"allow_unsafe_interrupts\" to enable VFIO IOMMU support on this platform\n",
2276                        __func__);
2277                 ret = -EPERM;
2278                 goto out_detach;
2279         }
2280
2281         /*
2282          * If the IOMMU can block non-coherent operations (ie PCIe TLPs with
2283          * no-snoop set) then VFIO always turns this feature on because on Intel
2284          * platforms it optimizes KVM to disable wbinvd emulation.
2285          */
2286         if (domain->domain->ops->enforce_cache_coherency)
2287                 domain->enforce_cache_coherency =
2288                         domain->domain->ops->enforce_cache_coherency(
2289                                 domain->domain);
2290
2291         /*
2292          * Try to match an existing compatible domain.  We don't want to
2293          * preclude an IOMMU driver supporting multiple bus_types and being
2294          * able to include different bus_types in the same IOMMU domain, so
2295          * we test whether the domains use the same iommu_ops rather than
2296          * testing if they're on the same bus_type.
2297          */
2298         list_for_each_entry(d, &iommu->domain_list, next) {
2299                 if (d->domain->ops == domain->domain->ops &&
2300                     d->enforce_cache_coherency ==
2301                             domain->enforce_cache_coherency) {
2302                         iommu_detach_group(domain->domain, group->iommu_group);
2303                         if (!iommu_attach_group(d->domain,
2304                                                 group->iommu_group)) {
2305                                 list_add(&group->next, &d->group_list);
2306                                 iommu_domain_free(domain->domain);
2307                                 kfree(domain);
2308                                 goto done;
2309                         }
2310
2311                         ret = iommu_attach_group(domain->domain,
2312                                                  group->iommu_group);
2313                         if (ret)
2314                                 goto out_domain;
2315                 }
2316         }
2317
2318         vfio_test_domain_fgsp(domain);
2319
2320         /* replay mappings on new domains */
2321         ret = vfio_iommu_replay(iommu, domain);
2322         if (ret)
2323                 goto out_detach;
2324
2325         if (resv_msi) {
2326                 ret = iommu_get_msi_cookie(domain->domain, resv_msi_base);
2327                 if (ret && ret != -ENODEV)
2328                         goto out_detach;
2329         }
2330
2331         list_add(&domain->next, &iommu->domain_list);
2332         vfio_update_pgsize_bitmap(iommu);
2333 done:
2334         /* Delete the old one and insert new iova list */
2335         vfio_iommu_iova_insert_copy(iommu, &iova_copy);
2336
2337         /*
2338          * An iommu backed group can dirty memory directly and therefore
2339          * demotes the iommu scope until it declares itself dirty tracking
2340          * capable via the page pinning interface.
2341          */
2342         iommu->num_non_pinned_groups++;
2343         mutex_unlock(&iommu->lock);
2344         vfio_iommu_resv_free(&group_resv_regions);
2345
2346         return 0;
2347
2348 out_detach:
2349         iommu_detach_group(domain->domain, group->iommu_group);
2350 out_domain:
2351         iommu_domain_free(domain->domain);
2352         vfio_iommu_iova_free(&iova_copy);
2353         vfio_iommu_resv_free(&group_resv_regions);
2354 out_free_domain:
2355         kfree(domain);
2356 out_free_group:
2357         kfree(group);
2358 out_unlock:
2359         mutex_unlock(&iommu->lock);
2360         return ret;
2361 }
2362
2363 static void vfio_iommu_unmap_unpin_all(struct vfio_iommu *iommu)
2364 {
2365         struct rb_node *node;
2366
2367         while ((node = rb_first(&iommu->dma_list)))
2368                 vfio_remove_dma(iommu, rb_entry(node, struct vfio_dma, node));
2369 }
2370
2371 static void vfio_iommu_unmap_unpin_reaccount(struct vfio_iommu *iommu)
2372 {
2373         struct rb_node *n, *p;
2374
2375         n = rb_first(&iommu->dma_list);
2376         for (; n; n = rb_next(n)) {
2377                 struct vfio_dma *dma;
2378                 long locked = 0, unlocked = 0;
2379
2380                 dma = rb_entry(n, struct vfio_dma, node);
2381                 unlocked += vfio_unmap_unpin(iommu, dma, false);
2382                 p = rb_first(&dma->pfn_list);
2383                 for (; p; p = rb_next(p)) {
2384                         struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn,
2385                                                          node);
2386
2387                         if (!is_invalid_reserved_pfn(vpfn->pfn))
2388                                 locked++;
2389                 }
2390                 vfio_lock_acct(dma, locked - unlocked, true);
2391         }
2392 }
2393
2394 /*
2395  * Called when a domain is removed in detach. It is possible that
2396  * the removed domain decided the iova aperture window. Modify the
2397  * iova aperture with the smallest window among existing domains.
2398  */
2399 static void vfio_iommu_aper_expand(struct vfio_iommu *iommu,
2400                                    struct list_head *iova_copy)
2401 {
2402         struct vfio_domain *domain;
2403         struct vfio_iova *node;
2404         dma_addr_t start = 0;
2405         dma_addr_t end = (dma_addr_t)~0;
2406
2407         if (list_empty(iova_copy))
2408                 return;
2409
2410         list_for_each_entry(domain, &iommu->domain_list, next) {
2411                 struct iommu_domain_geometry *geo = &domain->domain->geometry;
2412
2413                 if (geo->aperture_start > start)
2414                         start = geo->aperture_start;
2415                 if (geo->aperture_end < end)
2416                         end = geo->aperture_end;
2417         }
2418
2419         /* Modify aperture limits. The new aper is either same or bigger */
2420         node = list_first_entry(iova_copy, struct vfio_iova, list);
2421         node->start = start;
2422         node = list_last_entry(iova_copy, struct vfio_iova, list);
2423         node->end = end;
2424 }
2425
2426 /*
2427  * Called when a group is detached. The reserved regions for that
2428  * group can be part of valid iova now. But since reserved regions
2429  * may be duplicated among groups, populate the iova valid regions
2430  * list again.
2431  */
2432 static int vfio_iommu_resv_refresh(struct vfio_iommu *iommu,
2433                                    struct list_head *iova_copy)
2434 {
2435         struct vfio_domain *d;
2436         struct vfio_iommu_group *g;
2437         struct vfio_iova *node;
2438         dma_addr_t start, end;
2439         LIST_HEAD(resv_regions);
2440         int ret;
2441
2442         if (list_empty(iova_copy))
2443                 return -EINVAL;
2444
2445         list_for_each_entry(d, &iommu->domain_list, next) {
2446                 list_for_each_entry(g, &d->group_list, next) {
2447                         ret = iommu_get_group_resv_regions(g->iommu_group,
2448                                                            &resv_regions);
2449                         if (ret)
2450                                 goto done;
2451                 }
2452         }
2453
2454         node = list_first_entry(iova_copy, struct vfio_iova, list);
2455         start = node->start;
2456         node = list_last_entry(iova_copy, struct vfio_iova, list);
2457         end = node->end;
2458
2459         /* purge the iova list and create new one */
2460         vfio_iommu_iova_free(iova_copy);
2461
2462         ret = vfio_iommu_aper_resize(iova_copy, start, end);
2463         if (ret)
2464                 goto done;
2465
2466         /* Exclude current reserved regions from iova ranges */
2467         ret = vfio_iommu_resv_exclude(iova_copy, &resv_regions);
2468 done:
2469         vfio_iommu_resv_free(&resv_regions);
2470         return ret;
2471 }
2472
2473 static void vfio_iommu_type1_detach_group(void *iommu_data,
2474                                           struct iommu_group *iommu_group)
2475 {
2476         struct vfio_iommu *iommu = iommu_data;
2477         struct vfio_domain *domain;
2478         struct vfio_iommu_group *group;
2479         bool update_dirty_scope = false;
2480         LIST_HEAD(iova_copy);
2481
2482         mutex_lock(&iommu->lock);
2483         list_for_each_entry(group, &iommu->emulated_iommu_groups, next) {
2484                 if (group->iommu_group != iommu_group)
2485                         continue;
2486                 update_dirty_scope = !group->pinned_page_dirty_scope;
2487                 list_del(&group->next);
2488                 kfree(group);
2489
2490                 if (list_empty(&iommu->emulated_iommu_groups) &&
2491                     list_empty(&iommu->domain_list)) {
2492                         WARN_ON(!list_empty(&iommu->device_list));
2493                         vfio_iommu_unmap_unpin_all(iommu);
2494                 }
2495                 goto detach_group_done;
2496         }
2497
2498         /*
2499          * Get a copy of iova list. This will be used to update
2500          * and to replace the current one later. Please note that
2501          * we will leave the original list as it is if update fails.
2502          */
2503         vfio_iommu_iova_get_copy(iommu, &iova_copy);
2504
2505         list_for_each_entry(domain, &iommu->domain_list, next) {
2506                 group = find_iommu_group(domain, iommu_group);
2507                 if (!group)
2508                         continue;
2509
2510                 iommu_detach_group(domain->domain, group->iommu_group);
2511                 update_dirty_scope = !group->pinned_page_dirty_scope;
2512                 list_del(&group->next);
2513                 kfree(group);
2514                 /*
2515                  * Group ownership provides privilege, if the group list is
2516                  * empty, the domain goes away. If it's the last domain with
2517                  * iommu and external domain doesn't exist, then all the
2518                  * mappings go away too. If it's the last domain with iommu and
2519                  * external domain exist, update accounting
2520                  */
2521                 if (list_empty(&domain->group_list)) {
2522                         if (list_is_singular(&iommu->domain_list)) {
2523                                 if (list_empty(&iommu->emulated_iommu_groups)) {
2524                                         WARN_ON(!list_empty(
2525                                                 &iommu->device_list));
2526                                         vfio_iommu_unmap_unpin_all(iommu);
2527                                 } else {
2528                                         vfio_iommu_unmap_unpin_reaccount(iommu);
2529                                 }
2530                         }
2531                         iommu_domain_free(domain->domain);
2532                         list_del(&domain->next);
2533                         kfree(domain);
2534                         vfio_iommu_aper_expand(iommu, &iova_copy);
2535                         vfio_update_pgsize_bitmap(iommu);
2536                 }
2537                 break;
2538         }
2539
2540         if (!vfio_iommu_resv_refresh(iommu, &iova_copy))
2541                 vfio_iommu_iova_insert_copy(iommu, &iova_copy);
2542         else
2543                 vfio_iommu_iova_free(&iova_copy);
2544
2545 detach_group_done:
2546         /*
2547          * Removal of a group without dirty tracking may allow the iommu scope
2548          * to be promoted.
2549          */
2550         if (update_dirty_scope) {
2551                 iommu->num_non_pinned_groups--;
2552                 if (iommu->dirty_page_tracking)
2553                         vfio_iommu_populate_bitmap_full(iommu);
2554         }
2555         mutex_unlock(&iommu->lock);
2556 }
2557
2558 static void *vfio_iommu_type1_open(unsigned long arg)
2559 {
2560         struct vfio_iommu *iommu;
2561
2562         iommu = kzalloc(sizeof(*iommu), GFP_KERNEL);
2563         if (!iommu)
2564                 return ERR_PTR(-ENOMEM);
2565
2566         switch (arg) {
2567         case VFIO_TYPE1_IOMMU:
2568                 break;
2569         case VFIO_TYPE1_NESTING_IOMMU:
2570                 iommu->nesting = true;
2571                 fallthrough;
2572         case VFIO_TYPE1v2_IOMMU:
2573                 iommu->v2 = true;
2574                 break;
2575         default:
2576                 kfree(iommu);
2577                 return ERR_PTR(-EINVAL);
2578         }
2579
2580         INIT_LIST_HEAD(&iommu->domain_list);
2581         INIT_LIST_HEAD(&iommu->iova_list);
2582         iommu->dma_list = RB_ROOT;
2583         iommu->dma_avail = dma_entry_limit;
2584         iommu->container_open = true;
2585         mutex_init(&iommu->lock);
2586         mutex_init(&iommu->device_list_lock);
2587         INIT_LIST_HEAD(&iommu->device_list);
2588         init_waitqueue_head(&iommu->vaddr_wait);
2589         iommu->pgsize_bitmap = PAGE_MASK;
2590         INIT_LIST_HEAD(&iommu->emulated_iommu_groups);
2591
2592         return iommu;
2593 }
2594
2595 static void vfio_release_domain(struct vfio_domain *domain)
2596 {
2597         struct vfio_iommu_group *group, *group_tmp;
2598
2599         list_for_each_entry_safe(group, group_tmp,
2600                                  &domain->group_list, next) {
2601                 iommu_detach_group(domain->domain, group->iommu_group);
2602                 list_del(&group->next);
2603                 kfree(group);
2604         }
2605
2606         iommu_domain_free(domain->domain);
2607 }
2608
2609 static void vfio_iommu_type1_release(void *iommu_data)
2610 {
2611         struct vfio_iommu *iommu = iommu_data;
2612         struct vfio_domain *domain, *domain_tmp;
2613         struct vfio_iommu_group *group, *next_group;
2614
2615         list_for_each_entry_safe(group, next_group,
2616                         &iommu->emulated_iommu_groups, next) {
2617                 list_del(&group->next);
2618                 kfree(group);
2619         }
2620
2621         vfio_iommu_unmap_unpin_all(iommu);
2622
2623         list_for_each_entry_safe(domain, domain_tmp,
2624                                  &iommu->domain_list, next) {
2625                 vfio_release_domain(domain);
2626                 list_del(&domain->next);
2627                 kfree(domain);
2628         }
2629
2630         vfio_iommu_iova_free(&iommu->iova_list);
2631
2632         kfree(iommu);
2633 }
2634
2635 static int vfio_domains_have_enforce_cache_coherency(struct vfio_iommu *iommu)
2636 {
2637         struct vfio_domain *domain;
2638         int ret = 1;
2639
2640         mutex_lock(&iommu->lock);
2641         list_for_each_entry(domain, &iommu->domain_list, next) {
2642                 if (!(domain->enforce_cache_coherency)) {
2643                         ret = 0;
2644                         break;
2645                 }
2646         }
2647         mutex_unlock(&iommu->lock);
2648
2649         return ret;
2650 }
2651
2652 static int vfio_iommu_type1_check_extension(struct vfio_iommu *iommu,
2653                                             unsigned long arg)
2654 {
2655         switch (arg) {
2656         case VFIO_TYPE1_IOMMU:
2657         case VFIO_TYPE1v2_IOMMU:
2658         case VFIO_TYPE1_NESTING_IOMMU:
2659         case VFIO_UNMAP_ALL:
2660         case VFIO_UPDATE_VADDR:
2661                 return 1;
2662         case VFIO_DMA_CC_IOMMU:
2663                 if (!iommu)
2664                         return 0;
2665                 return vfio_domains_have_enforce_cache_coherency(iommu);
2666         default:
2667                 return 0;
2668         }
2669 }
2670
2671 static int vfio_iommu_iova_add_cap(struct vfio_info_cap *caps,
2672                  struct vfio_iommu_type1_info_cap_iova_range *cap_iovas,
2673                  size_t size)
2674 {
2675         struct vfio_info_cap_header *header;
2676         struct vfio_iommu_type1_info_cap_iova_range *iova_cap;
2677
2678         header = vfio_info_cap_add(caps, size,
2679                                    VFIO_IOMMU_TYPE1_INFO_CAP_IOVA_RANGE, 1);
2680         if (IS_ERR(header))
2681                 return PTR_ERR(header);
2682
2683         iova_cap = container_of(header,
2684                                 struct vfio_iommu_type1_info_cap_iova_range,
2685                                 header);
2686         iova_cap->nr_iovas = cap_iovas->nr_iovas;
2687         memcpy(iova_cap->iova_ranges, cap_iovas->iova_ranges,
2688                cap_iovas->nr_iovas * sizeof(*cap_iovas->iova_ranges));
2689         return 0;
2690 }
2691
2692 static int vfio_iommu_iova_build_caps(struct vfio_iommu *iommu,
2693                                       struct vfio_info_cap *caps)
2694 {
2695         struct vfio_iommu_type1_info_cap_iova_range *cap_iovas;
2696         struct vfio_iova *iova;
2697         size_t size;
2698         int iovas = 0, i = 0, ret;
2699
2700         list_for_each_entry(iova, &iommu->iova_list, list)
2701                 iovas++;
2702
2703         if (!iovas) {
2704                 /*
2705                  * Return 0 as a container with a single mdev device
2706                  * will have an empty list
2707                  */
2708                 return 0;
2709         }
2710
2711         size = struct_size(cap_iovas, iova_ranges, iovas);
2712
2713         cap_iovas = kzalloc(size, GFP_KERNEL);
2714         if (!cap_iovas)
2715                 return -ENOMEM;
2716
2717         cap_iovas->nr_iovas = iovas;
2718
2719         list_for_each_entry(iova, &iommu->iova_list, list) {
2720                 cap_iovas->iova_ranges[i].start = iova->start;
2721                 cap_iovas->iova_ranges[i].end = iova->end;
2722                 i++;
2723         }
2724
2725         ret = vfio_iommu_iova_add_cap(caps, cap_iovas, size);
2726
2727         kfree(cap_iovas);
2728         return ret;
2729 }
2730
2731 static int vfio_iommu_migration_build_caps(struct vfio_iommu *iommu,
2732                                            struct vfio_info_cap *caps)
2733 {
2734         struct vfio_iommu_type1_info_cap_migration cap_mig;
2735
2736         cap_mig.header.id = VFIO_IOMMU_TYPE1_INFO_CAP_MIGRATION;
2737         cap_mig.header.version = 1;
2738
2739         cap_mig.flags = 0;
2740         /* support minimum pgsize */
2741         cap_mig.pgsize_bitmap = (size_t)1 << __ffs(iommu->pgsize_bitmap);
2742         cap_mig.max_dirty_bitmap_size = DIRTY_BITMAP_SIZE_MAX;
2743
2744         return vfio_info_add_capability(caps, &cap_mig.header, sizeof(cap_mig));
2745 }
2746
2747 static int vfio_iommu_dma_avail_build_caps(struct vfio_iommu *iommu,
2748                                            struct vfio_info_cap *caps)
2749 {
2750         struct vfio_iommu_type1_info_dma_avail cap_dma_avail;
2751
2752         cap_dma_avail.header.id = VFIO_IOMMU_TYPE1_INFO_DMA_AVAIL;
2753         cap_dma_avail.header.version = 1;
2754
2755         cap_dma_avail.avail = iommu->dma_avail;
2756
2757         return vfio_info_add_capability(caps, &cap_dma_avail.header,
2758                                         sizeof(cap_dma_avail));
2759 }
2760
2761 static int vfio_iommu_type1_get_info(struct vfio_iommu *iommu,
2762                                      unsigned long arg)
2763 {
2764         struct vfio_iommu_type1_info info;
2765         unsigned long minsz;
2766         struct vfio_info_cap caps = { .buf = NULL, .size = 0 };
2767         unsigned long capsz;
2768         int ret;
2769
2770         minsz = offsetofend(struct vfio_iommu_type1_info, iova_pgsizes);
2771
2772         /* For backward compatibility, cannot require this */
2773         capsz = offsetofend(struct vfio_iommu_type1_info, cap_offset);
2774
2775         if (copy_from_user(&info, (void __user *)arg, minsz))
2776                 return -EFAULT;
2777
2778         if (info.argsz < minsz)
2779                 return -EINVAL;
2780
2781         if (info.argsz >= capsz) {
2782                 minsz = capsz;
2783                 info.cap_offset = 0; /* output, no-recopy necessary */
2784         }
2785
2786         mutex_lock(&iommu->lock);
2787         info.flags = VFIO_IOMMU_INFO_PGSIZES;
2788
2789         info.iova_pgsizes = iommu->pgsize_bitmap;
2790
2791         ret = vfio_iommu_migration_build_caps(iommu, &caps);
2792
2793         if (!ret)
2794                 ret = vfio_iommu_dma_avail_build_caps(iommu, &caps);
2795
2796         if (!ret)
2797                 ret = vfio_iommu_iova_build_caps(iommu, &caps);
2798
2799         mutex_unlock(&iommu->lock);
2800
2801         if (ret)
2802                 return ret;
2803
2804         if (caps.size) {
2805                 info.flags |= VFIO_IOMMU_INFO_CAPS;
2806
2807                 if (info.argsz < sizeof(info) + caps.size) {
2808                         info.argsz = sizeof(info) + caps.size;
2809                 } else {
2810                         vfio_info_cap_shift(&caps, sizeof(info));
2811                         if (copy_to_user((void __user *)arg +
2812                                         sizeof(info), caps.buf,
2813                                         caps.size)) {
2814                                 kfree(caps.buf);
2815                                 return -EFAULT;
2816                         }
2817                         info.cap_offset = sizeof(info);
2818                 }
2819
2820                 kfree(caps.buf);
2821         }
2822
2823         return copy_to_user((void __user *)arg, &info, minsz) ?
2824                         -EFAULT : 0;
2825 }
2826
2827 static int vfio_iommu_type1_map_dma(struct vfio_iommu *iommu,
2828                                     unsigned long arg)
2829 {
2830         struct vfio_iommu_type1_dma_map map;
2831         unsigned long minsz;
2832         uint32_t mask = VFIO_DMA_MAP_FLAG_READ | VFIO_DMA_MAP_FLAG_WRITE |
2833                         VFIO_DMA_MAP_FLAG_VADDR;
2834
2835         minsz = offsetofend(struct vfio_iommu_type1_dma_map, size);
2836
2837         if (copy_from_user(&map, (void __user *)arg, minsz))
2838                 return -EFAULT;
2839
2840         if (map.argsz < minsz || map.flags & ~mask)
2841                 return -EINVAL;
2842
2843         return vfio_dma_do_map(iommu, &map);
2844 }
2845
2846 static int vfio_iommu_type1_unmap_dma(struct vfio_iommu *iommu,
2847                                       unsigned long arg)
2848 {
2849         struct vfio_iommu_type1_dma_unmap unmap;
2850         struct vfio_bitmap bitmap = { 0 };
2851         uint32_t mask = VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP |
2852                         VFIO_DMA_UNMAP_FLAG_VADDR |
2853                         VFIO_DMA_UNMAP_FLAG_ALL;
2854         unsigned long minsz;
2855         int ret;
2856
2857         minsz = offsetofend(struct vfio_iommu_type1_dma_unmap, size);
2858
2859         if (copy_from_user(&unmap, (void __user *)arg, minsz))
2860                 return -EFAULT;
2861
2862         if (unmap.argsz < minsz || unmap.flags & ~mask)
2863                 return -EINVAL;
2864
2865         if ((unmap.flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
2866             (unmap.flags & (VFIO_DMA_UNMAP_FLAG_ALL |
2867                             VFIO_DMA_UNMAP_FLAG_VADDR)))
2868                 return -EINVAL;
2869
2870         if (unmap.flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
2871                 unsigned long pgshift;
2872
2873                 if (unmap.argsz < (minsz + sizeof(bitmap)))
2874                         return -EINVAL;
2875
2876                 if (copy_from_user(&bitmap,
2877                                    (void __user *)(arg + minsz),
2878                                    sizeof(bitmap)))
2879                         return -EFAULT;
2880
2881                 if (!access_ok((void __user *)bitmap.data, bitmap.size))
2882                         return -EINVAL;
2883
2884                 pgshift = __ffs(bitmap.pgsize);
2885                 ret = verify_bitmap_size(unmap.size >> pgshift,
2886                                          bitmap.size);
2887                 if (ret)
2888                         return ret;
2889         }
2890
2891         ret = vfio_dma_do_unmap(iommu, &unmap, &bitmap);
2892         if (ret)
2893                 return ret;
2894
2895         return copy_to_user((void __user *)arg, &unmap, minsz) ?
2896                         -EFAULT : 0;
2897 }
2898
2899 static int vfio_iommu_type1_dirty_pages(struct vfio_iommu *iommu,
2900                                         unsigned long arg)
2901 {
2902         struct vfio_iommu_type1_dirty_bitmap dirty;
2903         uint32_t mask = VFIO_IOMMU_DIRTY_PAGES_FLAG_START |
2904                         VFIO_IOMMU_DIRTY_PAGES_FLAG_STOP |
2905                         VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP;
2906         unsigned long minsz;
2907         int ret = 0;
2908
2909         if (!iommu->v2)
2910                 return -EACCES;
2911
2912         minsz = offsetofend(struct vfio_iommu_type1_dirty_bitmap, flags);
2913
2914         if (copy_from_user(&dirty, (void __user *)arg, minsz))
2915                 return -EFAULT;
2916
2917         if (dirty.argsz < minsz || dirty.flags & ~mask)
2918                 return -EINVAL;
2919
2920         /* only one flag should be set at a time */
2921         if (__ffs(dirty.flags) != __fls(dirty.flags))
2922                 return -EINVAL;
2923
2924         if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_START) {
2925                 size_t pgsize;
2926
2927                 mutex_lock(&iommu->lock);
2928                 pgsize = 1 << __ffs(iommu->pgsize_bitmap);
2929                 if (!iommu->dirty_page_tracking) {
2930                         ret = vfio_dma_bitmap_alloc_all(iommu, pgsize);
2931                         if (!ret)
2932                                 iommu->dirty_page_tracking = true;
2933                 }
2934                 mutex_unlock(&iommu->lock);
2935                 return ret;
2936         } else if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_STOP) {
2937                 mutex_lock(&iommu->lock);
2938                 if (iommu->dirty_page_tracking) {
2939                         iommu->dirty_page_tracking = false;
2940                         vfio_dma_bitmap_free_all(iommu);
2941                 }
2942                 mutex_unlock(&iommu->lock);
2943                 return 0;
2944         } else if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP) {
2945                 struct vfio_iommu_type1_dirty_bitmap_get range;
2946                 unsigned long pgshift;
2947                 size_t data_size = dirty.argsz - minsz;
2948                 size_t iommu_pgsize;
2949
2950                 if (!data_size || data_size < sizeof(range))
2951                         return -EINVAL;
2952
2953                 if (copy_from_user(&range, (void __user *)(arg + minsz),
2954                                    sizeof(range)))
2955                         return -EFAULT;
2956
2957                 if (range.iova + range.size < range.iova)
2958                         return -EINVAL;
2959                 if (!access_ok((void __user *)range.bitmap.data,
2960                                range.bitmap.size))
2961                         return -EINVAL;
2962
2963                 pgshift = __ffs(range.bitmap.pgsize);
2964                 ret = verify_bitmap_size(range.size >> pgshift,
2965                                          range.bitmap.size);
2966                 if (ret)
2967                         return ret;
2968
2969                 mutex_lock(&iommu->lock);
2970
2971                 iommu_pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
2972
2973                 /* allow only smallest supported pgsize */
2974                 if (range.bitmap.pgsize != iommu_pgsize) {
2975                         ret = -EINVAL;
2976                         goto out_unlock;
2977                 }
2978                 if (range.iova & (iommu_pgsize - 1)) {
2979                         ret = -EINVAL;
2980                         goto out_unlock;
2981                 }
2982                 if (!range.size || range.size & (iommu_pgsize - 1)) {
2983                         ret = -EINVAL;
2984                         goto out_unlock;
2985                 }
2986
2987                 if (iommu->dirty_page_tracking)
2988                         ret = vfio_iova_dirty_bitmap(range.bitmap.data,
2989                                                      iommu, range.iova,
2990                                                      range.size,
2991                                                      range.bitmap.pgsize);
2992                 else
2993                         ret = -EINVAL;
2994 out_unlock:
2995                 mutex_unlock(&iommu->lock);
2996
2997                 return ret;
2998         }
2999
3000         return -EINVAL;
3001 }
3002
3003 static long vfio_iommu_type1_ioctl(void *iommu_data,
3004                                    unsigned int cmd, unsigned long arg)
3005 {
3006         struct vfio_iommu *iommu = iommu_data;
3007
3008         switch (cmd) {
3009         case VFIO_CHECK_EXTENSION:
3010                 return vfio_iommu_type1_check_extension(iommu, arg);
3011         case VFIO_IOMMU_GET_INFO:
3012                 return vfio_iommu_type1_get_info(iommu, arg);
3013         case VFIO_IOMMU_MAP_DMA:
3014                 return vfio_iommu_type1_map_dma(iommu, arg);
3015         case VFIO_IOMMU_UNMAP_DMA:
3016                 return vfio_iommu_type1_unmap_dma(iommu, arg);
3017         case VFIO_IOMMU_DIRTY_PAGES:
3018                 return vfio_iommu_type1_dirty_pages(iommu, arg);
3019         default:
3020                 return -ENOTTY;
3021         }
3022 }
3023
3024 static void vfio_iommu_type1_register_device(void *iommu_data,
3025                                              struct vfio_device *vdev)
3026 {
3027         struct vfio_iommu *iommu = iommu_data;
3028
3029         if (!vdev->ops->dma_unmap)
3030                 return;
3031
3032         /*
3033          * list_empty(&iommu->device_list) is tested under the iommu->lock while
3034          * iteration for dma_unmap must be done under the device_list_lock.
3035          * Holding both locks here allows avoiding the device_list_lock in
3036          * several fast paths. See vfio_notify_dma_unmap()
3037          */
3038         mutex_lock(&iommu->lock);
3039         mutex_lock(&iommu->device_list_lock);
3040         list_add(&vdev->iommu_entry, &iommu->device_list);
3041         mutex_unlock(&iommu->device_list_lock);
3042         mutex_unlock(&iommu->lock);
3043 }
3044
3045 static void vfio_iommu_type1_unregister_device(void *iommu_data,
3046                                                struct vfio_device *vdev)
3047 {
3048         struct vfio_iommu *iommu = iommu_data;
3049
3050         if (!vdev->ops->dma_unmap)
3051                 return;
3052
3053         mutex_lock(&iommu->lock);
3054         mutex_lock(&iommu->device_list_lock);
3055         list_del(&vdev->iommu_entry);
3056         mutex_unlock(&iommu->device_list_lock);
3057         mutex_unlock(&iommu->lock);
3058 }
3059
3060 static int vfio_iommu_type1_dma_rw_chunk(struct vfio_iommu *iommu,
3061                                          dma_addr_t user_iova, void *data,
3062                                          size_t count, bool write,
3063                                          size_t *copied)
3064 {
3065         struct mm_struct *mm;
3066         unsigned long vaddr;
3067         struct vfio_dma *dma;
3068         bool kthread = current->mm == NULL;
3069         size_t offset;
3070         int ret;
3071
3072         *copied = 0;
3073
3074         ret = vfio_find_dma_valid(iommu, user_iova, 1, &dma);
3075         if (ret < 0)
3076                 return ret;
3077
3078         if ((write && !(dma->prot & IOMMU_WRITE)) ||
3079                         !(dma->prot & IOMMU_READ))
3080                 return -EPERM;
3081
3082         mm = get_task_mm(dma->task);
3083
3084         if (!mm)
3085                 return -EPERM;
3086
3087         if (kthread)
3088                 kthread_use_mm(mm);
3089         else if (current->mm != mm)
3090                 goto out;
3091
3092         offset = user_iova - dma->iova;
3093
3094         if (count > dma->size - offset)
3095                 count = dma->size - offset;
3096
3097         vaddr = dma->vaddr + offset;
3098
3099         if (write) {
3100                 *copied = copy_to_user((void __user *)vaddr, data,
3101                                          count) ? 0 : count;
3102                 if (*copied && iommu->dirty_page_tracking) {
3103                         unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
3104                         /*
3105                          * Bitmap populated with the smallest supported page
3106                          * size
3107                          */
3108                         bitmap_set(dma->bitmap, offset >> pgshift,
3109                                    ((offset + *copied - 1) >> pgshift) -
3110                                    (offset >> pgshift) + 1);
3111                 }
3112         } else
3113                 *copied = copy_from_user(data, (void __user *)vaddr,
3114                                            count) ? 0 : count;
3115         if (kthread)
3116                 kthread_unuse_mm(mm);
3117 out:
3118         mmput(mm);
3119         return *copied ? 0 : -EFAULT;
3120 }
3121
3122 static int vfio_iommu_type1_dma_rw(void *iommu_data, dma_addr_t user_iova,
3123                                    void *data, size_t count, bool write)
3124 {
3125         struct vfio_iommu *iommu = iommu_data;
3126         int ret = 0;
3127         size_t done;
3128
3129         mutex_lock(&iommu->lock);
3130         while (count > 0) {
3131                 ret = vfio_iommu_type1_dma_rw_chunk(iommu, user_iova, data,
3132                                                     count, write, &done);
3133                 if (ret)
3134                         break;
3135
3136                 count -= done;
3137                 data += done;
3138                 user_iova += done;
3139         }
3140
3141         mutex_unlock(&iommu->lock);
3142         return ret;
3143 }
3144
3145 static struct iommu_domain *
3146 vfio_iommu_type1_group_iommu_domain(void *iommu_data,
3147                                     struct iommu_group *iommu_group)
3148 {
3149         struct iommu_domain *domain = ERR_PTR(-ENODEV);
3150         struct vfio_iommu *iommu = iommu_data;
3151         struct vfio_domain *d;
3152
3153         if (!iommu || !iommu_group)
3154                 return ERR_PTR(-EINVAL);
3155
3156         mutex_lock(&iommu->lock);
3157         list_for_each_entry(d, &iommu->domain_list, next) {
3158                 if (find_iommu_group(d, iommu_group)) {
3159                         domain = d->domain;
3160                         break;
3161                 }
3162         }
3163         mutex_unlock(&iommu->lock);
3164
3165         return domain;
3166 }
3167
3168 static void vfio_iommu_type1_notify(void *iommu_data,
3169                                     enum vfio_iommu_notify_type event)
3170 {
3171         struct vfio_iommu *iommu = iommu_data;
3172
3173         if (event != VFIO_IOMMU_CONTAINER_CLOSE)
3174                 return;
3175         mutex_lock(&iommu->lock);
3176         iommu->container_open = false;
3177         mutex_unlock(&iommu->lock);
3178         wake_up_all(&iommu->vaddr_wait);
3179 }
3180
3181 static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
3182         .name                   = "vfio-iommu-type1",
3183         .owner                  = THIS_MODULE,
3184         .open                   = vfio_iommu_type1_open,
3185         .release                = vfio_iommu_type1_release,
3186         .ioctl                  = vfio_iommu_type1_ioctl,
3187         .attach_group           = vfio_iommu_type1_attach_group,
3188         .detach_group           = vfio_iommu_type1_detach_group,
3189         .pin_pages              = vfio_iommu_type1_pin_pages,
3190         .unpin_pages            = vfio_iommu_type1_unpin_pages,
3191         .register_device        = vfio_iommu_type1_register_device,
3192         .unregister_device      = vfio_iommu_type1_unregister_device,
3193         .dma_rw                 = vfio_iommu_type1_dma_rw,
3194         .group_iommu_domain     = vfio_iommu_type1_group_iommu_domain,
3195         .notify                 = vfio_iommu_type1_notify,
3196 };
3197
3198 static int __init vfio_iommu_type1_init(void)
3199 {
3200         return vfio_register_iommu_driver(&vfio_iommu_driver_ops_type1);
3201 }
3202
3203 static void __exit vfio_iommu_type1_cleanup(void)
3204 {
3205         vfio_unregister_iommu_driver(&vfio_iommu_driver_ops_type1);
3206 }
3207
3208 module_init(vfio_iommu_type1_init);
3209 module_exit(vfio_iommu_type1_cleanup);
3210
3211 MODULE_VERSION(DRIVER_VERSION);
3212 MODULE_LICENSE("GPL v2");
3213 MODULE_AUTHOR(DRIVER_AUTHOR);
3214 MODULE_DESCRIPTION(DRIVER_DESC);