Merge https://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf-next
[linux-2.6-microblaze.git] / net / unix / af_unix.c
index ba7ced9..7cad52b 100644 (file)
 #include <linux/security.h>
 #include <linux/freezer.h>
 #include <linux/file.h>
+#include <linux/btf_ids.h>
 
 #include "scm.h"
 
@@ -494,6 +495,7 @@ static void unix_dgram_disconnected(struct sock *sk, struct sock *other)
                        sk_error_report(other);
                }
        }
+       sk->sk_state = other->sk_state = TCP_CLOSE;
 }
 
 static void unix_sock_destructor(struct sock *sk)
@@ -502,6 +504,12 @@ static void unix_sock_destructor(struct sock *sk)
 
        skb_queue_purge(&sk->sk_receive_queue);
 
+#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
+       if (u->oob_skb) {
+               kfree_skb(u->oob_skb);
+               u->oob_skb = NULL;
+       }
+#endif
        WARN_ON(refcount_read(&sk->sk_wmem_alloc));
        WARN_ON(!sk_unhashed(sk));
        WARN_ON(sk->sk_socket);
@@ -669,6 +677,10 @@ static ssize_t unix_stream_splice_read(struct socket *,  loff_t *ppos,
                                       unsigned int flags);
 static int unix_dgram_sendmsg(struct socket *, struct msghdr *, size_t);
 static int unix_dgram_recvmsg(struct socket *, struct msghdr *, size_t, int);
+static int unix_read_sock(struct sock *sk, read_descriptor_t *desc,
+                         sk_read_actor_t recv_actor);
+static int unix_stream_read_sock(struct sock *sk, read_descriptor_t *desc,
+                                sk_read_actor_t recv_actor);
 static int unix_dgram_connect(struct socket *, struct sockaddr *,
                              int, int);
 static int unix_seqpacket_sendmsg(struct socket *, struct msghdr *, size_t);
@@ -722,6 +734,7 @@ static const struct proto_ops unix_stream_ops = {
        .shutdown =     unix_shutdown,
        .sendmsg =      unix_stream_sendmsg,
        .recvmsg =      unix_stream_recvmsg,
+       .read_sock =    unix_stream_read_sock,
        .mmap =         sock_no_mmap,
        .sendpage =     unix_stream_sendpage,
        .splice_read =  unix_stream_splice_read,
@@ -746,6 +759,7 @@ static const struct proto_ops unix_dgram_ops = {
        .listen =       sock_no_listen,
        .shutdown =     unix_shutdown,
        .sendmsg =      unix_dgram_sendmsg,
+       .read_sock =    unix_read_sock,
        .recvmsg =      unix_dgram_recvmsg,
        .mmap =         sock_no_mmap,
        .sendpage =     sock_no_sendpage,
@@ -777,13 +791,42 @@ static const struct proto_ops unix_seqpacket_ops = {
        .show_fdinfo =  unix_show_fdinfo,
 };
 
-static struct proto unix_proto = {
-       .name                   = "UNIX",
+static void unix_close(struct sock *sk, long timeout)
+{
+       /* Nothing to do here, unix socket does not need a ->close().
+        * This is merely for sockmap.
+        */
+}
+
+static void unix_unhash(struct sock *sk)
+{
+       /* Nothing to do here, unix socket does not need a ->unhash().
+        * This is merely for sockmap.
+        */
+}
+
+struct proto unix_dgram_proto = {
+       .name                   = "UNIX-DGRAM",
        .owner                  = THIS_MODULE,
        .obj_size               = sizeof(struct unix_sock),
+       .close                  = unix_close,
+#ifdef CONFIG_BPF_SYSCALL
+       .psock_update_sk_prot   = unix_dgram_bpf_update_proto,
+#endif
 };
 
-static struct sock *unix_create1(struct net *net, struct socket *sock, int kern)
+struct proto unix_stream_proto = {
+       .name                   = "UNIX-STREAM",
+       .owner                  = THIS_MODULE,
+       .obj_size               = sizeof(struct unix_sock),
+       .close                  = unix_close,
+       .unhash                 = unix_unhash,
+#ifdef CONFIG_BPF_SYSCALL
+       .psock_update_sk_prot   = unix_stream_bpf_update_proto,
+#endif
+};
+
+static struct sock *unix_create1(struct net *net, struct socket *sock, int kern, int type)
 {
        struct sock *sk = NULL;
        struct unix_sock *u;
@@ -792,7 +835,11 @@ static struct sock *unix_create1(struct net *net, struct socket *sock, int kern)
        if (atomic_long_read(&unix_nr_socks) > 2 * get_max_files())
                goto out;
 
-       sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_proto, kern);
+       if (type == SOCK_STREAM)
+               sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_stream_proto, kern);
+       else /*dgram and  seqpacket */
+               sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_dgram_proto, kern);
+
        if (!sk)
                goto out;
 
@@ -854,7 +901,7 @@ static int unix_create(struct net *net, struct socket *sock, int protocol,
                return -ESOCKTNOSUPPORT;
        }
 
-       return unix_create1(net, sock, kern) ? 0 : -ENOMEM;
+       return unix_create1(net, sock, kern, sock->type) ? 0 : -ENOMEM;
 }
 
 static int unix_release(struct socket *sock)
