netfilter: netns: shrink netns_ct struct
[linux-2.6-microblaze.git] / drivers / infiniband / core / umem_odp.c
1 /*
2  * Copyright (c) 2014 Mellanox Technologies. All rights reserved.
3  *
4  * This software is available to you under a choice of one of two
5  * licenses.  You may choose to be licensed under the terms of the GNU
6  * General Public License (GPL) Version 2, available from the file
7  * COPYING in the main directory of this source tree, or the
8  * OpenIB.org BSD license below:
9  *
10  *     Redistribution and use in source and binary forms, with or
11  *     without modification, are permitted provided that the following
12  *     conditions are met:
13  *
14  *      - Redistributions of source code must retain the above
15  *        copyright notice, this list of conditions and the following
16  *        disclaimer.
17  *
18  *      - Redistributions in binary form must reproduce the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer in the documentation and/or other materials
21  *        provided with the distribution.
22  *
23  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
28  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
29  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30  * SOFTWARE.
31  */
32
33 #include <linux/types.h>
34 #include <linux/sched.h>
35 #include <linux/sched/mm.h>
36 #include <linux/sched/task.h>
37 #include <linux/pid.h>
38 #include <linux/slab.h>
39 #include <linux/export.h>
40 #include <linux/vmalloc.h>
41 #include <linux/hugetlb.h>
42 #include <linux/interval_tree_generic.h>
43
44 #include <rdma/ib_verbs.h>
45 #include <rdma/ib_umem.h>
46 #include <rdma/ib_umem_odp.h>
47
48 /*
49  * The ib_umem list keeps track of memory regions for which the HW
50  * device request to receive notification when the related memory
51  * mapping is changed.
52  *
53  * ib_umem_lock protects the list.
54  */
55
56 static u64 node_start(struct umem_odp_node *n)
57 {
58         struct ib_umem_odp *umem_odp =
59                         container_of(n, struct ib_umem_odp, interval_tree);
60
61         return ib_umem_start(&umem_odp->umem);
62 }
63
64 /* Note that the representation of the intervals in the interval tree
65  * considers the ending point as contained in the interval, while the
66  * function ib_umem_end returns the first address which is not contained
67  * in the umem.
68  */
69 static u64 node_last(struct umem_odp_node *n)
70 {
71         struct ib_umem_odp *umem_odp =
72                         container_of(n, struct ib_umem_odp, interval_tree);
73
74         return ib_umem_end(&umem_odp->umem) - 1;
75 }
76
77 INTERVAL_TREE_DEFINE(struct umem_odp_node, rb, u64, __subtree_last,
78                      node_start, node_last, static, rbt_ib_umem)
79
80 static void ib_umem_notifier_start_account(struct ib_umem_odp *umem_odp)
81 {
82         mutex_lock(&umem_odp->umem_mutex);
83         if (umem_odp->notifiers_count++ == 0)
84                 /*
85                  * Initialize the completion object for waiting on
86                  * notifiers. Since notifier_count is zero, no one should be
87                  * waiting right now.
88                  */
89                 reinit_completion(&umem_odp->notifier_completion);
90         mutex_unlock(&umem_odp->umem_mutex);
91 }
92
93 static void ib_umem_notifier_end_account(struct ib_umem_odp *umem_odp)
94 {
95         mutex_lock(&umem_odp->umem_mutex);
96         /*
97          * This sequence increase will notify the QP page fault that the page
98          * that is going to be mapped in the spte could have been freed.
99          */
100         ++umem_odp->notifiers_seq;
101         if (--umem_odp->notifiers_count == 0)
102                 complete_all(&umem_odp->notifier_completion);
103         mutex_unlock(&umem_odp->umem_mutex);
104 }
105
106 static int ib_umem_notifier_release_trampoline(struct ib_umem_odp *umem_odp,
107                                                u64 start, u64 end, void *cookie)
108 {
109         struct ib_umem *umem = &umem_odp->umem;
110
111         /*
112          * Increase the number of notifiers running, to
113          * prevent any further fault handling on this MR.
114          */
115         ib_umem_notifier_start_account(umem_odp);
116         umem_odp->dying = 1;
117         /* Make sure that the fact the umem is dying is out before we release
118          * all pending page faults. */
119         smp_wmb();
120         complete_all(&umem_odp->notifier_completion);
121         umem->context->invalidate_range(umem_odp, ib_umem_start(umem),
122                                         ib_umem_end(umem));
123         return 0;
124 }
125
126 static void ib_umem_notifier_release(struct mmu_notifier *mn,
127                                      struct mm_struct *mm)
128 {
129         struct ib_ucontext_per_mm *per_mm =
130                 container_of(mn, struct ib_ucontext_per_mm, mn);
131
132         down_read(&per_mm->umem_rwsem);
133         if (per_mm->active)
134                 rbt_ib_umem_for_each_in_range(
135                         &per_mm->umem_tree, 0, ULLONG_MAX,
136                         ib_umem_notifier_release_trampoline, true, NULL);
137         up_read(&per_mm->umem_rwsem);
138 }
139
140 static int invalidate_page_trampoline(struct ib_umem_odp *item, u64 start,
141                                       u64 end, void *cookie)
142 {
143         ib_umem_notifier_start_account(item);
144         item->umem.context->invalidate_range(item, start, start + PAGE_SIZE);
145         ib_umem_notifier_end_account(item);
146         return 0;
147 }
148
149 static int invalidate_range_start_trampoline(struct ib_umem_odp *item,
150                                              u64 start, u64 end, void *cookie)
151 {
152         ib_umem_notifier_start_account(item);
153         item->umem.context->invalidate_range(item, start, end);
154         return 0;
155 }
156
157 static int ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
158                                                     struct mm_struct *mm,
159                                                     unsigned long start,
160                                                     unsigned long end,
161                                                     bool blockable)
162 {
163         struct ib_ucontext_per_mm *per_mm =
164                 container_of(mn, struct ib_ucontext_per_mm, mn);
165
166         if (blockable)
167                 down_read(&per_mm->umem_rwsem);
168         else if (!down_read_trylock(&per_mm->umem_rwsem))
169                 return -EAGAIN;
170
171         if (!per_mm->active) {
172                 up_read(&per_mm->umem_rwsem);
173                 /*
174                  * At this point active is permanently set and visible to this
175                  * CPU without a lock, that fact is relied on to skip the unlock
176                  * in range_end.
177                  */
178                 return 0;
179         }
180
181         return rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, start, end,
182                                              invalidate_range_start_trampoline,
183                                              blockable, NULL);
184 }
185
186 static int invalidate_range_end_trampoline(struct ib_umem_odp *item, u64 start,
187                                            u64 end, void *cookie)
188 {
189         ib_umem_notifier_end_account(item);
190         return 0;
191 }
192
193 static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
194                                                   struct mm_struct *mm,
195                                                   unsigned long start,
196                                                   unsigned long end)
197 {
198         struct ib_ucontext_per_mm *per_mm =
199                 container_of(mn, struct ib_ucontext_per_mm, mn);
200
201         if (unlikely(!per_mm->active))
202                 return;
203
204         rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, start,
205                                       end,
206                                       invalidate_range_end_trampoline, true, NULL);
207         up_read(&per_mm->umem_rwsem);
208 }
209
210 static const struct mmu_notifier_ops ib_umem_notifiers = {
211         .release                    = ib_umem_notifier_release,
212         .invalidate_range_start     = ib_umem_notifier_invalidate_range_start,
213         .invalidate_range_end       = ib_umem_notifier_invalidate_range_end,
214 };
215
216 static void add_umem_to_per_mm(struct ib_umem_odp *umem_odp)
217 {
218         struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
219         struct ib_umem *umem = &umem_odp->umem;
220
221         down_write(&per_mm->umem_rwsem);
222         if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
223                 rbt_ib_umem_insert(&umem_odp->interval_tree,
224                                    &per_mm->umem_tree);
225         up_write(&per_mm->umem_rwsem);
226 }
227
228 static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp)
229 {
230         struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
231         struct ib_umem *umem = &umem_odp->umem;
232
233         down_write(&per_mm->umem_rwsem);
234         if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
235                 rbt_ib_umem_remove(&umem_odp->interval_tree,
236                                    &per_mm->umem_tree);
237         complete_all(&umem_odp->notifier_completion);
238
239         up_write(&per_mm->umem_rwsem);
240 }
241
242 static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
243                                                struct mm_struct *mm)
244 {
245         struct ib_ucontext_per_mm *per_mm;
246         int ret;
247
248         per_mm = kzalloc(sizeof(*per_mm), GFP_KERNEL);
249         if (!per_mm)
250                 return ERR_PTR(-ENOMEM);
251
252         per_mm->context = ctx;
253         per_mm->mm = mm;
254         per_mm->umem_tree = RB_ROOT_CACHED;
255         init_rwsem(&per_mm->umem_rwsem);
256         per_mm->active = ctx->invalidate_range;
257
258         rcu_read_lock();
259         per_mm->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
260         rcu_read_unlock();
261
262         WARN_ON(mm != current->mm);
263
264         per_mm->mn.ops = &ib_umem_notifiers;
265         ret = mmu_notifier_register(&per_mm->mn, per_mm->mm);
266         if (ret) {
267                 dev_err(&ctx->device->dev,
268                         "Failed to register mmu_notifier %d\n", ret);
269                 goto out_pid;
270         }
271
272         list_add(&per_mm->ucontext_list, &ctx->per_mm_list);
273         return per_mm;
274
275 out_pid:
276         put_pid(per_mm->tgid);
277         kfree(per_mm);
278         return ERR_PTR(ret);
279 }
280
281 static int get_per_mm(struct ib_umem_odp *umem_odp)
282 {
283         struct ib_ucontext *ctx = umem_odp->umem.context;
284         struct ib_ucontext_per_mm *per_mm;
285
286         /*
287          * Generally speaking we expect only one or two per_mm in this list,
288          * so no reason to optimize this search today.
289          */
290         mutex_lock(&ctx->per_mm_list_lock);
291         list_for_each_entry(per_mm, &ctx->per_mm_list, ucontext_list) {
292                 if (per_mm->mm == umem_odp->umem.owning_mm)
293                         goto found;
294         }
295
296         per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm);
297         if (IS_ERR(per_mm)) {
298                 mutex_unlock(&ctx->per_mm_list_lock);
299                 return PTR_ERR(per_mm);
300         }
301
302 found:
303         umem_odp->per_mm = per_mm;
304         per_mm->odp_mrs_count++;
305         mutex_unlock(&ctx->per_mm_list_lock);
306
307         return 0;
308 }
309
310 static void free_per_mm(struct rcu_head *rcu)
311 {
312         kfree(container_of(rcu, struct ib_ucontext_per_mm, rcu));
313 }
314
315 void put_per_mm(struct ib_umem_odp *umem_odp)
316 {
317         struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
318         struct ib_ucontext *ctx = umem_odp->umem.context;
319         bool need_free;
320
321         mutex_lock(&ctx->per_mm_list_lock);
322         umem_odp->per_mm = NULL;
323         per_mm->odp_mrs_count--;
324         need_free = per_mm->odp_mrs_count == 0;
325         if (need_free)
326                 list_del(&per_mm->ucontext_list);
327         mutex_unlock(&ctx->per_mm_list_lock);
328
329         if (!need_free)
330                 return;
331
332         /*
333          * NOTE! mmu_notifier_unregister() can happen between a start/end
334          * callback, resulting in an start/end, and thus an unbalanced
335          * lock. This doesn't really matter to us since we are about to kfree
336          * the memory that holds the lock, however LOCKDEP doesn't like this.
337          */
338         down_write(&per_mm->umem_rwsem);
339         per_mm->active = false;
340         up_write(&per_mm->umem_rwsem);
341
342         WARN_ON(!RB_EMPTY_ROOT(&per_mm->umem_tree.rb_root));
343         mmu_notifier_unregister_no_release(&per_mm->mn, per_mm->mm);
344         put_pid(per_mm->tgid);
345         mmu_notifier_call_srcu(&per_mm->rcu, free_per_mm);
346 }
347
348 struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext_per_mm *per_mm,
349                                       unsigned long addr, size_t size)
350 {
351         struct ib_ucontext *ctx = per_mm->context;
352         struct ib_umem_odp *odp_data;
353         struct ib_umem *umem;
354         int pages = size >> PAGE_SHIFT;
355         int ret;
356
357         odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
358         if (!odp_data)
359                 return ERR_PTR(-ENOMEM);
360         umem = &odp_data->umem;
361         umem->context    = ctx;
362         umem->length     = size;
363         umem->address    = addr;
364         umem->page_shift = PAGE_SHIFT;
365         umem->writable   = 1;
366         umem->is_odp = 1;
367         odp_data->per_mm = per_mm;
368
369         mutex_init(&odp_data->umem_mutex);
370         init_completion(&odp_data->notifier_completion);
371
372         odp_data->page_list =
373                 vzalloc(array_size(pages, sizeof(*odp_data->page_list)));
374         if (!odp_data->page_list) {
375                 ret = -ENOMEM;
376                 goto out_odp_data;
377         }
378
379         odp_data->dma_list =
380                 vzalloc(array_size(pages, sizeof(*odp_data->dma_list)));
381         if (!odp_data->dma_list) {
382                 ret = -ENOMEM;
383                 goto out_page_list;
384         }
385
386         /*
387          * Caller must ensure that the umem_odp that the per_mm came from
388          * cannot be freed during the call to ib_alloc_odp_umem.
389          */
390         mutex_lock(&ctx->per_mm_list_lock);
391         per_mm->odp_mrs_count++;
392         mutex_unlock(&ctx->per_mm_list_lock);
393         add_umem_to_per_mm(odp_data);
394
395         return odp_data;
396
397 out_page_list:
398         vfree(odp_data->page_list);
399 out_odp_data:
400         kfree(odp_data);
401         return ERR_PTR(ret);
402 }
403 EXPORT_SYMBOL(ib_alloc_odp_umem);
404
405 int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
406 {
407         struct ib_umem *umem = &umem_odp->umem;
408         /*
409          * NOTE: This must called in a process context where umem->owning_mm
410          * == current->mm
411          */
412         struct mm_struct *mm = umem->owning_mm;
413         int ret_val;
414
415         if (access & IB_ACCESS_HUGETLB) {
416                 struct vm_area_struct *vma;
417                 struct hstate *h;
418
419                 down_read(&mm->mmap_sem);
420                 vma = find_vma(mm, ib_umem_start(umem));
421                 if (!vma || !is_vm_hugetlb_page(vma)) {
422                         up_read(&mm->mmap_sem);
423                         return -EINVAL;
424                 }
425                 h = hstate_vma(vma);
426                 umem->page_shift = huge_page_shift(h);
427                 up_read(&mm->mmap_sem);
428                 umem->hugetlb = 1;
429         } else {
430                 umem->hugetlb = 0;
431         }
432
433         mutex_init(&umem_odp->umem_mutex);
434
435         init_completion(&umem_odp->notifier_completion);
436
437         if (ib_umem_num_pages(umem)) {
438                 umem_odp->page_list =
439                         vzalloc(array_size(sizeof(*umem_odp->page_list),
440                                            ib_umem_num_pages(umem)));
441                 if (!umem_odp->page_list)
442                         return -ENOMEM;
443
444                 umem_odp->dma_list =
445                         vzalloc(array_size(sizeof(*umem_odp->dma_list),
446                                            ib_umem_num_pages(umem)));
447                 if (!umem_odp->dma_list) {
448                         ret_val = -ENOMEM;
449                         goto out_page_list;
450                 }
451         }
452
453         ret_val = get_per_mm(umem_odp);
454         if (ret_val)
455                 goto out_dma_list;
456         add_umem_to_per_mm(umem_odp);
457
458         return 0;
459
460 out_dma_list:
461         vfree(umem_odp->dma_list);
462 out_page_list:
463         vfree(umem_odp->page_list);
464         return ret_val;
465 }
466
467 void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
468 {
469         struct ib_umem *umem = &umem_odp->umem;
470
471         /*
472          * Ensure that no more pages are mapped in the umem.
473          *
474          * It is the driver's responsibility to ensure, before calling us,
475          * that the hardware will not attempt to access the MR any more.
476          */
477         ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem),
478                                     ib_umem_end(umem));
479
480         remove_umem_from_per_mm(umem_odp);
481         put_per_mm(umem_odp);
482         vfree(umem_odp->dma_list);
483         vfree(umem_odp->page_list);
484 }
485
486 /*
487  * Map for DMA and insert a single page into the on-demand paging page tables.
488  *
489  * @umem: the umem to insert the page to.
490  * @page_index: index in the umem to add the page to.
491  * @page: the page struct to map and add.
492  * @access_mask: access permissions needed for this page.
493  * @current_seq: sequence number for synchronization with invalidations.
494  *               the sequence number is taken from
495  *               umem_odp->notifiers_seq.
496  *
497  * The function returns -EFAULT if the DMA mapping operation fails. It returns
498  * -EAGAIN if a concurrent invalidation prevents us from updating the page.
499  *
500  * The page is released via put_page even if the operation failed. For
501  * on-demand pinning, the page is released whenever it isn't stored in the
502  * umem.
503  */
504 static int ib_umem_odp_map_dma_single_page(
505                 struct ib_umem_odp *umem_odp,
506                 int page_index,
507                 struct page *page,
508                 u64 access_mask,
509                 unsigned long current_seq)
510 {
511         struct ib_umem *umem = &umem_odp->umem;
512         struct ib_device *dev = umem->context->device;
513         dma_addr_t dma_addr;
514         int stored_page = 0;
515         int remove_existing_mapping = 0;
516         int ret = 0;
517
518         /*
519          * Note: we avoid writing if seq is different from the initial seq, to
520          * handle case of a racing notifier. This check also allows us to bail
521          * early if we have a notifier running in parallel with us.
522          */
523         if (ib_umem_mmu_notifier_retry(umem_odp, current_seq)) {
524                 ret = -EAGAIN;
525                 goto out;
526         }
527         if (!(umem_odp->dma_list[page_index])) {
528                 dma_addr = ib_dma_map_page(dev,
529                                            page,
530                                            0, BIT(umem->page_shift),
531                                            DMA_BIDIRECTIONAL);
532                 if (ib_dma_mapping_error(dev, dma_addr)) {
533                         ret = -EFAULT;
534                         goto out;
535                 }
536                 umem_odp->dma_list[page_index] = dma_addr | access_mask;
537                 umem_odp->page_list[page_index] = page;
538                 umem->npages++;
539                 stored_page = 1;
540         } else if (umem_odp->page_list[page_index] == page) {
541                 umem_odp->dma_list[page_index] |= access_mask;
542         } else {
543                 pr_err("error: got different pages in IB device and from get_user_pages. IB device page: %p, gup page: %p\n",
544                        umem_odp->page_list[page_index], page);
545                 /* Better remove the mapping now, to prevent any further
546                  * damage. */
547                 remove_existing_mapping = 1;
548         }
549
550 out:
551         /* On Demand Paging - avoid pinning the page */
552         if (umem->context->invalidate_range || !stored_page)
553                 put_page(page);
554
555         if (remove_existing_mapping && umem->context->invalidate_range) {
556                 invalidate_page_trampoline(
557                         umem_odp,
558                         ib_umem_start(umem) + (page_index >> umem->page_shift),
559                         ib_umem_start(umem) + ((page_index + 1) >>
560                                                umem->page_shift),
561                         NULL);
562                 ret = -EAGAIN;
563         }
564
565         return ret;
566 }
567
568 /**
569  * ib_umem_odp_map_dma_pages - Pin and DMA map userspace memory in an ODP MR.
570  *
571  * Pins the range of pages passed in the argument, and maps them to
572  * DMA addresses. The DMA addresses of the mapped pages is updated in
573  * umem_odp->dma_list.
574  *
575  * Returns the number of pages mapped in success, negative error code
576  * for failure.
577  * An -EAGAIN error code is returned when a concurrent mmu notifier prevents
578  * the function from completing its task.
579  * An -ENOENT error code indicates that userspace process is being terminated
580  * and mm was already destroyed.
581  * @umem_odp: the umem to map and pin
582  * @user_virt: the address from which we need to map.
583  * @bcnt: the minimal number of bytes to pin and map. The mapping might be
584  *        bigger due to alignment, and may also be smaller in case of an error
585  *        pinning or mapping a page. The actual pages mapped is returned in
586  *        the return value.
587  * @access_mask: bit mask of the requested access permissions for the given
588  *               range.
589  * @current_seq: the MMU notifiers sequance value for synchronization with
590  *               invalidations. the sequance number is read from
591  *               umem_odp->notifiers_seq before calling this function
592  */
593 int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
594                               u64 bcnt, u64 access_mask,
595                               unsigned long current_seq)
596 {
597         struct ib_umem *umem = &umem_odp->umem;
598         struct task_struct *owning_process  = NULL;
599         struct mm_struct *owning_mm = umem_odp->umem.owning_mm;
600         struct page       **local_page_list = NULL;
601         u64 page_mask, off;
602         int j, k, ret = 0, start_idx, npages = 0, page_shift;
603         unsigned int flags = 0;
604         phys_addr_t p = 0;
605
606         if (access_mask == 0)
607                 return -EINVAL;
608
609         if (user_virt < ib_umem_start(umem) ||
610             user_virt + bcnt > ib_umem_end(umem))
611                 return -EFAULT;
612
613         local_page_list = (struct page **)__get_free_page(GFP_KERNEL);
614         if (!local_page_list)
615                 return -ENOMEM;
616
617         page_shift = umem->page_shift;
618         page_mask = ~(BIT(page_shift) - 1);
619         off = user_virt & (~page_mask);
620         user_virt = user_virt & page_mask;
621         bcnt += off; /* Charge for the first page offset as well. */
622
623         /*
624          * owning_process is allowed to be NULL, this means somehow the mm is
625          * existing beyond the lifetime of the originating process.. Presumably
626          * mmget_not_zero will fail in this case.
627          */
628         owning_process = get_pid_task(umem_odp->per_mm->tgid, PIDTYPE_PID);
629         if (WARN_ON(!mmget_not_zero(umem_odp->umem.owning_mm))) {
630                 ret = -EINVAL;
631                 goto out_put_task;
632         }
633
634         if (access_mask & ODP_WRITE_ALLOWED_BIT)
635                 flags |= FOLL_WRITE;
636
637         start_idx = (user_virt - ib_umem_start(umem)) >> page_shift;
638         k = start_idx;
639
640         while (bcnt > 0) {
641                 const size_t gup_num_pages = min_t(size_t,
642                                 (bcnt + BIT(page_shift) - 1) >> page_shift,
643                                 PAGE_SIZE / sizeof(struct page *));
644
645                 down_read(&owning_mm->mmap_sem);
646                 /*
647                  * Note: this might result in redundent page getting. We can
648                  * avoid this by checking dma_list to be 0 before calling
649                  * get_user_pages. However, this make the code much more
650                  * complex (and doesn't gain us much performance in most use
651                  * cases).
652                  */
653                 npages = get_user_pages_remote(owning_process, owning_mm,
654                                 user_virt, gup_num_pages,
655                                 flags, local_page_list, NULL, NULL);
656                 up_read(&owning_mm->mmap_sem);
657
658                 if (npages < 0)
659                         break;
660
661                 bcnt -= min_t(size_t, npages << PAGE_SHIFT, bcnt);
662                 mutex_lock(&umem_odp->umem_mutex);
663                 for (j = 0; j < npages; j++, user_virt += PAGE_SIZE) {
664                         if (user_virt & ~page_mask) {
665                                 p += PAGE_SIZE;
666                                 if (page_to_phys(local_page_list[j]) != p) {
667                                         ret = -EFAULT;
668                                         break;
669                                 }
670                                 put_page(local_page_list[j]);
671                                 continue;
672                         }
673
674                         ret = ib_umem_odp_map_dma_single_page(
675                                         umem_odp, k, local_page_list[j],
676                                         access_mask, current_seq);
677                         if (ret < 0)
678                                 break;
679
680                         p = page_to_phys(local_page_list[j]);
681                         k++;
682                 }
683                 mutex_unlock(&umem_odp->umem_mutex);
684
685                 if (ret < 0) {
686                         /* Release left over pages when handling errors. */
687                         for (++j; j < npages; ++j)
688                                 put_page(local_page_list[j]);
689                         break;
690                 }
691         }
692
693         if (ret >= 0) {
694                 if (npages < 0 && k == start_idx)
695                         ret = npages;
696                 else
697                         ret = k - start_idx;
698         }
699
700         mmput(owning_mm);
701 out_put_task:
702         if (owning_process)
703                 put_task_struct(owning_process);
704         free_page((unsigned long)local_page_list);
705         return ret;
706 }
707 EXPORT_SYMBOL(ib_umem_odp_map_dma_pages);
708
709 void ib_umem_odp_unmap_dma_pages(struct ib_umem_odp *umem_odp, u64 virt,
710                                  u64 bound)
711 {
712         struct ib_umem *umem = &umem_odp->umem;
713         int idx;
714         u64 addr;
715         struct ib_device *dev = umem->context->device;
716
717         virt  = max_t(u64, virt,  ib_umem_start(umem));
718         bound = min_t(u64, bound, ib_umem_end(umem));
719         /* Note that during the run of this function, the
720          * notifiers_count of the MR is > 0, preventing any racing
721          * faults from completion. We might be racing with other
722          * invalidations, so we must make sure we free each page only
723          * once. */
724         mutex_lock(&umem_odp->umem_mutex);
725         for (addr = virt; addr < bound; addr += BIT(umem->page_shift)) {
726                 idx = (addr - ib_umem_start(umem)) >> umem->page_shift;
727                 if (umem_odp->page_list[idx]) {
728                         struct page *page = umem_odp->page_list[idx];
729                         dma_addr_t dma = umem_odp->dma_list[idx];
730                         dma_addr_t dma_addr = dma & ODP_DMA_ADDR_MASK;
731
732                         WARN_ON(!dma_addr);
733
734                         ib_dma_unmap_page(dev, dma_addr, PAGE_SIZE,
735                                           DMA_BIDIRECTIONAL);
736                         if (dma & ODP_WRITE_ALLOWED_BIT) {
737                                 struct page *head_page = compound_head(page);
738                                 /*
739                                  * set_page_dirty prefers being called with
740                                  * the page lock. However, MMU notifiers are
741                                  * called sometimes with and sometimes without
742                                  * the lock. We rely on the umem_mutex instead
743                                  * to prevent other mmu notifiers from
744                                  * continuing and allowing the page mapping to
745                                  * be removed.
746                                  */
747                                 set_page_dirty(head_page);
748                         }
749                         /* on demand pinning support */
750                         if (!umem->context->invalidate_range)
751                                 put_page(page);
752                         umem_odp->page_list[idx] = NULL;
753                         umem_odp->dma_list[idx] = 0;
754                         umem->npages--;
755                 }
756         }
757         mutex_unlock(&umem_odp->umem_mutex);
758 }
759 EXPORT_SYMBOL(ib_umem_odp_unmap_dma_pages);
760
761 /* @last is not a part of the interval. See comment for function
762  * node_last.
763  */
764 int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root,
765                                   u64 start, u64 last,
766                                   umem_call_back cb,
767                                   bool blockable,
768                                   void *cookie)
769 {
770         int ret_val = 0;
771         struct umem_odp_node *node, *next;
772         struct ib_umem_odp *umem;
773
774         if (unlikely(start == last))
775                 return ret_val;
776
777         for (node = rbt_ib_umem_iter_first(root, start, last - 1);
778                         node; node = next) {
779                 /* TODO move the blockable decision up to the callback */
780                 if (!blockable)
781                         return -EAGAIN;
782                 next = rbt_ib_umem_iter_next(node, start, last - 1);
783                 umem = container_of(node, struct ib_umem_odp, interval_tree);
784                 ret_val = cb(umem, start, last, cookie) || ret_val;
785         }
786
787         return ret_val;
788 }
789 EXPORT_SYMBOL(rbt_ib_umem_for_each_in_range);
790
791 struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root_cached *root,
792                                        u64 addr, u64 length)
793 {
794         struct umem_odp_node *node;
795
796         node = rbt_ib_umem_iter_first(root, addr, addr + length - 1);
797         if (node)
798                 return container_of(node, struct ib_umem_odp, interval_tree);
799         return NULL;
800
801 }
802 EXPORT_SYMBOL(rbt_ib_umem_lookup);