vfio/iommu_type1: Populate full dirty when detach non-pinned group
[linux-2.6-microblaze.git] / drivers / vfio / vfio_iommu_type1.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * VFIO: IOMMU DMA mapping support for Type1 IOMMU
4  *
5  * Copyright (C) 2012 Red Hat, Inc.  All rights reserved.
6  *     Author: Alex Williamson <alex.williamson@redhat.com>
7  *
8  * Derived from original vfio:
9  * Copyright 2010 Cisco Systems, Inc.  All rights reserved.
10  * Author: Tom Lyon, pugs@cisco.com
11  *
12  * We arbitrarily define a Type1 IOMMU as one matching the below code.
13  * It could be called the x86 IOMMU as it's designed for AMD-Vi & Intel
14  * VT-d, but that makes it harder to re-use as theoretically anyone
15  * implementing a similar IOMMU could make use of this.  We expect the
16  * IOMMU to support the IOMMU API and have few to no restrictions around
17  * the IOVA range that can be mapped.  The Type1 IOMMU is currently
18  * optimized for relatively static mappings of a userspace process with
19  * userpsace pages pinned into memory.  We also assume devices and IOMMU
20  * domains are PCI based as the IOMMU API is still centered around a
21  * device/bus interface rather than a group interface.
22  */
23
24 #include <linux/compat.h>
25 #include <linux/device.h>
26 #include <linux/fs.h>
27 #include <linux/iommu.h>
28 #include <linux/module.h>
29 #include <linux/mm.h>
30 #include <linux/kthread.h>
31 #include <linux/rbtree.h>
32 #include <linux/sched/signal.h>
33 #include <linux/sched/mm.h>
34 #include <linux/slab.h>
35 #include <linux/uaccess.h>
36 #include <linux/vfio.h>
37 #include <linux/workqueue.h>
38 #include <linux/mdev.h>
39 #include <linux/notifier.h>
40 #include <linux/dma-iommu.h>
41 #include <linux/irqdomain.h>
42
43 #define DRIVER_VERSION  "0.2"
44 #define DRIVER_AUTHOR   "Alex Williamson <alex.williamson@redhat.com>"
45 #define DRIVER_DESC     "Type1 IOMMU driver for VFIO"
46
47 static bool allow_unsafe_interrupts;
48 module_param_named(allow_unsafe_interrupts,
49                    allow_unsafe_interrupts, bool, S_IRUGO | S_IWUSR);
50 MODULE_PARM_DESC(allow_unsafe_interrupts,
51                  "Enable VFIO IOMMU support for on platforms without interrupt remapping support.");
52
53 static bool disable_hugepages;
54 module_param_named(disable_hugepages,
55                    disable_hugepages, bool, S_IRUGO | S_IWUSR);
56 MODULE_PARM_DESC(disable_hugepages,
57                  "Disable VFIO IOMMU support for IOMMU hugepages.");
58
59 static unsigned int dma_entry_limit __read_mostly = U16_MAX;
60 module_param_named(dma_entry_limit, dma_entry_limit, uint, 0644);
61 MODULE_PARM_DESC(dma_entry_limit,
62                  "Maximum number of user DMA mappings per container (65535).");
63
64 struct vfio_iommu {
65         struct list_head        domain_list;
66         struct list_head        iova_list;
67         struct vfio_domain      *external_domain; /* domain for external user */
68         struct mutex            lock;
69         struct rb_root          dma_list;
70         struct blocking_notifier_head notifier;
71         unsigned int            dma_avail;
72         uint64_t                pgsize_bitmap;
73         bool                    v2;
74         bool                    nesting;
75         bool                    dirty_page_tracking;
76         bool                    pinned_page_dirty_scope;
77 };
78
79 struct vfio_domain {
80         struct iommu_domain     *domain;
81         struct list_head        next;
82         struct list_head        group_list;
83         int                     prot;           /* IOMMU_CACHE */
84         bool                    fgsp;           /* Fine-grained super pages */
85 };
86
87 struct vfio_dma {
88         struct rb_node          node;
89         dma_addr_t              iova;           /* Device address */
90         unsigned long           vaddr;          /* Process virtual addr */
91         size_t                  size;           /* Map size (bytes) */
92         int                     prot;           /* IOMMU_READ/WRITE */
93         bool                    iommu_mapped;
94         bool                    lock_cap;       /* capable(CAP_IPC_LOCK) */
95         struct task_struct      *task;
96         struct rb_root          pfn_list;       /* Ex-user pinned pfn list */
97         unsigned long           *bitmap;
98 };
99
100 struct vfio_group {
101         struct iommu_group      *iommu_group;
102         struct list_head        next;
103         bool                    mdev_group;     /* An mdev group */
104         bool                    pinned_page_dirty_scope;
105 };
106
107 struct vfio_iova {
108         struct list_head        list;
109         dma_addr_t              start;
110         dma_addr_t              end;
111 };
112
113 /*
114  * Guest RAM pinning working set or DMA target
115  */
116 struct vfio_pfn {
117         struct rb_node          node;
118         dma_addr_t              iova;           /* Device address */
119         unsigned long           pfn;            /* Host pfn */
120         unsigned int            ref_count;
121 };
122
123 struct vfio_regions {
124         struct list_head list;
125         dma_addr_t iova;
126         phys_addr_t phys;
127         size_t len;
128 };
129
130 #define IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu) \
131                                         (!list_empty(&iommu->domain_list))
132
133 #define DIRTY_BITMAP_BYTES(n)   (ALIGN(n, BITS_PER_TYPE(u64)) / BITS_PER_BYTE)
134
135 /*
136  * Input argument of number of bits to bitmap_set() is unsigned integer, which
137  * further casts to signed integer for unaligned multi-bit operation,
138  * __bitmap_set().
139  * Then maximum bitmap size supported is 2^31 bits divided by 2^3 bits/byte,
140  * that is 2^28 (256 MB) which maps to 2^31 * 2^12 = 2^43 (8TB) on 4K page
141  * system.
142  */
143 #define DIRTY_BITMAP_PAGES_MAX   ((u64)INT_MAX)
144 #define DIRTY_BITMAP_SIZE_MAX    DIRTY_BITMAP_BYTES(DIRTY_BITMAP_PAGES_MAX)
145
146 static int put_pfn(unsigned long pfn, int prot);
147
148 static struct vfio_group *vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
149                                                struct iommu_group *iommu_group);
150
151 static void update_pinned_page_dirty_scope(struct vfio_iommu *iommu);
152 /*
153  * This code handles mapping and unmapping of user data buffers
154  * into DMA'ble space using the IOMMU
155  */
156
157 static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
158                                       dma_addr_t start, size_t size)
159 {
160         struct rb_node *node = iommu->dma_list.rb_node;
161
162         while (node) {
163                 struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
164
165                 if (start + size <= dma->iova)
166                         node = node->rb_left;
167                 else if (start >= dma->iova + dma->size)
168                         node = node->rb_right;
169                 else
170                         return dma;
171         }
172
173         return NULL;
174 }
175
176 static void vfio_link_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
177 {
178         struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
179         struct vfio_dma *dma;
180
181         while (*link) {
182                 parent = *link;
183                 dma = rb_entry(parent, struct vfio_dma, node);
184
185                 if (new->iova + new->size <= dma->iova)
186                         link = &(*link)->rb_left;
187                 else
188                         link = &(*link)->rb_right;
189         }
190
191         rb_link_node(&new->node, parent, link);
192         rb_insert_color(&new->node, &iommu->dma_list);
193 }
194
195 static void vfio_unlink_dma(struct vfio_iommu *iommu, struct vfio_dma *old)
196 {
197         rb_erase(&old->node, &iommu->dma_list);
198 }
199
200
201 static int vfio_dma_bitmap_alloc(struct vfio_dma *dma, size_t pgsize)
202 {
203         uint64_t npages = dma->size / pgsize;
204
205         if (npages > DIRTY_BITMAP_PAGES_MAX)
206                 return -EINVAL;
207
208         /*
209          * Allocate extra 64 bits that are used to calculate shift required for
210          * bitmap_shift_left() to manipulate and club unaligned number of pages
211          * in adjacent vfio_dma ranges.
212          */
213         dma->bitmap = kvzalloc(DIRTY_BITMAP_BYTES(npages) + sizeof(u64),
214                                GFP_KERNEL);
215         if (!dma->bitmap)
216                 return -ENOMEM;
217
218         return 0;
219 }
220
221 static void vfio_dma_bitmap_free(struct vfio_dma *dma)
222 {
223         kfree(dma->bitmap);
224         dma->bitmap = NULL;
225 }
226
227 static void vfio_dma_populate_bitmap(struct vfio_dma *dma, size_t pgsize)
228 {
229         struct rb_node *p;
230         unsigned long pgshift = __ffs(pgsize);
231
232         for (p = rb_first(&dma->pfn_list); p; p = rb_next(p)) {
233                 struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn, node);
234
235                 bitmap_set(dma->bitmap, (vpfn->iova - dma->iova) >> pgshift, 1);
236         }
237 }
238
239 static void vfio_iommu_populate_bitmap_full(struct vfio_iommu *iommu)
240 {
241         struct rb_node *n;
242         unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
243
244         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
245                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
246
247                 bitmap_set(dma->bitmap, 0, dma->size >> pgshift);
248         }
249 }
250
251 static int vfio_dma_bitmap_alloc_all(struct vfio_iommu *iommu, size_t pgsize)
252 {
253         struct rb_node *n;
254
255         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
256                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
257                 int ret;
258
259                 ret = vfio_dma_bitmap_alloc(dma, pgsize);
260                 if (ret) {
261                         struct rb_node *p;
262
263                         for (p = rb_prev(n); p; p = rb_prev(p)) {
264                                 struct vfio_dma *dma = rb_entry(n,
265                                                         struct vfio_dma, node);
266
267                                 vfio_dma_bitmap_free(dma);
268                         }
269                         return ret;
270                 }
271                 vfio_dma_populate_bitmap(dma, pgsize);
272         }
273         return 0;
274 }
275
276 static void vfio_dma_bitmap_free_all(struct vfio_iommu *iommu)
277 {
278         struct rb_node *n;
279
280         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
281                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
282
283                 vfio_dma_bitmap_free(dma);
284         }
285 }
286
287 /*
288  * Helper Functions for host iova-pfn list
289  */
290 static struct vfio_pfn *vfio_find_vpfn(struct vfio_dma *dma, dma_addr_t iova)
291 {
292         struct vfio_pfn *vpfn;
293         struct rb_node *node = dma->pfn_list.rb_node;
294
295         while (node) {
296                 vpfn = rb_entry(node, struct vfio_pfn, node);
297
298                 if (iova < vpfn->iova)
299                         node = node->rb_left;
300                 else if (iova > vpfn->iova)
301                         node = node->rb_right;
302                 else
303                         return vpfn;
304         }
305         return NULL;
306 }
307
308 static void vfio_link_pfn(struct vfio_dma *dma,
309                           struct vfio_pfn *new)
310 {
311         struct rb_node **link, *parent = NULL;
312         struct vfio_pfn *vpfn;
313
314         link = &dma->pfn_list.rb_node;
315         while (*link) {
316                 parent = *link;
317                 vpfn = rb_entry(parent, struct vfio_pfn, node);
318
319                 if (new->iova < vpfn->iova)
320                         link = &(*link)->rb_left;
321                 else
322                         link = &(*link)->rb_right;
323         }
324
325         rb_link_node(&new->node, parent, link);
326         rb_insert_color(&new->node, &dma->pfn_list);
327 }
328
329 static void vfio_unlink_pfn(struct vfio_dma *dma, struct vfio_pfn *old)
330 {
331         rb_erase(&old->node, &dma->pfn_list);
332 }
333
334 static int vfio_add_to_pfn_list(struct vfio_dma *dma, dma_addr_t iova,
335                                 unsigned long pfn)
336 {
337         struct vfio_pfn *vpfn;
338
339         vpfn = kzalloc(sizeof(*vpfn), GFP_KERNEL);
340         if (!vpfn)
341                 return -ENOMEM;
342
343         vpfn->iova = iova;
344         vpfn->pfn = pfn;
345         vpfn->ref_count = 1;
346         vfio_link_pfn(dma, vpfn);
347         return 0;
348 }
349
350 static void vfio_remove_from_pfn_list(struct vfio_dma *dma,
351                                       struct vfio_pfn *vpfn)
352 {
353         vfio_unlink_pfn(dma, vpfn);
354         kfree(vpfn);
355 }
356
357 static struct vfio_pfn *vfio_iova_get_vfio_pfn(struct vfio_dma *dma,
358                                                unsigned long iova)
359 {
360         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
361
362         if (vpfn)
363                 vpfn->ref_count++;
364         return vpfn;
365 }
366
367 static int vfio_iova_put_vfio_pfn(struct vfio_dma *dma, struct vfio_pfn *vpfn)
368 {
369         int ret = 0;
370
371         vpfn->ref_count--;
372         if (!vpfn->ref_count) {
373                 ret = put_pfn(vpfn->pfn, dma->prot);
374                 vfio_remove_from_pfn_list(dma, vpfn);
375         }
376         return ret;
377 }
378
379 static int vfio_lock_acct(struct vfio_dma *dma, long npage, bool async)
380 {
381         struct mm_struct *mm;
382         int ret;
383
384         if (!npage)
385                 return 0;
386
387         mm = async ? get_task_mm(dma->task) : dma->task->mm;
388         if (!mm)
389                 return -ESRCH; /* process exited */
390
391         ret = mmap_write_lock_killable(mm);
392         if (!ret) {
393                 ret = __account_locked_vm(mm, abs(npage), npage > 0, dma->task,
394                                           dma->lock_cap);
395                 mmap_write_unlock(mm);
396         }
397
398         if (async)
399                 mmput(mm);
400
401         return ret;
402 }
403
404 /*
405  * Some mappings aren't backed by a struct page, for example an mmap'd
406  * MMIO range for our own or another device.  These use a different
407  * pfn conversion and shouldn't be tracked as locked pages.
408  * For compound pages, any driver that sets the reserved bit in head
409  * page needs to set the reserved bit in all subpages to be safe.
410  */
411 static bool is_invalid_reserved_pfn(unsigned long pfn)
412 {
413         if (pfn_valid(pfn))
414                 return PageReserved(pfn_to_page(pfn));
415
416         return true;
417 }
418
419 static int put_pfn(unsigned long pfn, int prot)
420 {
421         if (!is_invalid_reserved_pfn(pfn)) {
422                 struct page *page = pfn_to_page(pfn);
423
424                 unpin_user_pages_dirty_lock(&page, 1, prot & IOMMU_WRITE);
425                 return 1;
426         }
427         return 0;
428 }
429
430 static int follow_fault_pfn(struct vm_area_struct *vma, struct mm_struct *mm,
431                             unsigned long vaddr, unsigned long *pfn,
432                             bool write_fault)
433 {
434         int ret;
435
436         ret = follow_pfn(vma, vaddr, pfn);
437         if (ret) {
438                 bool unlocked = false;
439
440                 ret = fixup_user_fault(mm, vaddr,
441                                        FAULT_FLAG_REMOTE |
442                                        (write_fault ?  FAULT_FLAG_WRITE : 0),
443                                        &unlocked);
444                 if (unlocked)
445                         return -EAGAIN;
446
447                 if (ret)
448                         return ret;
449
450                 ret = follow_pfn(vma, vaddr, pfn);
451         }
452
453         return ret;
454 }
455
456 static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
457                          int prot, unsigned long *pfn)
458 {
459         struct page *page[1];
460         struct vm_area_struct *vma;
461         unsigned int flags = 0;
462         int ret;
463
464         if (prot & IOMMU_WRITE)
465                 flags |= FOLL_WRITE;
466
467         mmap_read_lock(mm);
468         ret = pin_user_pages_remote(mm, vaddr, 1, flags | FOLL_LONGTERM,
469                                     page, NULL, NULL);
470         if (ret == 1) {
471                 *pfn = page_to_pfn(page[0]);
472                 ret = 0;
473                 goto done;
474         }
475
476         vaddr = untagged_addr(vaddr);
477
478 retry:
479         vma = find_vma_intersection(mm, vaddr, vaddr + 1);
480
481         if (vma && vma->vm_flags & VM_PFNMAP) {
482                 ret = follow_fault_pfn(vma, mm, vaddr, pfn, prot & IOMMU_WRITE);
483                 if (ret == -EAGAIN)
484                         goto retry;
485
486                 if (!ret && !is_invalid_reserved_pfn(*pfn))
487                         ret = -EFAULT;
488         }
489 done:
490         mmap_read_unlock(mm);
491         return ret;
492 }
493
494 /*
495  * Attempt to pin pages.  We really don't want to track all the pfns and
496  * the iommu can only map chunks of consecutive pfns anyway, so get the
497  * first page and all consecutive pages with the same locking.
498  */
499 static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
500                                   long npage, unsigned long *pfn_base,
501                                   unsigned long limit)
502 {
503         unsigned long pfn = 0;
504         long ret, pinned = 0, lock_acct = 0;
505         bool rsvd;
506         dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
507
508         /* This code path is only user initiated */
509         if (!current->mm)
510                 return -ENODEV;
511
512         ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, pfn_base);
513         if (ret)
514                 return ret;
515
516         pinned++;
517         rsvd = is_invalid_reserved_pfn(*pfn_base);
518
519         /*
520          * Reserved pages aren't counted against the user, externally pinned
521          * pages are already counted against the user.
522          */
523         if (!rsvd && !vfio_find_vpfn(dma, iova)) {
524                 if (!dma->lock_cap && current->mm->locked_vm + 1 > limit) {
525                         put_pfn(*pfn_base, dma->prot);
526                         pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__,
527                                         limit << PAGE_SHIFT);
528                         return -ENOMEM;
529                 }
530                 lock_acct++;
531         }
532
533         if (unlikely(disable_hugepages))
534                 goto out;
535
536         /* Lock all the consecutive pages from pfn_base */
537         for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; pinned < npage;
538              pinned++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) {
539                 ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, &pfn);
540                 if (ret)
541                         break;
542
543                 if (pfn != *pfn_base + pinned ||
544                     rsvd != is_invalid_reserved_pfn(pfn)) {
545                         put_pfn(pfn, dma->prot);
546                         break;
547                 }
548
549                 if (!rsvd && !vfio_find_vpfn(dma, iova)) {
550                         if (!dma->lock_cap &&
551                             current->mm->locked_vm + lock_acct + 1 > limit) {
552                                 put_pfn(pfn, dma->prot);
553                                 pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
554                                         __func__, limit << PAGE_SHIFT);
555                                 ret = -ENOMEM;
556                                 goto unpin_out;
557                         }
558                         lock_acct++;
559                 }
560         }
561
562 out:
563         ret = vfio_lock_acct(dma, lock_acct, false);
564
565 unpin_out:
566         if (ret) {
567                 if (!rsvd) {
568                         for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
569                                 put_pfn(pfn, dma->prot);
570                 }
571
572                 return ret;
573         }
574
575         return pinned;
576 }
577
578 static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
579                                     unsigned long pfn, long npage,
580                                     bool do_accounting)
581 {
582         long unlocked = 0, locked = 0;
583         long i;
584
585         for (i = 0; i < npage; i++, iova += PAGE_SIZE) {
586                 if (put_pfn(pfn++, dma->prot)) {
587                         unlocked++;
588                         if (vfio_find_vpfn(dma, iova))
589                                 locked++;
590                 }
591         }
592
593         if (do_accounting)
594                 vfio_lock_acct(dma, locked - unlocked, true);
595
596         return unlocked;
597 }
598
599 static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
600                                   unsigned long *pfn_base, bool do_accounting)
601 {
602         struct mm_struct *mm;
603         int ret;
604
605         mm = get_task_mm(dma->task);
606         if (!mm)
607                 return -ENODEV;
608
609         ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base);
610         if (!ret && do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
611                 ret = vfio_lock_acct(dma, 1, true);
612                 if (ret) {
613                         put_pfn(*pfn_base, dma->prot);
614                         if (ret == -ENOMEM)
615                                 pr_warn("%s: Task %s (%d) RLIMIT_MEMLOCK "
616                                         "(%ld) exceeded\n", __func__,
617                                         dma->task->comm, task_pid_nr(dma->task),
618                                         task_rlimit(dma->task, RLIMIT_MEMLOCK));
619                 }
620         }
621
622         mmput(mm);
623         return ret;
624 }
625
626 static int vfio_unpin_page_external(struct vfio_dma *dma, dma_addr_t iova,
627                                     bool do_accounting)
628 {
629         int unlocked;
630         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
631
632         if (!vpfn)
633                 return 0;
634
635         unlocked = vfio_iova_put_vfio_pfn(dma, vpfn);
636
637         if (do_accounting)
638                 vfio_lock_acct(dma, -unlocked, true);
639
640         return unlocked;
641 }
642
643 static int vfio_iommu_type1_pin_pages(void *iommu_data,
644                                       struct iommu_group *iommu_group,
645                                       unsigned long *user_pfn,
646                                       int npage, int prot,
647                                       unsigned long *phys_pfn)
648 {
649         struct vfio_iommu *iommu = iommu_data;
650         struct vfio_group *group;
651         int i, j, ret;
652         unsigned long remote_vaddr;
653         struct vfio_dma *dma;
654         bool do_accounting;
655
656         if (!iommu || !user_pfn || !phys_pfn)
657                 return -EINVAL;
658
659         /* Supported for v2 version only */
660         if (!iommu->v2)
661                 return -EACCES;
662
663         mutex_lock(&iommu->lock);
664
665         /* Fail if notifier list is empty */
666         if (!iommu->notifier.head) {
667                 ret = -EINVAL;
668                 goto pin_done;
669         }
670
671         /*
672          * If iommu capable domain exist in the container then all pages are
673          * already pinned and accounted. Accouting should be done if there is no
674          * iommu capable domain in the container.
675          */
676         do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
677
678         for (i = 0; i < npage; i++) {
679                 dma_addr_t iova;
680                 struct vfio_pfn *vpfn;
681
682                 iova = user_pfn[i] << PAGE_SHIFT;
683                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
684                 if (!dma) {
685                         ret = -EINVAL;
686                         goto pin_unwind;
687                 }
688
689                 if ((dma->prot & prot) != prot) {
690                         ret = -EPERM;
691                         goto pin_unwind;
692                 }
693
694                 vpfn = vfio_iova_get_vfio_pfn(dma, iova);
695                 if (vpfn) {
696                         phys_pfn[i] = vpfn->pfn;
697                         continue;
698                 }
699
700                 remote_vaddr = dma->vaddr + (iova - dma->iova);
701                 ret = vfio_pin_page_external(dma, remote_vaddr, &phys_pfn[i],
702                                              do_accounting);
703                 if (ret)
704                         goto pin_unwind;
705
706                 ret = vfio_add_to_pfn_list(dma, iova, phys_pfn[i]);
707                 if (ret) {
708                         if (put_pfn(phys_pfn[i], dma->prot) && do_accounting)
709                                 vfio_lock_acct(dma, -1, true);
710                         goto pin_unwind;
711                 }
712
713                 if (iommu->dirty_page_tracking) {
714                         unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
715
716                         /*
717                          * Bitmap populated with the smallest supported page
718                          * size
719                          */
720                         bitmap_set(dma->bitmap,
721                                    (iova - dma->iova) >> pgshift, 1);
722                 }
723         }
724         ret = i;
725
726         group = vfio_iommu_find_iommu_group(iommu, iommu_group);
727         if (!group->pinned_page_dirty_scope) {
728                 group->pinned_page_dirty_scope = true;
729                 update_pinned_page_dirty_scope(iommu);
730         }
731
732         goto pin_done;
733
734 pin_unwind:
735         phys_pfn[i] = 0;
736         for (j = 0; j < i; j++) {
737                 dma_addr_t iova;
738
739                 iova = user_pfn[j] << PAGE_SHIFT;
740                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
741                 vfio_unpin_page_external(dma, iova, do_accounting);
742                 phys_pfn[j] = 0;
743         }
744 pin_done:
745         mutex_unlock(&iommu->lock);
746         return ret;
747 }
748
749 static int vfio_iommu_type1_unpin_pages(void *iommu_data,
750                                         unsigned long *user_pfn,
751                                         int npage)
752 {
753         struct vfio_iommu *iommu = iommu_data;
754         bool do_accounting;
755         int i;
756
757         if (!iommu || !user_pfn)
758                 return -EINVAL;
759
760         /* Supported for v2 version only */
761         if (!iommu->v2)
762                 return -EACCES;
763
764         mutex_lock(&iommu->lock);
765
766         do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
767         for (i = 0; i < npage; i++) {
768                 struct vfio_dma *dma;
769                 dma_addr_t iova;
770
771                 iova = user_pfn[i] << PAGE_SHIFT;
772                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
773                 if (!dma)
774                         goto unpin_exit;
775                 vfio_unpin_page_external(dma, iova, do_accounting);
776         }
777
778 unpin_exit:
779         mutex_unlock(&iommu->lock);
780         return i > npage ? npage : (i > 0 ? i : -EINVAL);
781 }
782
783 static long vfio_sync_unpin(struct vfio_dma *dma, struct vfio_domain *domain,
784                             struct list_head *regions,
785                             struct iommu_iotlb_gather *iotlb_gather)
786 {
787         long unlocked = 0;
788         struct vfio_regions *entry, *next;
789
790         iommu_iotlb_sync(domain->domain, iotlb_gather);
791
792         list_for_each_entry_safe(entry, next, regions, list) {
793                 unlocked += vfio_unpin_pages_remote(dma,
794                                                     entry->iova,
795                                                     entry->phys >> PAGE_SHIFT,
796                                                     entry->len >> PAGE_SHIFT,
797                                                     false);
798                 list_del(&entry->list);
799                 kfree(entry);
800         }
801
802         cond_resched();
803
804         return unlocked;
805 }
806
807 /*
808  * Generally, VFIO needs to unpin remote pages after each IOTLB flush.
809  * Therefore, when using IOTLB flush sync interface, VFIO need to keep track
810  * of these regions (currently using a list).
811  *
812  * This value specifies maximum number of regions for each IOTLB flush sync.
813  */
814 #define VFIO_IOMMU_TLB_SYNC_MAX         512
815
816 static size_t unmap_unpin_fast(struct vfio_domain *domain,
817                                struct vfio_dma *dma, dma_addr_t *iova,
818                                size_t len, phys_addr_t phys, long *unlocked,
819                                struct list_head *unmapped_list,
820                                int *unmapped_cnt,
821                                struct iommu_iotlb_gather *iotlb_gather)
822 {
823         size_t unmapped = 0;
824         struct vfio_regions *entry = kzalloc(sizeof(*entry), GFP_KERNEL);
825
826         if (entry) {
827                 unmapped = iommu_unmap_fast(domain->domain, *iova, len,
828                                             iotlb_gather);
829
830                 if (!unmapped) {
831                         kfree(entry);
832                 } else {
833                         entry->iova = *iova;
834                         entry->phys = phys;
835                         entry->len  = unmapped;
836                         list_add_tail(&entry->list, unmapped_list);
837
838                         *iova += unmapped;
839                         (*unmapped_cnt)++;
840                 }
841         }
842
843         /*
844          * Sync if the number of fast-unmap regions hits the limit
845          * or in case of errors.
846          */
847         if (*unmapped_cnt >= VFIO_IOMMU_TLB_SYNC_MAX || !unmapped) {
848                 *unlocked += vfio_sync_unpin(dma, domain, unmapped_list,
849                                              iotlb_gather);
850                 *unmapped_cnt = 0;
851         }
852
853         return unmapped;
854 }
855
856 static size_t unmap_unpin_slow(struct vfio_domain *domain,
857                                struct vfio_dma *dma, dma_addr_t *iova,
858                                size_t len, phys_addr_t phys,
859                                long *unlocked)
860 {
861         size_t unmapped = iommu_unmap(domain->domain, *iova, len);
862
863         if (unmapped) {
864                 *unlocked += vfio_unpin_pages_remote(dma, *iova,
865                                                      phys >> PAGE_SHIFT,
866                                                      unmapped >> PAGE_SHIFT,
867                                                      false);
868                 *iova += unmapped;
869                 cond_resched();
870         }
871         return unmapped;
872 }
873
874 static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
875                              bool do_accounting)
876 {
877         dma_addr_t iova = dma->iova, end = dma->iova + dma->size;
878         struct vfio_domain *domain, *d;
879         LIST_HEAD(unmapped_region_list);
880         struct iommu_iotlb_gather iotlb_gather;
881         int unmapped_region_cnt = 0;
882         long unlocked = 0;
883
884         if (!dma->size)
885                 return 0;
886
887         if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
888                 return 0;
889
890         /*
891          * We use the IOMMU to track the physical addresses, otherwise we'd
892          * need a much more complicated tracking system.  Unfortunately that
893          * means we need to use one of the iommu domains to figure out the
894          * pfns to unpin.  The rest need to be unmapped in advance so we have
895          * no iommu translations remaining when the pages are unpinned.
896          */
897         domain = d = list_first_entry(&iommu->domain_list,
898                                       struct vfio_domain, next);
899
900         list_for_each_entry_continue(d, &iommu->domain_list, next) {
901                 iommu_unmap(d->domain, dma->iova, dma->size);
902                 cond_resched();
903         }
904
905         iommu_iotlb_gather_init(&iotlb_gather);
906         while (iova < end) {
907                 size_t unmapped, len;
908                 phys_addr_t phys, next;
909
910                 phys = iommu_iova_to_phys(domain->domain, iova);
911                 if (WARN_ON(!phys)) {
912                         iova += PAGE_SIZE;
913                         continue;
914                 }
915
916                 /*
917                  * To optimize for fewer iommu_unmap() calls, each of which
918                  * may require hardware cache flushing, try to find the
919                  * largest contiguous physical memory chunk to unmap.
920                  */
921                 for (len = PAGE_SIZE;
922                      !domain->fgsp && iova + len < end; len += PAGE_SIZE) {
923                         next = iommu_iova_to_phys(domain->domain, iova + len);
924                         if (next != phys + len)
925                                 break;
926                 }
927
928                 /*
929                  * First, try to use fast unmap/unpin. In case of failure,
930                  * switch to slow unmap/unpin path.
931                  */
932                 unmapped = unmap_unpin_fast(domain, dma, &iova, len, phys,
933                                             &unlocked, &unmapped_region_list,
934                                             &unmapped_region_cnt,
935                                             &iotlb_gather);
936                 if (!unmapped) {
937                         unmapped = unmap_unpin_slow(domain, dma, &iova, len,
938                                                     phys, &unlocked);
939                         if (WARN_ON(!unmapped))
940                                 break;
941                 }
942         }
943
944         dma->iommu_mapped = false;
945
946         if (unmapped_region_cnt) {
947                 unlocked += vfio_sync_unpin(dma, domain, &unmapped_region_list,
948                                             &iotlb_gather);
949         }
950
951         if (do_accounting) {
952                 vfio_lock_acct(dma, -unlocked, true);
953                 return 0;
954         }
955         return unlocked;
956 }
957
958 static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
959 {
960         vfio_unmap_unpin(iommu, dma, true);
961         vfio_unlink_dma(iommu, dma);
962         put_task_struct(dma->task);
963         vfio_dma_bitmap_free(dma);
964         kfree(dma);
965         iommu->dma_avail++;
966 }
967
968 static void vfio_update_pgsize_bitmap(struct vfio_iommu *iommu)
969 {
970         struct vfio_domain *domain;
971
972         iommu->pgsize_bitmap = ULONG_MAX;
973
974         list_for_each_entry(domain, &iommu->domain_list, next)
975                 iommu->pgsize_bitmap &= domain->domain->pgsize_bitmap;
976
977         /*
978          * In case the IOMMU supports page sizes smaller than PAGE_SIZE
979          * we pretend PAGE_SIZE is supported and hide sub-PAGE_SIZE sizes.
980          * That way the user will be able to map/unmap buffers whose size/
981          * start address is aligned with PAGE_SIZE. Pinning code uses that
982          * granularity while iommu driver can use the sub-PAGE_SIZE size
983          * to map the buffer.
984          */
985         if (iommu->pgsize_bitmap & ~PAGE_MASK) {
986                 iommu->pgsize_bitmap &= PAGE_MASK;
987                 iommu->pgsize_bitmap |= PAGE_SIZE;
988         }
989 }
990
991 static int update_user_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
992                               struct vfio_dma *dma, dma_addr_t base_iova,
993                               size_t pgsize)
994 {
995         unsigned long pgshift = __ffs(pgsize);
996         unsigned long nbits = dma->size >> pgshift;
997         unsigned long bit_offset = (dma->iova - base_iova) >> pgshift;
998         unsigned long copy_offset = bit_offset / BITS_PER_LONG;
999         unsigned long shift = bit_offset % BITS_PER_LONG;
1000         unsigned long leftover;
1001
1002         /*
1003          * mark all pages dirty if any IOMMU capable device is not able
1004          * to report dirty pages and all pages are pinned and mapped.
1005          */
1006         if (!iommu->pinned_page_dirty_scope && dma->iommu_mapped)
1007                 bitmap_set(dma->bitmap, 0, nbits);
1008
1009         if (shift) {
1010                 bitmap_shift_left(dma->bitmap, dma->bitmap, shift,
1011                                   nbits + shift);
1012
1013                 if (copy_from_user(&leftover,
1014                                    (void __user *)(bitmap + copy_offset),
1015                                    sizeof(leftover)))
1016                         return -EFAULT;
1017
1018                 bitmap_or(dma->bitmap, dma->bitmap, &leftover, shift);
1019         }
1020
1021         if (copy_to_user((void __user *)(bitmap + copy_offset), dma->bitmap,
1022                          DIRTY_BITMAP_BYTES(nbits + shift)))
1023                 return -EFAULT;
1024
1025         return 0;
1026 }
1027
1028 static int vfio_iova_dirty_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
1029                                   dma_addr_t iova, size_t size, size_t pgsize)
1030 {
1031         struct vfio_dma *dma;
1032         struct rb_node *n;
1033         unsigned long pgshift = __ffs(pgsize);
1034         int ret;
1035
1036         /*
1037          * GET_BITMAP request must fully cover vfio_dma mappings.  Multiple
1038          * vfio_dma mappings may be clubbed by specifying large ranges, but
1039          * there must not be any previous mappings bisected by the range.
1040          * An error will be returned if these conditions are not met.
1041          */
1042         dma = vfio_find_dma(iommu, iova, 1);
1043         if (dma && dma->iova != iova)
1044                 return -EINVAL;
1045
1046         dma = vfio_find_dma(iommu, iova + size - 1, 0);
1047         if (dma && dma->iova + dma->size != iova + size)
1048                 return -EINVAL;
1049
1050         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
1051                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1052
1053                 if (dma->iova < iova)
1054                         continue;
1055
1056                 if (dma->iova > iova + size - 1)
1057                         break;
1058
1059                 ret = update_user_bitmap(bitmap, iommu, dma, iova, pgsize);
1060                 if (ret)
1061                         return ret;
1062
1063                 /*
1064                  * Re-populate bitmap to include all pinned pages which are
1065                  * considered as dirty but exclude pages which are unpinned and
1066                  * pages which are marked dirty by vfio_dma_rw()
1067                  */
1068                 bitmap_clear(dma->bitmap, 0, dma->size >> pgshift);
1069                 vfio_dma_populate_bitmap(dma, pgsize);
1070         }
1071         return 0;
1072 }
1073
1074 static int verify_bitmap_size(uint64_t npages, uint64_t bitmap_size)
1075 {
1076         if (!npages || !bitmap_size || (bitmap_size > DIRTY_BITMAP_SIZE_MAX) ||
1077             (bitmap_size < DIRTY_BITMAP_BYTES(npages)))
1078                 return -EINVAL;
1079
1080         return 0;
1081 }
1082
1083 static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
1084                              struct vfio_iommu_type1_dma_unmap *unmap,
1085                              struct vfio_bitmap *bitmap)
1086 {
1087         struct vfio_dma *dma, *dma_last = NULL;
1088         size_t unmapped = 0, pgsize;
1089         int ret = 0, retries = 0;
1090         unsigned long pgshift;
1091
1092         mutex_lock(&iommu->lock);
1093
1094         pgshift = __ffs(iommu->pgsize_bitmap);
1095         pgsize = (size_t)1 << pgshift;
1096
1097         if (unmap->iova & (pgsize - 1)) {
1098                 ret = -EINVAL;
1099                 goto unlock;
1100         }
1101
1102         if (!unmap->size || unmap->size & (pgsize - 1)) {
1103                 ret = -EINVAL;
1104                 goto unlock;
1105         }
1106
1107         if (unmap->iova + unmap->size - 1 < unmap->iova ||
1108             unmap->size > SIZE_MAX) {
1109                 ret = -EINVAL;
1110                 goto unlock;
1111         }
1112
1113         /* When dirty tracking is enabled, allow only min supported pgsize */
1114         if ((unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
1115             (!iommu->dirty_page_tracking || (bitmap->pgsize != pgsize))) {
1116                 ret = -EINVAL;
1117                 goto unlock;
1118         }
1119
1120         WARN_ON((pgsize - 1) & PAGE_MASK);
1121 again:
1122         /*
1123          * vfio-iommu-type1 (v1) - User mappings were coalesced together to
1124          * avoid tracking individual mappings.  This means that the granularity
1125          * of the original mapping was lost and the user was allowed to attempt
1126          * to unmap any range.  Depending on the contiguousness of physical
1127          * memory and page sizes supported by the IOMMU, arbitrary unmaps may
1128          * or may not have worked.  We only guaranteed unmap granularity
1129          * matching the original mapping; even though it was untracked here,
1130          * the original mappings are reflected in IOMMU mappings.  This
1131          * resulted in a couple unusual behaviors.  First, if a range is not
1132          * able to be unmapped, ex. a set of 4k pages that was mapped as a
1133          * 2M hugepage into the IOMMU, the unmap ioctl returns success but with
1134          * a zero sized unmap.  Also, if an unmap request overlaps the first
1135          * address of a hugepage, the IOMMU will unmap the entire hugepage.
1136          * This also returns success and the returned unmap size reflects the
1137          * actual size unmapped.
1138          *
1139          * We attempt to maintain compatibility with this "v1" interface, but
1140          * we take control out of the hands of the IOMMU.  Therefore, an unmap
1141          * request offset from the beginning of the original mapping will
1142          * return success with zero sized unmap.  And an unmap request covering
1143          * the first iova of mapping will unmap the entire range.
1144          *
1145          * The v2 version of this interface intends to be more deterministic.
1146          * Unmap requests must fully cover previous mappings.  Multiple
1147          * mappings may still be unmaped by specifying large ranges, but there
1148          * must not be any previous mappings bisected by the range.  An error
1149          * will be returned if these conditions are not met.  The v2 interface
1150          * will only return success and a size of zero if there were no
1151          * mappings within the range.
1152          */
1153         if (iommu->v2) {
1154                 dma = vfio_find_dma(iommu, unmap->iova, 1);
1155                 if (dma && dma->iova != unmap->iova) {
1156                         ret = -EINVAL;
1157                         goto unlock;
1158                 }
1159                 dma = vfio_find_dma(iommu, unmap->iova + unmap->size - 1, 0);
1160                 if (dma && dma->iova + dma->size != unmap->iova + unmap->size) {
1161                         ret = -EINVAL;
1162                         goto unlock;
1163                 }
1164         }
1165
1166         while ((dma = vfio_find_dma(iommu, unmap->iova, unmap->size))) {
1167                 if (!iommu->v2 && unmap->iova > dma->iova)
1168                         break;
1169                 /*
1170                  * Task with same address space who mapped this iova range is
1171                  * allowed to unmap the iova range.
1172                  */
1173                 if (dma->task->mm != current->mm)
1174                         break;
1175
1176                 if (!RB_EMPTY_ROOT(&dma->pfn_list)) {
1177                         struct vfio_iommu_type1_dma_unmap nb_unmap;
1178
1179                         if (dma_last == dma) {
1180                                 BUG_ON(++retries > 10);
1181                         } else {
1182                                 dma_last = dma;
1183                                 retries = 0;
1184                         }
1185
1186                         nb_unmap.iova = dma->iova;
1187                         nb_unmap.size = dma->size;
1188
1189                         /*
1190                          * Notify anyone (mdev vendor drivers) to invalidate and
1191                          * unmap iovas within the range we're about to unmap.
1192                          * Vendor drivers MUST unpin pages in response to an
1193                          * invalidation.
1194                          */
1195                         mutex_unlock(&iommu->lock);
1196                         blocking_notifier_call_chain(&iommu->notifier,
1197                                                     VFIO_IOMMU_NOTIFY_DMA_UNMAP,
1198                                                     &nb_unmap);
1199                         mutex_lock(&iommu->lock);
1200                         goto again;
1201                 }
1202
1203                 if (unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
1204                         ret = update_user_bitmap(bitmap->data, iommu, dma,
1205                                                  unmap->iova, pgsize);
1206                         if (ret)
1207                                 break;
1208                 }
1209
1210                 unmapped += dma->size;
1211                 vfio_remove_dma(iommu, dma);
1212         }
1213
1214 unlock:
1215         mutex_unlock(&iommu->lock);
1216
1217         /* Report how much was unmapped */
1218         unmap->size = unmapped;
1219
1220         return ret;
1221 }
1222
1223 static int vfio_iommu_map(struct vfio_iommu *iommu, dma_addr_t iova,
1224                           unsigned long pfn, long npage, int prot)
1225 {
1226         struct vfio_domain *d;
1227         int ret;
1228
1229         list_for_each_entry(d, &iommu->domain_list, next) {
1230                 ret = iommu_map(d->domain, iova, (phys_addr_t)pfn << PAGE_SHIFT,
1231                                 npage << PAGE_SHIFT, prot | d->prot);
1232                 if (ret)
1233                         goto unwind;
1234
1235                 cond_resched();
1236         }
1237
1238         return 0;
1239
1240 unwind:
1241         list_for_each_entry_continue_reverse(d, &iommu->domain_list, next) {
1242                 iommu_unmap(d->domain, iova, npage << PAGE_SHIFT);
1243                 cond_resched();
1244         }
1245
1246         return ret;
1247 }
1248
1249 static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
1250                             size_t map_size)
1251 {
1252         dma_addr_t iova = dma->iova;
1253         unsigned long vaddr = dma->vaddr;
1254         size_t size = map_size;
1255         long npage;
1256         unsigned long pfn, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1257         int ret = 0;
1258
1259         while (size) {
1260                 /* Pin a contiguous chunk of memory */
1261                 npage = vfio_pin_pages_remote(dma, vaddr + dma->size,
1262                                               size >> PAGE_SHIFT, &pfn, limit);
1263                 if (npage <= 0) {
1264                         WARN_ON(!npage);
1265                         ret = (int)npage;
1266                         break;
1267                 }
1268
1269                 /* Map it! */
1270                 ret = vfio_iommu_map(iommu, iova + dma->size, pfn, npage,
1271                                      dma->prot);
1272                 if (ret) {
1273                         vfio_unpin_pages_remote(dma, iova + dma->size, pfn,
1274                                                 npage, true);
1275                         break;
1276                 }
1277
1278                 size -= npage << PAGE_SHIFT;
1279                 dma->size += npage << PAGE_SHIFT;
1280         }
1281
1282         dma->iommu_mapped = true;
1283
1284         if (ret)
1285                 vfio_remove_dma(iommu, dma);
1286
1287         return ret;
1288 }
1289
1290 /*
1291  * Check dma map request is within a valid iova range
1292  */
1293 static bool vfio_iommu_iova_dma_valid(struct vfio_iommu *iommu,
1294                                       dma_addr_t start, dma_addr_t end)
1295 {
1296         struct list_head *iova = &iommu->iova_list;
1297         struct vfio_iova *node;
1298
1299         list_for_each_entry(node, iova, list) {
1300                 if (start >= node->start && end <= node->end)
1301                         return true;
1302         }
1303
1304         /*
1305          * Check for list_empty() as well since a container with
1306          * a single mdev device will have an empty list.
1307          */
1308         return list_empty(iova);
1309 }
1310
1311 static int vfio_dma_do_map(struct vfio_iommu *iommu,
1312                            struct vfio_iommu_type1_dma_map *map)
1313 {
1314         dma_addr_t iova = map->iova;
1315         unsigned long vaddr = map->vaddr;
1316         size_t size = map->size;
1317         int ret = 0, prot = 0;
1318         size_t pgsize;
1319         struct vfio_dma *dma;
1320
1321         /* Verify that none of our __u64 fields overflow */
1322         if (map->size != size || map->vaddr != vaddr || map->iova != iova)
1323                 return -EINVAL;
1324
1325         /* READ/WRITE from device perspective */
1326         if (map->flags & VFIO_DMA_MAP_FLAG_WRITE)
1327                 prot |= IOMMU_WRITE;
1328         if (map->flags & VFIO_DMA_MAP_FLAG_READ)
1329                 prot |= IOMMU_READ;
1330
1331         mutex_lock(&iommu->lock);
1332
1333         pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
1334
1335         WARN_ON((pgsize - 1) & PAGE_MASK);
1336
1337         if (!prot || !size || (size | iova | vaddr) & (pgsize - 1)) {
1338                 ret = -EINVAL;
1339                 goto out_unlock;
1340         }
1341
1342         /* Don't allow IOVA or virtual address wrap */
1343         if (iova + size - 1 < iova || vaddr + size - 1 < vaddr) {
1344                 ret = -EINVAL;
1345                 goto out_unlock;
1346         }
1347
1348         if (vfio_find_dma(iommu, iova, size)) {
1349                 ret = -EEXIST;
1350                 goto out_unlock;
1351         }
1352
1353         if (!iommu->dma_avail) {
1354                 ret = -ENOSPC;
1355                 goto out_unlock;
1356         }
1357
1358         if (!vfio_iommu_iova_dma_valid(iommu, iova, iova + size - 1)) {
1359                 ret = -EINVAL;
1360                 goto out_unlock;
1361         }
1362
1363         dma = kzalloc(sizeof(*dma), GFP_KERNEL);
1364         if (!dma) {
1365                 ret = -ENOMEM;
1366                 goto out_unlock;
1367         }
1368
1369         iommu->dma_avail--;
1370         dma->iova = iova;
1371         dma->vaddr = vaddr;
1372         dma->prot = prot;
1373
1374         /*
1375          * We need to be able to both add to a task's locked memory and test
1376          * against the locked memory limit and we need to be able to do both
1377          * outside of this call path as pinning can be asynchronous via the
1378          * external interfaces for mdev devices.  RLIMIT_MEMLOCK requires a
1379          * task_struct and VM locked pages requires an mm_struct, however
1380          * holding an indefinite mm reference is not recommended, therefore we
1381          * only hold a reference to a task.  We could hold a reference to
1382          * current, however QEMU uses this call path through vCPU threads,
1383          * which can be killed resulting in a NULL mm and failure in the unmap
1384          * path when called via a different thread.  Avoid this problem by
1385          * using the group_leader as threads within the same group require
1386          * both CLONE_THREAD and CLONE_VM and will therefore use the same
1387          * mm_struct.
1388          *
1389          * Previously we also used the task for testing CAP_IPC_LOCK at the
1390          * time of pinning and accounting, however has_capability() makes use
1391          * of real_cred, a copy-on-write field, so we can't guarantee that it
1392          * matches group_leader, or in fact that it might not change by the
1393          * time it's evaluated.  If a process were to call MAP_DMA with
1394          * CAP_IPC_LOCK but later drop it, it doesn't make sense that they
1395          * possibly see different results for an iommu_mapped vfio_dma vs
1396          * externally mapped.  Therefore track CAP_IPC_LOCK in vfio_dma at the
1397          * time of calling MAP_DMA.
1398          */
1399         get_task_struct(current->group_leader);
1400         dma->task = current->group_leader;
1401         dma->lock_cap = capable(CAP_IPC_LOCK);
1402
1403         dma->pfn_list = RB_ROOT;
1404
1405         /* Insert zero-sized and grow as we map chunks of it */
1406         vfio_link_dma(iommu, dma);
1407
1408         /* Don't pin and map if container doesn't contain IOMMU capable domain*/
1409         if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1410                 dma->size = size;
1411         else
1412                 ret = vfio_pin_map_dma(iommu, dma, size);
1413
1414         if (!ret && iommu->dirty_page_tracking) {
1415                 ret = vfio_dma_bitmap_alloc(dma, pgsize);
1416                 if (ret)
1417                         vfio_remove_dma(iommu, dma);
1418         }
1419
1420 out_unlock:
1421         mutex_unlock(&iommu->lock);
1422         return ret;
1423 }
1424
1425 static int vfio_bus_type(struct device *dev, void *data)
1426 {
1427         struct bus_type **bus = data;
1428
1429         if (*bus && *bus != dev->bus)
1430                 return -EINVAL;
1431
1432         *bus = dev->bus;
1433
1434         return 0;
1435 }
1436
1437 static int vfio_iommu_replay(struct vfio_iommu *iommu,
1438                              struct vfio_domain *domain)
1439 {
1440         struct vfio_domain *d = NULL;
1441         struct rb_node *n;
1442         unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1443         int ret;
1444
1445         /* Arbitrarily pick the first domain in the list for lookups */
1446         if (!list_empty(&iommu->domain_list))
1447                 d = list_first_entry(&iommu->domain_list,
1448                                      struct vfio_domain, next);
1449
1450         n = rb_first(&iommu->dma_list);
1451
1452         for (; n; n = rb_next(n)) {
1453                 struct vfio_dma *dma;
1454                 dma_addr_t iova;
1455
1456                 dma = rb_entry(n, struct vfio_dma, node);
1457                 iova = dma->iova;
1458
1459                 while (iova < dma->iova + dma->size) {
1460                         phys_addr_t phys;
1461                         size_t size;
1462
1463                         if (dma->iommu_mapped) {
1464                                 phys_addr_t p;
1465                                 dma_addr_t i;
1466
1467                                 if (WARN_ON(!d)) { /* mapped w/o a domain?! */
1468                                         ret = -EINVAL;
1469                                         goto unwind;
1470                                 }
1471
1472                                 phys = iommu_iova_to_phys(d->domain, iova);
1473
1474                                 if (WARN_ON(!phys)) {
1475                                         iova += PAGE_SIZE;
1476                                         continue;
1477                                 }
1478
1479                                 size = PAGE_SIZE;
1480                                 p = phys + size;
1481                                 i = iova + size;
1482                                 while (i < dma->iova + dma->size &&
1483                                        p == iommu_iova_to_phys(d->domain, i)) {
1484                                         size += PAGE_SIZE;
1485                                         p += PAGE_SIZE;
1486                                         i += PAGE_SIZE;
1487                                 }
1488                         } else {
1489                                 unsigned long pfn;
1490                                 unsigned long vaddr = dma->vaddr +
1491                                                      (iova - dma->iova);
1492                                 size_t n = dma->iova + dma->size - iova;
1493                                 long npage;
1494
1495                                 npage = vfio_pin_pages_remote(dma, vaddr,
1496                                                               n >> PAGE_SHIFT,
1497                                                               &pfn, limit);
1498                                 if (npage <= 0) {
1499                                         WARN_ON(!npage);
1500                                         ret = (int)npage;
1501                                         goto unwind;
1502                                 }
1503
1504                                 phys = pfn << PAGE_SHIFT;
1505                                 size = npage << PAGE_SHIFT;
1506                         }
1507
1508                         ret = iommu_map(domain->domain, iova, phys,
1509                                         size, dma->prot | domain->prot);
1510                         if (ret) {
1511                                 if (!dma->iommu_mapped)
1512                                         vfio_unpin_pages_remote(dma, iova,
1513                                                         phys >> PAGE_SHIFT,
1514                                                         size >> PAGE_SHIFT,
1515                                                         true);
1516                                 goto unwind;
1517                         }
1518
1519                         iova += size;
1520                 }
1521         }
1522
1523         /* All dmas are now mapped, defer to second tree walk for unwind */
1524         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
1525                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1526
1527                 dma->iommu_mapped = true;
1528         }
1529
1530         return 0;
1531
1532 unwind:
1533         for (; n; n = rb_prev(n)) {
1534                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1535                 dma_addr_t iova;
1536
1537                 if (dma->iommu_mapped) {
1538                         iommu_unmap(domain->domain, dma->iova, dma->size);
1539                         continue;
1540                 }
1541
1542                 iova = dma->iova;
1543                 while (iova < dma->iova + dma->size) {
1544                         phys_addr_t phys, p;
1545                         size_t size;
1546                         dma_addr_t i;
1547
1548                         phys = iommu_iova_to_phys(domain->domain, iova);
1549                         if (!phys) {
1550                                 iova += PAGE_SIZE;
1551                                 continue;
1552                         }
1553
1554                         size = PAGE_SIZE;
1555                         p = phys + size;
1556                         i = iova + size;
1557                         while (i < dma->iova + dma->size &&
1558                                p == iommu_iova_to_phys(domain->domain, i)) {
1559                                 size += PAGE_SIZE;
1560                                 p += PAGE_SIZE;
1561                                 i += PAGE_SIZE;
1562                         }
1563
1564                         iommu_unmap(domain->domain, iova, size);
1565                         vfio_unpin_pages_remote(dma, iova, phys >> PAGE_SHIFT,
1566                                                 size >> PAGE_SHIFT, true);
1567                 }
1568         }
1569
1570         return ret;
1571 }
1572
1573 /*
1574  * We change our unmap behavior slightly depending on whether the IOMMU
1575  * supports fine-grained superpages.  IOMMUs like AMD-Vi will use a superpage
1576  * for practically any contiguous power-of-two mapping we give it.  This means
1577  * we don't need to look for contiguous chunks ourselves to make unmapping
1578  * more efficient.  On IOMMUs with coarse-grained super pages, like Intel VT-d
1579  * with discrete 2M/1G/512G/1T superpages, identifying contiguous chunks
1580  * significantly boosts non-hugetlbfs mappings and doesn't seem to hurt when
1581  * hugetlbfs is in use.
1582  */
1583 static void vfio_test_domain_fgsp(struct vfio_domain *domain)
1584 {
1585         struct page *pages;
1586         int ret, order = get_order(PAGE_SIZE * 2);
1587
1588         pages = alloc_pages(GFP_KERNEL | __GFP_ZERO, order);
1589         if (!pages)
1590                 return;
1591
1592         ret = iommu_map(domain->domain, 0, page_to_phys(pages), PAGE_SIZE * 2,
1593                         IOMMU_READ | IOMMU_WRITE | domain->prot);
1594         if (!ret) {
1595                 size_t unmapped = iommu_unmap(domain->domain, 0, PAGE_SIZE);
1596
1597                 if (unmapped == PAGE_SIZE)
1598                         iommu_unmap(domain->domain, PAGE_SIZE, PAGE_SIZE);
1599                 else
1600                         domain->fgsp = true;
1601         }
1602
1603         __free_pages(pages, order);
1604 }
1605
1606 static struct vfio_group *find_iommu_group(struct vfio_domain *domain,
1607                                            struct iommu_group *iommu_group)
1608 {
1609         struct vfio_group *g;
1610
1611         list_for_each_entry(g, &domain->group_list, next) {
1612                 if (g->iommu_group == iommu_group)
1613                         return g;
1614         }
1615
1616         return NULL;
1617 }
1618
1619 static struct vfio_group *vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
1620                                                struct iommu_group *iommu_group)
1621 {
1622         struct vfio_domain *domain;
1623         struct vfio_group *group = NULL;
1624
1625         list_for_each_entry(domain, &iommu->domain_list, next) {
1626                 group = find_iommu_group(domain, iommu_group);
1627                 if (group)
1628                         return group;
1629         }
1630
1631         if (iommu->external_domain)
1632                 group = find_iommu_group(iommu->external_domain, iommu_group);
1633
1634         return group;
1635 }
1636
1637 static void update_pinned_page_dirty_scope(struct vfio_iommu *iommu)
1638 {
1639         struct vfio_domain *domain;
1640         struct vfio_group *group;
1641
1642         list_for_each_entry(domain, &iommu->domain_list, next) {
1643                 list_for_each_entry(group, &domain->group_list, next) {
1644                         if (!group->pinned_page_dirty_scope) {
1645                                 iommu->pinned_page_dirty_scope = false;
1646                                 return;
1647                         }
1648                 }
1649         }
1650
1651         if (iommu->external_domain) {
1652                 domain = iommu->external_domain;
1653                 list_for_each_entry(group, &domain->group_list, next) {
1654                         if (!group->pinned_page_dirty_scope) {
1655                                 iommu->pinned_page_dirty_scope = false;
1656                                 return;
1657                         }
1658                 }
1659         }
1660
1661         iommu->pinned_page_dirty_scope = true;
1662 }
1663
1664 static bool vfio_iommu_has_sw_msi(struct list_head *group_resv_regions,
1665                                   phys_addr_t *base)
1666 {
1667         struct iommu_resv_region *region;
1668         bool ret = false;
1669
1670         list_for_each_entry(region, group_resv_regions, list) {
1671                 /*
1672                  * The presence of any 'real' MSI regions should take
1673                  * precedence over the software-managed one if the
1674                  * IOMMU driver happens to advertise both types.
1675                  */
1676                 if (region->type == IOMMU_RESV_MSI) {
1677                         ret = false;
1678                         break;
1679                 }
1680
1681                 if (region->type == IOMMU_RESV_SW_MSI) {
1682                         *base = region->start;
1683                         ret = true;
1684                 }
1685         }
1686
1687         return ret;
1688 }
1689
1690 static struct device *vfio_mdev_get_iommu_device(struct device *dev)
1691 {
1692         struct device *(*fn)(struct device *dev);
1693         struct device *iommu_device;
1694
1695         fn = symbol_get(mdev_get_iommu_device);
1696         if (fn) {
1697                 iommu_device = fn(dev);
1698                 symbol_put(mdev_get_iommu_device);
1699
1700                 return iommu_device;
1701         }
1702
1703         return NULL;
1704 }
1705
1706 static int vfio_mdev_attach_domain(struct device *dev, void *data)
1707 {
1708         struct iommu_domain *domain = data;
1709         struct device *iommu_device;
1710
1711         iommu_device = vfio_mdev_get_iommu_device(dev);
1712         if (iommu_device) {
1713                 if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
1714                         return iommu_aux_attach_device(domain, iommu_device);
1715                 else
1716                         return iommu_attach_device(domain, iommu_device);
1717         }
1718
1719         return -EINVAL;
1720 }
1721
1722 static int vfio_mdev_detach_domain(struct device *dev, void *data)
1723 {
1724         struct iommu_domain *domain = data;
1725         struct device *iommu_device;
1726
1727         iommu_device = vfio_mdev_get_iommu_device(dev);
1728         if (iommu_device) {
1729                 if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
1730                         iommu_aux_detach_device(domain, iommu_device);
1731                 else
1732                         iommu_detach_device(domain, iommu_device);
1733         }
1734
1735         return 0;
1736 }
1737
1738 static int vfio_iommu_attach_group(struct vfio_domain *domain,
1739                                    struct vfio_group *group)
1740 {
1741         if (group->mdev_group)
1742                 return iommu_group_for_each_dev(group->iommu_group,
1743                                                 domain->domain,
1744                                                 vfio_mdev_attach_domain);
1745         else
1746                 return iommu_attach_group(domain->domain, group->iommu_group);
1747 }
1748
1749 static void vfio_iommu_detach_group(struct vfio_domain *domain,
1750                                     struct vfio_group *group)
1751 {
1752         if (group->mdev_group)
1753                 iommu_group_for_each_dev(group->iommu_group, domain->domain,
1754                                          vfio_mdev_detach_domain);
1755         else
1756                 iommu_detach_group(domain->domain, group->iommu_group);
1757 }
1758
1759 static bool vfio_bus_is_mdev(struct bus_type *bus)
1760 {
1761         struct bus_type *mdev_bus;
1762         bool ret = false;
1763
1764         mdev_bus = symbol_get(mdev_bus_type);
1765         if (mdev_bus) {
1766                 ret = (bus == mdev_bus);
1767                 symbol_put(mdev_bus_type);
1768         }
1769
1770         return ret;
1771 }
1772
1773 static int vfio_mdev_iommu_device(struct device *dev, void *data)
1774 {
1775         struct device **old = data, *new;
1776
1777         new = vfio_mdev_get_iommu_device(dev);
1778         if (!new || (*old && *old != new))
1779                 return -EINVAL;
1780
1781         *old = new;
1782
1783         return 0;
1784 }
1785
1786 /*
1787  * This is a helper function to insert an address range to iova list.
1788  * The list is initially created with a single entry corresponding to
1789  * the IOMMU domain geometry to which the device group is attached.
1790  * The list aperture gets modified when a new domain is added to the
1791  * container if the new aperture doesn't conflict with the current one
1792  * or with any existing dma mappings. The list is also modified to
1793  * exclude any reserved regions associated with the device group.
1794  */
1795 static int vfio_iommu_iova_insert(struct list_head *head,
1796                                   dma_addr_t start, dma_addr_t end)
1797 {
1798         struct vfio_iova *region;
1799
1800         region = kmalloc(sizeof(*region), GFP_KERNEL);
1801         if (!region)
1802                 return -ENOMEM;
1803
1804         INIT_LIST_HEAD(&region->list);
1805         region->start = start;
1806         region->end = end;
1807
1808         list_add_tail(&region->list, head);
1809         return 0;
1810 }
1811
1812 /*
1813  * Check the new iommu aperture conflicts with existing aper or with any
1814  * existing dma mappings.
1815  */
1816 static bool vfio_iommu_aper_conflict(struct vfio_iommu *iommu,
1817                                      dma_addr_t start, dma_addr_t end)
1818 {
1819         struct vfio_iova *first, *last;
1820         struct list_head *iova = &iommu->iova_list;
1821
1822         if (list_empty(iova))
1823                 return false;
1824
1825         /* Disjoint sets, return conflict */
1826         first = list_first_entry(iova, struct vfio_iova, list);
1827         last = list_last_entry(iova, struct vfio_iova, list);
1828         if (start > last->end || end < first->start)
1829                 return true;
1830
1831         /* Check for any existing dma mappings below the new start */
1832         if (start > first->start) {
1833                 if (vfio_find_dma(iommu, first->start, start - first->start))
1834                         return true;
1835         }
1836
1837         /* Check for any existing dma mappings beyond the new end */
1838         if (end < last->end) {
1839                 if (vfio_find_dma(iommu, end + 1, last->end - end))
1840                         return true;
1841         }
1842
1843         return false;
1844 }
1845
1846 /*
1847  * Resize iommu iova aperture window. This is called only if the new
1848  * aperture has no conflict with existing aperture and dma mappings.
1849  */
1850 static int vfio_iommu_aper_resize(struct list_head *iova,
1851                                   dma_addr_t start, dma_addr_t end)
1852 {
1853         struct vfio_iova *node, *next;
1854
1855         if (list_empty(iova))
1856                 return vfio_iommu_iova_insert(iova, start, end);
1857
1858         /* Adjust iova list start */
1859         list_for_each_entry_safe(node, next, iova, list) {
1860                 if (start < node->start)
1861                         break;
1862                 if (start >= node->start && start < node->end) {
1863                         node->start = start;
1864                         break;
1865                 }
1866                 /* Delete nodes before new start */
1867                 list_del(&node->list);
1868                 kfree(node);
1869         }
1870
1871         /* Adjust iova list end */
1872         list_for_each_entry_safe(node, next, iova, list) {
1873                 if (end > node->end)
1874                         continue;
1875                 if (end > node->start && end <= node->end) {
1876                         node->end = end;
1877                         continue;
1878                 }
1879                 /* Delete nodes after new end */
1880                 list_del(&node->list);
1881                 kfree(node);
1882         }
1883
1884         return 0;
1885 }
1886
1887 /*
1888  * Check reserved region conflicts with existing dma mappings
1889  */
1890 static bool vfio_iommu_resv_conflict(struct vfio_iommu *iommu,
1891                                      struct list_head *resv_regions)
1892 {
1893         struct iommu_resv_region *region;
1894
1895         /* Check for conflict with existing dma mappings */
1896         list_for_each_entry(region, resv_regions, list) {
1897                 if (region->type == IOMMU_RESV_DIRECT_RELAXABLE)
1898                         continue;
1899
1900                 if (vfio_find_dma(iommu, region->start, region->length))
1901                         return true;
1902         }
1903
1904         return false;
1905 }
1906
1907 /*
1908  * Check iova region overlap with  reserved regions and
1909  * exclude them from the iommu iova range
1910  */
1911 static int vfio_iommu_resv_exclude(struct list_head *iova,
1912                                    struct list_head *resv_regions)
1913 {
1914         struct iommu_resv_region *resv;
1915         struct vfio_iova *n, *next;
1916
1917         list_for_each_entry(resv, resv_regions, list) {
1918                 phys_addr_t start, end;
1919
1920                 if (resv->type == IOMMU_RESV_DIRECT_RELAXABLE)
1921                         continue;
1922
1923                 start = resv->start;
1924                 end = resv->start + resv->length - 1;
1925
1926                 list_for_each_entry_safe(n, next, iova, list) {
1927                         int ret = 0;
1928
1929                         /* No overlap */
1930                         if (start > n->end || end < n->start)
1931                                 continue;
1932                         /*
1933                          * Insert a new node if current node overlaps with the
1934                          * reserve region to exlude that from valid iova range.
1935                          * Note that, new node is inserted before the current
1936                          * node and finally the current node is deleted keeping
1937                          * the list updated and sorted.
1938                          */
1939                         if (start > n->start)
1940                                 ret = vfio_iommu_iova_insert(&n->list, n->start,
1941                                                              start - 1);
1942                         if (!ret && end < n->end)
1943                                 ret = vfio_iommu_iova_insert(&n->list, end + 1,
1944                                                              n->end);
1945                         if (ret)
1946                                 return ret;
1947
1948                         list_del(&n->list);
1949                         kfree(n);
1950                 }
1951         }
1952
1953         if (list_empty(iova))
1954                 return -EINVAL;
1955
1956         return 0;
1957 }
1958
1959 static void vfio_iommu_resv_free(struct list_head *resv_regions)
1960 {
1961         struct iommu_resv_region *n, *next;
1962
1963         list_for_each_entry_safe(n, next, resv_regions, list) {
1964                 list_del(&n->list);
1965                 kfree(n);
1966         }
1967 }
1968
1969 static void vfio_iommu_iova_free(struct list_head *iova)
1970 {
1971         struct vfio_iova *n, *next;
1972
1973         list_for_each_entry_safe(n, next, iova, list) {
1974                 list_del(&n->list);
1975                 kfree(n);
1976         }
1977 }
1978
1979 static int vfio_iommu_iova_get_copy(struct vfio_iommu *iommu,
1980                                     struct list_head *iova_copy)
1981 {
1982         struct list_head *iova = &iommu->iova_list;
1983         struct vfio_iova *n;
1984         int ret;
1985
1986         list_for_each_entry(n, iova, list) {
1987                 ret = vfio_iommu_iova_insert(iova_copy, n->start, n->end);
1988                 if (ret)
1989                         goto out_free;
1990         }
1991
1992         return 0;
1993
1994 out_free:
1995         vfio_iommu_iova_free(iova_copy);
1996         return ret;
1997 }
1998
1999 static void vfio_iommu_iova_insert_copy(struct vfio_iommu *iommu,
2000                                         struct list_head *iova_copy)
2001 {
2002         struct list_head *iova = &iommu->iova_list;
2003
2004         vfio_iommu_iova_free(iova);
2005
2006         list_splice_tail(iova_copy, iova);
2007 }
2008
2009 static int vfio_iommu_type1_attach_group(void *iommu_data,
2010                                          struct iommu_group *iommu_group)
2011 {
2012         struct vfio_iommu *iommu = iommu_data;
2013         struct vfio_group *group;
2014         struct vfio_domain *domain, *d;
2015         struct bus_type *bus = NULL;
2016         int ret;
2017         bool resv_msi, msi_remap;
2018         phys_addr_t resv_msi_base = 0;
2019         struct iommu_domain_geometry geo;
2020         LIST_HEAD(iova_copy);
2021         LIST_HEAD(group_resv_regions);
2022
2023         mutex_lock(&iommu->lock);
2024
2025         /* Check for duplicates */
2026         if (vfio_iommu_find_iommu_group(iommu, iommu_group)) {
2027                 mutex_unlock(&iommu->lock);
2028                 return -EINVAL;
2029         }
2030
2031         group = kzalloc(sizeof(*group), GFP_KERNEL);
2032         domain = kzalloc(sizeof(*domain), GFP_KERNEL);
2033         if (!group || !domain) {
2034                 ret = -ENOMEM;
2035                 goto out_free;
2036         }
2037
2038         group->iommu_group = iommu_group;
2039
2040         /* Determine bus_type in order to allocate a domain */
2041         ret = iommu_group_for_each_dev(iommu_group, &bus, vfio_bus_type);
2042         if (ret)
2043                 goto out_free;
2044
2045         if (vfio_bus_is_mdev(bus)) {
2046                 struct device *iommu_device = NULL;
2047
2048                 group->mdev_group = true;
2049
2050                 /* Determine the isolation type */
2051                 ret = iommu_group_for_each_dev(iommu_group, &iommu_device,
2052                                                vfio_mdev_iommu_device);
2053                 if (ret || !iommu_device) {
2054                         if (!iommu->external_domain) {
2055                                 INIT_LIST_HEAD(&domain->group_list);
2056                                 iommu->external_domain = domain;
2057                                 vfio_update_pgsize_bitmap(iommu);
2058                         } else {
2059                                 kfree(domain);
2060                         }
2061
2062                         list_add(&group->next,
2063                                  &iommu->external_domain->group_list);
2064                         /*
2065                          * Non-iommu backed group cannot dirty memory directly,
2066                          * it can only use interfaces that provide dirty
2067                          * tracking.
2068                          * The iommu scope can only be promoted with the
2069                          * addition of a dirty tracking group.
2070                          */
2071                         group->pinned_page_dirty_scope = true;
2072                         if (!iommu->pinned_page_dirty_scope)
2073                                 update_pinned_page_dirty_scope(iommu);
2074                         mutex_unlock(&iommu->lock);
2075
2076                         return 0;
2077                 }
2078
2079                 bus = iommu_device->bus;
2080         }
2081
2082         domain->domain = iommu_domain_alloc(bus);
2083         if (!domain->domain) {
2084                 ret = -EIO;
2085                 goto out_free;
2086         }
2087
2088         if (iommu->nesting) {
2089                 int attr = 1;
2090
2091                 ret = iommu_domain_set_attr(domain->domain, DOMAIN_ATTR_NESTING,
2092                                             &attr);
2093                 if (ret)
2094                         goto out_domain;
2095         }
2096
2097         ret = vfio_iommu_attach_group(domain, group);
2098         if (ret)
2099                 goto out_domain;
2100
2101         /* Get aperture info */
2102         iommu_domain_get_attr(domain->domain, DOMAIN_ATTR_GEOMETRY, &geo);
2103
2104         if (vfio_iommu_aper_conflict(iommu, geo.aperture_start,
2105                                      geo.aperture_end)) {
2106                 ret = -EINVAL;
2107                 goto out_detach;
2108         }
2109
2110         ret = iommu_get_group_resv_regions(iommu_group, &group_resv_regions);
2111         if (ret)
2112                 goto out_detach;
2113
2114         if (vfio_iommu_resv_conflict(iommu, &group_resv_regions)) {
2115                 ret = -EINVAL;
2116                 goto out_detach;
2117         }
2118
2119         /*
2120          * We don't want to work on the original iova list as the list
2121          * gets modified and in case of failure we have to retain the
2122          * original list. Get a copy here.
2123          */
2124         ret = vfio_iommu_iova_get_copy(iommu, &iova_copy);
2125         if (ret)
2126                 goto out_detach;
2127
2128         ret = vfio_iommu_aper_resize(&iova_copy, geo.aperture_start,
2129                                      geo.aperture_end);
2130         if (ret)
2131                 goto out_detach;
2132
2133         ret = vfio_iommu_resv_exclude(&iova_copy, &group_resv_regions);
2134         if (ret)
2135                 goto out_detach;
2136
2137         resv_msi = vfio_iommu_has_sw_msi(&group_resv_regions, &resv_msi_base);
2138
2139         INIT_LIST_HEAD(&domain->group_list);
2140         list_add(&group->next, &domain->group_list);
2141
2142         msi_remap = irq_domain_check_msi_remap() ||
2143                     iommu_capable(bus, IOMMU_CAP_INTR_REMAP);
2144
2145         if (!allow_unsafe_interrupts && !msi_remap) {
2146                 pr_warn("%s: No interrupt remapping support.  Use the module param \"allow_unsafe_interrupts\" to enable VFIO IOMMU support on this platform\n",
2147                        __func__);
2148                 ret = -EPERM;
2149                 goto out_detach;
2150         }
2151
2152         if (iommu_capable(bus, IOMMU_CAP_CACHE_COHERENCY))
2153                 domain->prot |= IOMMU_CACHE;
2154
2155         /*
2156          * Try to match an existing compatible domain.  We don't want to
2157          * preclude an IOMMU driver supporting multiple bus_types and being
2158          * able to include different bus_types in the same IOMMU domain, so
2159          * we test whether the domains use the same iommu_ops rather than
2160          * testing if they're on the same bus_type.
2161          */
2162         list_for_each_entry(d, &iommu->domain_list, next) {
2163                 if (d->domain->ops == domain->domain->ops &&
2164                     d->prot == domain->prot) {
2165                         vfio_iommu_detach_group(domain, group);
2166                         if (!vfio_iommu_attach_group(d, group)) {
2167                                 list_add(&group->next, &d->group_list);
2168                                 iommu_domain_free(domain->domain);
2169                                 kfree(domain);
2170                                 goto done;
2171                         }
2172
2173                         ret = vfio_iommu_attach_group(domain, group);
2174                         if (ret)
2175                                 goto out_domain;
2176                 }
2177         }
2178
2179         vfio_test_domain_fgsp(domain);
2180
2181         /* replay mappings on new domains */
2182         ret = vfio_iommu_replay(iommu, domain);
2183         if (ret)
2184                 goto out_detach;
2185
2186         if (resv_msi) {
2187                 ret = iommu_get_msi_cookie(domain->domain, resv_msi_base);
2188                 if (ret && ret != -ENODEV)
2189                         goto out_detach;
2190         }
2191
2192         list_add(&domain->next, &iommu->domain_list);
2193         vfio_update_pgsize_bitmap(iommu);
2194 done:
2195         /* Delete the old one and insert new iova list */
2196         vfio_iommu_iova_insert_copy(iommu, &iova_copy);
2197
2198         /*
2199          * An iommu backed group can dirty memory directly and therefore
2200          * demotes the iommu scope until it declares itself dirty tracking
2201          * capable via the page pinning interface.
2202          */
2203         iommu->pinned_page_dirty_scope = false;
2204         mutex_unlock(&iommu->lock);
2205         vfio_iommu_resv_free(&group_resv_regions);
2206
2207         return 0;
2208
2209 out_detach:
2210         vfio_iommu_detach_group(domain, group);
2211 out_domain:
2212         iommu_domain_free(domain->domain);
2213         vfio_iommu_iova_free(&iova_copy);
2214         vfio_iommu_resv_free(&group_resv_regions);
2215 out_free:
2216         kfree(domain);
2217         kfree(group);
2218         mutex_unlock(&iommu->lock);
2219         return ret;
2220 }
2221
2222 static void vfio_iommu_unmap_unpin_all(struct vfio_iommu *iommu)
2223 {
2224         struct rb_node *node;
2225
2226         while ((node = rb_first(&iommu->dma_list)))
2227                 vfio_remove_dma(iommu, rb_entry(node, struct vfio_dma, node));
2228 }
2229
2230 static void vfio_iommu_unmap_unpin_reaccount(struct vfio_iommu *iommu)
2231 {
2232         struct rb_node *n, *p;
2233
2234         n = rb_first(&iommu->dma_list);
2235         for (; n; n = rb_next(n)) {
2236                 struct vfio_dma *dma;
2237                 long locked = 0, unlocked = 0;
2238
2239                 dma = rb_entry(n, struct vfio_dma, node);
2240                 unlocked += vfio_unmap_unpin(iommu, dma, false);
2241                 p = rb_first(&dma->pfn_list);
2242                 for (; p; p = rb_next(p)) {
2243                         struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn,
2244                                                          node);
2245
2246                         if (!is_invalid_reserved_pfn(vpfn->pfn))
2247                                 locked++;
2248                 }
2249                 vfio_lock_acct(dma, locked - unlocked, true);
2250         }
2251 }
2252
2253 static void vfio_sanity_check_pfn_list(struct vfio_iommu *iommu)
2254 {
2255         struct rb_node *n;
2256
2257         n = rb_first(&iommu->dma_list);
2258         for (; n; n = rb_next(n)) {
2259                 struct vfio_dma *dma;
2260
2261                 dma = rb_entry(n, struct vfio_dma, node);
2262
2263                 if (WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list)))
2264                         break;
2265         }
2266         /* mdev vendor driver must unregister notifier */
2267         WARN_ON(iommu->notifier.head);
2268 }
2269
2270 /*
2271  * Called when a domain is removed in detach. It is possible that
2272  * the removed domain decided the iova aperture window. Modify the
2273  * iova aperture with the smallest window among existing domains.
2274  */
2275 static void vfio_iommu_aper_expand(struct vfio_iommu *iommu,
2276                                    struct list_head *iova_copy)
2277 {
2278         struct vfio_domain *domain;
2279         struct iommu_domain_geometry geo;
2280         struct vfio_iova *node;
2281         dma_addr_t start = 0;
2282         dma_addr_t end = (dma_addr_t)~0;
2283
2284         if (list_empty(iova_copy))
2285                 return;
2286
2287         list_for_each_entry(domain, &iommu->domain_list, next) {
2288                 iommu_domain_get_attr(domain->domain, DOMAIN_ATTR_GEOMETRY,
2289                                       &geo);
2290                 if (geo.aperture_start > start)
2291                         start = geo.aperture_start;
2292                 if (geo.aperture_end < end)
2293                         end = geo.aperture_end;
2294         }
2295
2296         /* Modify aperture limits. The new aper is either same or bigger */
2297         node = list_first_entry(iova_copy, struct vfio_iova, list);
2298         node->start = start;
2299         node = list_last_entry(iova_copy, struct vfio_iova, list);
2300         node->end = end;
2301 }
2302
2303 /*
2304  * Called when a group is detached. The reserved regions for that
2305  * group can be part of valid iova now. But since reserved regions
2306  * may be duplicated among groups, populate the iova valid regions
2307  * list again.
2308  */
2309 static int vfio_iommu_resv_refresh(struct vfio_iommu *iommu,
2310                                    struct list_head *iova_copy)
2311 {
2312         struct vfio_domain *d;
2313         struct vfio_group *g;
2314         struct vfio_iova *node;
2315         dma_addr_t start, end;
2316         LIST_HEAD(resv_regions);
2317         int ret;
2318
2319         if (list_empty(iova_copy))
2320                 return -EINVAL;
2321
2322         list_for_each_entry(d, &iommu->domain_list, next) {
2323                 list_for_each_entry(g, &d->group_list, next) {
2324                         ret = iommu_get_group_resv_regions(g->iommu_group,
2325                                                            &resv_regions);
2326                         if (ret)
2327                                 goto done;
2328                 }
2329         }
2330
2331         node = list_first_entry(iova_copy, struct vfio_iova, list);
2332         start = node->start;
2333         node = list_last_entry(iova_copy, struct vfio_iova, list);
2334         end = node->end;
2335
2336         /* purge the iova list and create new one */
2337         vfio_iommu_iova_free(iova_copy);
2338
2339         ret = vfio_iommu_aper_resize(iova_copy, start, end);
2340         if (ret)
2341                 goto done;
2342
2343         /* Exclude current reserved regions from iova ranges */
2344         ret = vfio_iommu_resv_exclude(iova_copy, &resv_regions);
2345 done:
2346         vfio_iommu_resv_free(&resv_regions);
2347         return ret;
2348 }
2349
2350 static void vfio_iommu_type1_detach_group(void *iommu_data,
2351                                           struct iommu_group *iommu_group)
2352 {
2353         struct vfio_iommu *iommu = iommu_data;
2354         struct vfio_domain *domain;
2355         struct vfio_group *group;
2356         bool update_dirty_scope = false;
2357         LIST_HEAD(iova_copy);
2358
2359         mutex_lock(&iommu->lock);
2360
2361         if (iommu->external_domain) {
2362                 group = find_iommu_group(iommu->external_domain, iommu_group);
2363                 if (group) {
2364                         update_dirty_scope = !group->pinned_page_dirty_scope;
2365                         list_del(&group->next);
2366                         kfree(group);
2367
2368                         if (list_empty(&iommu->external_domain->group_list)) {
2369                                 vfio_sanity_check_pfn_list(iommu);
2370
2371                                 if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
2372                                         vfio_iommu_unmap_unpin_all(iommu);
2373
2374                                 kfree(iommu->external_domain);
2375                                 iommu->external_domain = NULL;
2376                         }
2377                         goto detach_group_done;
2378                 }
2379         }
2380
2381         /*
2382          * Get a copy of iova list. This will be used to update
2383          * and to replace the current one later. Please note that
2384          * we will leave the original list as it is if update fails.
2385          */
2386         vfio_iommu_iova_get_copy(iommu, &iova_copy);
2387
2388         list_for_each_entry(domain, &iommu->domain_list, next) {
2389                 group = find_iommu_group(domain, iommu_group);
2390                 if (!group)
2391                         continue;
2392
2393                 vfio_iommu_detach_group(domain, group);
2394                 update_dirty_scope = !group->pinned_page_dirty_scope;
2395                 list_del(&group->next);
2396                 kfree(group);
2397                 /*
2398                  * Group ownership provides privilege, if the group list is
2399                  * empty, the domain goes away. If it's the last domain with
2400                  * iommu and external domain doesn't exist, then all the
2401                  * mappings go away too. If it's the last domain with iommu and
2402                  * external domain exist, update accounting
2403                  */
2404                 if (list_empty(&domain->group_list)) {
2405                         if (list_is_singular(&iommu->domain_list)) {
2406                                 if (!iommu->external_domain)
2407                                         vfio_iommu_unmap_unpin_all(iommu);
2408                                 else
2409                                         vfio_iommu_unmap_unpin_reaccount(iommu);
2410                         }
2411                         iommu_domain_free(domain->domain);
2412                         list_del(&domain->next);
2413                         kfree(domain);
2414                         vfio_iommu_aper_expand(iommu, &iova_copy);
2415                         vfio_update_pgsize_bitmap(iommu);
2416                 }
2417                 break;
2418         }
2419
2420         if (!vfio_iommu_resv_refresh(iommu, &iova_copy))
2421                 vfio_iommu_iova_insert_copy(iommu, &iova_copy);
2422         else
2423                 vfio_iommu_iova_free(&iova_copy);
2424
2425 detach_group_done:
2426         /*
2427          * Removal of a group without dirty tracking may allow the iommu scope
2428          * to be promoted.
2429          */
2430         if (update_dirty_scope) {
2431                 update_pinned_page_dirty_scope(iommu);
2432                 if (iommu->dirty_page_tracking)
2433                         vfio_iommu_populate_bitmap_full(iommu);
2434         }
2435         mutex_unlock(&iommu->lock);
2436 }
2437
2438 static void *vfio_iommu_type1_open(unsigned long arg)
2439 {
2440         struct vfio_iommu *iommu;
2441
2442         iommu = kzalloc(sizeof(*iommu), GFP_KERNEL);
2443         if (!iommu)
2444                 return ERR_PTR(-ENOMEM);
2445
2446         switch (arg) {
2447         case VFIO_TYPE1_IOMMU:
2448                 break;
2449         case VFIO_TYPE1_NESTING_IOMMU:
2450                 iommu->nesting = true;
2451                 fallthrough;
2452         case VFIO_TYPE1v2_IOMMU:
2453                 iommu->v2 = true;
2454                 break;
2455         default:
2456                 kfree(iommu);
2457                 return ERR_PTR(-EINVAL);
2458         }
2459
2460         INIT_LIST_HEAD(&iommu->domain_list);
2461         INIT_LIST_HEAD(&iommu->iova_list);
2462         iommu->dma_list = RB_ROOT;
2463         iommu->dma_avail = dma_entry_limit;
2464         mutex_init(&iommu->lock);
2465         BLOCKING_INIT_NOTIFIER_HEAD(&iommu->notifier);
2466
2467         return iommu;
2468 }
2469
2470 static void vfio_release_domain(struct vfio_domain *domain, bool external)
2471 {
2472         struct vfio_group *group, *group_tmp;
2473
2474         list_for_each_entry_safe(group, group_tmp,
2475                                  &domain->group_list, next) {
2476                 if (!external)
2477                         vfio_iommu_detach_group(domain, group);
2478                 list_del(&group->next);
2479                 kfree(group);
2480         }
2481
2482         if (!external)
2483                 iommu_domain_free(domain->domain);
2484 }
2485
2486 static void vfio_iommu_type1_release(void *iommu_data)
2487 {
2488         struct vfio_iommu *iommu = iommu_data;
2489         struct vfio_domain *domain, *domain_tmp;
2490
2491         if (iommu->external_domain) {
2492                 vfio_release_domain(iommu->external_domain, true);
2493                 vfio_sanity_check_pfn_list(iommu);
2494                 kfree(iommu->external_domain);
2495         }
2496
2497         vfio_iommu_unmap_unpin_all(iommu);
2498
2499         list_for_each_entry_safe(domain, domain_tmp,
2500                                  &iommu->domain_list, next) {
2501                 vfio_release_domain(domain, false);
2502                 list_del(&domain->next);
2503                 kfree(domain);
2504         }
2505
2506         vfio_iommu_iova_free(&iommu->iova_list);
2507
2508         kfree(iommu);
2509 }
2510
2511 static int vfio_domains_have_iommu_cache(struct vfio_iommu *iommu)
2512 {
2513         struct vfio_domain *domain;
2514         int ret = 1;
2515
2516         mutex_lock(&iommu->lock);
2517         list_for_each_entry(domain, &iommu->domain_list, next) {
2518                 if (!(domain->prot & IOMMU_CACHE)) {
2519                         ret = 0;
2520                         break;
2521                 }
2522         }
2523         mutex_unlock(&iommu->lock);
2524
2525         return ret;
2526 }
2527
2528 static int vfio_iommu_type1_check_extension(struct vfio_iommu *iommu,
2529                                             unsigned long arg)
2530 {
2531         switch (arg) {
2532         case VFIO_TYPE1_IOMMU:
2533         case VFIO_TYPE1v2_IOMMU:
2534         case VFIO_TYPE1_NESTING_IOMMU:
2535                 return 1;
2536         case VFIO_DMA_CC_IOMMU:
2537                 if (!iommu)
2538                         return 0;
2539                 return vfio_domains_have_iommu_cache(iommu);
2540         default:
2541                 return 0;
2542         }
2543 }
2544
2545 static int vfio_iommu_iova_add_cap(struct vfio_info_cap *caps,
2546                  struct vfio_iommu_type1_info_cap_iova_range *cap_iovas,
2547                  size_t size)
2548 {
2549         struct vfio_info_cap_header *header;
2550         struct vfio_iommu_type1_info_cap_iova_range *iova_cap;
2551
2552         header = vfio_info_cap_add(caps, size,
2553                                    VFIO_IOMMU_TYPE1_INFO_CAP_IOVA_RANGE, 1);
2554         if (IS_ERR(header))
2555                 return PTR_ERR(header);
2556
2557         iova_cap = container_of(header,
2558                                 struct vfio_iommu_type1_info_cap_iova_range,
2559                                 header);
2560         iova_cap->nr_iovas = cap_iovas->nr_iovas;
2561         memcpy(iova_cap->iova_ranges, cap_iovas->iova_ranges,
2562                cap_iovas->nr_iovas * sizeof(*cap_iovas->iova_ranges));
2563         return 0;
2564 }
2565
2566 static int vfio_iommu_iova_build_caps(struct vfio_iommu *iommu,
2567                                       struct vfio_info_cap *caps)
2568 {
2569         struct vfio_iommu_type1_info_cap_iova_range *cap_iovas;
2570         struct vfio_iova *iova;
2571         size_t size;
2572         int iovas = 0, i = 0, ret;
2573
2574         list_for_each_entry(iova, &iommu->iova_list, list)
2575                 iovas++;
2576
2577         if (!iovas) {
2578                 /*
2579                  * Return 0 as a container with a single mdev device
2580                  * will have an empty list
2581                  */
2582                 return 0;
2583         }
2584
2585         size = sizeof(*cap_iovas) + (iovas * sizeof(*cap_iovas->iova_ranges));
2586
2587         cap_iovas = kzalloc(size, GFP_KERNEL);
2588         if (!cap_iovas)
2589                 return -ENOMEM;
2590
2591         cap_iovas->nr_iovas = iovas;
2592
2593         list_for_each_entry(iova, &iommu->iova_list, list) {
2594                 cap_iovas->iova_ranges[i].start = iova->start;
2595                 cap_iovas->iova_ranges[i].end = iova->end;
2596                 i++;
2597         }
2598
2599         ret = vfio_iommu_iova_add_cap(caps, cap_iovas, size);
2600
2601         kfree(cap_iovas);
2602         return ret;
2603 }
2604
2605 static int vfio_iommu_migration_build_caps(struct vfio_iommu *iommu,
2606                                            struct vfio_info_cap *caps)
2607 {
2608         struct vfio_iommu_type1_info_cap_migration cap_mig;
2609
2610         cap_mig.header.id = VFIO_IOMMU_TYPE1_INFO_CAP_MIGRATION;
2611         cap_mig.header.version = 1;
2612
2613         cap_mig.flags = 0;
2614         /* support minimum pgsize */
2615         cap_mig.pgsize_bitmap = (size_t)1 << __ffs(iommu->pgsize_bitmap);
2616         cap_mig.max_dirty_bitmap_size = DIRTY_BITMAP_SIZE_MAX;
2617
2618         return vfio_info_add_capability(caps, &cap_mig.header, sizeof(cap_mig));
2619 }
2620
2621 static int vfio_iommu_dma_avail_build_caps(struct vfio_iommu *iommu,
2622                                            struct vfio_info_cap *caps)
2623 {
2624         struct vfio_iommu_type1_info_dma_avail cap_dma_avail;
2625
2626         cap_dma_avail.header.id = VFIO_IOMMU_TYPE1_INFO_DMA_AVAIL;
2627         cap_dma_avail.header.version = 1;
2628
2629         cap_dma_avail.avail = iommu->dma_avail;
2630
2631         return vfio_info_add_capability(caps, &cap_dma_avail.header,
2632                                         sizeof(cap_dma_avail));
2633 }
2634
2635 static int vfio_iommu_type1_get_info(struct vfio_iommu *iommu,
2636                                      unsigned long arg)
2637 {
2638         struct vfio_iommu_type1_info info;
2639         unsigned long minsz;
2640         struct vfio_info_cap caps = { .buf = NULL, .size = 0 };
2641         unsigned long capsz;
2642         int ret;
2643
2644         minsz = offsetofend(struct vfio_iommu_type1_info, iova_pgsizes);
2645
2646         /* For backward compatibility, cannot require this */
2647         capsz = offsetofend(struct vfio_iommu_type1_info, cap_offset);
2648
2649         if (copy_from_user(&info, (void __user *)arg, minsz))
2650                 return -EFAULT;
2651
2652         if (info.argsz < minsz)
2653                 return -EINVAL;
2654
2655         if (info.argsz >= capsz) {
2656                 minsz = capsz;
2657                 info.cap_offset = 0; /* output, no-recopy necessary */
2658         }
2659
2660         mutex_lock(&iommu->lock);
2661         info.flags = VFIO_IOMMU_INFO_PGSIZES;
2662
2663         info.iova_pgsizes = iommu->pgsize_bitmap;
2664
2665         ret = vfio_iommu_migration_build_caps(iommu, &caps);
2666
2667         if (!ret)
2668                 ret = vfio_iommu_dma_avail_build_caps(iommu, &caps);
2669
2670         if (!ret)
2671                 ret = vfio_iommu_iova_build_caps(iommu, &caps);
2672
2673         mutex_unlock(&iommu->lock);
2674
2675         if (ret)
2676                 return ret;
2677
2678         if (caps.size) {
2679                 info.flags |= VFIO_IOMMU_INFO_CAPS;
2680
2681                 if (info.argsz < sizeof(info) + caps.size) {
2682                         info.argsz = sizeof(info) + caps.size;
2683                 } else {
2684                         vfio_info_cap_shift(&caps, sizeof(info));
2685                         if (copy_to_user((void __user *)arg +
2686                                         sizeof(info), caps.buf,
2687                                         caps.size)) {
2688                                 kfree(caps.buf);
2689                                 return -EFAULT;
2690                         }
2691                         info.cap_offset = sizeof(info);
2692                 }
2693
2694                 kfree(caps.buf);
2695         }
2696
2697         return copy_to_user((void __user *)arg, &info, minsz) ?
2698                         -EFAULT : 0;
2699 }
2700
2701 static int vfio_iommu_type1_map_dma(struct vfio_iommu *iommu,
2702                                     unsigned long arg)
2703 {
2704         struct vfio_iommu_type1_dma_map map;
2705         unsigned long minsz;
2706         uint32_t mask = VFIO_DMA_MAP_FLAG_READ | VFIO_DMA_MAP_FLAG_WRITE;
2707
2708         minsz = offsetofend(struct vfio_iommu_type1_dma_map, size);
2709
2710         if (copy_from_user(&map, (void __user *)arg, minsz))
2711                 return -EFAULT;
2712
2713         if (map.argsz < minsz || map.flags & ~mask)
2714                 return -EINVAL;
2715
2716         return vfio_dma_do_map(iommu, &map);
2717 }
2718
2719 static int vfio_iommu_type1_unmap_dma(struct vfio_iommu *iommu,
2720                                       unsigned long arg)
2721 {
2722         struct vfio_iommu_type1_dma_unmap unmap;
2723         struct vfio_bitmap bitmap = { 0 };
2724         unsigned long minsz;
2725         int ret;
2726
2727         minsz = offsetofend(struct vfio_iommu_type1_dma_unmap, size);
2728
2729         if (copy_from_user(&unmap, (void __user *)arg, minsz))
2730                 return -EFAULT;
2731
2732         if (unmap.argsz < minsz ||
2733             unmap.flags & ~VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP)
2734                 return -EINVAL;
2735
2736         if (unmap.flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
2737                 unsigned long pgshift;
2738
2739                 if (unmap.argsz < (minsz + sizeof(bitmap)))
2740                         return -EINVAL;
2741
2742                 if (copy_from_user(&bitmap,
2743                                    (void __user *)(arg + minsz),
2744                                    sizeof(bitmap)))
2745                         return -EFAULT;
2746
2747                 if (!access_ok((void __user *)bitmap.data, bitmap.size))
2748                         return -EINVAL;
2749
2750                 pgshift = __ffs(bitmap.pgsize);
2751                 ret = verify_bitmap_size(unmap.size >> pgshift,
2752                                          bitmap.size);
2753                 if (ret)
2754                         return ret;
2755         }
2756
2757         ret = vfio_dma_do_unmap(iommu, &unmap, &bitmap);
2758         if (ret)
2759                 return ret;
2760
2761         return copy_to_user((void __user *)arg, &unmap, minsz) ?
2762                         -EFAULT : 0;
2763 }
2764
2765 static int vfio_iommu_type1_dirty_pages(struct vfio_iommu *iommu,
2766                                         unsigned long arg)
2767 {
2768         struct vfio_iommu_type1_dirty_bitmap dirty;
2769         uint32_t mask = VFIO_IOMMU_DIRTY_PAGES_FLAG_START |
2770                         VFIO_IOMMU_DIRTY_PAGES_FLAG_STOP |
2771                         VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP;
2772         unsigned long minsz;
2773         int ret = 0;
2774
2775         if (!iommu->v2)
2776                 return -EACCES;
2777
2778         minsz = offsetofend(struct vfio_iommu_type1_dirty_bitmap, flags);
2779
2780         if (copy_from_user(&dirty, (void __user *)arg, minsz))
2781                 return -EFAULT;
2782
2783         if (dirty.argsz < minsz || dirty.flags & ~mask)
2784                 return -EINVAL;
2785
2786         /* only one flag should be set at a time */
2787         if (__ffs(dirty.flags) != __fls(dirty.flags))
2788                 return -EINVAL;
2789
2790         if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_START) {
2791                 size_t pgsize;
2792
2793                 mutex_lock(&iommu->lock);
2794                 pgsize = 1 << __ffs(iommu->pgsize_bitmap);
2795                 if (!iommu->dirty_page_tracking) {
2796                         ret = vfio_dma_bitmap_alloc_all(iommu, pgsize);
2797                         if (!ret)
2798                                 iommu->dirty_page_tracking = true;
2799                 }
2800                 mutex_unlock(&iommu->lock);
2801                 return ret;
2802         } else if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_STOP) {
2803                 mutex_lock(&iommu->lock);
2804                 if (iommu->dirty_page_tracking) {
2805                         iommu->dirty_page_tracking = false;
2806                         vfio_dma_bitmap_free_all(iommu);
2807                 }
2808                 mutex_unlock(&iommu->lock);
2809                 return 0;
2810         } else if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP) {
2811                 struct vfio_iommu_type1_dirty_bitmap_get range;
2812                 unsigned long pgshift;
2813                 size_t data_size = dirty.argsz - minsz;
2814                 size_t iommu_pgsize;
2815
2816                 if (!data_size || data_size < sizeof(range))
2817                         return -EINVAL;
2818
2819                 if (copy_from_user(&range, (void __user *)(arg + minsz),
2820                                    sizeof(range)))
2821                         return -EFAULT;
2822
2823                 if (range.iova + range.size < range.iova)
2824                         return -EINVAL;
2825                 if (!access_ok((void __user *)range.bitmap.data,
2826                                range.bitmap.size))
2827                         return -EINVAL;
2828
2829                 pgshift = __ffs(range.bitmap.pgsize);
2830                 ret = verify_bitmap_size(range.size >> pgshift,
2831                                          range.bitmap.size);
2832                 if (ret)
2833                         return ret;
2834
2835                 mutex_lock(&iommu->lock);
2836
2837                 iommu_pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
2838
2839                 /* allow only smallest supported pgsize */
2840                 if (range.bitmap.pgsize != iommu_pgsize) {
2841                         ret = -EINVAL;
2842                         goto out_unlock;
2843                 }
2844                 if (range.iova & (iommu_pgsize - 1)) {
2845                         ret = -EINVAL;
2846                         goto out_unlock;
2847                 }
2848                 if (!range.size || range.size & (iommu_pgsize - 1)) {
2849                         ret = -EINVAL;
2850                         goto out_unlock;
2851                 }
2852
2853                 if (iommu->dirty_page_tracking)
2854                         ret = vfio_iova_dirty_bitmap(range.bitmap.data,
2855                                                      iommu, range.iova,
2856                                                      range.size,
2857                                                      range.bitmap.pgsize);
2858                 else
2859                         ret = -EINVAL;
2860 out_unlock:
2861                 mutex_unlock(&iommu->lock);
2862
2863                 return ret;
2864         }
2865
2866         return -EINVAL;
2867 }
2868
2869 static long vfio_iommu_type1_ioctl(void *iommu_data,
2870                                    unsigned int cmd, unsigned long arg)
2871 {
2872         struct vfio_iommu *iommu = iommu_data;
2873
2874         switch (cmd) {
2875         case VFIO_CHECK_EXTENSION:
2876                 return vfio_iommu_type1_check_extension(iommu, arg);
2877         case VFIO_IOMMU_GET_INFO:
2878                 return vfio_iommu_type1_get_info(iommu, arg);
2879         case VFIO_IOMMU_MAP_DMA:
2880                 return vfio_iommu_type1_map_dma(iommu, arg);
2881         case VFIO_IOMMU_UNMAP_DMA:
2882                 return vfio_iommu_type1_unmap_dma(iommu, arg);
2883         case VFIO_IOMMU_DIRTY_PAGES:
2884                 return vfio_iommu_type1_dirty_pages(iommu, arg);
2885         default:
2886                 return -ENOTTY;
2887         }
2888 }
2889
2890 static int vfio_iommu_type1_register_notifier(void *iommu_data,
2891                                               unsigned long *events,
2892                                               struct notifier_block *nb)
2893 {
2894         struct vfio_iommu *iommu = iommu_data;
2895
2896         /* clear known events */
2897         *events &= ~VFIO_IOMMU_NOTIFY_DMA_UNMAP;
2898
2899         /* refuse to register if still events remaining */
2900         if (*events)
2901                 return -EINVAL;
2902
2903         return blocking_notifier_chain_register(&iommu->notifier, nb);
2904 }
2905
2906 static int vfio_iommu_type1_unregister_notifier(void *iommu_data,
2907                                                 struct notifier_block *nb)
2908 {
2909         struct vfio_iommu *iommu = iommu_data;
2910
2911         return blocking_notifier_chain_unregister(&iommu->notifier, nb);
2912 }
2913
2914 static int vfio_iommu_type1_dma_rw_chunk(struct vfio_iommu *iommu,
2915                                          dma_addr_t user_iova, void *data,
2916                                          size_t count, bool write,
2917                                          size_t *copied)
2918 {
2919         struct mm_struct *mm;
2920         unsigned long vaddr;
2921         struct vfio_dma *dma;
2922         bool kthread = current->mm == NULL;
2923         size_t offset;
2924
2925         *copied = 0;
2926
2927         dma = vfio_find_dma(iommu, user_iova, 1);
2928         if (!dma)
2929                 return -EINVAL;
2930
2931         if ((write && !(dma->prot & IOMMU_WRITE)) ||
2932                         !(dma->prot & IOMMU_READ))
2933                 return -EPERM;
2934
2935         mm = get_task_mm(dma->task);
2936
2937         if (!mm)
2938                 return -EPERM;
2939
2940         if (kthread)
2941                 kthread_use_mm(mm);
2942         else if (current->mm != mm)
2943                 goto out;
2944
2945         offset = user_iova - dma->iova;
2946
2947         if (count > dma->size - offset)
2948                 count = dma->size - offset;
2949
2950         vaddr = dma->vaddr + offset;
2951
2952         if (write) {
2953                 *copied = copy_to_user((void __user *)vaddr, data,
2954                                          count) ? 0 : count;
2955                 if (*copied && iommu->dirty_page_tracking) {
2956                         unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
2957                         /*
2958                          * Bitmap populated with the smallest supported page
2959                          * size
2960                          */
2961                         bitmap_set(dma->bitmap, offset >> pgshift,
2962                                    ((offset + *copied - 1) >> pgshift) -
2963                                    (offset >> pgshift) + 1);
2964                 }
2965         } else
2966                 *copied = copy_from_user(data, (void __user *)vaddr,
2967                                            count) ? 0 : count;
2968         if (kthread)
2969                 kthread_unuse_mm(mm);
2970 out:
2971         mmput(mm);
2972         return *copied ? 0 : -EFAULT;
2973 }
2974
2975 static int vfio_iommu_type1_dma_rw(void *iommu_data, dma_addr_t user_iova,
2976                                    void *data, size_t count, bool write)
2977 {
2978         struct vfio_iommu *iommu = iommu_data;
2979         int ret = 0;
2980         size_t done;
2981
2982         mutex_lock(&iommu->lock);
2983         while (count > 0) {
2984                 ret = vfio_iommu_type1_dma_rw_chunk(iommu, user_iova, data,
2985                                                     count, write, &done);
2986                 if (ret)
2987                         break;
2988
2989                 count -= done;
2990                 data += done;
2991                 user_iova += done;
2992         }
2993
2994         mutex_unlock(&iommu->lock);
2995         return ret;
2996 }
2997
2998 static struct iommu_domain *
2999 vfio_iommu_type1_group_iommu_domain(void *iommu_data,
3000                                     struct iommu_group *iommu_group)
3001 {
3002         struct iommu_domain *domain = ERR_PTR(-ENODEV);
3003         struct vfio_iommu *iommu = iommu_data;
3004         struct vfio_domain *d;
3005
3006         if (!iommu || !iommu_group)
3007                 return ERR_PTR(-EINVAL);
3008
3009         mutex_lock(&iommu->lock);
3010         list_for_each_entry(d, &iommu->domain_list, next) {
3011                 if (find_iommu_group(d, iommu_group)) {
3012                         domain = d->domain;
3013                         break;
3014                 }
3015         }
3016         mutex_unlock(&iommu->lock);
3017
3018         return domain;
3019 }
3020
3021 static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
3022         .name                   = "vfio-iommu-type1",
3023         .owner                  = THIS_MODULE,
3024         .open                   = vfio_iommu_type1_open,
3025         .release                = vfio_iommu_type1_release,
3026         .ioctl                  = vfio_iommu_type1_ioctl,
3027         .attach_group           = vfio_iommu_type1_attach_group,
3028         .detach_group           = vfio_iommu_type1_detach_group,
3029         .pin_pages              = vfio_iommu_type1_pin_pages,
3030         .unpin_pages            = vfio_iommu_type1_unpin_pages,
3031         .register_notifier      = vfio_iommu_type1_register_notifier,
3032         .unregister_notifier    = vfio_iommu_type1_unregister_notifier,
3033         .dma_rw                 = vfio_iommu_type1_dma_rw,
3034         .group_iommu_domain     = vfio_iommu_type1_group_iommu_domain,
3035 };
3036
3037 static int __init vfio_iommu_type1_init(void)
3038 {
3039         return vfio_register_iommu_driver(&vfio_iommu_driver_ops_type1);
3040 }
3041
3042 static void __exit vfio_iommu_type1_cleanup(void)
3043 {
3044         vfio_unregister_iommu_driver(&vfio_iommu_driver_ops_type1);
3045 }
3046
3047 module_init(vfio_iommu_type1_init);
3048 module_exit(vfio_iommu_type1_cleanup);
3049
3050 MODULE_VERSION(DRIVER_VERSION);
3051 MODULE_LICENSE("GPL v2");
3052 MODULE_AUTHOR(DRIVER_AUTHOR);
3053 MODULE_DESCRIPTION(DRIVER_DESC);