@@ -864,6 +911,7 @@ static int unix_release(struct socket *sock)
        if (!sk)
                return 0;
 
+       sk->sk_prot->close(sk, 0);
        unix_release_sock(sk, 0);
        sock->sk = NULL;
 
@@ -1199,6 +1247,9 @@ restart:
                unix_peer(sk) = other;
                unix_state_double_unlock(sk, other);
        }
+
+       if (unix_peer(sk))
+               sk->sk_state = other->sk_state = TCP_ESTABLISHED;
        return 0;
 
 out_unlock:
@@ -1264,7 +1315,7 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr,
        err = -ENOMEM;
 
        /* create new sock for complete connection */
-       newsk = unix_create1(sock_net(sk), NULL, 0);
+       newsk = unix_create1(sock_net(sk), NULL, 0, sock->type);
        if (newsk == NULL)
                goto out;
 
@@ -1431,12 +1482,10 @@ static int unix_socketpair(struct socket *socka, struct socket *sockb)
        init_peercred(ska);
        init_peercred(skb);
 
-       if (ska->sk_type != SOCK_DGRAM) {
-               ska->sk_state = TCP_ESTABLISHED;
-               skb->sk_state = TCP_ESTABLISHED;
-               socka->state  = SS_CONNECTED;
-               sockb->state  = SS_CONNECTED;
-       }
+       ska->sk_state = TCP_ESTABLISHED;
+       skb->sk_state = TCP_ESTABLISHED;
+       socka->state  = SS_CONNECTED;
+       sockb->state  = SS_CONNECTED;
        return 0;
 }
 
@@ -1872,6 +1921,53 @@ out:
  */
 #define UNIX_SKB_FRAGS_SZ (PAGE_SIZE << get_order(32768))
 
