Merge https://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf-next
[linux-2.6-microblaze.git] / net / unix / af_unix.c
index 4cf0b1c..7cad52b 100644 (file)
 #include <linux/security.h>
 #include <linux/freezer.h>
 #include <linux/file.h>
+#include <linux/btf_ids.h>
 
 #include "scm.h"
 
@@ -678,6 +679,8 @@ static int unix_dgram_sendmsg(struct socket *, struct msghdr *, size_t);
 static int unix_dgram_recvmsg(struct socket *, struct msghdr *, size_t, int);
 static int unix_read_sock(struct sock *sk, read_descriptor_t *desc,
                          sk_read_actor_t recv_actor);
+static int unix_stream_read_sock(struct sock *sk, read_descriptor_t *desc,
+                                sk_read_actor_t recv_actor);
 static int unix_dgram_connect(struct socket *, struct sockaddr *,
                              int, int);
 static int unix_seqpacket_sendmsg(struct socket *, struct msghdr *, size_t);
@@ -731,6 +734,7 @@ static const struct proto_ops unix_stream_ops = {
        .shutdown =     unix_shutdown,
        .sendmsg =      unix_stream_sendmsg,
        .recvmsg =      unix_stream_recvmsg,
+       .read_sock =    unix_stream_read_sock,
        .mmap =         sock_no_mmap,
        .sendpage =     unix_stream_sendpage,
        .splice_read =  unix_stream_splice_read,
@@ -794,17 +798,35 @@ static void unix_close(struct sock *sk, long timeout)
         */
 }
 
-struct proto unix_proto = {
-       .name                   = "UNIX",
+static void unix_unhash(struct sock *sk)
+{
+       /* Nothing to do here, unix socket does not need a ->unhash().
+        * This is merely for sockmap.
+        */
+}
+
+struct proto unix_dgram_proto = {
+       .name                   = "UNIX-DGRAM",
        .owner                  = THIS_MODULE,
        .obj_size               = sizeof(struct unix_sock),
        .close                  = unix_close,
 #ifdef CONFIG_BPF_SYSCALL
-       .psock_update_sk_prot   = unix_bpf_update_proto,
+       .psock_update_sk_prot   = unix_dgram_bpf_update_proto,
 #endif
 };
 
-static struct sock *unix_create1(struct net *net, struct socket *sock, int kern)
+struct proto unix_stream_proto = {
+       .name                   = "UNIX-STREAM",
+       .owner                  = THIS_MODULE,
+       .obj_size               = sizeof(struct unix_sock),
+       .close                  = unix_close,
+       .unhash                 = unix_unhash,
+#ifdef CONFIG_BPF_SYSCALL
+       .psock_update_sk_prot   = unix_stream_bpf_update_proto,
+#endif
+};
+
+static struct sock *unix_create1(struct net *net, struct socket *sock, int kern, int type)
 {
        struct sock *sk = NULL;
        struct unix_sock *u;
@@ -813,7 +835,11 @@ static struct sock *unix_create1(struct net *net, struct socket *sock, int kern)
        if (atomic_long_read(&unix_nr_socks) > 2 * get_max_files())
                goto out;
 
-       sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_proto, kern);
+       if (type == SOCK_STREAM)
+               sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_stream_proto, kern);
+       else /*dgram and  seqpacket */
+               sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_dgram_proto, kern);
+
        if (!sk)
                goto out;
 
@@ -875,7 +901,7 @@ static int unix_create(struct net *net, struct socket *sock, int protocol,
                return -ESOCKTNOSUPPORT;
        }
 
-       return unix_create1(net, sock, kern) ? 0 : -ENOMEM;
+       return unix_create1(net, sock, kern, sock->type) ? 0 : -ENOMEM;
 }
 
 static int unix_release(struct socket *sock)
@@ -1289,7 +1315,7 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr,
        err = -ENOMEM;
 
        /* create new sock for complete connection */
-       newsk = unix_create1(sock_net(sk), NULL, 0);
+       newsk = unix_create1(sock_net(sk), NULL, 0, sock->type);
        if (newsk == NULL)
                goto out;
 
