Merge branch 'gate-page-refcount' (patches from Dave Hansen)
[linux-2.6-microblaze.git] / net / bpfilter / bpfilter_kern.c
1 // SPDX-License-Identifier: GPL-2.0
2 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
3 #include <linux/init.h>
4 #include <linux/module.h>
5 #include <linux/umh.h>
6 #include <linux/bpfilter.h>
7 #include <linux/sched.h>
8 #include <linux/sched/signal.h>
9 #include <linux/fs.h>
10 #include <linux/file.h>
11 #include "msgfmt.h"
12
13 extern char bpfilter_umh_start;
14 extern char bpfilter_umh_end;
15
16 static void shutdown_umh(void)
17 {
18         struct umd_info *info = &bpfilter_ops.info;
19         struct pid *tgid = info->tgid;
20
21         if (tgid) {
22                 kill_pid(tgid, SIGKILL, 1);
23                 wait_event(tgid->wait_pidfd, thread_group_exited(tgid));
24                 bpfilter_umh_cleanup(info);
25         }
26 }
27
28 static void __stop_umh(void)
29 {
30         if (IS_ENABLED(CONFIG_INET))
31                 shutdown_umh();
32 }
33
34 static int bpfilter_send_req(struct mbox_request *req)
35 {
36         struct mbox_reply reply;
37         loff_t pos = 0;
38         ssize_t n;
39
40         if (!bpfilter_ops.info.tgid)
41                 return -EFAULT;
42         pos = 0;
43         n = kernel_write(bpfilter_ops.info.pipe_to_umh, req, sizeof(*req),
44                            &pos);
45         if (n != sizeof(*req)) {
46                 pr_err("write fail %zd\n", n);
47                 goto stop;
48         }
49         pos = 0;
50         n = kernel_read(bpfilter_ops.info.pipe_from_umh, &reply, sizeof(reply),
51                         &pos);
52         if (n != sizeof(reply)) {
53                 pr_err("read fail %zd\n", n);
54                 goto stop;
55         }
56         return reply.status;
57 stop:
58         __stop_umh();
59         return -EFAULT;
60 }
61
62 static int bpfilter_process_sockopt(struct sock *sk, int optname,
63                                     sockptr_t optval, unsigned int optlen,
64                                     bool is_set)
65 {
66         struct mbox_request req = {
67                 .is_set         = is_set,
68                 .pid            = current->pid,
69                 .cmd            = optname,
70                 .addr           = (uintptr_t)optval.user,
71                 .len            = optlen,
72         };
73         if (uaccess_kernel() || sockptr_is_kernel(optval)) {
74                 pr_err("kernel access not supported\n");
75                 return -EFAULT;
76         }
77         return bpfilter_send_req(&req);
78 }
79
80 static int start_umh(void)
81 {
82         struct mbox_request req = { .pid = current->pid };
83         int err;
84
85         /* fork usermode process */
86         err = fork_usermode_driver(&bpfilter_ops.info);
87         if (err)
88                 return err;
89         pr_info("Loaded bpfilter_umh pid %d\n", pid_nr(bpfilter_ops.info.tgid));
90
91         /* health check that usermode process started correctly */
92         if (bpfilter_send_req(&req) != 0) {
93                 shutdown_umh();
94                 return -EFAULT;
95         }
96
97         return 0;
98 }
99
100 static int __init load_umh(void)
101 {
102         int err;
103
104         err = umd_load_blob(&bpfilter_ops.info,
105                             &bpfilter_umh_start,
106                             &bpfilter_umh_end - &bpfilter_umh_start);
107         if (err)
108                 return err;
109
110         mutex_lock(&bpfilter_ops.lock);
111         err = start_umh();
112         if (!err && IS_ENABLED(CONFIG_INET)) {
113                 bpfilter_ops.sockopt = &bpfilter_process_sockopt;
114                 bpfilter_ops.start = &start_umh;
115         }
116         mutex_unlock(&bpfilter_ops.lock);
117         if (err)
118                 umd_unload_blob(&bpfilter_ops.info);
119         return err;
120 }
121
122 static void __exit fini_umh(void)
123 {
124         mutex_lock(&bpfilter_ops.lock);
125         if (IS_ENABLED(CONFIG_INET)) {
126                 shutdown_umh();
127                 bpfilter_ops.start = NULL;
128                 bpfilter_ops.sockopt = NULL;
129         }
130         mutex_unlock(&bpfilter_ops.lock);
131
132         umd_unload_blob(&bpfilter_ops.info);
133 }
134 module_init(load_umh);
135 module_exit(fini_umh);
136 MODULE_LICENSE("GPL");