+#if (IS_ENABLED(CONFIG_AF_UNIX_OOB))
+static int queue_oob(struct socket *sock, struct msghdr *msg, struct sock *other)
+{
+       struct unix_sock *ousk = unix_sk(other);
+       struct sk_buff *skb;
+       int err = 0;
+
+       skb = sock_alloc_send_skb(sock->sk, 1, msg->msg_flags & MSG_DONTWAIT, &err);
+
+       if (!skb)
+               return err;
+
+       skb_put(skb, 1);
+       err = skb_copy_datagram_from_iter(skb, 0, &msg->msg_iter, 1);
+
+       if (err) {
+               kfree_skb(skb);
+               return err;
+       }
+
+       unix_state_lock(other);
+
+       if (sock_flag(other, SOCK_DEAD) ||
+           (other->sk_shutdown & RCV_SHUTDOWN)) {
+               unix_state_unlock(other);
+               kfree_skb(skb);
+               return -EPIPE;
+       }
+
+       maybe_add_creds(skb, sock, other);
+       skb_get(skb);
+
+       if (ousk->oob_skb)
+               consume_skb(ousk->oob_skb);
+
+       ousk->oob_skb = skb;
+
+       scm_stat_add(other, skb);
+       skb_queue_tail(&other->sk_receive_queue, skb);
+       sk_send_sigurg(other);
+       unix_state_unlock(other);
+       other->sk_data_ready(other);
+
+       return err;
+}
+#endif
+
 static int unix_stream_sendmsg(struct socket *sock, struct msghdr *msg,
                               size_t len)
 {
@@ -1890,8 +1986,14 @@ static int unix_stream_sendmsg(struct socket *sock, struct msghdr *msg,
                return err;
 
        err = -EOPNOTSUPP;
-       if (msg->msg_flags&MSG_OOB)
-               goto out_err;
+       if (msg->msg_flags & MSG_OOB) {
+#if (IS_ENABLED(CONFIG_AF_UNIX_OOB))
+               if (len)
+                       len--;
+               else
+#endif
+                       goto out_err;
+       }
 
        if (msg->msg_namelen) {
                err = sk->sk_state == TCP_ESTABLISHED ? -EISCONN : -EOPNOTSUPP;
@@ -1956,6 +2058,15 @@ static int unix_stream_sendmsg(struct socket *sock, struct msghdr *msg,
                sent += size;
        }
 
+#if (IS_ENABLED(CONFIG_AF_UNIX_OOB))
+       if (msg->msg_flags & MSG_OOB) {
+               err = queue_oob(sock, msg, other);
+               if (err)
+                       goto out_err;
+               sent++;
+       }
+#endif
+
        scm_destroy(&scm);
 
        return sent;
@@ -2128,11 +2239,11 @@ static void unix_copy_addr(struct msghdr *msg, struct sock *sk)
        }
 }
 
-static int unix_dgram_recvmsg(struct socket *sock, struct msghdr *msg,
-                             size_t size, int flags)
+int __unix_dgram_recvmsg(struct sock *sk, struct msghdr *msg, size_t size,
+                        int flags)
 {
        struct scm_cookie scm;
-       struct sock *sk = sock->sk;
+       struct socket *sock = sk->sk_socket;
        struct unix_sock *u = unix_sk(sk);
        struct sk_buff *skb, *last;
        long timeo;
@@ -2235,6 +2346,55 @@ out:
        return err;
 }
 
+static int unix_dgram_recvmsg(struct socket *sock, struct msghdr *msg, size_t size,
+                             int flags)
+{
+       struct sock *sk = sock->sk;
+
+#ifdef CONFIG_BPF_SYSCALL
+       const struct proto *prot = READ_ONCE(sk->sk_prot);
+
+       if (prot != &unix_dgram_proto)
+               return prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT,
+                                           flags & ~MSG_DONTWAIT, NULL);
+#endif
+       return __unix_dgram_recvmsg(sk, msg, size, flags);
+}
+
+static int unix_read_sock(struct sock *sk, read_descriptor_t *desc,
+                         sk_read_actor_t recv_actor)
+{
+       int copied = 0;
+
+       while (1) {
+               struct unix_sock *u = unix_sk(sk);
+               struct sk_buff *skb;
+               int used, err;
+
+               mutex_lock(&u->iolock);
+               skb = skb_recv_datagram(sk, 0, 1, &err);
+               mutex_unlock(&u->iolock);
+               if (!skb)
+                       return err;
+
+               used = recv_actor(desc, skb, 0, skb->len);
+               if (used <= 0) {
+                       if (!copied)
+                               copied = used;
+                       kfree_skb(skb);
+                       break;
+               } else if (used <= skb->len) {
+                       copied += used;
+               }
+
+               kfree_skb(skb);
+               if (!desc->count)
+                       break;
+       }
+
+       return copied;
+}
+
 /*
  *     Sleep until more data has arrived. But check for races..
  */
