Merge branches 'iommu/fixes', 'arm/mediatek', 'arm/smmu', 'arm/exynos', 'unisoc'...
[linux-2.6-microblaze.git] / drivers / vfio / vfio_iommu_type1.c
index 4bb162c..3c8048d 100644 (file)
@@ -189,7 +189,7 @@ static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
 }
 
 static struct rb_node *vfio_find_dma_first_node(struct vfio_iommu *iommu,
-                                               dma_addr_t start, size_t size)
+                                               dma_addr_t start, u64 size)
 {
        struct rb_node *res = NULL;
        struct rb_node *node = iommu->dma_list.rb_node;
@@ -739,6 +739,12 @@ out:
        ret = vfio_lock_acct(dma, lock_acct, false);
 
 unpin_out:
+       if (batch->size == 1 && !batch->offset) {
+               /* May be a VM_PFNMAP pfn, which the batch can't remember. */
+               put_pfn(pfn, dma->prot);
+               batch->size = 0;
+       }
+
        if (ret < 0) {
                if (pinned && !rsvd) {
                        for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
@@ -785,7 +791,12 @@ static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
                return -ENODEV;
 
        ret = vaddr_get_pfns(mm, vaddr, 1, dma->prot, pfn_base, pages);
-       if (ret == 1 && do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
+       if (ret != 1)
+               goto out;
+
+       ret = 0;
+
+       if (do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
                ret = vfio_lock_acct(dma, 1, true);
                if (ret) {
                        put_pfn(*pfn_base, dma->prot);
@@ -797,6 +808,7 @@ static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
                }
        }
 
+out:
        mmput(mm);
        return ret;
 }
@@ -1288,7 +1300,7 @@ static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
        int ret = -EINVAL, retries = 0;
        unsigned long pgshift;
        dma_addr_t iova = unmap->iova;
-       unsigned long size = unmap->size;
+       u64 size = unmap->size;
        bool unmap_all = unmap->flags & VFIO_DMA_UNMAP_FLAG_ALL;
        bool invalidate_vaddr = unmap->flags & VFIO_DMA_UNMAP_FLAG_VADDR;
        struct rb_node *n, *first_n;
@@ -1304,14 +1316,12 @@ static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
        if (unmap_all) {
                if (iova || size)
                        goto unlock;
-               size = SIZE_MAX;
-       } else if (!size || size & (pgsize - 1)) {
+               size = U64_MAX;
+       } else if (!size || size & (pgsize - 1) ||
+                  iova + size - 1 < iova || size > SIZE_MAX) {
                goto unlock;
        }
 
-       if (iova + size - 1 < iova || size > SIZE_MAX)
-               goto unlock;
-
        /* When dirty tracking is enabled, allow only min supported pgsize */
        if ((unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
            (!iommu->dirty_page_tracking || (bitmap->pgsize != pgsize))) {
@@ -2252,7 +2262,7 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
        int ret;
        bool resv_msi, msi_remap;
        phys_addr_t resv_msi_base = 0;
-       struct iommu_domain_geometry geo;
+       struct iommu_domain_geometry *geo;
        LIST_HEAD(iova_copy);
        LIST_HEAD(group_resv_regions);
 
@@ -2320,10 +2330,7 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
        }
 
        if (iommu->nesting) {
-               int attr = 1;
-
-               ret = iommu_domain_set_attr(domain->domain, DOMAIN_ATTR_NESTING,
-                                           &attr);
+               ret = iommu_enable_nesting(domain->domain);
                if (ret)
                        goto out_domain;
        }
@@ -2333,10 +2340,9 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
                goto out_domain;
 
        /* Get aperture info */
-       iommu_domain_get_attr(domain->domain, DOMAIN_ATTR_GEOMETRY, &geo);
-
-       if (vfio_iommu_aper_conflict(iommu, geo.aperture_start,
-                                    geo.aperture_end)) {
+       geo = &domain->domain->geometry;
+       if (vfio_iommu_aper_conflict(iommu, geo->aperture_start,
+                                    geo->aperture_end)) {
                ret = -EINVAL;
                goto out_detach;
        }
@@ -2359,8 +2365,8 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
        if (ret)
                goto out_detach;
 
-       ret = vfio_iommu_aper_resize(&iova_copy, geo.aperture_start,
-                                    geo.aperture_end);
+       ret = vfio_iommu_aper_resize(&iova_copy, geo->aperture_start,
+                                    geo->aperture_end);
        if (ret)
                goto out_detach;
 
@@ -2493,7 +2499,6 @@ static void vfio_iommu_aper_expand(struct vfio_iommu *iommu,
                                   struct list_head *iova_copy)
 {
        struct vfio_domain *domain;
-       struct iommu_domain_geometry geo;
        struct vfio_iova *node;
        dma_addr_t start = 0;
        dma_addr_t end = (dma_addr_t)~0;
@@ -2502,12 +2507,12 @@ static void vfio_iommu_aper_expand(struct vfio_iommu *iommu,
                return;
 
        list_for_each_entry(domain, &iommu->domain_list, next) {
-               iommu_domain_get_attr(domain->domain, DOMAIN_ATTR_GEOMETRY,
-                                     &geo);
-               if (geo.aperture_start > start)
-                       start = geo.aperture_start;
-               if (geo.aperture_end < end)
-                       end = geo.aperture_end;
+               struct iommu_domain_geometry *geo = &domain->domain->geometry;
+
+               if (geo->aperture_start > start)
+                       start = geo->aperture_start;
+               if (geo->aperture_end < end)
+                       end = geo->aperture_end;
        }
 
        /* Modify aperture limits. The new aper is either same or bigger */