@@ -2326,8 +2352,10 @@ static int unix_dgram_recvmsg(struct socket *sock, struct msghdr *msg, size_t si
        struct sock *sk = sock->sk;
 
 #ifdef CONFIG_BPF_SYSCALL
-       if (sk->sk_prot != &unix_proto)
-               return sk->sk_prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT,
+       const struct proto *prot = READ_ONCE(sk->sk_prot);
+
+       if (prot != &unix_dgram_proto)
+               return prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT,
                                            flags & ~MSG_DONTWAIT, NULL);
 #endif
        return __unix_dgram_recvmsg(sk, msg, size, flags);
@@ -2497,6 +2525,15 @@ static struct sk_buff *manage_oob(struct sk_buff *skb, struct sock *sk,
 }
 #endif
 
+static int unix_stream_read_sock(struct sock *sk, read_descriptor_t *desc,
+                                sk_read_actor_t recv_actor)
+{
+       if (unlikely(sk->sk_state != TCP_ESTABLISHED))
+               return -ENOTCONN;
+
+       return unix_read_sock(sk, desc, recv_actor);
+}
+
 static int unix_stream_read_generic(struct unix_stream_read_state *state,
                                    bool freezable)
 {
@@ -2722,6 +2759,20 @@ static int unix_stream_read_actor(struct sk_buff *skb,
        return ret ?: chunk;
 }
 
+int __unix_stream_recvmsg(struct sock *sk, struct msghdr *msg,
+                         size_t size, int flags)
+{
+       struct unix_stream_read_state state = {
+               .recv_actor = unix_stream_read_actor,
+               .socket = sk->sk_socket,
+               .msg = msg,
+               .size = size,
+               .flags = flags
+       };
+
+       return unix_stream_read_generic(&state, true);
+}
+
 static int unix_stream_recvmsg(struct socket *sock, struct msghdr *msg,
                               size_t size, int flags)
 {
@@ -2733,6 +2784,14 @@ static int unix_stream_recvmsg(struct socket *sock, struct msghdr *msg,
                .flags = flags
        };
 
+#ifdef CONFIG_BPF_SYSCALL
+       struct sock *sk = sock->sk;
+       const struct proto *prot = READ_ONCE(sk->sk_prot);
+
+       if (prot != &unix_stream_proto)
+               return prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT,
+                                           flags & ~MSG_DONTWAIT, NULL);
+#endif
        return unix_stream_read_generic(&state, true);
 }
 
@@ -2793,7 +2852,10 @@ static int unix_shutdown(struct socket *sock, int mode)
                (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)) {
 
                int peer_mode = 0;
+               const struct proto *prot = READ_ONCE(other->sk_prot);
 
+               if (prot->unhash)
+                       prot->unhash(other);
                if (mode&RCV_SHUTDOWN)
                        peer_mode |= SEND_SHUTDOWN;
                if (mode&SEND_SHUTDOWN)
@@ -2802,10 +2864,12 @@ static int unix_shutdown(struct socket *sock, int mode)
                other->sk_shutdown |= peer_mode;
                unix_state_unlock(other);
                other->sk_state_change(other);
-               if (peer_mode == SHUTDOWN_MASK)
+               if (peer_mode == SHUTDOWN_MASK) {
                        sk_wake_async(other, SOCK_WAKE_WAITD, POLL_HUP);
-               else if (peer_mode & RCV_SHUTDOWN)
+                       other->sk_state = TCP_CLOSE;
+               } else if (peer_mode & RCV_SHUTDOWN) {
                        sk_wake_async(other, SOCK_WAKE_WAITD, POLL_IN);
+               }
        }
        if (other)
                sock_put(other);
@@ -3150,6 +3214,64 @@ static const struct seq_operations unix_seq_ops = {
        .stop   = unix_seq_stop,
        .show   = unix_seq_show,
 };
