Merge tag 'mlx5-updates-2023-10-10' of git://git.kernel.org/pub/scm/linux/kernel...
[linux-2.6-microblaze.git] / kernel / bpf / task_iter.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright (c) 2020 Facebook */
3
4 #include <linux/init.h>
5 #include <linux/namei.h>
6 #include <linux/pid_namespace.h>
7 #include <linux/fs.h>
8 #include <linux/fdtable.h>
9 #include <linux/filter.h>
10 #include <linux/bpf_mem_alloc.h>
11 #include <linux/btf_ids.h>
12 #include <linux/mm_types.h>
13 #include "mmap_unlock_work.h"
14
15 static const char * const iter_task_type_names[] = {
16         "ALL",
17         "TID",
18         "PID",
19 };
20
21 struct bpf_iter_seq_task_common {
22         struct pid_namespace *ns;
23         enum bpf_iter_task_type type;
24         u32 pid;
25         u32 pid_visiting;
26 };
27
28 struct bpf_iter_seq_task_info {
29         /* The first field must be struct bpf_iter_seq_task_common.
30          * this is assumed by {init, fini}_seq_pidns() callback functions.
31          */
32         struct bpf_iter_seq_task_common common;
33         u32 tid;
34 };
35
36 static struct task_struct *task_group_seq_get_next(struct bpf_iter_seq_task_common *common,
37                                                    u32 *tid,
38                                                    bool skip_if_dup_files)
39 {
40         struct task_struct *task;
41         struct pid *pid;
42         u32 next_tid;
43
44         if (!*tid) {
45                 /* The first time, the iterator calls this function. */
46                 pid = find_pid_ns(common->pid, common->ns);
47                 task = get_pid_task(pid, PIDTYPE_TGID);
48                 if (!task)
49                         return NULL;
50
51                 *tid = common->pid;
52                 common->pid_visiting = common->pid;
53
54                 return task;
55         }
56
57         /* If the control returns to user space and comes back to the
58          * kernel again, *tid and common->pid_visiting should be the
59          * same for task_seq_start() to pick up the correct task.
60          */
61         if (*tid == common->pid_visiting) {
62                 pid = find_pid_ns(common->pid_visiting, common->ns);
63                 task = get_pid_task(pid, PIDTYPE_PID);
64
65                 return task;
66         }
67
68         task = find_task_by_pid_ns(common->pid_visiting, common->ns);
69         if (!task)
70                 return NULL;
71
72 retry:
73         task = next_thread(task);
74
75         next_tid = __task_pid_nr_ns(task, PIDTYPE_PID, common->ns);
76         if (!next_tid || next_tid == common->pid) {
77                 /* Run out of tasks of a process.  The tasks of a
78                  * thread_group are linked as circular linked list.
79                  */
80                 return NULL;
81         }
82
83         if (skip_if_dup_files && task->files == task->group_leader->files)
84                 goto retry;
85
86         *tid = common->pid_visiting = next_tid;
87         get_task_struct(task);
88         return task;
89 }
90
91 static struct task_struct *task_seq_get_next(struct bpf_iter_seq_task_common *common,
92                                              u32 *tid,
93                                              bool skip_if_dup_files)
94 {
95         struct task_struct *task = NULL;
96         struct pid *pid;
97
98         if (common->type == BPF_TASK_ITER_TID) {
99                 if (*tid && *tid != common->pid)
100                         return NULL;
101                 rcu_read_lock();
102                 pid = find_pid_ns(common->pid, common->ns);
103                 if (pid) {
104                         task = get_pid_task(pid, PIDTYPE_TGID);
105                         *tid = common->pid;
106                 }
107                 rcu_read_unlock();
108
109                 return task;
110         }
111
112         if (common->type == BPF_TASK_ITER_TGID) {
113                 rcu_read_lock();
114                 task = task_group_seq_get_next(common, tid, skip_if_dup_files);
115                 rcu_read_unlock();
116
117                 return task;
118         }
119
120         rcu_read_lock();
121 retry:
122         pid = find_ge_pid(*tid, common->ns);
123         if (pid) {
124                 *tid = pid_nr_ns(pid, common->ns);
125                 task = get_pid_task(pid, PIDTYPE_PID);
126                 if (!task) {
127                         ++*tid;
128                         goto retry;
129                 } else if (skip_if_dup_files && !thread_group_leader(task) &&
130                            task->files == task->group_leader->files) {
131                         put_task_struct(task);
132                         task = NULL;
133                         ++*tid;
134                         goto retry;
135                 }
136         }
137         rcu_read_unlock();
138
139         return task;
140 }
141
142 static void *task_seq_start(struct seq_file *seq, loff_t *pos)
143 {
144         struct bpf_iter_seq_task_info *info = seq->private;
145         struct task_struct *task;
146
147         task = task_seq_get_next(&info->common, &info->tid, false);
148         if (!task)
149                 return NULL;
150
151         if (*pos == 0)
152                 ++*pos;
153         return task;
154 }
155
156 static void *task_seq_next(struct seq_file *seq, void *v, loff_t *pos)
157 {
158         struct bpf_iter_seq_task_info *info = seq->private;
159         struct task_struct *task;
160
161         ++*pos;
162         ++info->tid;
163         put_task_struct((struct task_struct *)v);
164         task = task_seq_get_next(&info->common, &info->tid, false);
165         if (!task)
166                 return NULL;
167
168         return task;
169 }
170
171 struct bpf_iter__task {
172         __bpf_md_ptr(struct bpf_iter_meta *, meta);
173         __bpf_md_ptr(struct task_struct *, task);
174 };
175
176 DEFINE_BPF_ITER_FUNC(task, struct bpf_iter_meta *meta, struct task_struct *task)
177
178 static int __task_seq_show(struct seq_file *seq, struct task_struct *task,
179                            bool in_stop)
180 {
181         struct bpf_iter_meta meta;
182         struct bpf_iter__task ctx;
183         struct bpf_prog *prog;
184
185         meta.seq = seq;
186         prog = bpf_iter_get_info(&meta, in_stop);
187         if (!prog)
188                 return 0;
189
190         ctx.meta = &meta;
191         ctx.task = task;
192         return bpf_iter_run_prog(prog, &ctx);
193 }
194
195 static int task_seq_show(struct seq_file *seq, void *v)
196 {
197         return __task_seq_show(seq, v, false);
198 }
199
200 static void task_seq_stop(struct seq_file *seq, void *v)
201 {
202         if (!v)
203                 (void)__task_seq_show(seq, v, true);
204         else
205                 put_task_struct((struct task_struct *)v);
206 }
207
208 static int bpf_iter_attach_task(struct bpf_prog *prog,
209                                 union bpf_iter_link_info *linfo,
210                                 struct bpf_iter_aux_info *aux)
211 {
212         unsigned int flags;
213         struct pid *pid;
214         pid_t tgid;
215
216         if ((!!linfo->task.tid + !!linfo->task.pid + !!linfo->task.pid_fd) > 1)
217                 return -EINVAL;
218
219         aux->task.type = BPF_TASK_ITER_ALL;
220         if (linfo->task.tid != 0) {
221                 aux->task.type = BPF_TASK_ITER_TID;
222                 aux->task.pid = linfo->task.tid;
223         }
224         if (linfo->task.pid != 0) {
225                 aux->task.type = BPF_TASK_ITER_TGID;
226                 aux->task.pid = linfo->task.pid;
227         }
228         if (linfo->task.pid_fd != 0) {
229                 aux->task.type = BPF_TASK_ITER_TGID;
230
231                 pid = pidfd_get_pid(linfo->task.pid_fd, &flags);
232                 if (IS_ERR(pid))
233                         return PTR_ERR(pid);
234
235                 tgid = pid_nr_ns(pid, task_active_pid_ns(current));
236                 aux->task.pid = tgid;
237                 put_pid(pid);
238         }
239
240         return 0;
241 }
242
243 static const struct seq_operations task_seq_ops = {
244         .start  = task_seq_start,
245         .next   = task_seq_next,
246         .stop   = task_seq_stop,
247         .show   = task_seq_show,
248 };
249
250 struct bpf_iter_seq_task_file_info {
251         /* The first field must be struct bpf_iter_seq_task_common.
252          * this is assumed by {init, fini}_seq_pidns() callback functions.
253          */
254         struct bpf_iter_seq_task_common common;
255         struct task_struct *task;
256         u32 tid;
257         u32 fd;
258 };
259
260 static struct file *
261 task_file_seq_get_next(struct bpf_iter_seq_task_file_info *info)
262 {
263         u32 saved_tid = info->tid;
264         struct task_struct *curr_task;
265         unsigned int curr_fd = info->fd;
266
267         /* If this function returns a non-NULL file object,
268          * it held a reference to the task/file.
269          * Otherwise, it does not hold any reference.
270          */
271 again:
272         if (info->task) {
273                 curr_task = info->task;
274                 curr_fd = info->fd;
275         } else {
276                 curr_task = task_seq_get_next(&info->common, &info->tid, true);
277                 if (!curr_task) {
278                         info->task = NULL;
279                         return NULL;
280                 }
281
282                 /* set info->task */
283                 info->task = curr_task;
284                 if (saved_tid == info->tid)
285                         curr_fd = info->fd;
286                 else
287                         curr_fd = 0;
288         }
289
290         rcu_read_lock();
291         for (;; curr_fd++) {
292                 struct file *f;
293                 f = task_lookup_next_fd_rcu(curr_task, &curr_fd);
294                 if (!f)
295                         break;
296                 if (!get_file_rcu(f))
297                         continue;
298
299                 /* set info->fd */
300                 info->fd = curr_fd;
301                 rcu_read_unlock();
302                 return f;
303         }
304
305         /* the current task is done, go to the next task */
306         rcu_read_unlock();
307         put_task_struct(curr_task);
308
309         if (info->common.type == BPF_TASK_ITER_TID) {
310                 info->task = NULL;
311                 return NULL;
312         }
313
314         info->task = NULL;
315         info->fd = 0;
316         saved_tid = ++(info->tid);
317         goto again;
318 }
319
320 static void *task_file_seq_start(struct seq_file *seq, loff_t *pos)
321 {
322         struct bpf_iter_seq_task_file_info *info = seq->private;
323         struct file *file;
324
325         info->task = NULL;
326         file = task_file_seq_get_next(info);
327         if (file && *pos == 0)
328                 ++*pos;
329
330         return file;
331 }
332
333 static void *task_file_seq_next(struct seq_file *seq, void *v, loff_t *pos)
334 {
335         struct bpf_iter_seq_task_file_info *info = seq->private;
336
337         ++*pos;
338         ++info->fd;
339         fput((struct file *)v);
340         return task_file_seq_get_next(info);
341 }
342
343 struct bpf_iter__task_file {
344         __bpf_md_ptr(struct bpf_iter_meta *, meta);
345         __bpf_md_ptr(struct task_struct *, task);
346         u32 fd __aligned(8);
347         __bpf_md_ptr(struct file *, file);
348 };
349
350 DEFINE_BPF_ITER_FUNC(task_file, struct bpf_iter_meta *meta,
351                      struct task_struct *task, u32 fd,
352                      struct file *file)
353
354 static int __task_file_seq_show(struct seq_file *seq, struct file *file,
355                                 bool in_stop)
356 {
357         struct bpf_iter_seq_task_file_info *info = seq->private;
358         struct bpf_iter__task_file ctx;
359         struct bpf_iter_meta meta;
360         struct bpf_prog *prog;
361
362         meta.seq = seq;
363         prog = bpf_iter_get_info(&meta, in_stop);
364         if (!prog)
365                 return 0;
366
367         ctx.meta = &meta;
368         ctx.task = info->task;
369         ctx.fd = info->fd;
370         ctx.file = file;
371         return bpf_iter_run_prog(prog, &ctx);
372 }
373
374 static int task_file_seq_show(struct seq_file *seq, void *v)
375 {
376         return __task_file_seq_show(seq, v, false);
377 }
378
379 static void task_file_seq_stop(struct seq_file *seq, void *v)
380 {
381         struct bpf_iter_seq_task_file_info *info = seq->private;
382
383         if (!v) {
384                 (void)__task_file_seq_show(seq, v, true);
385         } else {
386                 fput((struct file *)v);
387                 put_task_struct(info->task);
388                 info->task = NULL;
389         }
390 }
391
392 static int init_seq_pidns(void *priv_data, struct bpf_iter_aux_info *aux)
393 {
394         struct bpf_iter_seq_task_common *common = priv_data;
395
396         common->ns = get_pid_ns(task_active_pid_ns(current));
397         common->type = aux->task.type;
398         common->pid = aux->task.pid;
399
400         return 0;
401 }
402
403 static void fini_seq_pidns(void *priv_data)
404 {
405         struct bpf_iter_seq_task_common *common = priv_data;
406
407         put_pid_ns(common->ns);
408 }
409
410 static const struct seq_operations task_file_seq_ops = {
411         .start  = task_file_seq_start,
412         .next   = task_file_seq_next,
413         .stop   = task_file_seq_stop,
414         .show   = task_file_seq_show,
415 };
416
417 struct bpf_iter_seq_task_vma_info {
418         /* The first field must be struct bpf_iter_seq_task_common.
419          * this is assumed by {init, fini}_seq_pidns() callback functions.
420          */
421         struct bpf_iter_seq_task_common common;
422         struct task_struct *task;
423         struct mm_struct *mm;
424         struct vm_area_struct *vma;
425         u32 tid;
426         unsigned long prev_vm_start;
427         unsigned long prev_vm_end;
428 };
429
430 enum bpf_task_vma_iter_find_op {
431         task_vma_iter_first_vma,   /* use find_vma() with addr 0 */
432         task_vma_iter_next_vma,    /* use vma_next() with curr_vma */
433         task_vma_iter_find_vma,    /* use find_vma() to find next vma */
434 };
435
436 static struct vm_area_struct *
437 task_vma_seq_get_next(struct bpf_iter_seq_task_vma_info *info)
438 {
439         enum bpf_task_vma_iter_find_op op;
440         struct vm_area_struct *curr_vma;
441         struct task_struct *curr_task;
442         struct mm_struct *curr_mm;
443         u32 saved_tid = info->tid;
444
445         /* If this function returns a non-NULL vma, it holds a reference to
446          * the task_struct, holds a refcount on mm->mm_users, and holds
447          * read lock on vma->mm->mmap_lock.
448          * If this function returns NULL, it does not hold any reference or
449          * lock.
450          */
451         if (info->task) {
452                 curr_task = info->task;
453                 curr_vma = info->vma;
454                 curr_mm = info->mm;
455                 /* In case of lock contention, drop mmap_lock to unblock
456                  * the writer.
457                  *
458                  * After relock, call find(mm, prev_vm_end - 1) to find
459                  * new vma to process.
460                  *
461                  *   +------+------+-----------+
462                  *   | VMA1 | VMA2 | VMA3      |
463                  *   +------+------+-----------+
464                  *   |      |      |           |
465                  *  4k     8k     16k         400k
466                  *
467                  * For example, curr_vma == VMA2. Before unlock, we set
468                  *
469                  *    prev_vm_start = 8k
470                  *    prev_vm_end   = 16k
471                  *
472                  * There are a few cases:
473                  *
474                  * 1) VMA2 is freed, but VMA3 exists.
475                  *
476                  *    find_vma() will return VMA3, just process VMA3.
477                  *
478                  * 2) VMA2 still exists.
479                  *
480                  *    find_vma() will return VMA2, process VMA2->next.
481                  *
482                  * 3) no more vma in this mm.
483                  *
484                  *    Process the next task.
485                  *
486                  * 4) find_vma() returns a different vma, VMA2'.
487                  *
488                  *    4.1) If VMA2 covers same range as VMA2', skip VMA2',
489                  *         because we already covered the range;
490                  *    4.2) VMA2 and VMA2' covers different ranges, process
491                  *         VMA2'.
492                  */
493                 if (mmap_lock_is_contended(curr_mm)) {
494                         info->prev_vm_start = curr_vma->vm_start;
495                         info->prev_vm_end = curr_vma->vm_end;
496                         op = task_vma_iter_find_vma;
497                         mmap_read_unlock(curr_mm);
498                         if (mmap_read_lock_killable(curr_mm)) {
499                                 mmput(curr_mm);
500                                 goto finish;
501                         }
502                 } else {
503                         op = task_vma_iter_next_vma;
504                 }
505         } else {
506 again:
507                 curr_task = task_seq_get_next(&info->common, &info->tid, true);
508                 if (!curr_task) {
509                         info->tid++;
510                         goto finish;
511                 }
512
513                 if (saved_tid != info->tid) {
514                         /* new task, process the first vma */
515                         op = task_vma_iter_first_vma;
516                 } else {
517                         /* Found the same tid, which means the user space
518                          * finished data in previous buffer and read more.
519                          * We dropped mmap_lock before returning to user
520                          * space, so it is necessary to use find_vma() to
521                          * find the next vma to process.
522                          */
523                         op = task_vma_iter_find_vma;
524                 }
525
526                 curr_mm = get_task_mm(curr_task);
527                 if (!curr_mm)
528                         goto next_task;
529
530                 if (mmap_read_lock_killable(curr_mm)) {
531                         mmput(curr_mm);
532                         goto finish;
533                 }
534         }
535
536         switch (op) {
537         case task_vma_iter_first_vma:
538                 curr_vma = find_vma(curr_mm, 0);
539                 break;
540         case task_vma_iter_next_vma:
541                 curr_vma = find_vma(curr_mm, curr_vma->vm_end);
542                 break;
543         case task_vma_iter_find_vma:
544                 /* We dropped mmap_lock so it is necessary to use find_vma
545                  * to find the next vma. This is similar to the  mechanism
546                  * in show_smaps_rollup().
547                  */
548                 curr_vma = find_vma(curr_mm, info->prev_vm_end - 1);
549                 /* case 1) and 4.2) above just use curr_vma */
550
551                 /* check for case 2) or case 4.1) above */
552                 if (curr_vma &&
553                     curr_vma->vm_start == info->prev_vm_start &&
554                     curr_vma->vm_end == info->prev_vm_end)
555                         curr_vma = find_vma(curr_mm, curr_vma->vm_end);
556                 break;
557         }
558         if (!curr_vma) {
559                 /* case 3) above, or case 2) 4.1) with vma->next == NULL */
560                 mmap_read_unlock(curr_mm);
561                 mmput(curr_mm);
562                 goto next_task;
563         }
564         info->task = curr_task;
565         info->vma = curr_vma;
566         info->mm = curr_mm;
567         return curr_vma;
568
569 next_task:
570         if (info->common.type == BPF_TASK_ITER_TID)
571                 goto finish;
572
573         put_task_struct(curr_task);
574         info->task = NULL;
575         info->mm = NULL;
576         info->tid++;
577         goto again;
578
579 finish:
580         if (curr_task)
581                 put_task_struct(curr_task);
582         info->task = NULL;
583         info->vma = NULL;
584         info->mm = NULL;
585         return NULL;
586 }
587
588 static void *task_vma_seq_start(struct seq_file *seq, loff_t *pos)
589 {
590         struct bpf_iter_seq_task_vma_info *info = seq->private;
591         struct vm_area_struct *vma;
592
593         vma = task_vma_seq_get_next(info);
594         if (vma && *pos == 0)
595                 ++*pos;
596
597         return vma;
598 }
599
600 static void *task_vma_seq_next(struct seq_file *seq, void *v, loff_t *pos)
601 {
602         struct bpf_iter_seq_task_vma_info *info = seq->private;
603
604         ++*pos;
605         return task_vma_seq_get_next(info);
606 }
607
608 struct bpf_iter__task_vma {
609         __bpf_md_ptr(struct bpf_iter_meta *, meta);
610         __bpf_md_ptr(struct task_struct *, task);
611         __bpf_md_ptr(struct vm_area_struct *, vma);
612 };
613
614 DEFINE_BPF_ITER_FUNC(task_vma, struct bpf_iter_meta *meta,
615                      struct task_struct *task, struct vm_area_struct *vma)
616
617 static int __task_vma_seq_show(struct seq_file *seq, bool in_stop)
618 {
619         struct bpf_iter_seq_task_vma_info *info = seq->private;
620         struct bpf_iter__task_vma ctx;
621         struct bpf_iter_meta meta;
622         struct bpf_prog *prog;
623
624         meta.seq = seq;
625         prog = bpf_iter_get_info(&meta, in_stop);
626         if (!prog)
627                 return 0;
628
629         ctx.meta = &meta;
630         ctx.task = info->task;
631         ctx.vma = info->vma;
632         return bpf_iter_run_prog(prog, &ctx);
633 }
634
635 static int task_vma_seq_show(struct seq_file *seq, void *v)
636 {
637         return __task_vma_seq_show(seq, false);
638 }
639
640 static void task_vma_seq_stop(struct seq_file *seq, void *v)
641 {
642         struct bpf_iter_seq_task_vma_info *info = seq->private;
643
644         if (!v) {
645                 (void)__task_vma_seq_show(seq, true);
646         } else {
647                 /* info->vma has not been seen by the BPF program. If the
648                  * user space reads more, task_vma_seq_get_next should
649                  * return this vma again. Set prev_vm_start to ~0UL,
650                  * so that we don't skip the vma returned by the next
651                  * find_vma() (case task_vma_iter_find_vma in
652                  * task_vma_seq_get_next()).
653                  */
654                 info->prev_vm_start = ~0UL;
655                 info->prev_vm_end = info->vma->vm_end;
656                 mmap_read_unlock(info->mm);
657                 mmput(info->mm);
658                 info->mm = NULL;
659                 put_task_struct(info->task);
660                 info->task = NULL;
661         }
662 }
663
664 static const struct seq_operations task_vma_seq_ops = {
665         .start  = task_vma_seq_start,
666         .next   = task_vma_seq_next,
667         .stop   = task_vma_seq_stop,
668         .show   = task_vma_seq_show,
669 };
670
671 static const struct bpf_iter_seq_info task_seq_info = {
672         .seq_ops                = &task_seq_ops,
673         .init_seq_private       = init_seq_pidns,
674         .fini_seq_private       = fini_seq_pidns,
675         .seq_priv_size          = sizeof(struct bpf_iter_seq_task_info),
676 };
677
678 static int bpf_iter_fill_link_info(const struct bpf_iter_aux_info *aux, struct bpf_link_info *info)
679 {
680         switch (aux->task.type) {
681         case BPF_TASK_ITER_TID:
682                 info->iter.task.tid = aux->task.pid;
683                 break;
684         case BPF_TASK_ITER_TGID:
685                 info->iter.task.pid = aux->task.pid;
686                 break;
687         default:
688                 break;
689         }
690         return 0;
691 }
692
693 static void bpf_iter_task_show_fdinfo(const struct bpf_iter_aux_info *aux, struct seq_file *seq)
694 {
695         seq_printf(seq, "task_type:\t%s\n", iter_task_type_names[aux->task.type]);
696         if (aux->task.type == BPF_TASK_ITER_TID)
697                 seq_printf(seq, "tid:\t%u\n", aux->task.pid);
698         else if (aux->task.type == BPF_TASK_ITER_TGID)
699                 seq_printf(seq, "pid:\t%u\n", aux->task.pid);
700 }
701
702 static struct bpf_iter_reg task_reg_info = {
703         .target                 = "task",
704         .attach_target          = bpf_iter_attach_task,
705         .feature                = BPF_ITER_RESCHED,
706         .ctx_arg_info_size      = 1,
707         .ctx_arg_info           = {
708                 { offsetof(struct bpf_iter__task, task),
709                   PTR_TO_BTF_ID_OR_NULL },
710         },
711         .seq_info               = &task_seq_info,
712         .fill_link_info         = bpf_iter_fill_link_info,
713         .show_fdinfo            = bpf_iter_task_show_fdinfo,
714 };
715
716 static const struct bpf_iter_seq_info task_file_seq_info = {
717         .seq_ops                = &task_file_seq_ops,
718         .init_seq_private       = init_seq_pidns,
719         .fini_seq_private       = fini_seq_pidns,
720         .seq_priv_size          = sizeof(struct bpf_iter_seq_task_file_info),
721 };
722
723 static struct bpf_iter_reg task_file_reg_info = {
724         .target                 = "task_file",
725         .attach_target          = bpf_iter_attach_task,
726         .feature                = BPF_ITER_RESCHED,
727         .ctx_arg_info_size      = 2,
728         .ctx_arg_info           = {
729                 { offsetof(struct bpf_iter__task_file, task),
730                   PTR_TO_BTF_ID_OR_NULL },
731                 { offsetof(struct bpf_iter__task_file, file),
732                   PTR_TO_BTF_ID_OR_NULL },
733         },
734         .seq_info               = &task_file_seq_info,
735         .fill_link_info         = bpf_iter_fill_link_info,
736         .show_fdinfo            = bpf_iter_task_show_fdinfo,
737 };
738
739 static const struct bpf_iter_seq_info task_vma_seq_info = {
740         .seq_ops                = &task_vma_seq_ops,
741         .init_seq_private       = init_seq_pidns,
742         .fini_seq_private       = fini_seq_pidns,
743         .seq_priv_size          = sizeof(struct bpf_iter_seq_task_vma_info),
744 };
745
746 static struct bpf_iter_reg task_vma_reg_info = {
747         .target                 = "task_vma",
748         .attach_target          = bpf_iter_attach_task,
749         .feature                = BPF_ITER_RESCHED,
750         .ctx_arg_info_size      = 2,
751         .ctx_arg_info           = {
752                 { offsetof(struct bpf_iter__task_vma, task),
753                   PTR_TO_BTF_ID_OR_NULL },
754                 { offsetof(struct bpf_iter__task_vma, vma),
755                   PTR_TO_BTF_ID_OR_NULL },
756         },
757         .seq_info               = &task_vma_seq_info,
758         .fill_link_info         = bpf_iter_fill_link_info,
759         .show_fdinfo            = bpf_iter_task_show_fdinfo,
760 };
761
762 BPF_CALL_5(bpf_find_vma, struct task_struct *, task, u64, start,
763            bpf_callback_t, callback_fn, void *, callback_ctx, u64, flags)
764 {
765         struct mmap_unlock_irq_work *work = NULL;
766         struct vm_area_struct *vma;
767         bool irq_work_busy = false;
768         struct mm_struct *mm;
769         int ret = -ENOENT;
770
771         if (flags)
772                 return -EINVAL;
773
774         if (!task)
775                 return -ENOENT;
776
777         mm = task->mm;
778         if (!mm)
779                 return -ENOENT;
780
781         irq_work_busy = bpf_mmap_unlock_get_irq_work(&work);
782
783         if (irq_work_busy || !mmap_read_trylock(mm))
784                 return -EBUSY;
785
786         vma = find_vma(mm, start);
787
788         if (vma && vma->vm_start <= start && vma->vm_end > start) {
789                 callback_fn((u64)(long)task, (u64)(long)vma,
790                             (u64)(long)callback_ctx, 0, 0);
791                 ret = 0;
792         }
793         bpf_mmap_unlock_mm(work, mm);
794         return ret;
795 }
796
797 const struct bpf_func_proto bpf_find_vma_proto = {
798         .func           = bpf_find_vma,
799         .ret_type       = RET_INTEGER,
800         .arg1_type      = ARG_PTR_TO_BTF_ID,
801         .arg1_btf_id    = &btf_tracing_ids[BTF_TRACING_TYPE_TASK],
802         .arg2_type      = ARG_ANYTHING,
803         .arg3_type      = ARG_PTR_TO_FUNC,
804         .arg4_type      = ARG_PTR_TO_STACK_OR_NULL,
805         .arg5_type      = ARG_ANYTHING,
806 };
807
808 struct bpf_iter_task_vma_kern_data {
809         struct task_struct *task;
810         struct mm_struct *mm;
811         struct mmap_unlock_irq_work *work;
812         struct vma_iterator vmi;
813 };
814
815 struct bpf_iter_task_vma {
816         /* opaque iterator state; having __u64 here allows to preserve correct
817          * alignment requirements in vmlinux.h, generated from BTF
818          */
819         __u64 __opaque[1];
820 } __attribute__((aligned(8)));
821
822 /* Non-opaque version of bpf_iter_task_vma */
823 struct bpf_iter_task_vma_kern {
824         struct bpf_iter_task_vma_kern_data *data;
825 } __attribute__((aligned(8)));
826
827 __diag_push();
828 __diag_ignore_all("-Wmissing-prototypes",
829                   "Global functions as their definitions will be in vmlinux BTF");
830
831 __bpf_kfunc int bpf_iter_task_vma_new(struct bpf_iter_task_vma *it,
832                                       struct task_struct *task, u64 addr)
833 {
834         struct bpf_iter_task_vma_kern *kit = (void *)it;
835         bool irq_work_busy = false;
836         int err;
837
838         BUILD_BUG_ON(sizeof(struct bpf_iter_task_vma_kern) != sizeof(struct bpf_iter_task_vma));
839         BUILD_BUG_ON(__alignof__(struct bpf_iter_task_vma_kern) != __alignof__(struct bpf_iter_task_vma));
840
841         /* is_iter_reg_valid_uninit guarantees that kit hasn't been initialized
842          * before, so non-NULL kit->data doesn't point to previously
843          * bpf_mem_alloc'd bpf_iter_task_vma_kern_data
844          */
845         kit->data = bpf_mem_alloc(&bpf_global_ma, sizeof(struct bpf_iter_task_vma_kern_data));
846         if (!kit->data)
847                 return -ENOMEM;
848
849         kit->data->task = get_task_struct(task);
850         kit->data->mm = task->mm;
851         if (!kit->data->mm) {
852                 err = -ENOENT;
853                 goto err_cleanup_iter;
854         }
855
856         /* kit->data->work == NULL is valid after bpf_mmap_unlock_get_irq_work */
857         irq_work_busy = bpf_mmap_unlock_get_irq_work(&kit->data->work);
858         if (irq_work_busy || !mmap_read_trylock(kit->data->mm)) {
859                 err = -EBUSY;
860                 goto err_cleanup_iter;
861         }
862
863         vma_iter_init(&kit->data->vmi, kit->data->mm, addr);
864         return 0;
865
866 err_cleanup_iter:
867         if (kit->data->task)
868                 put_task_struct(kit->data->task);
869         bpf_mem_free(&bpf_global_ma, kit->data);
870         /* NULL kit->data signals failed bpf_iter_task_vma initialization */
871         kit->data = NULL;
872         return err;
873 }
874
875 __bpf_kfunc struct vm_area_struct *bpf_iter_task_vma_next(struct bpf_iter_task_vma *it)
876 {
877         struct bpf_iter_task_vma_kern *kit = (void *)it;
878
879         if (!kit->data) /* bpf_iter_task_vma_new failed */
880                 return NULL;
881         return vma_next(&kit->data->vmi);
882 }
883
884 __bpf_kfunc void bpf_iter_task_vma_destroy(struct bpf_iter_task_vma *it)
885 {
886         struct bpf_iter_task_vma_kern *kit = (void *)it;
887
888         if (kit->data) {
889                 bpf_mmap_unlock_mm(kit->data->work, kit->data->mm);
890                 put_task_struct(kit->data->task);
891                 bpf_mem_free(&bpf_global_ma, kit->data);
892         }
893 }
894
895 __diag_pop();
896
897 DEFINE_PER_CPU(struct mmap_unlock_irq_work, mmap_unlock_work);
898
899 static void do_mmap_read_unlock(struct irq_work *entry)
900 {
901         struct mmap_unlock_irq_work *work;
902
903         if (WARN_ON_ONCE(IS_ENABLED(CONFIG_PREEMPT_RT)))
904                 return;
905
906         work = container_of(entry, struct mmap_unlock_irq_work, irq_work);
907         mmap_read_unlock_non_owner(work->mm);
908 }
909
910 static int __init task_iter_init(void)
911 {
912         struct mmap_unlock_irq_work *work;
913         int ret, cpu;
914
915         for_each_possible_cpu(cpu) {
916                 work = per_cpu_ptr(&mmap_unlock_work, cpu);
917                 init_irq_work(&work->irq_work, do_mmap_read_unlock);
918         }
919
920         task_reg_info.ctx_arg_info[0].btf_id = btf_tracing_ids[BTF_TRACING_TYPE_TASK];
921         ret = bpf_iter_reg_target(&task_reg_info);
922         if (ret)
923                 return ret;
924
925         task_file_reg_info.ctx_arg_info[0].btf_id = btf_tracing_ids[BTF_TRACING_TYPE_TASK];
926         task_file_reg_info.ctx_arg_info[1].btf_id = btf_tracing_ids[BTF_TRACING_TYPE_FILE];
927         ret =  bpf_iter_reg_target(&task_file_reg_info);
928         if (ret)
929                 return ret;
930
931         task_vma_reg_info.ctx_arg_info[0].btf_id = btf_tracing_ids[BTF_TRACING_TYPE_TASK];
932         task_vma_reg_info.ctx_arg_info[1].btf_id = btf_tracing_ids[BTF_TRACING_TYPE_VMA];
933         return bpf_iter_reg_target(&task_vma_reg_info);
934 }
935 late_initcall(task_iter_init);