@@ -2294,6 +2454,86 @@ struct unix_stream_read_state {
        unsigned int splice_flags;
 };
 
+#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
+static int unix_stream_recv_urg(struct unix_stream_read_state *state)
+{
+       struct socket *sock = state->socket;
+       struct sock *sk = sock->sk;
+       struct unix_sock *u = unix_sk(sk);
+       int chunk = 1;
+       struct sk_buff *oob_skb;
+
+       mutex_lock(&u->iolock);
+       unix_state_lock(sk);
+
+       if (sock_flag(sk, SOCK_URGINLINE) || !u->oob_skb) {
+               unix_state_unlock(sk);
+               mutex_unlock(&u->iolock);
+               return -EINVAL;
+       }
+
+       oob_skb = u->oob_skb;
+
+       if (!(state->flags & MSG_PEEK)) {
+               u->oob_skb = NULL;
+       }
+
+       unix_state_unlock(sk);
+
+       chunk = state->recv_actor(oob_skb, 0, chunk, state);
+
+       if (!(state->flags & MSG_PEEK)) {
+               UNIXCB(oob_skb).consumed += 1;
+               kfree_skb(oob_skb);
+       }
+
+       mutex_unlock(&u->iolock);
+
+       if (chunk < 0)
+               return -EFAULT;
+
+       state->msg->msg_flags |= MSG_OOB;
+       return 1;
+}
+
+static struct sk_buff *manage_oob(struct sk_buff *skb, struct sock *sk,
+                                 int flags, int copied)
+{
+       struct unix_sock *u = unix_sk(sk);
+
+       if (!unix_skb_len(skb) && !(flags & MSG_PEEK)) {
+               skb_unlink(skb, &sk->sk_receive_queue);
+               consume_skb(skb);
+               skb = NULL;
+       } else {
+               if (skb == u->oob_skb) {
+                       if (copied) {
+                               skb = NULL;
+                       } else if (sock_flag(sk, SOCK_URGINLINE)) {
+                               if (!(flags & MSG_PEEK)) {
+                                       u->oob_skb = NULL;
+                                       consume_skb(skb);
+                               }
+                       } else if (!(flags & MSG_PEEK)) {
+                               skb_unlink(skb, &sk->sk_receive_queue);
+                               consume_skb(skb);
+                               skb = skb_peek(&sk->sk_receive_queue);
+                       }
+               }
+       }
+       return skb;
+}
+#endif
+
+static int unix_stream_read_sock(struct sock *sk, read_descriptor_t *desc,
+                                sk_read_actor_t recv_actor)
+{
+       if (unlikely(sk->sk_state != TCP_ESTABLISHED))
+               return -ENOTCONN;
+
+       return unix_read_sock(sk, desc, recv_actor);
+}
+
 static int unix_stream_read_generic(struct unix_stream_read_state *state,
                                    bool freezable)
 {
@@ -2319,6 +2559,9 @@ static int unix_stream_read_generic(struct unix_stream_read_state *state,
 
        if (unlikely(flags & MSG_OOB)) {
                err = -EOPNOTSUPP;
+#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
+               err = unix_stream_recv_urg(state);
+#endif
                goto out;
        }
 
@@ -2347,6 +2590,18 @@ redo:
                }
                last = skb = skb_peek(&sk->sk_receive_queue);
                last_len = last ? last->len : 0;
+
+#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
+               if (skb) {
+                       skb = manage_oob(skb, sk, flags, copied);
+                       if (!skb) {
+                               unix_state_unlock(sk);
+                               if (copied)
+                                       break;
+                               goto redo;
+                       }
+               }
+#endif
 again:
                if (skb == NULL) {
                        if (copied >= target)
@@ -2504,6 +2759,20 @@ static int unix_stream_read_actor(struct sk_buff *skb,
        return ret ?: chunk;
 }
 
+int __unix_stream_recvmsg(struct sock *sk, struct msghdr *msg,
+                         size_t size, int flags)
+{
+       struct unix_stream_read_state state = {
+               .recv_actor = unix_stream_read_actor,
+               .socket = sk->sk_socket,
+               .msg = msg,
+               .size = size,
+               .flags = flags
+       };
+
+       return unix_stream_read_generic(&state, true);
+}
+
 static int unix_stream_recvmsg(struct socket *sock, struct msghdr *msg,
                               size_t size, int flags)
 {
@@ -2515,6 +2784,14 @@ static int unix_stream_recvmsg(struct socket *sock, struct msghdr *msg,
                .flags = flags
        };
 
+#ifdef CONFIG_BPF_SYSCALL
+       struct sock *sk = sock->sk;
+       const struct proto *prot = READ_ONCE(sk->sk_prot);
+
+       if (prot != &unix_stream_proto)
+               return prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT,
+                                           flags & ~MSG_DONTWAIT, NULL);
+#endif
        return unix_stream_read_generic(&state, true);
 }
 
@@ -2575,7 +2852,10 @@ static int unix_shutdown(struct socket *sock, int mode)
                (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)) {
 
                int peer_mode = 0;
+               const struct proto *prot = READ_ONCE(other->sk_prot);
 
+               if (prot->unhash)
+                       prot->unhash(other);
                if (mode&RCV_SHUTDOWN)
                        peer_mode |= SEND_SHUTDOWN;
                if (mode&SEND_SHUTDOWN)
@@ -2584,10 +2864,12 @@ static int unix_shutdown(struct socket *sock, int mode)
                other->sk_shutdown |= peer_mode;
                unix_state_unlock(other);
                other->sk_state_change(other);
-               if (peer_mode == SHUTDOWN_MASK)
+               if (peer_mode == SHUTDOWN_MASK) {
                        sk_wake_async(other, SOCK_WAKE_WAITD, POLL_HUP);
-               else if (peer_mode & RCV_SHUTDOWN)
+                       other->sk_state = TCP_CLOSE;
+               } else if (peer_mode & RCV_SHUTDOWN) {
                        sk_wake_async(other, SOCK_WAKE_WAITD, POLL_IN);
+               }
        }
        if (other)
                sock_put(other);