+
+#if IS_BUILTIN(CONFIG_UNIX) && defined(CONFIG_BPF_SYSCALL)
+struct bpf_iter__unix {
+       __bpf_md_ptr(struct bpf_iter_meta *, meta);
+       __bpf_md_ptr(struct unix_sock *, unix_sk);
+       uid_t uid __aligned(8);
+};
+
+static int unix_prog_seq_show(struct bpf_prog *prog, struct bpf_iter_meta *meta,
+                             struct unix_sock *unix_sk, uid_t uid)
+{
+       struct bpf_iter__unix ctx;
+
+       meta->seq_num--;  /* skip SEQ_START_TOKEN */
+       ctx.meta = meta;
+       ctx.unix_sk = unix_sk;
+       ctx.uid = uid;
+       return bpf_iter_run_prog(prog, &ctx);
+}
+
+static int bpf_iter_unix_seq_show(struct seq_file *seq, void *v)
+{
+       struct bpf_iter_meta meta;
+       struct bpf_prog *prog;
+       struct sock *sk = v;
+       uid_t uid;
+
+       if (v == SEQ_START_TOKEN)
+               return 0;
+
+       uid = from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk));
+       meta.seq = seq;
+       prog = bpf_iter_get_info(&meta, false);
+       return unix_prog_seq_show(prog, &meta, v, uid);
+}
+
+static void bpf_iter_unix_seq_stop(struct seq_file *seq, void *v)
+{
+       struct bpf_iter_meta meta;
+       struct bpf_prog *prog;
+
+       if (!v) {
+               meta.seq = seq;
+               prog = bpf_iter_get_info(&meta, true);
+               if (prog)
+                       (void)unix_prog_seq_show(prog, &meta, v, 0);
+       }
+
+       unix_seq_stop(seq, v);
+}
+
+static const struct seq_operations bpf_iter_unix_seq_ops = {
+       .start  = unix_seq_start,
+       .next   = unix_seq_next,
+       .stop   = bpf_iter_unix_seq_stop,
+       .show   = bpf_iter_unix_seq_show,
+};
+#endif
 #endif
 
 static const struct net_proto_family unix_family_ops = {
@@ -3190,13 +3312,48 @@ static struct pernet_operations unix_net_ops = {
        .exit = unix_net_exit,
 };
 
+#if IS_BUILTIN(CONFIG_UNIX) && defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS)
+DEFINE_BPF_ITER_FUNC(unix, struct bpf_iter_meta *meta,
+                    struct unix_sock *unix_sk, uid_t uid)
+
+static const struct bpf_iter_seq_info unix_seq_info = {
+       .seq_ops                = &bpf_iter_unix_seq_ops,
+       .init_seq_private       = bpf_iter_init_seq_net,
+       .fini_seq_private       = bpf_iter_fini_seq_net,
+       .seq_priv_size          = sizeof(struct seq_net_private),
+};
+
+static struct bpf_iter_reg unix_reg_info = {
+       .target                 = "unix",
+       .ctx_arg_info_size      = 1,
+       .ctx_arg_info           = {
+               { offsetof(struct bpf_iter__unix, unix_sk),
+                 PTR_TO_BTF_ID_OR_NULL },
+       },
+       .seq_info               = &unix_seq_info,
+};
+
+static void __init bpf_iter_register(void)
+{
+       unix_reg_info.ctx_arg_info[0].btf_id = btf_sock_ids[BTF_SOCK_TYPE_UNIX];
+       if (bpf_iter_reg_target(&unix_reg_info))
+               pr_warn("Warning: could not register bpf iterator unix\n");
+}
+#endif
+
 static int __init af_unix_init(void)
 {
        int rc = -1;
 
        BUILD_BUG_ON(sizeof(struct unix_skb_parms) > sizeof_field(struct sk_buff, cb));
 
-       rc = proto_register(&unix_proto, 1);
+       rc = proto_register(&unix_dgram_proto, 1);
+       if (rc != 0) {
+               pr_crit("%s: Cannot create unix_sock SLAB cache!\n", __func__);
+               goto out;
+       }
+
+       rc = proto_register(&unix_stream_proto, 1);
        if (rc != 0) {
                pr_crit("%s: Cannot create unix_sock SLAB cache!\n", __func__);
                goto out;
@@ -3205,6 +3362,11 @@ static int __init af_unix_init(void)
        sock_register(&unix_family_ops);
        register_pernet_subsys(&unix_net_ops);
        unix_bpf_build_proto();
+
+#if IS_BUILTIN(CONFIG_UNIX) && defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS)
+       bpf_iter_register();
+#endif
+
 out:
        return rc;
 }
@@ -3212,7 +3374,8 @@ out:
 static void __exit af_unix_exit(void)
 {
        sock_unregister(PF_UNIX);
-       proto_unregister(&unix_proto);
+       proto_unregister(&unix_dgram_proto);
+       proto_unregister(&unix_stream_proto);
        unregister_pernet_subsys(&unix_net_ops);
 }