vfio/type1: massage unmap iteration
[linux-2.6-microblaze.git] / drivers / vfio / vfio_iommu_type1.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * VFIO: IOMMU DMA mapping support for Type1 IOMMU
4  *
5  * Copyright (C) 2012 Red Hat, Inc.  All rights reserved.
6  *     Author: Alex Williamson <alex.williamson@redhat.com>
7  *
8  * Derived from original vfio:
9  * Copyright 2010 Cisco Systems, Inc.  All rights reserved.
10  * Author: Tom Lyon, pugs@cisco.com
11  *
12  * We arbitrarily define a Type1 IOMMU as one matching the below code.
13  * It could be called the x86 IOMMU as it's designed for AMD-Vi & Intel
14  * VT-d, but that makes it harder to re-use as theoretically anyone
15  * implementing a similar IOMMU could make use of this.  We expect the
16  * IOMMU to support the IOMMU API and have few to no restrictions around
17  * the IOVA range that can be mapped.  The Type1 IOMMU is currently
18  * optimized for relatively static mappings of a userspace process with
19  * userpsace pages pinned into memory.  We also assume devices and IOMMU
20  * domains are PCI based as the IOMMU API is still centered around a
21  * device/bus interface rather than a group interface.
22  */
23
24 #include <linux/compat.h>
25 #include <linux/device.h>
26 #include <linux/fs.h>
27 #include <linux/iommu.h>
28 #include <linux/module.h>
29 #include <linux/mm.h>
30 #include <linux/kthread.h>
31 #include <linux/rbtree.h>
32 #include <linux/sched/signal.h>
33 #include <linux/sched/mm.h>
34 #include <linux/slab.h>
35 #include <linux/uaccess.h>
36 #include <linux/vfio.h>
37 #include <linux/workqueue.h>
38 #include <linux/mdev.h>
39 #include <linux/notifier.h>
40 #include <linux/dma-iommu.h>
41 #include <linux/irqdomain.h>
42
43 #define DRIVER_VERSION  "0.2"
44 #define DRIVER_AUTHOR   "Alex Williamson <alex.williamson@redhat.com>"
45 #define DRIVER_DESC     "Type1 IOMMU driver for VFIO"
46
47 static bool allow_unsafe_interrupts;
48 module_param_named(allow_unsafe_interrupts,
49                    allow_unsafe_interrupts, bool, S_IRUGO | S_IWUSR);
50 MODULE_PARM_DESC(allow_unsafe_interrupts,
51                  "Enable VFIO IOMMU support for on platforms without interrupt remapping support.");
52
53 static bool disable_hugepages;
54 module_param_named(disable_hugepages,
55                    disable_hugepages, bool, S_IRUGO | S_IWUSR);
56 MODULE_PARM_DESC(disable_hugepages,
57                  "Disable VFIO IOMMU support for IOMMU hugepages.");
58
59 static unsigned int dma_entry_limit __read_mostly = U16_MAX;
60 module_param_named(dma_entry_limit, dma_entry_limit, uint, 0644);
61 MODULE_PARM_DESC(dma_entry_limit,
62                  "Maximum number of user DMA mappings per container (65535).");
63
64 struct vfio_iommu {
65         struct list_head        domain_list;
66         struct list_head        iova_list;
67         struct vfio_domain      *external_domain; /* domain for external user */
68         struct mutex            lock;
69         struct rb_root          dma_list;
70         struct blocking_notifier_head notifier;
71         unsigned int            dma_avail;
72         uint64_t                pgsize_bitmap;
73         bool                    v2;
74         bool                    nesting;
75         bool                    dirty_page_tracking;
76         bool                    pinned_page_dirty_scope;
77 };
78
79 struct vfio_domain {
80         struct iommu_domain     *domain;
81         struct list_head        next;
82         struct list_head        group_list;
83         int                     prot;           /* IOMMU_CACHE */
84         bool                    fgsp;           /* Fine-grained super pages */
85 };
86
87 struct vfio_dma {
88         struct rb_node          node;
89         dma_addr_t              iova;           /* Device address */
90         unsigned long           vaddr;          /* Process virtual addr */
91         size_t                  size;           /* Map size (bytes) */
92         int                     prot;           /* IOMMU_READ/WRITE */
93         bool                    iommu_mapped;
94         bool                    lock_cap;       /* capable(CAP_IPC_LOCK) */
95         struct task_struct      *task;
96         struct rb_root          pfn_list;       /* Ex-user pinned pfn list */
97         unsigned long           *bitmap;
98 };
99
100 struct vfio_group {
101         struct iommu_group      *iommu_group;
102         struct list_head        next;
103         bool                    mdev_group;     /* An mdev group */
104         bool                    pinned_page_dirty_scope;
105 };
106
107 struct vfio_iova {
108         struct list_head        list;
109         dma_addr_t              start;
110         dma_addr_t              end;
111 };
112
113 /*
114  * Guest RAM pinning working set or DMA target
115  */
116 struct vfio_pfn {
117         struct rb_node          node;
118         dma_addr_t              iova;           /* Device address */
119         unsigned long           pfn;            /* Host pfn */
120         unsigned int            ref_count;
121 };
122
123 struct vfio_regions {
124         struct list_head list;
125         dma_addr_t iova;
126         phys_addr_t phys;
127         size_t len;
128 };
129
130 #define IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu) \
131                                         (!list_empty(&iommu->domain_list))
132
133 #define DIRTY_BITMAP_BYTES(n)   (ALIGN(n, BITS_PER_TYPE(u64)) / BITS_PER_BYTE)
134
135 /*
136  * Input argument of number of bits to bitmap_set() is unsigned integer, which
137  * further casts to signed integer for unaligned multi-bit operation,
138  * __bitmap_set().
139  * Then maximum bitmap size supported is 2^31 bits divided by 2^3 bits/byte,
140  * that is 2^28 (256 MB) which maps to 2^31 * 2^12 = 2^43 (8TB) on 4K page
141  * system.
142  */
143 #define DIRTY_BITMAP_PAGES_MAX   ((u64)INT_MAX)
144 #define DIRTY_BITMAP_SIZE_MAX    DIRTY_BITMAP_BYTES(DIRTY_BITMAP_PAGES_MAX)
145
146 static int put_pfn(unsigned long pfn, int prot);
147
148 static struct vfio_group *vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
149                                                struct iommu_group *iommu_group);
150
151 static void update_pinned_page_dirty_scope(struct vfio_iommu *iommu);
152 /*
153  * This code handles mapping and unmapping of user data buffers
154  * into DMA'ble space using the IOMMU
155  */
156
157 static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
158                                       dma_addr_t start, size_t size)
159 {
160         struct rb_node *node = iommu->dma_list.rb_node;
161
162         while (node) {
163                 struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
164
165                 if (start + size <= dma->iova)
166                         node = node->rb_left;
167                 else if (start >= dma->iova + dma->size)
168                         node = node->rb_right;
169                 else
170                         return dma;
171         }
172
173         return NULL;
174 }
175
176 static struct rb_node *vfio_find_dma_first_node(struct vfio_iommu *iommu,
177                                                 dma_addr_t start, size_t size)
178 {
179         struct rb_node *res = NULL;
180         struct rb_node *node = iommu->dma_list.rb_node;
181         struct vfio_dma *dma_res = NULL;
182
183         while (node) {
184                 struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
185
186                 if (start < dma->iova + dma->size) {
187                         res = node;
188                         dma_res = dma;
189                         if (start >= dma->iova)
190                                 break;
191                         node = node->rb_left;
192                 } else {
193                         node = node->rb_right;
194                 }
195         }
196         if (res && size && dma_res->iova >= start + size)
197                 res = NULL;
198         return res;
199 }
200
201 static void vfio_link_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
202 {
203         struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
204         struct vfio_dma *dma;
205
206         while (*link) {
207                 parent = *link;
208                 dma = rb_entry(parent, struct vfio_dma, node);
209
210                 if (new->iova + new->size <= dma->iova)
211                         link = &(*link)->rb_left;
212                 else
213                         link = &(*link)->rb_right;
214         }
215
216         rb_link_node(&new->node, parent, link);
217         rb_insert_color(&new->node, &iommu->dma_list);
218 }
219
220 static void vfio_unlink_dma(struct vfio_iommu *iommu, struct vfio_dma *old)
221 {
222         rb_erase(&old->node, &iommu->dma_list);
223 }
224
225
226 static int vfio_dma_bitmap_alloc(struct vfio_dma *dma, size_t pgsize)
227 {
228         uint64_t npages = dma->size / pgsize;
229
230         if (npages > DIRTY_BITMAP_PAGES_MAX)
231                 return -EINVAL;
232
233         /*
234          * Allocate extra 64 bits that are used to calculate shift required for
235          * bitmap_shift_left() to manipulate and club unaligned number of pages
236          * in adjacent vfio_dma ranges.
237          */
238         dma->bitmap = kvzalloc(DIRTY_BITMAP_BYTES(npages) + sizeof(u64),
239                                GFP_KERNEL);
240         if (!dma->bitmap)
241                 return -ENOMEM;
242
243         return 0;
244 }
245
246 static void vfio_dma_bitmap_free(struct vfio_dma *dma)
247 {
248         kfree(dma->bitmap);
249         dma->bitmap = NULL;
250 }
251
252 static void vfio_dma_populate_bitmap(struct vfio_dma *dma, size_t pgsize)
253 {
254         struct rb_node *p;
255         unsigned long pgshift = __ffs(pgsize);
256
257         for (p = rb_first(&dma->pfn_list); p; p = rb_next(p)) {
258                 struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn, node);
259
260                 bitmap_set(dma->bitmap, (vpfn->iova - dma->iova) >> pgshift, 1);
261         }
262 }
263
264 static int vfio_dma_bitmap_alloc_all(struct vfio_iommu *iommu, size_t pgsize)
265 {
266         struct rb_node *n;
267
268         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
269                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
270                 int ret;
271
272                 ret = vfio_dma_bitmap_alloc(dma, pgsize);
273                 if (ret) {
274                         struct rb_node *p;
275
276                         for (p = rb_prev(n); p; p = rb_prev(p)) {
277                                 struct vfio_dma *dma = rb_entry(n,
278                                                         struct vfio_dma, node);
279
280                                 vfio_dma_bitmap_free(dma);
281                         }
282                         return ret;
283                 }
284                 vfio_dma_populate_bitmap(dma, pgsize);
285         }
286         return 0;
287 }
288
289 static void vfio_dma_bitmap_free_all(struct vfio_iommu *iommu)
290 {
291         struct rb_node *n;
292
293         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
294                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
295
296                 vfio_dma_bitmap_free(dma);
297         }
298 }
299
300 /*
301  * Helper Functions for host iova-pfn list
302  */
303 static struct vfio_pfn *vfio_find_vpfn(struct vfio_dma *dma, dma_addr_t iova)
304 {
305         struct vfio_pfn *vpfn;
306         struct rb_node *node = dma->pfn_list.rb_node;
307
308         while (node) {
309                 vpfn = rb_entry(node, struct vfio_pfn, node);
310
311                 if (iova < vpfn->iova)
312                         node = node->rb_left;
313                 else if (iova > vpfn->iova)
314                         node = node->rb_right;
315                 else
316                         return vpfn;
317         }
318         return NULL;
319 }
320
321 static void vfio_link_pfn(struct vfio_dma *dma,
322                           struct vfio_pfn *new)
323 {
324         struct rb_node **link, *parent = NULL;
325         struct vfio_pfn *vpfn;
326
327         link = &dma->pfn_list.rb_node;
328         while (*link) {
329                 parent = *link;
330                 vpfn = rb_entry(parent, struct vfio_pfn, node);
331
332                 if (new->iova < vpfn->iova)
333                         link = &(*link)->rb_left;
334                 else
335                         link = &(*link)->rb_right;
336         }
337
338         rb_link_node(&new->node, parent, link);
339         rb_insert_color(&new->node, &dma->pfn_list);
340 }
341
342 static void vfio_unlink_pfn(struct vfio_dma *dma, struct vfio_pfn *old)
343 {
344         rb_erase(&old->node, &dma->pfn_list);
345 }
346
347 static int vfio_add_to_pfn_list(struct vfio_dma *dma, dma_addr_t iova,
348                                 unsigned long pfn)
349 {
350         struct vfio_pfn *vpfn;
351
352         vpfn = kzalloc(sizeof(*vpfn), GFP_KERNEL);
353         if (!vpfn)
354                 return -ENOMEM;
355
356         vpfn->iova = iova;
357         vpfn->pfn = pfn;
358         vpfn->ref_count = 1;
359         vfio_link_pfn(dma, vpfn);
360         return 0;
361 }
362
363 static void vfio_remove_from_pfn_list(struct vfio_dma *dma,
364                                       struct vfio_pfn *vpfn)
365 {
366         vfio_unlink_pfn(dma, vpfn);
367         kfree(vpfn);
368 }
369
370 static struct vfio_pfn *vfio_iova_get_vfio_pfn(struct vfio_dma *dma,
371                                                unsigned long iova)
372 {
373         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
374
375         if (vpfn)
376                 vpfn->ref_count++;
377         return vpfn;
378 }
379
380 static int vfio_iova_put_vfio_pfn(struct vfio_dma *dma, struct vfio_pfn *vpfn)
381 {
382         int ret = 0;
383
384         vpfn->ref_count--;
385         if (!vpfn->ref_count) {
386                 ret = put_pfn(vpfn->pfn, dma->prot);
387                 vfio_remove_from_pfn_list(dma, vpfn);
388         }
389         return ret;
390 }
391
392 static int vfio_lock_acct(struct vfio_dma *dma, long npage, bool async)
393 {
394         struct mm_struct *mm;
395         int ret;
396
397         if (!npage)
398                 return 0;
399
400         mm = async ? get_task_mm(dma->task) : dma->task->mm;
401         if (!mm)
402                 return -ESRCH; /* process exited */
403
404         ret = mmap_write_lock_killable(mm);
405         if (!ret) {
406                 ret = __account_locked_vm(mm, abs(npage), npage > 0, dma->task,
407                                           dma->lock_cap);
408                 mmap_write_unlock(mm);
409         }
410
411         if (async)
412                 mmput(mm);
413
414         return ret;
415 }
416
417 /*
418  * Some mappings aren't backed by a struct page, for example an mmap'd
419  * MMIO range for our own or another device.  These use a different
420  * pfn conversion and shouldn't be tracked as locked pages.
421  * For compound pages, any driver that sets the reserved bit in head
422  * page needs to set the reserved bit in all subpages to be safe.
423  */
424 static bool is_invalid_reserved_pfn(unsigned long pfn)
425 {
426         if (pfn_valid(pfn))
427                 return PageReserved(pfn_to_page(pfn));
428
429         return true;
430 }
431
432 static int put_pfn(unsigned long pfn, int prot)
433 {
434         if (!is_invalid_reserved_pfn(pfn)) {
435                 struct page *page = pfn_to_page(pfn);
436
437                 unpin_user_pages_dirty_lock(&page, 1, prot & IOMMU_WRITE);
438                 return 1;
439         }
440         return 0;
441 }
442
443 static int follow_fault_pfn(struct vm_area_struct *vma, struct mm_struct *mm,
444                             unsigned long vaddr, unsigned long *pfn,
445                             bool write_fault)
446 {
447         int ret;
448
449         ret = follow_pfn(vma, vaddr, pfn);
450         if (ret) {
451                 bool unlocked = false;
452
453                 ret = fixup_user_fault(mm, vaddr,
454                                        FAULT_FLAG_REMOTE |
455                                        (write_fault ?  FAULT_FLAG_WRITE : 0),
456                                        &unlocked);
457                 if (unlocked)
458                         return -EAGAIN;
459
460                 if (ret)
461                         return ret;
462
463                 ret = follow_pfn(vma, vaddr, pfn);
464         }
465
466         return ret;
467 }
468
469 static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
470                          int prot, unsigned long *pfn)
471 {
472         struct page *page[1];
473         struct vm_area_struct *vma;
474         unsigned int flags = 0;
475         int ret;
476
477         if (prot & IOMMU_WRITE)
478                 flags |= FOLL_WRITE;
479
480         mmap_read_lock(mm);
481         ret = pin_user_pages_remote(mm, vaddr, 1, flags | FOLL_LONGTERM,
482                                     page, NULL, NULL);
483         if (ret == 1) {
484                 *pfn = page_to_pfn(page[0]);
485                 ret = 0;
486                 goto done;
487         }
488
489         vaddr = untagged_addr(vaddr);
490
491 retry:
492         vma = find_vma_intersection(mm, vaddr, vaddr + 1);
493
494         if (vma && vma->vm_flags & VM_PFNMAP) {
495                 ret = follow_fault_pfn(vma, mm, vaddr, pfn, prot & IOMMU_WRITE);
496                 if (ret == -EAGAIN)
497                         goto retry;
498
499                 if (!ret && !is_invalid_reserved_pfn(*pfn))
500                         ret = -EFAULT;
501         }
502 done:
503         mmap_read_unlock(mm);
504         return ret;
505 }
506
507 /*
508  * Attempt to pin pages.  We really don't want to track all the pfns and
509  * the iommu can only map chunks of consecutive pfns anyway, so get the
510  * first page and all consecutive pages with the same locking.
511  */
512 static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
513                                   long npage, unsigned long *pfn_base,
514                                   unsigned long limit)
515 {
516         unsigned long pfn = 0;
517         long ret, pinned = 0, lock_acct = 0;
518         bool rsvd;
519         dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
520
521         /* This code path is only user initiated */
522         if (!current->mm)
523                 return -ENODEV;
524
525         ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, pfn_base);
526         if (ret)
527                 return ret;
528
529         pinned++;
530         rsvd = is_invalid_reserved_pfn(*pfn_base);
531
532         /*
533          * Reserved pages aren't counted against the user, externally pinned
534          * pages are already counted against the user.
535          */
536         if (!rsvd && !vfio_find_vpfn(dma, iova)) {
537                 if (!dma->lock_cap && current->mm->locked_vm + 1 > limit) {
538                         put_pfn(*pfn_base, dma->prot);
539                         pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__,
540                                         limit << PAGE_SHIFT);
541                         return -ENOMEM;
542                 }
543                 lock_acct++;
544         }
545
546         if (unlikely(disable_hugepages))
547                 goto out;
548
549         /* Lock all the consecutive pages from pfn_base */
550         for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; pinned < npage;
551              pinned++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) {
552                 ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, &pfn);
553                 if (ret)
554                         break;
555
556                 if (pfn != *pfn_base + pinned ||
557                     rsvd != is_invalid_reserved_pfn(pfn)) {
558                         put_pfn(pfn, dma->prot);
559                         break;
560                 }
561
562                 if (!rsvd && !vfio_find_vpfn(dma, iova)) {
563                         if (!dma->lock_cap &&
564                             current->mm->locked_vm + lock_acct + 1 > limit) {
565                                 put_pfn(pfn, dma->prot);
566                                 pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
567                                         __func__, limit << PAGE_SHIFT);
568                                 ret = -ENOMEM;
569                                 goto unpin_out;
570                         }
571                         lock_acct++;
572                 }
573         }
574
575 out:
576         ret = vfio_lock_acct(dma, lock_acct, false);
577
578 unpin_out:
579         if (ret) {
580                 if (!rsvd) {
581                         for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
582                                 put_pfn(pfn, dma->prot);
583                 }
584
585                 return ret;
586         }
587
588         return pinned;
589 }
590
591 static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
592                                     unsigned long pfn, long npage,
593                                     bool do_accounting)
594 {
595         long unlocked = 0, locked = 0;
596         long i;
597
598         for (i = 0; i < npage; i++, iova += PAGE_SIZE) {
599                 if (put_pfn(pfn++, dma->prot)) {
600                         unlocked++;
601                         if (vfio_find_vpfn(dma, iova))
602                                 locked++;
603                 }
604         }
605
606         if (do_accounting)
607                 vfio_lock_acct(dma, locked - unlocked, true);
608
609         return unlocked;
610 }
611
612 static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
613                                   unsigned long *pfn_base, bool do_accounting)
614 {
615         struct mm_struct *mm;
616         int ret;
617
618         mm = get_task_mm(dma->task);
619         if (!mm)
620                 return -ENODEV;
621
622         ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base);
623         if (!ret && do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
624                 ret = vfio_lock_acct(dma, 1, true);
625                 if (ret) {
626                         put_pfn(*pfn_base, dma->prot);
627                         if (ret == -ENOMEM)
628                                 pr_warn("%s: Task %s (%d) RLIMIT_MEMLOCK "
629                                         "(%ld) exceeded\n", __func__,
630                                         dma->task->comm, task_pid_nr(dma->task),
631                                         task_rlimit(dma->task, RLIMIT_MEMLOCK));
632                 }
633         }
634
635         mmput(mm);
636         return ret;
637 }
638
639 static int vfio_unpin_page_external(struct vfio_dma *dma, dma_addr_t iova,
640                                     bool do_accounting)
641 {
642         int unlocked;
643         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
644
645         if (!vpfn)
646                 return 0;
647
648         unlocked = vfio_iova_put_vfio_pfn(dma, vpfn);
649
650         if (do_accounting)
651                 vfio_lock_acct(dma, -unlocked, true);
652
653         return unlocked;
654 }
655
656 static int vfio_iommu_type1_pin_pages(void *iommu_data,
657                                       struct iommu_group *iommu_group,
658                                       unsigned long *user_pfn,
659                                       int npage, int prot,
660                                       unsigned long *phys_pfn)
661 {
662         struct vfio_iommu *iommu = iommu_data;
663         struct vfio_group *group;
664         int i, j, ret;
665         unsigned long remote_vaddr;
666         struct vfio_dma *dma;
667         bool do_accounting;
668
669         if (!iommu || !user_pfn || !phys_pfn)
670                 return -EINVAL;
671
672         /* Supported for v2 version only */
673         if (!iommu->v2)
674                 return -EACCES;
675
676         mutex_lock(&iommu->lock);
677
678         /* Fail if notifier list is empty */
679         if (!iommu->notifier.head) {
680                 ret = -EINVAL;
681                 goto pin_done;
682         }
683
684         /*
685          * If iommu capable domain exist in the container then all pages are
686          * already pinned and accounted. Accouting should be done if there is no
687          * iommu capable domain in the container.
688          */
689         do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
690
691         for (i = 0; i < npage; i++) {
692                 dma_addr_t iova;
693                 struct vfio_pfn *vpfn;
694
695                 iova = user_pfn[i] << PAGE_SHIFT;
696                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
697                 if (!dma) {
698                         ret = -EINVAL;
699                         goto pin_unwind;
700                 }
701
702                 if ((dma->prot & prot) != prot) {
703                         ret = -EPERM;
704                         goto pin_unwind;
705                 }
706
707                 vpfn = vfio_iova_get_vfio_pfn(dma, iova);
708                 if (vpfn) {
709                         phys_pfn[i] = vpfn->pfn;
710                         continue;
711                 }
712
713                 remote_vaddr = dma->vaddr + (iova - dma->iova);
714                 ret = vfio_pin_page_external(dma, remote_vaddr, &phys_pfn[i],
715                                              do_accounting);
716                 if (ret)
717                         goto pin_unwind;
718
719                 ret = vfio_add_to_pfn_list(dma, iova, phys_pfn[i]);
720                 if (ret) {
721                         if (put_pfn(phys_pfn[i], dma->prot) && do_accounting)
722                                 vfio_lock_acct(dma, -1, true);
723                         goto pin_unwind;
724                 }
725
726                 if (iommu->dirty_page_tracking) {
727                         unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
728
729                         /*
730                          * Bitmap populated with the smallest supported page
731                          * size
732                          */
733                         bitmap_set(dma->bitmap,
734                                    (iova - dma->iova) >> pgshift, 1);
735                 }
736         }
737         ret = i;
738
739         group = vfio_iommu_find_iommu_group(iommu, iommu_group);
740         if (!group->pinned_page_dirty_scope) {
741                 group->pinned_page_dirty_scope = true;
742                 update_pinned_page_dirty_scope(iommu);
743         }
744
745         goto pin_done;
746
747 pin_unwind:
748         phys_pfn[i] = 0;
749         for (j = 0; j < i; j++) {
750                 dma_addr_t iova;
751
752                 iova = user_pfn[j] << PAGE_SHIFT;
753                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
754                 vfio_unpin_page_external(dma, iova, do_accounting);
755                 phys_pfn[j] = 0;
756         }
757 pin_done:
758         mutex_unlock(&iommu->lock);
759         return ret;
760 }
761
762 static int vfio_iommu_type1_unpin_pages(void *iommu_data,
763                                         unsigned long *user_pfn,
764                                         int npage)
765 {
766         struct vfio_iommu *iommu = iommu_data;
767         bool do_accounting;
768         int i;
769
770         if (!iommu || !user_pfn)
771                 return -EINVAL;
772
773         /* Supported for v2 version only */
774         if (!iommu->v2)
775                 return -EACCES;
776
777         mutex_lock(&iommu->lock);
778
779         do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
780         for (i = 0; i < npage; i++) {
781                 struct vfio_dma *dma;
782                 dma_addr_t iova;
783
784                 iova = user_pfn[i] << PAGE_SHIFT;
785                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
786                 if (!dma)
787                         goto unpin_exit;
788                 vfio_unpin_page_external(dma, iova, do_accounting);
789         }
790
791 unpin_exit:
792         mutex_unlock(&iommu->lock);
793         return i > npage ? npage : (i > 0 ? i : -EINVAL);
794 }
795
796 static long vfio_sync_unpin(struct vfio_dma *dma, struct vfio_domain *domain,
797                             struct list_head *regions,
798                             struct iommu_iotlb_gather *iotlb_gather)
799 {
800         long unlocked = 0;
801         struct vfio_regions *entry, *next;
802
803         iommu_iotlb_sync(domain->domain, iotlb_gather);
804
805         list_for_each_entry_safe(entry, next, regions, list) {
806                 unlocked += vfio_unpin_pages_remote(dma,
807                                                     entry->iova,
808                                                     entry->phys >> PAGE_SHIFT,
809                                                     entry->len >> PAGE_SHIFT,
810                                                     false);
811                 list_del(&entry->list);
812                 kfree(entry);
813         }
814
815         cond_resched();
816
817         return unlocked;
818 }
819
820 /*
821  * Generally, VFIO needs to unpin remote pages after each IOTLB flush.
822  * Therefore, when using IOTLB flush sync interface, VFIO need to keep track
823  * of these regions (currently using a list).
824  *
825  * This value specifies maximum number of regions for each IOTLB flush sync.
826  */
827 #define VFIO_IOMMU_TLB_SYNC_MAX         512
828
829 static size_t unmap_unpin_fast(struct vfio_domain *domain,
830                                struct vfio_dma *dma, dma_addr_t *iova,
831                                size_t len, phys_addr_t phys, long *unlocked,
832                                struct list_head *unmapped_list,
833                                int *unmapped_cnt,
834                                struct iommu_iotlb_gather *iotlb_gather)
835 {
836         size_t unmapped = 0;
837         struct vfio_regions *entry = kzalloc(sizeof(*entry), GFP_KERNEL);
838
839         if (entry) {
840                 unmapped = iommu_unmap_fast(domain->domain, *iova, len,
841                                             iotlb_gather);
842
843                 if (!unmapped) {
844                         kfree(entry);
845                 } else {
846                         entry->iova = *iova;
847                         entry->phys = phys;
848                         entry->len  = unmapped;
849                         list_add_tail(&entry->list, unmapped_list);
850
851                         *iova += unmapped;
852                         (*unmapped_cnt)++;
853                 }
854         }
855
856         /*
857          * Sync if the number of fast-unmap regions hits the limit
858          * or in case of errors.
859          */
860         if (*unmapped_cnt >= VFIO_IOMMU_TLB_SYNC_MAX || !unmapped) {
861                 *unlocked += vfio_sync_unpin(dma, domain, unmapped_list,
862                                              iotlb_gather);
863                 *unmapped_cnt = 0;
864         }
865
866         return unmapped;
867 }
868
869 static size_t unmap_unpin_slow(struct vfio_domain *domain,
870                                struct vfio_dma *dma, dma_addr_t *iova,
871                                size_t len, phys_addr_t phys,
872                                long *unlocked)
873 {
874         size_t unmapped = iommu_unmap(domain->domain, *iova, len);
875
876         if (unmapped) {
877                 *unlocked += vfio_unpin_pages_remote(dma, *iova,
878                                                      phys >> PAGE_SHIFT,
879                                                      unmapped >> PAGE_SHIFT,
880                                                      false);
881                 *iova += unmapped;
882                 cond_resched();
883         }
884         return unmapped;
885 }
886
887 static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
888                              bool do_accounting)
889 {
890         dma_addr_t iova = dma->iova, end = dma->iova + dma->size;
891         struct vfio_domain *domain, *d;
892         LIST_HEAD(unmapped_region_list);
893         struct iommu_iotlb_gather iotlb_gather;
894         int unmapped_region_cnt = 0;
895         long unlocked = 0;
896
897         if (!dma->size)
898                 return 0;
899
900         if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
901                 return 0;
902
903         /*
904          * We use the IOMMU to track the physical addresses, otherwise we'd
905          * need a much more complicated tracking system.  Unfortunately that
906          * means we need to use one of the iommu domains to figure out the
907          * pfns to unpin.  The rest need to be unmapped in advance so we have
908          * no iommu translations remaining when the pages are unpinned.
909          */
910         domain = d = list_first_entry(&iommu->domain_list,
911                                       struct vfio_domain, next);
912
913         list_for_each_entry_continue(d, &iommu->domain_list, next) {
914                 iommu_unmap(d->domain, dma->iova, dma->size);
915                 cond_resched();
916         }
917
918         iommu_iotlb_gather_init(&iotlb_gather);
919         while (iova < end) {
920                 size_t unmapped, len;
921                 phys_addr_t phys, next;
922
923                 phys = iommu_iova_to_phys(domain->domain, iova);
924                 if (WARN_ON(!phys)) {
925                         iova += PAGE_SIZE;
926                         continue;
927                 }
928
929                 /*
930                  * To optimize for fewer iommu_unmap() calls, each of which
931                  * may require hardware cache flushing, try to find the
932                  * largest contiguous physical memory chunk to unmap.
933                  */
934                 for (len = PAGE_SIZE;
935                      !domain->fgsp && iova + len < end; len += PAGE_SIZE) {
936                         next = iommu_iova_to_phys(domain->domain, iova + len);
937                         if (next != phys + len)
938                                 break;
939                 }
940
941                 /*
942                  * First, try to use fast unmap/unpin. In case of failure,
943                  * switch to slow unmap/unpin path.
944                  */
945                 unmapped = unmap_unpin_fast(domain, dma, &iova, len, phys,
946                                             &unlocked, &unmapped_region_list,
947                                             &unmapped_region_cnt,
948                                             &iotlb_gather);
949                 if (!unmapped) {
950                         unmapped = unmap_unpin_slow(domain, dma, &iova, len,
951                                                     phys, &unlocked);
952                         if (WARN_ON(!unmapped))
953                                 break;
954                 }
955         }
956
957         dma->iommu_mapped = false;
958
959         if (unmapped_region_cnt) {
960                 unlocked += vfio_sync_unpin(dma, domain, &unmapped_region_list,
961                                             &iotlb_gather);
962         }
963
964         if (do_accounting) {
965                 vfio_lock_acct(dma, -unlocked, true);
966                 return 0;
967         }
968         return unlocked;
969 }
970
971 static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
972 {
973         vfio_unmap_unpin(iommu, dma, true);
974         vfio_unlink_dma(iommu, dma);
975         put_task_struct(dma->task);
976         vfio_dma_bitmap_free(dma);
977         kfree(dma);
978         iommu->dma_avail++;
979 }
980
981 static void vfio_update_pgsize_bitmap(struct vfio_iommu *iommu)
982 {
983         struct vfio_domain *domain;
984
985         iommu->pgsize_bitmap = ULONG_MAX;
986
987         list_for_each_entry(domain, &iommu->domain_list, next)
988                 iommu->pgsize_bitmap &= domain->domain->pgsize_bitmap;
989
990         /*
991          * In case the IOMMU supports page sizes smaller than PAGE_SIZE
992          * we pretend PAGE_SIZE is supported and hide sub-PAGE_SIZE sizes.
993          * That way the user will be able to map/unmap buffers whose size/
994          * start address is aligned with PAGE_SIZE. Pinning code uses that
995          * granularity while iommu driver can use the sub-PAGE_SIZE size
996          * to map the buffer.
997          */
998         if (iommu->pgsize_bitmap & ~PAGE_MASK) {
999                 iommu->pgsize_bitmap &= PAGE_MASK;
1000                 iommu->pgsize_bitmap |= PAGE_SIZE;
1001         }
1002 }
1003
1004 static int update_user_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
1005                               struct vfio_dma *dma, dma_addr_t base_iova,
1006                               size_t pgsize)
1007 {
1008         unsigned long pgshift = __ffs(pgsize);
1009         unsigned long nbits = dma->size >> pgshift;
1010         unsigned long bit_offset = (dma->iova - base_iova) >> pgshift;
1011         unsigned long copy_offset = bit_offset / BITS_PER_LONG;
1012         unsigned long shift = bit_offset % BITS_PER_LONG;
1013         unsigned long leftover;
1014
1015         /*
1016          * mark all pages dirty if any IOMMU capable device is not able
1017          * to report dirty pages and all pages are pinned and mapped.
1018          */
1019         if (!iommu->pinned_page_dirty_scope && dma->iommu_mapped)
1020                 bitmap_set(dma->bitmap, 0, nbits);
1021
1022         if (shift) {
1023                 bitmap_shift_left(dma->bitmap, dma->bitmap, shift,
1024                                   nbits + shift);
1025
1026                 if (copy_from_user(&leftover,
1027                                    (void __user *)(bitmap + copy_offset),
1028                                    sizeof(leftover)))
1029                         return -EFAULT;
1030
1031                 bitmap_or(dma->bitmap, dma->bitmap, &leftover, shift);
1032         }
1033
1034         if (copy_to_user((void __user *)(bitmap + copy_offset), dma->bitmap,
1035                          DIRTY_BITMAP_BYTES(nbits + shift)))
1036                 return -EFAULT;
1037
1038         return 0;
1039 }
1040
1041 static int vfio_iova_dirty_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
1042                                   dma_addr_t iova, size_t size, size_t pgsize)
1043 {
1044         struct vfio_dma *dma;
1045         struct rb_node *n;
1046         unsigned long pgshift = __ffs(pgsize);
1047         int ret;
1048
1049         /*
1050          * GET_BITMAP request must fully cover vfio_dma mappings.  Multiple
1051          * vfio_dma mappings may be clubbed by specifying large ranges, but
1052          * there must not be any previous mappings bisected by the range.
1053          * An error will be returned if these conditions are not met.
1054          */
1055         dma = vfio_find_dma(iommu, iova, 1);
1056         if (dma && dma->iova != iova)
1057                 return -EINVAL;
1058
1059         dma = vfio_find_dma(iommu, iova + size - 1, 0);
1060         if (dma && dma->iova + dma->size != iova + size)
1061                 return -EINVAL;
1062
1063         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
1064                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1065
1066                 if (dma->iova < iova)
1067                         continue;
1068
1069                 if (dma->iova > iova + size - 1)
1070                         break;
1071
1072                 ret = update_user_bitmap(bitmap, iommu, dma, iova, pgsize);
1073                 if (ret)
1074                         return ret;
1075
1076                 /*
1077                  * Re-populate bitmap to include all pinned pages which are
1078                  * considered as dirty but exclude pages which are unpinned and
1079                  * pages which are marked dirty by vfio_dma_rw()
1080                  */
1081                 bitmap_clear(dma->bitmap, 0, dma->size >> pgshift);
1082                 vfio_dma_populate_bitmap(dma, pgsize);
1083         }
1084         return 0;
1085 }
1086
1087 static int verify_bitmap_size(uint64_t npages, uint64_t bitmap_size)
1088 {
1089         if (!npages || !bitmap_size || (bitmap_size > DIRTY_BITMAP_SIZE_MAX) ||
1090             (bitmap_size < DIRTY_BITMAP_BYTES(npages)))
1091                 return -EINVAL;
1092
1093         return 0;
1094 }
1095
1096 static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
1097                              struct vfio_iommu_type1_dma_unmap *unmap,
1098                              struct vfio_bitmap *bitmap)
1099 {
1100         struct vfio_dma *dma, *dma_last = NULL;
1101         size_t unmapped = 0, pgsize;
1102         int ret = -EINVAL, retries = 0;
1103         unsigned long pgshift;
1104         dma_addr_t iova = unmap->iova;
1105         unsigned long size = unmap->size;
1106         bool unmap_all = unmap->flags & VFIO_DMA_UNMAP_FLAG_ALL;
1107         struct rb_node *n;
1108
1109         mutex_lock(&iommu->lock);
1110
1111         pgshift = __ffs(iommu->pgsize_bitmap);
1112         pgsize = (size_t)1 << pgshift;
1113
1114         if (iova & (pgsize - 1))
1115                 goto unlock;
1116
1117         if (unmap_all) {
1118                 if (iova || size)
1119                         goto unlock;
1120                 size = SIZE_MAX;
1121         } else if (!size || size & (pgsize - 1)) {
1122                 goto unlock;
1123         }
1124
1125         if (iova + size - 1 < iova || size > SIZE_MAX)
1126                 goto unlock;
1127
1128         /* When dirty tracking is enabled, allow only min supported pgsize */
1129         if ((unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
1130             (!iommu->dirty_page_tracking || (bitmap->pgsize != pgsize))) {
1131                 goto unlock;
1132         }
1133
1134         WARN_ON((pgsize - 1) & PAGE_MASK);
1135 again:
1136         /*
1137          * vfio-iommu-type1 (v1) - User mappings were coalesced together to
1138          * avoid tracking individual mappings.  This means that the granularity
1139          * of the original mapping was lost and the user was allowed to attempt
1140          * to unmap any range.  Depending on the contiguousness of physical
1141          * memory and page sizes supported by the IOMMU, arbitrary unmaps may
1142          * or may not have worked.  We only guaranteed unmap granularity
1143          * matching the original mapping; even though it was untracked here,
1144          * the original mappings are reflected in IOMMU mappings.  This
1145          * resulted in a couple unusual behaviors.  First, if a range is not
1146          * able to be unmapped, ex. a set of 4k pages that was mapped as a
1147          * 2M hugepage into the IOMMU, the unmap ioctl returns success but with
1148          * a zero sized unmap.  Also, if an unmap request overlaps the first
1149          * address of a hugepage, the IOMMU will unmap the entire hugepage.
1150          * This also returns success and the returned unmap size reflects the
1151          * actual size unmapped.
1152          *
1153          * We attempt to maintain compatibility with this "v1" interface, but
1154          * we take control out of the hands of the IOMMU.  Therefore, an unmap
1155          * request offset from the beginning of the original mapping will
1156          * return success with zero sized unmap.  And an unmap request covering
1157          * the first iova of mapping will unmap the entire range.
1158          *
1159          * The v2 version of this interface intends to be more deterministic.
1160          * Unmap requests must fully cover previous mappings.  Multiple
1161          * mappings may still be unmaped by specifying large ranges, but there
1162          * must not be any previous mappings bisected by the range.  An error
1163          * will be returned if these conditions are not met.  The v2 interface
1164          * will only return success and a size of zero if there were no
1165          * mappings within the range.
1166          */
1167         if (iommu->v2 && !unmap_all) {
1168                 dma = vfio_find_dma(iommu, iova, 1);
1169                 if (dma && dma->iova != iova)
1170                         goto unlock;
1171
1172                 dma = vfio_find_dma(iommu, iova + size - 1, 0);
1173                 if (dma && dma->iova + dma->size != iova + size)
1174                         goto unlock;
1175         }
1176
1177         ret = 0;
1178         n = vfio_find_dma_first_node(iommu, iova, size);
1179
1180         while (n) {
1181                 dma = rb_entry(n, struct vfio_dma, node);
1182                 if (dma->iova >= iova + size)
1183                         break;
1184
1185                 if (!iommu->v2 && iova > dma->iova)
1186                         break;
1187                 /*
1188                  * Task with same address space who mapped this iova range is
1189                  * allowed to unmap the iova range.
1190                  */
1191                 if (dma->task->mm != current->mm)
1192                         break;
1193
1194                 if (!RB_EMPTY_ROOT(&dma->pfn_list)) {
1195                         struct vfio_iommu_type1_dma_unmap nb_unmap;
1196
1197                         if (dma_last == dma) {
1198                                 BUG_ON(++retries > 10);
1199                         } else {
1200                                 dma_last = dma;
1201                                 retries = 0;
1202                         }
1203
1204                         nb_unmap.iova = dma->iova;
1205                         nb_unmap.size = dma->size;
1206
1207                         /*
1208                          * Notify anyone (mdev vendor drivers) to invalidate and
1209                          * unmap iovas within the range we're about to unmap.
1210                          * Vendor drivers MUST unpin pages in response to an
1211                          * invalidation.
1212                          */
1213                         mutex_unlock(&iommu->lock);
1214                         blocking_notifier_call_chain(&iommu->notifier,
1215                                                     VFIO_IOMMU_NOTIFY_DMA_UNMAP,
1216                                                     &nb_unmap);
1217                         mutex_lock(&iommu->lock);
1218                         goto again;
1219                 }
1220
1221                 if (unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
1222                         ret = update_user_bitmap(bitmap->data, iommu, dma,
1223                                                  iova, pgsize);
1224                         if (ret)
1225                                 break;
1226                 }
1227
1228                 unmapped += dma->size;
1229                 n = rb_next(n);
1230                 vfio_remove_dma(iommu, dma);
1231         }
1232
1233 unlock:
1234         mutex_unlock(&iommu->lock);
1235
1236         /* Report how much was unmapped */
1237         unmap->size = unmapped;
1238
1239         return ret;
1240 }
1241
1242 static int vfio_iommu_map(struct vfio_iommu *iommu, dma_addr_t iova,
1243                           unsigned long pfn, long npage, int prot)
1244 {
1245         struct vfio_domain *d;
1246         int ret;
1247
1248         list_for_each_entry(d, &iommu->domain_list, next) {
1249                 ret = iommu_map(d->domain, iova, (phys_addr_t)pfn << PAGE_SHIFT,
1250                                 npage << PAGE_SHIFT, prot | d->prot);
1251                 if (ret)
1252                         goto unwind;
1253
1254                 cond_resched();
1255         }
1256
1257         return 0;
1258
1259 unwind:
1260         list_for_each_entry_continue_reverse(d, &iommu->domain_list, next) {
1261                 iommu_unmap(d->domain, iova, npage << PAGE_SHIFT);
1262                 cond_resched();
1263         }
1264
1265         return ret;
1266 }
1267
1268 static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
1269                             size_t map_size)
1270 {
1271         dma_addr_t iova = dma->iova;
1272         unsigned long vaddr = dma->vaddr;
1273         size_t size = map_size;
1274         long npage;
1275         unsigned long pfn, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1276         int ret = 0;
1277
1278         while (size) {
1279                 /* Pin a contiguous chunk of memory */
1280                 npage = vfio_pin_pages_remote(dma, vaddr + dma->size,
1281                                               size >> PAGE_SHIFT, &pfn, limit);
1282                 if (npage <= 0) {
1283                         WARN_ON(!npage);
1284                         ret = (int)npage;
1285                         break;
1286                 }
1287
1288                 /* Map it! */
1289                 ret = vfio_iommu_map(iommu, iova + dma->size, pfn, npage,
1290                                      dma->prot);
1291                 if (ret) {
1292                         vfio_unpin_pages_remote(dma, iova + dma->size, pfn,
1293                                                 npage, true);
1294                         break;
1295                 }
1296
1297                 size -= npage << PAGE_SHIFT;
1298                 dma->size += npage << PAGE_SHIFT;
1299         }
1300
1301         dma->iommu_mapped = true;
1302
1303         if (ret)
1304                 vfio_remove_dma(iommu, dma);
1305
1306         return ret;
1307 }
1308
1309 /*
1310  * Check dma map request is within a valid iova range
1311  */
1312 static bool vfio_iommu_iova_dma_valid(struct vfio_iommu *iommu,
1313                                       dma_addr_t start, dma_addr_t end)
1314 {
1315         struct list_head *iova = &iommu->iova_list;
1316         struct vfio_iova *node;
1317
1318         list_for_each_entry(node, iova, list) {
1319                 if (start >= node->start && end <= node->end)
1320                         return true;
1321         }
1322
1323         /*
1324          * Check for list_empty() as well since a container with
1325          * a single mdev device will have an empty list.
1326          */
1327         return list_empty(iova);
1328 }
1329
1330 static int vfio_dma_do_map(struct vfio_iommu *iommu,
1331                            struct vfio_iommu_type1_dma_map *map)
1332 {
1333         dma_addr_t iova = map->iova;
1334         unsigned long vaddr = map->vaddr;
1335         size_t size = map->size;
1336         int ret = 0, prot = 0;
1337         size_t pgsize;
1338         struct vfio_dma *dma;
1339
1340         /* Verify that none of our __u64 fields overflow */
1341         if (map->size != size || map->vaddr != vaddr || map->iova != iova)
1342                 return -EINVAL;
1343
1344         /* READ/WRITE from device perspective */
1345         if (map->flags & VFIO_DMA_MAP_FLAG_WRITE)
1346                 prot |= IOMMU_WRITE;
1347         if (map->flags & VFIO_DMA_MAP_FLAG_READ)
1348                 prot |= IOMMU_READ;
1349
1350         mutex_lock(&iommu->lock);
1351
1352         pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
1353
1354         WARN_ON((pgsize - 1) & PAGE_MASK);
1355
1356         if (!prot || !size || (size | iova | vaddr) & (pgsize - 1)) {
1357                 ret = -EINVAL;
1358                 goto out_unlock;
1359         }
1360
1361         /* Don't allow IOVA or virtual address wrap */
1362         if (iova + size - 1 < iova || vaddr + size - 1 < vaddr) {
1363                 ret = -EINVAL;
1364                 goto out_unlock;
1365         }
1366
1367         if (vfio_find_dma(iommu, iova, size)) {
1368                 ret = -EEXIST;
1369                 goto out_unlock;
1370         }
1371
1372         if (!iommu->dma_avail) {
1373                 ret = -ENOSPC;
1374                 goto out_unlock;
1375         }
1376
1377         if (!vfio_iommu_iova_dma_valid(iommu, iova, iova + size - 1)) {
1378                 ret = -EINVAL;
1379                 goto out_unlock;
1380         }
1381
1382         dma = kzalloc(sizeof(*dma), GFP_KERNEL);
1383         if (!dma) {
1384                 ret = -ENOMEM;
1385                 goto out_unlock;
1386         }
1387
1388         iommu->dma_avail--;
1389         dma->iova = iova;
1390         dma->vaddr = vaddr;
1391         dma->prot = prot;
1392
1393         /*
1394          * We need to be able to both add to a task's locked memory and test
1395          * against the locked memory limit and we need to be able to do both
1396          * outside of this call path as pinning can be asynchronous via the
1397          * external interfaces for mdev devices.  RLIMIT_MEMLOCK requires a
1398          * task_struct and VM locked pages requires an mm_struct, however
1399          * holding an indefinite mm reference is not recommended, therefore we
1400          * only hold a reference to a task.  We could hold a reference to
1401          * current, however QEMU uses this call path through vCPU threads,
1402          * which can be killed resulting in a NULL mm and failure in the unmap
1403          * path when called via a different thread.  Avoid this problem by
1404          * using the group_leader as threads within the same group require
1405          * both CLONE_THREAD and CLONE_VM and will therefore use the same
1406          * mm_struct.
1407          *
1408          * Previously we also used the task for testing CAP_IPC_LOCK at the
1409          * time of pinning and accounting, however has_capability() makes use
1410          * of real_cred, a copy-on-write field, so we can't guarantee that it
1411          * matches group_leader, or in fact that it might not change by the
1412          * time it's evaluated.  If a process were to call MAP_DMA with
1413          * CAP_IPC_LOCK but later drop it, it doesn't make sense that they
1414          * possibly see different results for an iommu_mapped vfio_dma vs
1415          * externally mapped.  Therefore track CAP_IPC_LOCK in vfio_dma at the
1416          * time of calling MAP_DMA.
1417          */
1418         get_task_struct(current->group_leader);
1419         dma->task = current->group_leader;
1420         dma->lock_cap = capable(CAP_IPC_LOCK);
1421
1422         dma->pfn_list = RB_ROOT;
1423
1424         /* Insert zero-sized and grow as we map chunks of it */
1425         vfio_link_dma(iommu, dma);
1426
1427         /* Don't pin and map if container doesn't contain IOMMU capable domain*/
1428         if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1429                 dma->size = size;
1430         else
1431                 ret = vfio_pin_map_dma(iommu, dma, size);
1432
1433         if (!ret && iommu->dirty_page_tracking) {
1434                 ret = vfio_dma_bitmap_alloc(dma, pgsize);
1435                 if (ret)
1436                         vfio_remove_dma(iommu, dma);
1437         }
1438
1439 out_unlock:
1440         mutex_unlock(&iommu->lock);
1441         return ret;
1442 }
1443
1444 static int vfio_bus_type(struct device *dev, void *data)
1445 {
1446         struct bus_type **bus = data;
1447
1448         if (*bus && *bus != dev->bus)
1449                 return -EINVAL;
1450
1451         *bus = dev->bus;
1452
1453         return 0;
1454 }
1455
1456 static int vfio_iommu_replay(struct vfio_iommu *iommu,
1457                              struct vfio_domain *domain)
1458 {
1459         struct vfio_domain *d = NULL;
1460         struct rb_node *n;
1461         unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1462         int ret;
1463
1464         /* Arbitrarily pick the first domain in the list for lookups */
1465         if (!list_empty(&iommu->domain_list))
1466                 d = list_first_entry(&iommu->domain_list,
1467                                      struct vfio_domain, next);
1468
1469         n = rb_first(&iommu->dma_list);
1470
1471         for (; n; n = rb_next(n)) {
1472                 struct vfio_dma *dma;
1473                 dma_addr_t iova;
1474
1475                 dma = rb_entry(n, struct vfio_dma, node);
1476                 iova = dma->iova;
1477
1478                 while (iova < dma->iova + dma->size) {
1479                         phys_addr_t phys;
1480                         size_t size;
1481
1482                         if (dma->iommu_mapped) {
1483                                 phys_addr_t p;
1484                                 dma_addr_t i;
1485
1486                                 if (WARN_ON(!d)) { /* mapped w/o a domain?! */
1487                                         ret = -EINVAL;
1488                                         goto unwind;
1489                                 }
1490
1491                                 phys = iommu_iova_to_phys(d->domain, iova);
1492
1493                                 if (WARN_ON(!phys)) {
1494                                         iova += PAGE_SIZE;
1495                                         continue;
1496                                 }
1497
1498                                 size = PAGE_SIZE;
1499                                 p = phys + size;
1500                                 i = iova + size;
1501                                 while (i < dma->iova + dma->size &&
1502                                        p == iommu_iova_to_phys(d->domain, i)) {
1503                                         size += PAGE_SIZE;
1504                                         p += PAGE_SIZE;
1505                                         i += PAGE_SIZE;
1506                                 }
1507                         } else {
1508                                 unsigned long pfn;
1509                                 unsigned long vaddr = dma->vaddr +
1510                                                      (iova - dma->iova);
1511                                 size_t n = dma->iova + dma->size - iova;
1512                                 long npage;
1513
1514                                 npage = vfio_pin_pages_remote(dma, vaddr,
1515                                                               n >> PAGE_SHIFT,
1516                                                               &pfn, limit);
1517                                 if (npage <= 0) {
1518                                         WARN_ON(!npage);
1519                                         ret = (int)npage;
1520                                         goto unwind;
1521                                 }
1522
1523                                 phys = pfn << PAGE_SHIFT;
1524                                 size = npage << PAGE_SHIFT;
1525                         }
1526
1527                         ret = iommu_map(domain->domain, iova, phys,
1528                                         size, dma->prot | domain->prot);
1529                         if (ret) {
1530                                 if (!dma->iommu_mapped)
1531                                         vfio_unpin_pages_remote(dma, iova,
1532                                                         phys >> PAGE_SHIFT,
1533                                                         size >> PAGE_SHIFT,
1534                                                         true);
1535                                 goto unwind;
1536                         }
1537
1538                         iova += size;
1539                 }
1540         }
1541
1542         /* All dmas are now mapped, defer to second tree walk for unwind */
1543         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
1544                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1545
1546                 dma->iommu_mapped = true;
1547         }
1548
1549         return 0;
1550
1551 unwind:
1552         for (; n; n = rb_prev(n)) {
1553                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1554                 dma_addr_t iova;
1555
1556                 if (dma->iommu_mapped) {
1557                         iommu_unmap(domain->domain, dma->iova, dma->size);
1558                         continue;
1559                 }
1560
1561                 iova = dma->iova;
1562                 while (iova < dma->iova + dma->size) {
1563                         phys_addr_t phys, p;
1564                         size_t size;
1565                         dma_addr_t i;
1566
1567                         phys = iommu_iova_to_phys(domain->domain, iova);
1568                         if (!phys) {
1569                                 iova += PAGE_SIZE;
1570                                 continue;
1571                         }
1572
1573                         size = PAGE_SIZE;
1574                         p = phys + size;
1575                         i = iova + size;
1576                         while (i < dma->iova + dma->size &&
1577                                p == iommu_iova_to_phys(domain->domain, i)) {
1578                                 size += PAGE_SIZE;
1579                                 p += PAGE_SIZE;
1580                                 i += PAGE_SIZE;
1581                         }
1582
1583                         iommu_unmap(domain->domain, iova, size);
1584                         vfio_unpin_pages_remote(dma, iova, phys >> PAGE_SHIFT,
1585                                                 size >> PAGE_SHIFT, true);
1586                 }
1587         }
1588
1589         return ret;
1590 }
1591
1592 /*
1593  * We change our unmap behavior slightly depending on whether the IOMMU
1594  * supports fine-grained superpages.  IOMMUs like AMD-Vi will use a superpage
1595  * for practically any contiguous power-of-two mapping we give it.  This means
1596  * we don't need to look for contiguous chunks ourselves to make unmapping
1597  * more efficient.  On IOMMUs with coarse-grained super pages, like Intel VT-d
1598  * with discrete 2M/1G/512G/1T superpages, identifying contiguous chunks
1599  * significantly boosts non-hugetlbfs mappings and doesn't seem to hurt when
1600  * hugetlbfs is in use.
1601  */
1602 static void vfio_test_domain_fgsp(struct vfio_domain *domain)
1603 {
1604         struct page *pages;
1605         int ret, order = get_order(PAGE_SIZE * 2);
1606
1607         pages = alloc_pages(GFP_KERNEL | __GFP_ZERO, order);
1608         if (!pages)
1609                 return;
1610
1611         ret = iommu_map(domain->domain, 0, page_to_phys(pages), PAGE_SIZE * 2,
1612                         IOMMU_READ | IOMMU_WRITE | domain->prot);
1613         if (!ret) {
1614                 size_t unmapped = iommu_unmap(domain->domain, 0, PAGE_SIZE);
1615
1616                 if (unmapped == PAGE_SIZE)
1617                         iommu_unmap(domain->domain, PAGE_SIZE, PAGE_SIZE);
1618                 else
1619                         domain->fgsp = true;
1620         }
1621
1622         __free_pages(pages, order);
1623 }
1624
1625 static struct vfio_group *find_iommu_group(struct vfio_domain *domain,
1626                                            struct iommu_group *iommu_group)
1627 {
1628         struct vfio_group *g;
1629
1630         list_for_each_entry(g, &domain->group_list, next) {
1631                 if (g->iommu_group == iommu_group)
1632                         return g;
1633         }
1634
1635         return NULL;
1636 }
1637
1638 static struct vfio_group *vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
1639                                                struct iommu_group *iommu_group)
1640 {
1641         struct vfio_domain *domain;
1642         struct vfio_group *group = NULL;
1643
1644         list_for_each_entry(domain, &iommu->domain_list, next) {
1645                 group = find_iommu_group(domain, iommu_group);
1646                 if (group)
1647                         return group;
1648         }
1649
1650         if (iommu->external_domain)
1651                 group = find_iommu_group(iommu->external_domain, iommu_group);
1652
1653         return group;
1654 }
1655
1656 static void update_pinned_page_dirty_scope(struct vfio_iommu *iommu)
1657 {
1658         struct vfio_domain *domain;
1659         struct vfio_group *group;
1660
1661         list_for_each_entry(domain, &iommu->domain_list, next) {
1662                 list_for_each_entry(group, &domain->group_list, next) {
1663                         if (!group->pinned_page_dirty_scope) {
1664                                 iommu->pinned_page_dirty_scope = false;
1665                                 return;
1666                         }
1667                 }
1668         }
1669
1670         if (iommu->external_domain) {
1671                 domain = iommu->external_domain;
1672                 list_for_each_entry(group, &domain->group_list, next) {
1673                         if (!group->pinned_page_dirty_scope) {
1674                                 iommu->pinned_page_dirty_scope = false;
1675                                 return;
1676                         }
1677                 }
1678         }
1679
1680         iommu->pinned_page_dirty_scope = true;
1681 }
1682
1683 static bool vfio_iommu_has_sw_msi(struct list_head *group_resv_regions,
1684                                   phys_addr_t *base)
1685 {
1686         struct iommu_resv_region *region;
1687         bool ret = false;
1688
1689         list_for_each_entry(region, group_resv_regions, list) {
1690                 /*
1691                  * The presence of any 'real' MSI regions should take
1692                  * precedence over the software-managed one if the
1693                  * IOMMU driver happens to advertise both types.
1694                  */
1695                 if (region->type == IOMMU_RESV_MSI) {
1696                         ret = false;
1697                         break;
1698                 }
1699
1700                 if (region->type == IOMMU_RESV_SW_MSI) {
1701                         *base = region->start;
1702                         ret = true;
1703                 }
1704         }
1705
1706         return ret;
1707 }
1708
1709 static struct device *vfio_mdev_get_iommu_device(struct device *dev)
1710 {
1711         struct device *(*fn)(struct device *dev);
1712         struct device *iommu_device;
1713
1714         fn = symbol_get(mdev_get_iommu_device);
1715         if (fn) {
1716                 iommu_device = fn(dev);
1717                 symbol_put(mdev_get_iommu_device);
1718
1719                 return iommu_device;
1720         }
1721
1722         return NULL;
1723 }
1724
1725 static int vfio_mdev_attach_domain(struct device *dev, void *data)
1726 {
1727         struct iommu_domain *domain = data;
1728         struct device *iommu_device;
1729
1730         iommu_device = vfio_mdev_get_iommu_device(dev);
1731         if (iommu_device) {
1732                 if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
1733                         return iommu_aux_attach_device(domain, iommu_device);
1734                 else
1735                         return iommu_attach_device(domain, iommu_device);
1736         }
1737
1738         return -EINVAL;
1739 }
1740
1741 static int vfio_mdev_detach_domain(struct device *dev, void *data)
1742 {
1743         struct iommu_domain *domain = data;
1744         struct device *iommu_device;
1745
1746         iommu_device = vfio_mdev_get_iommu_device(dev);
1747         if (iommu_device) {
1748                 if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
1749                         iommu_aux_detach_device(domain, iommu_device);
1750                 else
1751                         iommu_detach_device(domain, iommu_device);
1752         }
1753
1754         return 0;
1755 }
1756
1757 static int vfio_iommu_attach_group(struct vfio_domain *domain,
1758                                    struct vfio_group *group)
1759 {
1760         if (group->mdev_group)
1761                 return iommu_group_for_each_dev(group->iommu_group,
1762                                                 domain->domain,
1763                                                 vfio_mdev_attach_domain);
1764         else
1765                 return iommu_attach_group(domain->domain, group->iommu_group);
1766 }
1767
1768 static void vfio_iommu_detach_group(struct vfio_domain *domain,
1769                                     struct vfio_group *group)
1770 {
1771         if (group->mdev_group)
1772                 iommu_group_for_each_dev(group->iommu_group, domain->domain,
1773                                          vfio_mdev_detach_domain);
1774         else
1775                 iommu_detach_group(domain->domain, group->iommu_group);
1776 }
1777
1778 static bool vfio_bus_is_mdev(struct bus_type *bus)
1779 {
1780         struct bus_type *mdev_bus;
1781         bool ret = false;
1782
1783         mdev_bus = symbol_get(mdev_bus_type);
1784         if (mdev_bus) {
1785                 ret = (bus == mdev_bus);
1786                 symbol_put(mdev_bus_type);
1787         }
1788
1789         return ret;
1790 }
1791
1792 static int vfio_mdev_iommu_device(struct device *dev, void *data)
1793 {
1794         struct device **old = data, *new;
1795
1796         new = vfio_mdev_get_iommu_device(dev);
1797         if (!new || (*old && *old != new))
1798                 return -EINVAL;
1799
1800         *old = new;
1801
1802         return 0;
1803 }
1804
1805 /*
1806  * This is a helper function to insert an address range to iova list.
1807  * The list is initially created with a single entry corresponding to
1808  * the IOMMU domain geometry to which the device group is attached.
1809  * The list aperture gets modified when a new domain is added to the
1810  * container if the new aperture doesn't conflict with the current one
1811  * or with any existing dma mappings. The list is also modified to
1812  * exclude any reserved regions associated with the device group.
1813  */
1814 static int vfio_iommu_iova_insert(struct list_head *head,
1815                                   dma_addr_t start, dma_addr_t end)
1816 {
1817         struct vfio_iova *region;
1818
1819         region = kmalloc(sizeof(*region), GFP_KERNEL);
1820         if (!region)
1821                 return -ENOMEM;
1822
1823         INIT_LIST_HEAD(&region->list);
1824         region->start = start;
1825         region->end = end;
1826
1827         list_add_tail(&region->list, head);
1828         return 0;
1829 }
1830
1831 /*
1832  * Check the new iommu aperture conflicts with existing aper or with any
1833  * existing dma mappings.
1834  */
1835 static bool vfio_iommu_aper_conflict(struct vfio_iommu *iommu,
1836                                      dma_addr_t start, dma_addr_t end)
1837 {
1838         struct vfio_iova *first, *last;
1839         struct list_head *iova = &iommu->iova_list;
1840
1841         if (list_empty(iova))
1842                 return false;
1843
1844         /* Disjoint sets, return conflict */
1845         first = list_first_entry(iova, struct vfio_iova, list);
1846         last = list_last_entry(iova, struct vfio_iova, list);
1847         if (start > last->end || end < first->start)
1848                 return true;
1849
1850         /* Check for any existing dma mappings below the new start */
1851         if (start > first->start) {
1852                 if (vfio_find_dma(iommu, first->start, start - first->start))
1853                         return true;
1854         }
1855
1856         /* Check for any existing dma mappings beyond the new end */
1857         if (end < last->end) {
1858                 if (vfio_find_dma(iommu, end + 1, last->end - end))
1859                         return true;
1860         }
1861
1862         return false;
1863 }
1864
1865 /*
1866  * Resize iommu iova aperture window. This is called only if the new
1867  * aperture has no conflict with existing aperture and dma mappings.
1868  */
1869 static int vfio_iommu_aper_resize(struct list_head *iova,
1870                                   dma_addr_t start, dma_addr_t end)
1871 {
1872         struct vfio_iova *node, *next;
1873
1874         if (list_empty(iova))
1875                 return vfio_iommu_iova_insert(iova, start, end);
1876
1877         /* Adjust iova list start */
1878         list_for_each_entry_safe(node, next, iova, list) {
1879                 if (start < node->start)
1880                         break;
1881                 if (start >= node->start && start < node->end) {
1882                         node->start = start;
1883                         break;
1884                 }
1885                 /* Delete nodes before new start */
1886                 list_del(&node->list);
1887                 kfree(node);
1888         }
1889
1890         /* Adjust iova list end */
1891         list_for_each_entry_safe(node, next, iova, list) {
1892                 if (end > node->end)
1893                         continue;
1894                 if (end > node->start && end <= node->end) {
1895                         node->end = end;
1896                         continue;
1897                 }
1898                 /* Delete nodes after new end */
1899                 list_del(&node->list);
1900                 kfree(node);
1901         }
1902
1903         return 0;
1904 }
1905
1906 /*
1907  * Check reserved region conflicts with existing dma mappings
1908  */
1909 static bool vfio_iommu_resv_conflict(struct vfio_iommu *iommu,
1910                                      struct list_head *resv_regions)
1911 {
1912         struct iommu_resv_region *region;
1913
1914         /* Check for conflict with existing dma mappings */
1915         list_for_each_entry(region, resv_regions, list) {
1916                 if (region->type == IOMMU_RESV_DIRECT_RELAXABLE)
1917                         continue;
1918
1919                 if (vfio_find_dma(iommu, region->start, region->length))
1920                         return true;
1921         }
1922
1923         return false;
1924 }
1925
1926 /*
1927  * Check iova region overlap with  reserved regions and
1928  * exclude them from the iommu iova range
1929  */
1930 static int vfio_iommu_resv_exclude(struct list_head *iova,
1931                                    struct list_head *resv_regions)
1932 {
1933         struct iommu_resv_region *resv;
1934         struct vfio_iova *n, *next;
1935
1936         list_for_each_entry(resv, resv_regions, list) {
1937                 phys_addr_t start, end;
1938
1939                 if (resv->type == IOMMU_RESV_DIRECT_RELAXABLE)
1940                         continue;
1941
1942                 start = resv->start;
1943                 end = resv->start + resv->length - 1;
1944
1945                 list_for_each_entry_safe(n, next, iova, list) {
1946                         int ret = 0;
1947
1948                         /* No overlap */
1949                         if (start > n->end || end < n->start)
1950                                 continue;
1951                         /*
1952                          * Insert a new node if current node overlaps with the
1953                          * reserve region to exlude that from valid iova range.
1954                          * Note that, new node is inserted before the current
1955                          * node and finally the current node is deleted keeping
1956                          * the list updated and sorted.
1957                          */
1958                         if (start > n->start)
1959                                 ret = vfio_iommu_iova_insert(&n->list, n->start,
1960                                                              start - 1);
1961                         if (!ret && end < n->end)
1962                                 ret = vfio_iommu_iova_insert(&n->list, end + 1,
1963                                                              n->end);
1964                         if (ret)
1965                                 return ret;
1966
1967                         list_del(&n->list);
1968                         kfree(n);
1969                 }
1970         }
1971
1972         if (list_empty(iova))
1973                 return -EINVAL;
1974
1975         return 0;
1976 }
1977
1978 static void vfio_iommu_resv_free(struct list_head *resv_regions)
1979 {
1980         struct iommu_resv_region *n, *next;
1981
1982         list_for_each_entry_safe(n, next, resv_regions, list) {
1983                 list_del(&n->list);
1984                 kfree(n);
1985         }
1986 }
1987
1988 static void vfio_iommu_iova_free(struct list_head *iova)
1989 {
1990         struct vfio_iova *n, *next;
1991
1992         list_for_each_entry_safe(n, next, iova, list) {
1993                 list_del(&n->list);
1994                 kfree(n);
1995         }
1996 }
1997
1998 static int vfio_iommu_iova_get_copy(struct vfio_iommu *iommu,
1999                                     struct list_head *iova_copy)
2000 {
2001         struct list_head *iova = &iommu->iova_list;
2002         struct vfio_iova *n;
2003         int ret;
2004
2005         list_for_each_entry(n, iova, list) {
2006                 ret = vfio_iommu_iova_insert(iova_copy, n->start, n->end);
2007                 if (ret)
2008                         goto out_free;
2009         }
2010
2011         return 0;
2012
2013 out_free:
2014         vfio_iommu_iova_free(iova_copy);
2015         return ret;
2016 }
2017
2018 static void vfio_iommu_iova_insert_copy(struct vfio_iommu *iommu,
2019                                         struct list_head *iova_copy)
2020 {
2021         struct list_head *iova = &iommu->iova_list;
2022
2023         vfio_iommu_iova_free(iova);
2024
2025         list_splice_tail(iova_copy, iova);
2026 }
2027
2028 static int vfio_iommu_type1_attach_group(void *iommu_data,
2029                                          struct iommu_group *iommu_group)
2030 {
2031         struct vfio_iommu *iommu = iommu_data;
2032         struct vfio_group *group;
2033         struct vfio_domain *domain, *d;
2034         struct bus_type *bus = NULL;
2035         int ret;
2036         bool resv_msi, msi_remap;
2037         phys_addr_t resv_msi_base = 0;
2038         struct iommu_domain_geometry geo;
2039         LIST_HEAD(iova_copy);
2040         LIST_HEAD(group_resv_regions);
2041
2042         mutex_lock(&iommu->lock);
2043
2044         /* Check for duplicates */
2045         if (vfio_iommu_find_iommu_group(iommu, iommu_group)) {
2046                 mutex_unlock(&iommu->lock);
2047                 return -EINVAL;
2048         }
2049
2050         group = kzalloc(sizeof(*group), GFP_KERNEL);
2051         domain = kzalloc(sizeof(*domain), GFP_KERNEL);
2052         if (!group || !domain) {
2053                 ret = -ENOMEM;
2054                 goto out_free;
2055         }
2056
2057         group->iommu_group = iommu_group;
2058
2059         /* Determine bus_type in order to allocate a domain */
2060         ret = iommu_group_for_each_dev(iommu_group, &bus, vfio_bus_type);
2061         if (ret)
2062                 goto out_free;
2063
2064         if (vfio_bus_is_mdev(bus)) {
2065                 struct device *iommu_device = NULL;
2066
2067                 group->mdev_group = true;
2068
2069                 /* Determine the isolation type */
2070                 ret = iommu_group_for_each_dev(iommu_group, &iommu_device,
2071                                                vfio_mdev_iommu_device);
2072                 if (ret || !iommu_device) {
2073                         if (!iommu->external_domain) {
2074                                 INIT_LIST_HEAD(&domain->group_list);
2075                                 iommu->external_domain = domain;
2076                                 vfio_update_pgsize_bitmap(iommu);
2077                         } else {
2078                                 kfree(domain);
2079                         }
2080
2081                         list_add(&group->next,
2082                                  &iommu->external_domain->group_list);
2083                         /*
2084                          * Non-iommu backed group cannot dirty memory directly,
2085                          * it can only use interfaces that provide dirty
2086                          * tracking.
2087                          * The iommu scope can only be promoted with the
2088                          * addition of a dirty tracking group.
2089                          */
2090                         group->pinned_page_dirty_scope = true;
2091                         if (!iommu->pinned_page_dirty_scope)
2092                                 update_pinned_page_dirty_scope(iommu);
2093                         mutex_unlock(&iommu->lock);
2094
2095                         return 0;
2096                 }
2097
2098                 bus = iommu_device->bus;
2099         }
2100
2101         domain->domain = iommu_domain_alloc(bus);
2102         if (!domain->domain) {
2103                 ret = -EIO;
2104                 goto out_free;
2105         }
2106
2107         if (iommu->nesting) {
2108                 int attr = 1;
2109
2110                 ret = iommu_domain_set_attr(domain->domain, DOMAIN_ATTR_NESTING,
2111                                             &attr);
2112                 if (ret)
2113                         goto out_domain;
2114         }
2115
2116         ret = vfio_iommu_attach_group(domain, group);
2117         if (ret)
2118                 goto out_domain;
2119
2120         /* Get aperture info */
2121         iommu_domain_get_attr(domain->domain, DOMAIN_ATTR_GEOMETRY, &geo);
2122
2123         if (vfio_iommu_aper_conflict(iommu, geo.aperture_start,
2124                                      geo.aperture_end)) {
2125                 ret = -EINVAL;
2126                 goto out_detach;
2127         }
2128
2129         ret = iommu_get_group_resv_regions(iommu_group, &group_resv_regions);
2130         if (ret)
2131                 goto out_detach;
2132
2133         if (vfio_iommu_resv_conflict(iommu, &group_resv_regions)) {
2134                 ret = -EINVAL;
2135                 goto out_detach;
2136         }
2137
2138         /*
2139          * We don't want to work on the original iova list as the list
2140          * gets modified and in case of failure we have to retain the
2141          * original list. Get a copy here.
2142          */
2143         ret = vfio_iommu_iova_get_copy(iommu, &iova_copy);
2144         if (ret)
2145                 goto out_detach;
2146
2147         ret = vfio_iommu_aper_resize(&iova_copy, geo.aperture_start,
2148                                      geo.aperture_end);
2149         if (ret)
2150                 goto out_detach;
2151
2152         ret = vfio_iommu_resv_exclude(&iova_copy, &group_resv_regions);
2153         if (ret)
2154                 goto out_detach;
2155
2156         resv_msi = vfio_iommu_has_sw_msi(&group_resv_regions, &resv_msi_base);
2157
2158         INIT_LIST_HEAD(&domain->group_list);
2159         list_add(&group->next, &domain->group_list);
2160
2161         msi_remap = irq_domain_check_msi_remap() ||
2162                     iommu_capable(bus, IOMMU_CAP_INTR_REMAP);
2163
2164         if (!allow_unsafe_interrupts && !msi_remap) {
2165                 pr_warn("%s: No interrupt remapping support.  Use the module param \"allow_unsafe_interrupts\" to enable VFIO IOMMU support on this platform\n",
2166                        __func__);
2167                 ret = -EPERM;
2168                 goto out_detach;
2169         }
2170
2171         if (iommu_capable(bus, IOMMU_CAP_CACHE_COHERENCY))
2172                 domain->prot |= IOMMU_CACHE;
2173
2174         /*
2175          * Try to match an existing compatible domain.  We don't want to
2176          * preclude an IOMMU driver supporting multiple bus_types and being
2177          * able to include different bus_types in the same IOMMU domain, so
2178          * we test whether the domains use the same iommu_ops rather than
2179          * testing if they're on the same bus_type.
2180          */
2181         list_for_each_entry(d, &iommu->domain_list, next) {
2182                 if (d->domain->ops == domain->domain->ops &&
2183                     d->prot == domain->prot) {
2184                         vfio_iommu_detach_group(domain, group);
2185                         if (!vfio_iommu_attach_group(d, group)) {
2186                                 list_add(&group->next, &d->group_list);
2187                                 iommu_domain_free(domain->domain);
2188                                 kfree(domain);
2189                                 goto done;
2190                         }
2191
2192                         ret = vfio_iommu_attach_group(domain, group);
2193                         if (ret)
2194                                 goto out_domain;
2195                 }
2196         }
2197
2198         vfio_test_domain_fgsp(domain);
2199
2200         /* replay mappings on new domains */
2201         ret = vfio_iommu_replay(iommu, domain);
2202         if (ret)
2203                 goto out_detach;
2204
2205         if (resv_msi) {
2206                 ret = iommu_get_msi_cookie(domain->domain, resv_msi_base);
2207                 if (ret && ret != -ENODEV)
2208                         goto out_detach;
2209         }
2210
2211         list_add(&domain->next, &iommu->domain_list);
2212         vfio_update_pgsize_bitmap(iommu);
2213 done:
2214         /* Delete the old one and insert new iova list */
2215         vfio_iommu_iova_insert_copy(iommu, &iova_copy);
2216
2217         /*
2218          * An iommu backed group can dirty memory directly and therefore
2219          * demotes the iommu scope until it declares itself dirty tracking
2220          * capable via the page pinning interface.
2221          */
2222         iommu->pinned_page_dirty_scope = false;
2223         mutex_unlock(&iommu->lock);
2224         vfio_iommu_resv_free(&group_resv_regions);
2225
2226         return 0;
2227
2228 out_detach:
2229         vfio_iommu_detach_group(domain, group);
2230 out_domain:
2231         iommu_domain_free(domain->domain);
2232         vfio_iommu_iova_free(&iova_copy);
2233         vfio_iommu_resv_free(&group_resv_regions);
2234 out_free:
2235         kfree(domain);
2236         kfree(group);
2237         mutex_unlock(&iommu->lock);
2238         return ret;
2239 }
2240
2241 static void vfio_iommu_unmap_unpin_all(struct vfio_iommu *iommu)
2242 {
2243         struct rb_node *node;
2244
2245         while ((node = rb_first(&iommu->dma_list)))
2246                 vfio_remove_dma(iommu, rb_entry(node, struct vfio_dma, node));
2247 }
2248
2249 static void vfio_iommu_unmap_unpin_reaccount(struct vfio_iommu *iommu)
2250 {
2251         struct rb_node *n, *p;
2252
2253         n = rb_first(&iommu->dma_list);
2254         for (; n; n = rb_next(n)) {
2255                 struct vfio_dma *dma;
2256                 long locked = 0, unlocked = 0;
2257
2258                 dma = rb_entry(n, struct vfio_dma, node);
2259                 unlocked += vfio_unmap_unpin(iommu, dma, false);
2260                 p = rb_first(&dma->pfn_list);
2261                 for (; p; p = rb_next(p)) {
2262                         struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn,
2263                                                          node);
2264
2265                         if (!is_invalid_reserved_pfn(vpfn->pfn))
2266                                 locked++;
2267                 }
2268                 vfio_lock_acct(dma, locked - unlocked, true);
2269         }
2270 }
2271
2272 static void vfio_sanity_check_pfn_list(struct vfio_iommu *iommu)
2273 {
2274         struct rb_node *n;
2275
2276         n = rb_first(&iommu->dma_list);
2277         for (; n; n = rb_next(n)) {
2278                 struct vfio_dma *dma;
2279
2280                 dma = rb_entry(n, struct vfio_dma, node);
2281
2282                 if (WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list)))
2283                         break;
2284         }
2285         /* mdev vendor driver must unregister notifier */
2286         WARN_ON(iommu->notifier.head);
2287 }
2288
2289 /*
2290  * Called when a domain is removed in detach. It is possible that
2291  * the removed domain decided the iova aperture window. Modify the
2292  * iova aperture with the smallest window among existing domains.
2293  */
2294 static void vfio_iommu_aper_expand(struct vfio_iommu *iommu,
2295                                    struct list_head *iova_copy)
2296 {
2297         struct vfio_domain *domain;
2298         struct iommu_domain_geometry geo;
2299         struct vfio_iova *node;
2300         dma_addr_t start = 0;
2301         dma_addr_t end = (dma_addr_t)~0;
2302
2303         if (list_empty(iova_copy))
2304                 return;
2305
2306         list_for_each_entry(domain, &iommu->domain_list, next) {
2307                 iommu_domain_get_attr(domain->domain, DOMAIN_ATTR_GEOMETRY,
2308                                       &geo);
2309                 if (geo.aperture_start > start)
2310                         start = geo.aperture_start;
2311                 if (geo.aperture_end < end)
2312                         end = geo.aperture_end;
2313         }
2314
2315         /* Modify aperture limits. The new aper is either same or bigger */
2316         node = list_first_entry(iova_copy, struct vfio_iova, list);
2317         node->start = start;
2318         node = list_last_entry(iova_copy, struct vfio_iova, list);
2319         node->end = end;
2320 }
2321
2322 /*
2323  * Called when a group is detached. The reserved regions for that
2324  * group can be part of valid iova now. But since reserved regions
2325  * may be duplicated among groups, populate the iova valid regions
2326  * list again.
2327  */
2328 static int vfio_iommu_resv_refresh(struct vfio_iommu *iommu,
2329                                    struct list_head *iova_copy)
2330 {
2331         struct vfio_domain *d;
2332         struct vfio_group *g;
2333         struct vfio_iova *node;
2334         dma_addr_t start, end;
2335         LIST_HEAD(resv_regions);
2336         int ret;
2337
2338         if (list_empty(iova_copy))
2339                 return -EINVAL;
2340
2341         list_for_each_entry(d, &iommu->domain_list, next) {
2342                 list_for_each_entry(g, &d->group_list, next) {
2343                         ret = iommu_get_group_resv_regions(g->iommu_group,
2344                                                            &resv_regions);
2345                         if (ret)
2346                                 goto done;
2347                 }
2348         }
2349
2350         node = list_first_entry(iova_copy, struct vfio_iova, list);
2351         start = node->start;
2352         node = list_last_entry(iova_copy, struct vfio_iova, list);
2353         end = node->end;
2354
2355         /* purge the iova list and create new one */
2356         vfio_iommu_iova_free(iova_copy);
2357
2358         ret = vfio_iommu_aper_resize(iova_copy, start, end);
2359         if (ret)
2360                 goto done;
2361
2362         /* Exclude current reserved regions from iova ranges */
2363         ret = vfio_iommu_resv_exclude(iova_copy, &resv_regions);
2364 done:
2365         vfio_iommu_resv_free(&resv_regions);
2366         return ret;
2367 }
2368
2369 static void vfio_iommu_type1_detach_group(void *iommu_data,
2370                                           struct iommu_group *iommu_group)
2371 {
2372         struct vfio_iommu *iommu = iommu_data;
2373         struct vfio_domain *domain;
2374         struct vfio_group *group;
2375         bool update_dirty_scope = false;
2376         LIST_HEAD(iova_copy);
2377
2378         mutex_lock(&iommu->lock);
2379
2380         if (iommu->external_domain) {
2381                 group = find_iommu_group(iommu->external_domain, iommu_group);
2382                 if (group) {
2383                         update_dirty_scope = !group->pinned_page_dirty_scope;
2384                         list_del(&group->next);
2385                         kfree(group);
2386
2387                         if (list_empty(&iommu->external_domain->group_list)) {
2388                                 vfio_sanity_check_pfn_list(iommu);
2389
2390                                 if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
2391                                         vfio_iommu_unmap_unpin_all(iommu);
2392
2393                                 kfree(iommu->external_domain);
2394                                 iommu->external_domain = NULL;
2395                         }
2396                         goto detach_group_done;
2397                 }
2398         }
2399
2400         /*
2401          * Get a copy of iova list. This will be used to update
2402          * and to replace the current one later. Please note that
2403          * we will leave the original list as it is if update fails.
2404          */
2405         vfio_iommu_iova_get_copy(iommu, &iova_copy);
2406
2407         list_for_each_entry(domain, &iommu->domain_list, next) {
2408                 group = find_iommu_group(domain, iommu_group);
2409                 if (!group)
2410                         continue;
2411
2412                 vfio_iommu_detach_group(domain, group);
2413                 update_dirty_scope = !group->pinned_page_dirty_scope;
2414                 list_del(&group->next);
2415                 kfree(group);
2416                 /*
2417                  * Group ownership provides privilege, if the group list is
2418                  * empty, the domain goes away. If it's the last domain with
2419                  * iommu and external domain doesn't exist, then all the
2420                  * mappings go away too. If it's the last domain with iommu and
2421                  * external domain exist, update accounting
2422                  */
2423                 if (list_empty(&domain->group_list)) {
2424                         if (list_is_singular(&iommu->domain_list)) {
2425                                 if (!iommu->external_domain)
2426                                         vfio_iommu_unmap_unpin_all(iommu);
2427                                 else
2428                                         vfio_iommu_unmap_unpin_reaccount(iommu);
2429                         }
2430                         iommu_domain_free(domain->domain);
2431                         list_del(&domain->next);
2432                         kfree(domain);
2433                         vfio_iommu_aper_expand(iommu, &iova_copy);
2434                         vfio_update_pgsize_bitmap(iommu);
2435                 }
2436                 break;
2437         }
2438
2439         if (!vfio_iommu_resv_refresh(iommu, &iova_copy))
2440                 vfio_iommu_iova_insert_copy(iommu, &iova_copy);
2441         else
2442                 vfio_iommu_iova_free(&iova_copy);
2443
2444 detach_group_done:
2445         /*
2446          * Removal of a group without dirty tracking may allow the iommu scope
2447          * to be promoted.
2448          */
2449         if (update_dirty_scope)
2450                 update_pinned_page_dirty_scope(iommu);
2451         mutex_unlock(&iommu->lock);
2452 }
2453
2454 static void *vfio_iommu_type1_open(unsigned long arg)
2455 {
2456         struct vfio_iommu *iommu;
2457
2458         iommu = kzalloc(sizeof(*iommu), GFP_KERNEL);
2459         if (!iommu)
2460                 return ERR_PTR(-ENOMEM);
2461
2462         switch (arg) {
2463         case VFIO_TYPE1_IOMMU:
2464                 break;
2465         case VFIO_TYPE1_NESTING_IOMMU:
2466                 iommu->nesting = true;
2467                 fallthrough;
2468         case VFIO_TYPE1v2_IOMMU:
2469                 iommu->v2 = true;
2470                 break;
2471         default:
2472                 kfree(iommu);
2473                 return ERR_PTR(-EINVAL);
2474         }
2475
2476         INIT_LIST_HEAD(&iommu->domain_list);
2477         INIT_LIST_HEAD(&iommu->iova_list);
2478         iommu->dma_list = RB_ROOT;
2479         iommu->dma_avail = dma_entry_limit;
2480         mutex_init(&iommu->lock);
2481         BLOCKING_INIT_NOTIFIER_HEAD(&iommu->notifier);
2482
2483         return iommu;
2484 }
2485
2486 static void vfio_release_domain(struct vfio_domain *domain, bool external)
2487 {
2488         struct vfio_group *group, *group_tmp;
2489
2490         list_for_each_entry_safe(group, group_tmp,
2491                                  &domain->group_list, next) {
2492                 if (!external)
2493                         vfio_iommu_detach_group(domain, group);
2494                 list_del(&group->next);
2495                 kfree(group);
2496         }
2497
2498         if (!external)
2499                 iommu_domain_free(domain->domain);
2500 }
2501
2502 static void vfio_iommu_type1_release(void *iommu_data)
2503 {
2504         struct vfio_iommu *iommu = iommu_data;
2505         struct vfio_domain *domain, *domain_tmp;
2506
2507         if (iommu->external_domain) {
2508                 vfio_release_domain(iommu->external_domain, true);
2509                 vfio_sanity_check_pfn_list(iommu);
2510                 kfree(iommu->external_domain);
2511         }
2512
2513         vfio_iommu_unmap_unpin_all(iommu);
2514
2515         list_for_each_entry_safe(domain, domain_tmp,
2516                                  &iommu->domain_list, next) {
2517                 vfio_release_domain(domain, false);
2518                 list_del(&domain->next);
2519                 kfree(domain);
2520         }
2521
2522         vfio_iommu_iova_free(&iommu->iova_list);
2523
2524         kfree(iommu);
2525 }
2526
2527 static int vfio_domains_have_iommu_cache(struct vfio_iommu *iommu)
2528 {
2529         struct vfio_domain *domain;
2530         int ret = 1;
2531
2532         mutex_lock(&iommu->lock);
2533         list_for_each_entry(domain, &iommu->domain_list, next) {
2534                 if (!(domain->prot & IOMMU_CACHE)) {
2535                         ret = 0;
2536                         break;
2537                 }
2538         }
2539         mutex_unlock(&iommu->lock);
2540
2541         return ret;
2542 }
2543
2544 static int vfio_iommu_type1_check_extension(struct vfio_iommu *iommu,
2545                                             unsigned long arg)
2546 {
2547         switch (arg) {
2548         case VFIO_TYPE1_IOMMU:
2549         case VFIO_TYPE1v2_IOMMU:
2550         case VFIO_TYPE1_NESTING_IOMMU:
2551         case VFIO_UNMAP_ALL:
2552                 return 1;
2553         case VFIO_DMA_CC_IOMMU:
2554                 if (!iommu)
2555                         return 0;
2556                 return vfio_domains_have_iommu_cache(iommu);
2557         default:
2558                 return 0;
2559         }
2560 }
2561
2562 static int vfio_iommu_iova_add_cap(struct vfio_info_cap *caps,
2563                  struct vfio_iommu_type1_info_cap_iova_range *cap_iovas,
2564                  size_t size)
2565 {
2566         struct vfio_info_cap_header *header;
2567         struct vfio_iommu_type1_info_cap_iova_range *iova_cap;
2568
2569         header = vfio_info_cap_add(caps, size,
2570                                    VFIO_IOMMU_TYPE1_INFO_CAP_IOVA_RANGE, 1);
2571         if (IS_ERR(header))
2572                 return PTR_ERR(header);
2573
2574         iova_cap = container_of(header,
2575                                 struct vfio_iommu_type1_info_cap_iova_range,
2576                                 header);
2577         iova_cap->nr_iovas = cap_iovas->nr_iovas;
2578         memcpy(iova_cap->iova_ranges, cap_iovas->iova_ranges,
2579                cap_iovas->nr_iovas * sizeof(*cap_iovas->iova_ranges));
2580         return 0;
2581 }
2582
2583 static int vfio_iommu_iova_build_caps(struct vfio_iommu *iommu,
2584                                       struct vfio_info_cap *caps)
2585 {
2586         struct vfio_iommu_type1_info_cap_iova_range *cap_iovas;
2587         struct vfio_iova *iova;
2588         size_t size;
2589         int iovas = 0, i = 0, ret;
2590
2591         list_for_each_entry(iova, &iommu->iova_list, list)
2592                 iovas++;
2593
2594         if (!iovas) {
2595                 /*
2596                  * Return 0 as a container with a single mdev device
2597                  * will have an empty list
2598                  */
2599                 return 0;
2600         }
2601
2602         size = sizeof(*cap_iovas) + (iovas * sizeof(*cap_iovas->iova_ranges));
2603
2604         cap_iovas = kzalloc(size, GFP_KERNEL);
2605         if (!cap_iovas)
2606                 return -ENOMEM;
2607
2608         cap_iovas->nr_iovas = iovas;
2609
2610         list_for_each_entry(iova, &iommu->iova_list, list) {
2611                 cap_iovas->iova_ranges[i].start = iova->start;
2612                 cap_iovas->iova_ranges[i].end = iova->end;
2613                 i++;
2614         }
2615
2616         ret = vfio_iommu_iova_add_cap(caps, cap_iovas, size);
2617
2618         kfree(cap_iovas);
2619         return ret;
2620 }
2621
2622 static int vfio_iommu_migration_build_caps(struct vfio_iommu *iommu,
2623                                            struct vfio_info_cap *caps)
2624 {
2625         struct vfio_iommu_type1_info_cap_migration cap_mig;
2626
2627         cap_mig.header.id = VFIO_IOMMU_TYPE1_INFO_CAP_MIGRATION;
2628         cap_mig.header.version = 1;
2629
2630         cap_mig.flags = 0;
2631         /* support minimum pgsize */
2632         cap_mig.pgsize_bitmap = (size_t)1 << __ffs(iommu->pgsize_bitmap);
2633         cap_mig.max_dirty_bitmap_size = DIRTY_BITMAP_SIZE_MAX;
2634
2635         return vfio_info_add_capability(caps, &cap_mig.header, sizeof(cap_mig));
2636 }
2637
2638 static int vfio_iommu_dma_avail_build_caps(struct vfio_iommu *iommu,
2639                                            struct vfio_info_cap *caps)
2640 {
2641         struct vfio_iommu_type1_info_dma_avail cap_dma_avail;
2642
2643         cap_dma_avail.header.id = VFIO_IOMMU_TYPE1_INFO_DMA_AVAIL;
2644         cap_dma_avail.header.version = 1;
2645
2646         cap_dma_avail.avail = iommu->dma_avail;
2647
2648         return vfio_info_add_capability(caps, &cap_dma_avail.header,
2649                                         sizeof(cap_dma_avail));
2650 }
2651
2652 static int vfio_iommu_type1_get_info(struct vfio_iommu *iommu,
2653                                      unsigned long arg)
2654 {
2655         struct vfio_iommu_type1_info info;
2656         unsigned long minsz;
2657         struct vfio_info_cap caps = { .buf = NULL, .size = 0 };
2658         unsigned long capsz;
2659         int ret;
2660
2661         minsz = offsetofend(struct vfio_iommu_type1_info, iova_pgsizes);
2662
2663         /* For backward compatibility, cannot require this */
2664         capsz = offsetofend(struct vfio_iommu_type1_info, cap_offset);
2665
2666         if (copy_from_user(&info, (void __user *)arg, minsz))
2667                 return -EFAULT;
2668
2669         if (info.argsz < minsz)
2670                 return -EINVAL;
2671
2672         if (info.argsz >= capsz) {
2673                 minsz = capsz;
2674                 info.cap_offset = 0; /* output, no-recopy necessary */
2675         }
2676
2677         mutex_lock(&iommu->lock);
2678         info.flags = VFIO_IOMMU_INFO_PGSIZES;
2679
2680         info.iova_pgsizes = iommu->pgsize_bitmap;
2681
2682         ret = vfio_iommu_migration_build_caps(iommu, &caps);
2683
2684         if (!ret)
2685                 ret = vfio_iommu_dma_avail_build_caps(iommu, &caps);
2686
2687         if (!ret)
2688                 ret = vfio_iommu_iova_build_caps(iommu, &caps);
2689
2690         mutex_unlock(&iommu->lock);
2691
2692         if (ret)
2693                 return ret;
2694
2695         if (caps.size) {
2696                 info.flags |= VFIO_IOMMU_INFO_CAPS;
2697
2698                 if (info.argsz < sizeof(info) + caps.size) {
2699                         info.argsz = sizeof(info) + caps.size;
2700                 } else {
2701                         vfio_info_cap_shift(&caps, sizeof(info));
2702                         if (copy_to_user((void __user *)arg +
2703                                         sizeof(info), caps.buf,
2704                                         caps.size)) {
2705                                 kfree(caps.buf);
2706                                 return -EFAULT;
2707                         }
2708                         info.cap_offset = sizeof(info);
2709                 }
2710
2711                 kfree(caps.buf);
2712         }
2713
2714         return copy_to_user((void __user *)arg, &info, minsz) ?
2715                         -EFAULT : 0;
2716 }
2717
2718 static int vfio_iommu_type1_map_dma(struct vfio_iommu *iommu,
2719                                     unsigned long arg)
2720 {
2721         struct vfio_iommu_type1_dma_map map;
2722         unsigned long minsz;
2723         uint32_t mask = VFIO_DMA_MAP_FLAG_READ | VFIO_DMA_MAP_FLAG_WRITE;
2724
2725         minsz = offsetofend(struct vfio_iommu_type1_dma_map, size);
2726
2727         if (copy_from_user(&map, (void __user *)arg, minsz))
2728                 return -EFAULT;
2729
2730         if (map.argsz < minsz || map.flags & ~mask)
2731                 return -EINVAL;
2732
2733         return vfio_dma_do_map(iommu, &map);
2734 }
2735
2736 static int vfio_iommu_type1_unmap_dma(struct vfio_iommu *iommu,
2737                                       unsigned long arg)
2738 {
2739         struct vfio_iommu_type1_dma_unmap unmap;
2740         struct vfio_bitmap bitmap = { 0 };
2741         uint32_t mask = VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP |
2742                         VFIO_DMA_UNMAP_FLAG_ALL;
2743         unsigned long minsz;
2744         int ret;
2745
2746         minsz = offsetofend(struct vfio_iommu_type1_dma_unmap, size);
2747
2748         if (copy_from_user(&unmap, (void __user *)arg, minsz))
2749                 return -EFAULT;
2750
2751         if (unmap.argsz < minsz || unmap.flags & ~mask)
2752                 return -EINVAL;
2753
2754         if ((unmap.flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
2755             (unmap.flags & VFIO_DMA_UNMAP_FLAG_ALL))
2756                 return -EINVAL;
2757
2758         if (unmap.flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
2759                 unsigned long pgshift;
2760
2761                 if (unmap.argsz < (minsz + sizeof(bitmap)))
2762                         return -EINVAL;
2763
2764                 if (copy_from_user(&bitmap,
2765                                    (void __user *)(arg + minsz),
2766                                    sizeof(bitmap)))
2767                         return -EFAULT;
2768
2769                 if (!access_ok((void __user *)bitmap.data, bitmap.size))
2770                         return -EINVAL;
2771
2772                 pgshift = __ffs(bitmap.pgsize);
2773                 ret = verify_bitmap_size(unmap.size >> pgshift,
2774                                          bitmap.size);
2775                 if (ret)
2776                         return ret;
2777         }
2778
2779         ret = vfio_dma_do_unmap(iommu, &unmap, &bitmap);
2780         if (ret)
2781                 return ret;
2782
2783         return copy_to_user((void __user *)arg, &unmap, minsz) ?
2784                         -EFAULT : 0;
2785 }
2786
2787 static int vfio_iommu_type1_dirty_pages(struct vfio_iommu *iommu,
2788                                         unsigned long arg)
2789 {
2790         struct vfio_iommu_type1_dirty_bitmap dirty;
2791         uint32_t mask = VFIO_IOMMU_DIRTY_PAGES_FLAG_START |
2792                         VFIO_IOMMU_DIRTY_PAGES_FLAG_STOP |
2793                         VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP;
2794         unsigned long minsz;
2795         int ret = 0;
2796
2797         if (!iommu->v2)
2798                 return -EACCES;
2799
2800         minsz = offsetofend(struct vfio_iommu_type1_dirty_bitmap, flags);
2801
2802         if (copy_from_user(&dirty, (void __user *)arg, minsz))
2803                 return -EFAULT;
2804
2805         if (dirty.argsz < minsz || dirty.flags & ~mask)
2806                 return -EINVAL;
2807
2808         /* only one flag should be set at a time */
2809         if (__ffs(dirty.flags) != __fls(dirty.flags))
2810                 return -EINVAL;
2811
2812         if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_START) {
2813                 size_t pgsize;
2814
2815                 mutex_lock(&iommu->lock);
2816                 pgsize = 1 << __ffs(iommu->pgsize_bitmap);
2817                 if (!iommu->dirty_page_tracking) {
2818                         ret = vfio_dma_bitmap_alloc_all(iommu, pgsize);
2819                         if (!ret)
2820                                 iommu->dirty_page_tracking = true;
2821                 }
2822                 mutex_unlock(&iommu->lock);
2823                 return ret;
2824         } else if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_STOP) {
2825                 mutex_lock(&iommu->lock);
2826                 if (iommu->dirty_page_tracking) {
2827                         iommu->dirty_page_tracking = false;
2828                         vfio_dma_bitmap_free_all(iommu);
2829                 }
2830                 mutex_unlock(&iommu->lock);
2831                 return 0;
2832         } else if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP) {
2833                 struct vfio_iommu_type1_dirty_bitmap_get range;
2834                 unsigned long pgshift;
2835                 size_t data_size = dirty.argsz - minsz;
2836                 size_t iommu_pgsize;
2837
2838                 if (!data_size || data_size < sizeof(range))
2839                         return -EINVAL;
2840
2841                 if (copy_from_user(&range, (void __user *)(arg + minsz),
2842                                    sizeof(range)))
2843                         return -EFAULT;
2844
2845                 if (range.iova + range.size < range.iova)
2846                         return -EINVAL;
2847                 if (!access_ok((void __user *)range.bitmap.data,
2848                                range.bitmap.size))
2849                         return -EINVAL;
2850
2851                 pgshift = __ffs(range.bitmap.pgsize);
2852                 ret = verify_bitmap_size(range.size >> pgshift,
2853                                          range.bitmap.size);
2854                 if (ret)
2855                         return ret;
2856
2857                 mutex_lock(&iommu->lock);
2858
2859                 iommu_pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
2860
2861                 /* allow only smallest supported pgsize */
2862                 if (range.bitmap.pgsize != iommu_pgsize) {
2863                         ret = -EINVAL;
2864                         goto out_unlock;
2865                 }
2866                 if (range.iova & (iommu_pgsize - 1)) {
2867                         ret = -EINVAL;
2868                         goto out_unlock;
2869                 }
2870                 if (!range.size || range.size & (iommu_pgsize - 1)) {
2871                         ret = -EINVAL;
2872                         goto out_unlock;
2873                 }
2874
2875                 if (iommu->dirty_page_tracking)
2876                         ret = vfio_iova_dirty_bitmap(range.bitmap.data,
2877                                                      iommu, range.iova,
2878                                                      range.size,
2879                                                      range.bitmap.pgsize);
2880                 else
2881                         ret = -EINVAL;
2882 out_unlock:
2883                 mutex_unlock(&iommu->lock);
2884
2885                 return ret;
2886         }
2887
2888         return -EINVAL;
2889 }
2890
2891 static long vfio_iommu_type1_ioctl(void *iommu_data,
2892                                    unsigned int cmd, unsigned long arg)
2893 {
2894         struct vfio_iommu *iommu = iommu_data;
2895
2896         switch (cmd) {
2897         case VFIO_CHECK_EXTENSION:
2898                 return vfio_iommu_type1_check_extension(iommu, arg);
2899         case VFIO_IOMMU_GET_INFO:
2900                 return vfio_iommu_type1_get_info(iommu, arg);
2901         case VFIO_IOMMU_MAP_DMA:
2902                 return vfio_iommu_type1_map_dma(iommu, arg);
2903         case VFIO_IOMMU_UNMAP_DMA:
2904                 return vfio_iommu_type1_unmap_dma(iommu, arg);
2905         case VFIO_IOMMU_DIRTY_PAGES:
2906                 return vfio_iommu_type1_dirty_pages(iommu, arg);
2907         default:
2908                 return -ENOTTY;
2909         }
2910 }
2911
2912 static int vfio_iommu_type1_register_notifier(void *iommu_data,
2913                                               unsigned long *events,
2914                                               struct notifier_block *nb)
2915 {
2916         struct vfio_iommu *iommu = iommu_data;
2917
2918         /* clear known events */
2919         *events &= ~VFIO_IOMMU_NOTIFY_DMA_UNMAP;
2920
2921         /* refuse to register if still events remaining */
2922         if (*events)
2923                 return -EINVAL;
2924
2925         return blocking_notifier_chain_register(&iommu->notifier, nb);
2926 }
2927
2928 static int vfio_iommu_type1_unregister_notifier(void *iommu_data,
2929                                                 struct notifier_block *nb)
2930 {
2931         struct vfio_iommu *iommu = iommu_data;
2932
2933         return blocking_notifier_chain_unregister(&iommu->notifier, nb);
2934 }
2935
2936 static int vfio_iommu_type1_dma_rw_chunk(struct vfio_iommu *iommu,
2937                                          dma_addr_t user_iova, void *data,
2938                                          size_t count, bool write,
2939                                          size_t *copied)
2940 {
2941         struct mm_struct *mm;
2942         unsigned long vaddr;
2943         struct vfio_dma *dma;
2944         bool kthread = current->mm == NULL;
2945         size_t offset;
2946
2947         *copied = 0;
2948
2949         dma = vfio_find_dma(iommu, user_iova, 1);
2950         if (!dma)
2951                 return -EINVAL;
2952
2953         if ((write && !(dma->prot & IOMMU_WRITE)) ||
2954                         !(dma->prot & IOMMU_READ))
2955                 return -EPERM;
2956
2957         mm = get_task_mm(dma->task);
2958
2959         if (!mm)
2960                 return -EPERM;
2961
2962         if (kthread)
2963                 kthread_use_mm(mm);
2964         else if (current->mm != mm)
2965                 goto out;
2966
2967         offset = user_iova - dma->iova;
2968
2969         if (count > dma->size - offset)
2970                 count = dma->size - offset;
2971
2972         vaddr = dma->vaddr + offset;
2973
2974         if (write) {
2975                 *copied = copy_to_user((void __user *)vaddr, data,
2976                                          count) ? 0 : count;
2977                 if (*copied && iommu->dirty_page_tracking) {
2978                         unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
2979                         /*
2980                          * Bitmap populated with the smallest supported page
2981                          * size
2982                          */
2983                         bitmap_set(dma->bitmap, offset >> pgshift,
2984                                    ((offset + *copied - 1) >> pgshift) -
2985                                    (offset >> pgshift) + 1);
2986                 }
2987         } else
2988                 *copied = copy_from_user(data, (void __user *)vaddr,
2989                                            count) ? 0 : count;
2990         if (kthread)
2991                 kthread_unuse_mm(mm);
2992 out:
2993         mmput(mm);
2994         return *copied ? 0 : -EFAULT;
2995 }
2996
2997 static int vfio_iommu_type1_dma_rw(void *iommu_data, dma_addr_t user_iova,
2998                                    void *data, size_t count, bool write)
2999 {
3000         struct vfio_iommu *iommu = iommu_data;
3001         int ret = 0;
3002         size_t done;
3003
3004         mutex_lock(&iommu->lock);
3005         while (count > 0) {
3006                 ret = vfio_iommu_type1_dma_rw_chunk(iommu, user_iova, data,
3007                                                     count, write, &done);
3008                 if (ret)
3009                         break;
3010
3011                 count -= done;
3012                 data += done;
3013                 user_iova += done;
3014         }
3015
3016         mutex_unlock(&iommu->lock);
3017         return ret;
3018 }
3019
3020 static struct iommu_domain *
3021 vfio_iommu_type1_group_iommu_domain(void *iommu_data,
3022                                     struct iommu_group *iommu_group)
3023 {
3024         struct iommu_domain *domain = ERR_PTR(-ENODEV);
3025         struct vfio_iommu *iommu = iommu_data;
3026         struct vfio_domain *d;
3027
3028         if (!iommu || !iommu_group)
3029                 return ERR_PTR(-EINVAL);
3030
3031         mutex_lock(&iommu->lock);
3032         list_for_each_entry(d, &iommu->domain_list, next) {
3033                 if (find_iommu_group(d, iommu_group)) {
3034                         domain = d->domain;
3035                         break;
3036                 }
3037         }
3038         mutex_unlock(&iommu->lock);
3039
3040         return domain;
3041 }
3042
3043 static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
3044         .name                   = "vfio-iommu-type1",
3045         .owner                  = THIS_MODULE,
3046         .open                   = vfio_iommu_type1_open,
3047         .release                = vfio_iommu_type1_release,
3048         .ioctl                  = vfio_iommu_type1_ioctl,
3049         .attach_group           = vfio_iommu_type1_attach_group,
3050         .detach_group           = vfio_iommu_type1_detach_group,
3051         .pin_pages              = vfio_iommu_type1_pin_pages,
3052         .unpin_pages            = vfio_iommu_type1_unpin_pages,
3053         .register_notifier      = vfio_iommu_type1_register_notifier,
3054         .unregister_notifier    = vfio_iommu_type1_unregister_notifier,
3055         .dma_rw                 = vfio_iommu_type1_dma_rw,
3056         .group_iommu_domain     = vfio_iommu_type1_group_iommu_domain,
3057 };
3058
3059 static int __init vfio_iommu_type1_init(void)
3060 {
3061         return vfio_register_iommu_driver(&vfio_iommu_driver_ops_type1);
3062 }
3063
3064 static void __exit vfio_iommu_type1_cleanup(void)
3065 {
3066         vfio_unregister_iommu_driver(&vfio_iommu_driver_ops_type1);
3067 }
3068
3069 module_init(vfio_iommu_type1_init);
3070 module_exit(vfio_iommu_type1_cleanup);
3071
3072 MODULE_VERSION(DRIVER_VERSION);
3073 MODULE_LICENSE("GPL v2");
3074 MODULE_AUTHOR(DRIVER_AUTHOR);
3075 MODULE_DESCRIPTION(DRIVER_DESC);