Merge tag 'nds32-for-linux-5.12' of git://git.kernel.org/pub/scm/linux/kernel/git...
[linux-2.6-microblaze.git] / net / bpfilter / bpfilter_kern.c
index 1905e01..51a941b 100644 (file)
@@ -15,15 +15,13 @@ extern char bpfilter_umh_end;
 
 static void shutdown_umh(void)
 {
-       struct task_struct *tsk;
+       struct umd_info *info = &bpfilter_ops.info;
+       struct pid *tgid = info->tgid;
 
-       if (bpfilter_ops.stop)
-               return;
-
-       tsk = get_pid_task(find_vpid(bpfilter_ops.info.pid), PIDTYPE_PID);
-       if (tsk) {
-               send_sig(SIGKILL, tsk, 1);
-               put_task_struct(tsk);
+       if (tgid) {
+               kill_pid(tgid, SIGKILL, 1);
+               wait_event(tgid->wait_pidfd, thread_group_exited(tgid));
+               bpfilter_umh_cleanup(info);
        }
 }
 
@@ -33,60 +31,65 @@ static void __stop_umh(void)
                shutdown_umh();
 }
 
-static int __bpfilter_process_sockopt(struct sock *sk, int optname,
-                                     char __user *optval,
-                                     unsigned int optlen, bool is_set)
+static int bpfilter_send_req(struct mbox_request *req)
 {
-       struct mbox_request req;
        struct mbox_reply reply;
-       loff_t pos;
+       loff_t pos = 0;
        ssize_t n;
-       int ret = -EFAULT;
-
-       req.is_set = is_set;
-       req.pid = current->pid;
-       req.cmd = optname;
-       req.addr = (long __force __user)optval;
-       req.len = optlen;
-       if (!bpfilter_ops.info.pid)
-               goto out;
-       n = kernel_write(bpfilter_ops.info.pipe_to_umh, &req, sizeof(req),
+
+       if (!bpfilter_ops.info.tgid)
+               return -EFAULT;
+       pos = 0;
+       n = kernel_write(bpfilter_ops.info.pipe_to_umh, req, sizeof(*req),
                           &pos);
-       if (n != sizeof(req)) {
+       if (n != sizeof(*req)) {
                pr_err("write fail %zd\n", n);
-               __stop_umh();
-               ret = -EFAULT;
-               goto out;
+               goto stop;
        }
        pos = 0;
        n = kernel_read(bpfilter_ops.info.pipe_from_umh, &reply, sizeof(reply),
                        &pos);
        if (n != sizeof(reply)) {
                pr_err("read fail %zd\n", n);
-               __stop_umh();
-               ret = -EFAULT;
-               goto out;
+               goto stop;
        }
-       ret = reply.status;
-out:
-       return ret;
+       return reply.status;
+stop:
+       __stop_umh();
+       return -EFAULT;
+}
+
+static int bpfilter_process_sockopt(struct sock *sk, int optname,
+                                   sockptr_t optval, unsigned int optlen,
+                                   bool is_set)
+{
+       struct mbox_request req = {
+               .is_set         = is_set,
+               .pid            = current->pid,
+               .cmd            = optname,
+               .addr           = (uintptr_t)optval.user,
+               .len            = optlen,
+       };
+       if (uaccess_kernel() || sockptr_is_kernel(optval)) {
+               pr_err("kernel access not supported\n");
+               return -EFAULT;
+       }
+       return bpfilter_send_req(&req);
 }
 
 static int start_umh(void)
 {
+       struct mbox_request req = { .pid = current->pid };
        int err;
 
        /* fork usermode process */
-       err = fork_usermode_blob(&bpfilter_umh_start,
-                                &bpfilter_umh_end - &bpfilter_umh_start,
-                                &bpfilter_ops.info);
+       err = fork_usermode_driver(&bpfilter_ops.info);
        if (err)
                return err;
-       bpfilter_ops.stop = false;
-       pr_info("Loaded bpfilter_umh pid %d\n", bpfilter_ops.info.pid);
+       pr_info("Loaded bpfilter_umh pid %d\n", pid_nr(bpfilter_ops.info.tgid));
 
        /* health check that usermode process started correctly */
-       if (__bpfilter_process_sockopt(NULL, 0, NULL, 0, 0) != 0) {
+       if (bpfilter_send_req(&req) != 0) {
                shutdown_umh();
                return -EFAULT;
        }
@@ -98,18 +101,21 @@ static int __init load_umh(void)
 {
        int err;
 
+       err = umd_load_blob(&bpfilter_ops.info,
+                           &bpfilter_umh_start,
+                           &bpfilter_umh_end - &bpfilter_umh_start);
+       if (err)
+               return err;
+
        mutex_lock(&bpfilter_ops.lock);
-       if (!bpfilter_ops.stop) {
-               err = -EFAULT;
-               goto out;
-       }
        err = start_umh();
        if (!err && IS_ENABLED(CONFIG_INET)) {
-               bpfilter_ops.sockopt = &__bpfilter_process_sockopt;
+               bpfilter_ops.sockopt = &bpfilter_process_sockopt;
                bpfilter_ops.start = &start_umh;
        }
-out:
        mutex_unlock(&bpfilter_ops.lock);
+       if (err)
+               umd_unload_blob(&bpfilter_ops.info);
        return err;
 }
 
@@ -122,6 +128,8 @@ static void __exit fini_umh(void)
                bpfilter_ops.sockopt = NULL;
        }
        mutex_unlock(&bpfilter_ops.lock);
+
+       umd_unload_blob(&bpfilter_ops.info);
 }
 module_init(load_umh);
 module_exit(fini_umh);