@@ -2682,6 +2964,20 @@ static int unix_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg)
        case SIOCUNIXFILE:
                err = unix_open_file(sk);
                break;
+#if IS_ENABLED(CONFIG_AF_UNIX_OOB)
+       case SIOCATMARK:
+               {
+                       struct sk_buff *skb;
+                       struct unix_sock *u = unix_sk(sk);
+                       int answ = 0;
+
+                       skb = skb_peek(&sk->sk_receive_queue);
+                       if (skb && skb == u->oob_skb)
+                               answ = 1;
+                       err = put_user(answ, (int __user *)arg);
+               }
+               break;
+#endif
        default:
                err = -ENOIOCTLCMD;
                break;
@@ -2918,6 +3214,64 @@ static const struct seq_operations unix_seq_ops = {
        .stop   = unix_seq_stop,
        .show   = unix_seq_show,
 };
+
+#if IS_BUILTIN(CONFIG_UNIX) && defined(CONFIG_BPF_SYSCALL)
+struct bpf_iter__unix {
+       __bpf_md_ptr(struct bpf_iter_meta *, meta);
+       __bpf_md_ptr(struct unix_sock *, unix_sk);
+       uid_t uid __aligned(8);
+};
+
+static int unix_prog_seq_show(struct bpf_prog *prog, struct bpf_iter_meta *meta,
+                             struct unix_sock *unix_sk, uid_t uid)
+{
+       struct bpf_iter__unix ctx;
+
+       meta->seq_num--;  /* skip SEQ_START_TOKEN */
+       ctx.meta = meta;
+       ctx.unix_sk = unix_sk;
+       ctx.uid = uid;
+       return bpf_iter_run_prog(prog, &ctx);
+}
+
+static int bpf_iter_unix_seq_show(struct seq_file *seq, void *v)
+{
+       struct bpf_iter_meta meta;
+       struct bpf_prog *prog;
+       struct sock *sk = v;
+       uid_t uid;
+
+       if (v == SEQ_START_TOKEN)
+               return 0;
+
+       uid = from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk));
+       meta.seq = seq;
+       prog = bpf_iter_get_info(&meta, false);
+       return unix_prog_seq_show(prog, &meta, v, uid);
+}
+
+static void bpf_iter_unix_seq_stop(struct seq_file *seq, void *v)
+{
+       struct bpf_iter_meta meta;
+       struct bpf_prog *prog;
+
+       if (!v) {
+               meta.seq = seq;
+               prog = bpf_iter_get_info(&meta, true);
+               if (prog)
+                       (void)unix_prog_seq_show(prog, &meta, v, 0);
+       }
+
+       unix_seq_stop(seq, v);
+}
+
+static const struct seq_operations bpf_iter_unix_seq_ops = {
+       .start  = unix_seq_start,
+       .next   = unix_seq_next,
+       .stop   = bpf_iter_unix_seq_stop,
+       .show   = bpf_iter_unix_seq_show,
+};
+#endif
 #endif
 
 static const struct net_proto_family unix_family_ops = {
@@ -2958,13 +3312,48 @@ static struct pernet_operations unix_net_ops = {
        .exit = unix_net_exit,
 };
 
+#if IS_BUILTIN(CONFIG_UNIX) && defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS)
+DEFINE_BPF_ITER_FUNC(unix, struct bpf_iter_meta *meta,
+                    struct unix_sock *unix_sk, uid_t uid)
+
+static const struct bpf_iter_seq_info unix_seq_info = {
+       .seq_ops                = &bpf_iter_unix_seq_ops,
+       .init_seq_private       = bpf_iter_init_seq_net,
+       .fini_seq_private       = bpf_iter_fini_seq_net,
+       .seq_priv_size          = sizeof(struct seq_net_private),
+};
+
+static struct bpf_iter_reg unix_reg_info = {
+       .target                 = "unix",
+       .ctx_arg_info_size      = 1,
+       .ctx_arg_info           = {
+               { offsetof(struct bpf_iter__unix, unix_sk),
+                 PTR_TO_BTF_ID_OR_NULL },
+       },
+       .seq_info               = &unix_seq_info,
+};
+
+static void __init bpf_iter_register(void)
+{
+       unix_reg_info.ctx_arg_info[0].btf_id = btf_sock_ids[BTF_SOCK_TYPE_UNIX];
+       if (bpf_iter_reg_target(&unix_reg_info))
+               pr_warn("Warning: could not register bpf iterator unix\n");
+}
+#endif
+
 static int __init af_unix_init(void)
 {
        int rc = -1;
 
        BUILD_BUG_ON(sizeof(struct unix_skb_parms) > sizeof_field(struct sk_buff, cb));
 
-       rc = proto_register(&unix_proto, 1);
+       rc = proto_register(&unix_dgram_proto, 1);
+       if (rc != 0) {
+               pr_crit("%s: Cannot create unix_sock SLAB cache!\n", __func__);
+               goto out;
+       }
+
+       rc = proto_register(&unix_stream_proto, 1);
        if (rc != 0) {
                pr_crit("%s: Cannot create unix_sock SLAB cache!\n", __func__);
                goto out;
@@ -2972,6 +3361,12 @@ static int __init af_unix_init(void)
 
        sock_register(&unix_family_ops);
        register_pernet_subsys(&unix_net_ops);
+       unix_bpf_build_proto();
+
+#if IS_BUILTIN(CONFIG_UNIX) && defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS)
+       bpf_iter_register();
+#endif
+
 out:
        return rc;
 }
@@ -2979,7 +3374,8 @@ out:
 static void __exit af_unix_exit(void)
 {
        sock_unregister(PF_UNIX);
-       proto_unregister(&unix_proto);
+       proto_unregister(&unix_dgram_proto);
+       proto_unregister(&unix_stream_proto);
        unregister_pernet_subsys(&unix_net_ops);
 }