Linux 6.9-rc1
[linux-2.6-microblaze.git] / net / socket.c
index 9b27c5e..e5f3af4 100644 (file)
@@ -57,6 +57,7 @@
 #include <linux/mm.h>
 #include <linux/socket.h>
 #include <linux/file.h>
+#include <linux/splice.h>
 #include <linux/net.h>
 #include <linux/interrupt.h>
 #include <linux/thread_info.h>
@@ -87,6 +88,7 @@
 #include <linux/xattr.h>
 #include <linux/nospec.h>
 #include <linux/indirect_call_wrapper.h>
+#include <linux/io_uring.h>
 
 #include <linux/uaccess.h>
 #include <asm/unistd.h>
 #include <net/busy_poll.h>
 #include <linux/errqueue.h>
 #include <linux/ptp_clock_kernel.h>
+#include <trace/events/sock.h>
 
 #ifdef CONFIG_NET_RX_BUSY_POLL
 unsigned int sysctl_net_busy_read __read_mostly;
@@ -125,19 +128,19 @@ static long compat_sock_ioctl(struct file *file,
                              unsigned int cmd, unsigned long arg);
 #endif
 static int sock_fasync(int fd, struct file *filp, int on);
-static ssize_t sock_sendpage(struct file *file, struct page *page,
-                            int offset, size_t size, loff_t *ppos, int more);
 static ssize_t sock_splice_read(struct file *file, loff_t *ppos,
                                struct pipe_inode_info *pipe, size_t len,
                                unsigned int flags);
+static void sock_splice_eof(struct file *file);
 
 #ifdef CONFIG_PROC_FS
 static void sock_show_fdinfo(struct seq_file *m, struct file *f)
 {
        struct socket *sock = f->private_data;
+       const struct proto_ops *ops = READ_ONCE(sock->ops);
 
-       if (sock->ops->show_fdinfo)
-               sock->ops->show_fdinfo(m, sock);
+       if (ops->show_fdinfo)
+               ops->show_fdinfo(m, sock);
 }
 #else
 #define sock_show_fdinfo NULL
@@ -158,12 +161,13 @@ static const struct file_operations socket_file_ops = {
 #ifdef CONFIG_COMPAT
        .compat_ioctl = compat_sock_ioctl,
 #endif
+       .uring_cmd =    io_uring_cmd_sock,
        .mmap =         sock_mmap,
        .release =      sock_close,
        .fasync =       sock_fasync,
-       .sendpage =     sock_sendpage,
-       .splice_write = generic_splice_sendpage,
+       .splice_write = splice_to_socket,
        .splice_read =  sock_splice_read,
+       .splice_eof =   sock_splice_eof,
        .show_fdinfo =  sock_show_fdinfo,
 };
 
@@ -339,7 +343,7 @@ static void init_inodecache(void)
                                              0,
                                              (SLAB_HWCACHE_ALIGN |
                                               SLAB_RECLAIM_ACCOUNT |
-                                              SLAB_MEM_SPREAD | SLAB_ACCOUNT),
+                                              SLAB_ACCOUNT),
                                              init_once);
        BUG_ON(sock_inode_cachep == NULL);
 }
@@ -355,7 +359,7 @@ static const struct super_operations sockfs_ops = {
  */
 static char *sockfs_dname(struct dentry *dentry, char *buffer, int buflen)
 {
-       return dynamic_dname(dentry, buffer, buflen, "socket:[%lu]",
+       return dynamic_dname(buffer, buflen, "socket:[%lu]",
                                d_inode(dentry)->i_ino);
 }
 
@@ -385,7 +389,7 @@ static const struct xattr_handler sockfs_xattr_handler = {
 };
 
 static int sockfs_security_xattr_set(const struct xattr_handler *handler,
-                                    struct user_namespace *mnt_userns,
+                                    struct mnt_idmap *idmap,
                                     struct dentry *dentry, struct inode *inode,
                                     const char *suffix, const void *value,
                                     size_t size, int flags)
@@ -399,7 +403,7 @@ static const struct xattr_handler sockfs_security_xattr_handler = {
        .set = sockfs_security_xattr_set,
 };
 
-static const struct xattr_handler *sockfs_xattr_handlers[] = {
+static const struct xattr_handler * const sockfs_xattr_handlers[] = {
        &sockfs_xattr_handler,
        &sockfs_security_xattr_handler,
        NULL
@@ -449,7 +453,9 @@ static struct file_system_type sock_fs_type = {
  *
  *     Returns the &file bound with @sock, implicitly storing it
  *     in sock->file. If dname is %NULL, sets to "".
- *     On failure the return is a ERR pointer (see linux/err.h).
+ *
+ *     On failure @sock is released, and an ERR pointer is returned.
+ *
  *     This function uses GFP_KERNEL internally.
  */
 
@@ -468,6 +474,7 @@ struct file *sock_alloc_file(struct socket *sock, int flags, const char *dname)
                return file;
        }
 
+       file->f_mode |= FMODE_NOWAIT;
        sock->file = file;
        file->private_data = sock;
        stream_open(SOCK_INODE(sock), file);
@@ -589,10 +596,10 @@ static ssize_t sockfs_listxattr(struct dentry *dentry, char *buffer,
        return used;
 }
 
-static int sockfs_setattr(struct user_namespace *mnt_userns,
+static int sockfs_setattr(struct mnt_idmap *idmap,
                          struct dentry *dentry, struct iattr *iattr)
 {
-       int err = simple_setattr(&init_user_ns, dentry, iattr);
+       int err = simple_setattr(&nop_mnt_idmap, dentry, iattr);
 
        if (!err && (iattr->ia_valid & ATTR_UID)) {
                struct socket *sock = SOCKET_I(d_inode(dentry));
@@ -642,12 +649,14 @@ EXPORT_SYMBOL(sock_alloc);
 
 static void __sock_release(struct socket *sock, struct inode *inode)
 {
-       if (sock->ops) {
-               struct module *owner = sock->ops->owner;
+       const struct proto_ops *ops = READ_ONCE(sock->ops);
+
+       if (ops) {
+               struct module *owner = ops->owner;
 
                if (inode)
                        inode_lock(inode);
-               sock->ops->release(sock);
+               ops->release(sock);
                sock->sk = NULL;
                if (inode)
                        inode_unlock(inode);
@@ -709,15 +718,33 @@ INDIRECT_CALLABLE_DECLARE(int inet_sendmsg(struct socket *, struct msghdr *,
                                           size_t));
 INDIRECT_CALLABLE_DECLARE(int inet6_sendmsg(struct socket *, struct msghdr *,
                                            size_t));
+
+static noinline void call_trace_sock_send_length(struct sock *sk, int ret,
+                                                int flags)
+{
+       trace_sock_send_length(sk, ret, 0);
+}
+
 static inline int sock_sendmsg_nosec(struct socket *sock, struct msghdr *msg)
 {
-       int ret = INDIRECT_CALL_INET(sock->ops->sendmsg, inet6_sendmsg,
+       int ret = INDIRECT_CALL_INET(READ_ONCE(sock->ops)->sendmsg, inet6_sendmsg,
                                     inet_sendmsg, sock, msg,
                                     msg_data_left(msg));
        BUG_ON(ret == -EIOCBQUEUED);
+
+       if (trace_sock_send_length_enabled())
+               call_trace_sock_send_length(sock->sk, ret, 0);
        return ret;
 }
 
+static int __sock_sendmsg(struct socket *sock, struct msghdr *msg)
+{
+       int err = security_socket_sendmsg(sock, msg,
+                                         msg_data_left(msg));
+
+       return err ?: sock_sendmsg_nosec(sock, msg);
+}
+
 /**
  *     sock_sendmsg - send a message through @sock
  *     @sock: socket
@@ -728,10 +755,21 @@ static inline int sock_sendmsg_nosec(struct socket *sock, struct msghdr *msg)
  */
 int sock_sendmsg(struct socket *sock, struct msghdr *msg)
 {
-       int err = security_socket_sendmsg(sock, msg,
-                                         msg_data_left(msg));
+       struct sockaddr_storage *save_addr = (struct sockaddr_storage *)msg->msg_name;
+       struct sockaddr_storage address;
+       int save_len = msg->msg_namelen;
+       int ret;
 
-       return err ?: sock_sendmsg_nosec(sock, msg);
+       if (msg->msg_name) {
+               memcpy(&address, msg->msg_name, msg->msg_namelen);
+               msg->msg_name = &address;
+       }
+
+       ret = __sock_sendmsg(sock, msg);
+       msg->msg_name = save_addr;
+       msg->msg_namelen = save_len;
+
+       return ret;
 }
 EXPORT_SYMBOL(sock_sendmsg);
 
@@ -750,7 +788,7 @@ EXPORT_SYMBOL(sock_sendmsg);
 int kernel_sendmsg(struct socket *sock, struct msghdr *msg,
                   struct kvec *vec, size_t num, size_t size)
 {
-       iov_iter_kvec(&msg->msg_iter, WRITE, vec, num, size);
+       iov_iter_kvec(&msg->msg_iter, ITER_SOURCE, vec, num, size);
        return sock_sendmsg(sock, msg);
 }
 EXPORT_SYMBOL(kernel_sendmsg);
@@ -772,13 +810,14 @@ int kernel_sendmsg_locked(struct sock *sk, struct msghdr *msg,
                          struct kvec *vec, size_t num, size_t size)
 {
        struct socket *sock = sk->sk_socket;
+       const struct proto_ops *ops = READ_ONCE(sock->ops);
 
-       if (!sock->ops->sendmsg_locked)
+       if (!ops->sendmsg_locked)
                return sock_no_sendmsg_locked(sk, msg, size);
 
-       iov_iter_kvec(&msg->msg_iter, WRITE, vec, num, size);
+       iov_iter_kvec(&msg->msg_iter, ITER_SOURCE, vec, num, size);
 
-       return sock->ops->sendmsg_locked(sk, msg, msg_data_left(msg));
+       return ops->sendmsg_locked(sk, msg, msg_data_left(msg));
 }
 EXPORT_SYMBOL(kernel_sendmsg_locked);
 
@@ -807,7 +846,7 @@ static bool skb_is_swtx_tstamp(const struct sk_buff *skb, int false_tstamp)
 
 static ktime_t get_timestamp(struct sock *sk, struct sk_buff *skb, int *if_index)
 {
-       bool cycles = sk->sk_tsflags & SOF_TIMESTAMPING_BIND_PHC;
+       bool cycles = READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_BIND_PHC;
        struct skb_shared_hwtstamps *shhwtstamps = skb_hwtstamps(skb);
        struct net_device *orig_dev;
        ktime_t hwtstamp;
@@ -859,12 +898,12 @@ void __sock_recv_timestamp(struct msghdr *msg, struct sock *sk,
        int need_software_tstamp = sock_flag(sk, SOCK_RCVTSTAMP);
        int new_tstamp = sock_flag(sk, SOCK_TSTAMP_NEW);
        struct scm_timestamping_internal tss;
-
        int empty = 1, false_tstamp = 0;
        struct skb_shared_hwtstamps *shhwtstamps =
                skb_hwtstamps(skb);
        int if_index;
        ktime_t hwtstamp;
+       u32 tsflags;
 
        /* Race occurred between timestamp enabling and packet
           receiving.  Fill in the current time for now. */
@@ -906,11 +945,12 @@ void __sock_recv_timestamp(struct msghdr *msg, struct sock *sk,
        }
 
        memset(&tss, 0, sizeof(tss));
-       if ((sk->sk_tsflags & SOF_TIMESTAMPING_SOFTWARE) &&
+       tsflags = READ_ONCE(sk->sk_tsflags);
+       if ((tsflags & SOF_TIMESTAMPING_SOFTWARE) &&
            ktime_to_timespec64_cond(skb->tstamp, tss.ts + 0))
                empty = 0;
        if (shhwtstamps &&
-           (sk->sk_tsflags & SOF_TIMESTAMPING_RAW_HARDWARE) &&
+           (tsflags & SOF_TIMESTAMPING_RAW_HARDWARE) &&
            !skb_is_swtx_tstamp(skb, false_tstamp)) {
                if_index = 0;
                if (skb_shinfo(skb)->tx_flags & SKBTX_HW_TSTAMP_NETDEV)
@@ -918,14 +958,14 @@ void __sock_recv_timestamp(struct msghdr *msg, struct sock *sk,
                else
                        hwtstamp = shhwtstamps->hwtstamp;
 
-               if (sk->sk_tsflags & SOF_TIMESTAMPING_BIND_PHC)
+               if (tsflags & SOF_TIMESTAMPING_BIND_PHC)
                        hwtstamp = ptp_convert_timestamp(&hwtstamp,
-                                                        sk->sk_bind_phc);
+                                                        READ_ONCE(sk->sk_bind_phc));
 
                if (ktime_to_timespec64_cond(hwtstamp, tss.ts + 2)) {
                        empty = 0;
 
-                       if ((sk->sk_tsflags & SOF_TIMESTAMPING_OPT_PKTINFO) &&
+                       if ((tsflags & SOF_TIMESTAMPING_OPT_PKTINFO) &&
                            !skb_is_err_queue(skb))
                                put_ts_pktinfo(msg, skb, if_index);
                }
@@ -944,6 +984,7 @@ void __sock_recv_timestamp(struct msghdr *msg, struct sock *sk,
 }
 EXPORT_SYMBOL_GPL(__sock_recv_timestamp);
 
+#ifdef CONFIG_WIRELESS
 void __sock_recv_wifi_status(struct msghdr *msg, struct sock *sk,
        struct sk_buff *skb)
 {
@@ -959,6 +1000,7 @@ void __sock_recv_wifi_status(struct msghdr *msg, struct sock *sk,
        put_cmsg(msg, SOL_SOCKET, SCM_WIFI_STATUS, sizeof(ack), &ack);
 }
 EXPORT_SYMBOL_GPL(__sock_recv_wifi_status);
+#endif
 
 static inline void sock_recv_drops(struct msghdr *msg, struct sock *sk,
                                   struct sk_buff *skb)
@@ -971,9 +1013,12 @@ static inline void sock_recv_drops(struct msghdr *msg, struct sock *sk,
 static void sock_recv_mark(struct msghdr *msg, struct sock *sk,
                           struct sk_buff *skb)
 {
-       if (sock_flag(sk, SOCK_RCVMARK) && skb)
-               put_cmsg(msg, SOL_SOCKET, SO_MARK, sizeof(__u32),
-                        &skb->mark);
+       if (sock_flag(sk, SOCK_RCVMARK) && skb) {
+               /* We must use a bounce buffer for CONFIG_HARDENED_USERCOPY=y */
+               __u32 mark = skb->mark;
+
+               put_cmsg(msg, SOL_SOCKET, SO_MARK, sizeof(__u32), &mark);
+       }
 }
 
 void __sock_recv_cmsgs(struct msghdr *msg, struct sock *sk,
@@ -989,12 +1034,22 @@ INDIRECT_CALLABLE_DECLARE(int inet_recvmsg(struct socket *, struct msghdr *,
                                           size_t, int));
 INDIRECT_CALLABLE_DECLARE(int inet6_recvmsg(struct socket *, struct msghdr *,
                                            size_t, int));
+
+static noinline void call_trace_sock_recv_length(struct sock *sk, int ret, int flags)
+{
+       trace_sock_recv_length(sk, ret, flags);
+}
+
 static inline int sock_recvmsg_nosec(struct socket *sock, struct msghdr *msg,
                                     int flags)
 {
-       return INDIRECT_CALL_INET(sock->ops->recvmsg, inet6_recvmsg,
-                                 inet_recvmsg, sock, msg, msg_data_left(msg),
-                                 flags);
+       int ret = INDIRECT_CALL_INET(READ_ONCE(sock->ops)->recvmsg,
+                                    inet6_recvmsg,
+                                    inet_recvmsg, sock, msg,
+                                    msg_data_left(msg), flags);
+       if (trace_sock_recv_length_enabled())
+               call_trace_sock_recv_length(sock->sk, ret, flags);
+       return ret;
 }
 
 /**
@@ -1034,36 +1089,33 @@ int kernel_recvmsg(struct socket *sock, struct msghdr *msg,
                   struct kvec *vec, size_t num, size_t size, int flags)
 {
        msg->msg_control_is_user = false;
-       iov_iter_kvec(&msg->msg_iter, READ, vec, num, size);
+       iov_iter_kvec(&msg->msg_iter, ITER_DEST, vec, num, size);
        return sock_recvmsg(sock, msg, flags);
 }
 EXPORT_SYMBOL(kernel_recvmsg);
 
-static ssize_t sock_sendpage(struct file *file, struct page *page,
-                            int offset, size_t size, loff_t *ppos, int more)
+static ssize_t sock_splice_read(struct file *file, loff_t *ppos,
+                               struct pipe_inode_info *pipe, size_t len,
+                               unsigned int flags)
 {
-       struct socket *sock;
-       int flags;
-
-       sock = file->private_data;
+       struct socket *sock = file->private_data;
+       const struct proto_ops *ops;
 
-       flags = (file->f_flags & O_NONBLOCK) ? MSG_DONTWAIT : 0;
-       /* more is a combination of MSG_MORE and MSG_SENDPAGE_NOTLAST */
-       flags |= more;
+       ops = READ_ONCE(sock->ops);
+       if (unlikely(!ops->splice_read))
+               return copy_splice_read(file, ppos, pipe, len, flags);
 
-       return kernel_sendpage(sock, page, offset, size, flags);
+       return ops->splice_read(sock, ppos, pipe, len, flags);
 }
 
-static ssize_t sock_splice_read(struct file *file, loff_t *ppos,
-                               struct pipe_inode_info *pipe, size_t len,
-                               unsigned int flags)
+static void sock_splice_eof(struct file *file)
 {
        struct socket *sock = file->private_data;
+       const struct proto_ops *ops;
 
-       if (unlikely(!sock->ops->splice_read))
-               return generic_file_splice_read(file, ppos, pipe, len, flags);
-
-       return sock->ops->splice_read(sock, ppos, pipe, len, flags);
+       ops = READ_ONCE(sock->ops);
+       if (ops->splice_eof)
+               ops->splice_eof(sock);
 }
 
 static ssize_t sock_read_iter(struct kiocb *iocb, struct iov_iter *to)
@@ -1105,7 +1157,7 @@ static ssize_t sock_write_iter(struct kiocb *iocb, struct iov_iter *from)
        if (sock->type == SOCK_SEQPACKET)
                msg.msg_flags |= MSG_EOR;
 
-       res = sock_sendmsg(sock, &msg);
+       res = __sock_sendmsg(sock, &msg);
        *from = msg.msg_iter;
        return res;
 }
@@ -1160,13 +1212,14 @@ EXPORT_SYMBOL(vlan_ioctl_set);
 static long sock_do_ioctl(struct net *net, struct socket *sock,
                          unsigned int cmd, unsigned long arg)
 {
+       const struct proto_ops *ops = READ_ONCE(sock->ops);
        struct ifreq ifr;
        bool need_copyout;
        int err;
        void __user *argp = (void __user *)arg;
        void __user *data;
 
-       err = sock->ops->ioctl(sock, cmd, arg);
+       err = ops->ioctl(sock, cmd, arg);
 
        /*
         * If this ioctl is unknown try to hand it down
@@ -1195,6 +1248,7 @@ static long sock_do_ioctl(struct net *net, struct socket *sock,
 
 static long sock_ioctl(struct file *file, unsigned cmd, unsigned long arg)
 {
+       const struct proto_ops  *ops;
        struct socket *sock;
        struct sock *sk;
        void __user *argp = (void __user *)arg;
@@ -1202,6 +1256,7 @@ static long sock_ioctl(struct file *file, unsigned cmd, unsigned long arg)
        struct net *net;
 
        sock = file->private_data;
+       ops = READ_ONCE(sock->ops);
        sk = sock->sk;
        net = sock_net(sk);
        if (unlikely(cmd >= SIOCDEVPRIVATE && cmd <= (SIOCDEVPRIVATE + 15))) {
@@ -1259,23 +1314,23 @@ static long sock_ioctl(struct file *file, unsigned cmd, unsigned long arg)
                        break;
                case SIOCGSTAMP_OLD:
                case SIOCGSTAMPNS_OLD:
-                       if (!sock->ops->gettstamp) {
+                       if (!ops->gettstamp) {
                                err = -ENOIOCTLCMD;
                                break;
                        }
-                       err = sock->ops->gettstamp(sock, argp,
-                                                  cmd == SIOCGSTAMP_OLD,
-                                                  !IS_ENABLED(CONFIG_64BIT));
+                       err = ops->gettstamp(sock, argp,
+                                            cmd == SIOCGSTAMP_OLD,
+                                            !IS_ENABLED(CONFIG_64BIT));
                        break;
                case SIOCGSTAMP_NEW:
                case SIOCGSTAMPNS_NEW:
-                       if (!sock->ops->gettstamp) {
+                       if (!ops->gettstamp) {
                                err = -ENOIOCTLCMD;
                                break;
                        }
-                       err = sock->ops->gettstamp(sock, argp,
-                                                  cmd == SIOCGSTAMP_NEW,
-                                                  false);
+                       err = ops->gettstamp(sock, argp,
+                                            cmd == SIOCGSTAMP_NEW,
+                                            false);
                        break;
 
                case SIOCGIFCONF:
@@ -1336,9 +1391,10 @@ EXPORT_SYMBOL(sock_create_lite);
 static __poll_t sock_poll(struct file *file, poll_table *wait)
 {
        struct socket *sock = file->private_data;
+       const struct proto_ops *ops = READ_ONCE(sock->ops);
        __poll_t events = poll_requested_events(wait), flag = 0;
 
-       if (!sock->ops->poll)
+       if (!ops->poll)
                return 0;
 
        if (sk_can_busy_loop(sock->sk)) {
@@ -1350,14 +1406,14 @@ static __poll_t sock_poll(struct file *file, poll_table *wait)
                flag = POLL_BUSY_LOOP;
        }
 
-       return sock->ops->poll(file, sock, wait) | flag;
+       return ops->poll(file, sock, wait) | flag;
 }
 
 static int sock_mmap(struct file *file, struct vm_area_struct *vma)
 {
        struct socket *sock = file->private_data;
 
-       return sock->ops->mmap(file, sock, vma);
+       return READ_ONCE(sock->ops)->mmap(file, sock, vma);
 }
 
 static int sock_close(struct inode *inode, struct file *filp)
@@ -1610,7 +1666,6 @@ static struct socket *__sys_socket_create(int family, int type, int protocol)
 struct file *__sys_socket_file(int family, int type, int protocol)
 {
        struct socket *sock;
-       struct file *file;
        int flags;
 
        sock = __sys_socket_create(family, type, protocol);
@@ -1621,19 +1676,35 @@ struct file *__sys_socket_file(int family, int type, int protocol)
        if (SOCK_NONBLOCK != O_NONBLOCK && (flags & SOCK_NONBLOCK))
                flags = (flags & ~SOCK_NONBLOCK) | O_NONBLOCK;
 
-       file = sock_alloc_file(sock, flags, NULL);
-       if (IS_ERR(file))
-               sock_release(sock);
+       return sock_alloc_file(sock, flags, NULL);
+}
 
-       return file;
+/*     A hook for bpf progs to attach to and update socket protocol.
+ *
+ *     A static noinline declaration here could cause the compiler to
+ *     optimize away the function. A global noinline declaration will
+ *     keep the definition, but may optimize away the callsite.
+ *     Therefore, __weak is needed to ensure that the call is still
+ *     emitted, by telling the compiler that we don't know what the
+ *     function might eventually be.
+ */
+
+__bpf_hook_start();
+
+__weak noinline int update_socket_protocol(int family, int type, int protocol)
+{
+       return protocol;
 }
 
+__bpf_hook_end();
+
 int __sys_socket(int family, int type, int protocol)
 {
        struct socket *sock;
        int flags;
 
-       sock = __sys_socket_create(family, type, protocol);
+       sock = __sys_socket_create(family, type,
+                                  update_socket_protocol(family, type, protocol));
        if (IS_ERR(sock))
                return PTR_ERR(sock);
 
@@ -1712,7 +1783,7 @@ int __sys_socketpair(int family, int type, int protocol, int __user *usockvec)
                goto out;
        }
 
-       err = sock1->ops->socketpair(sock1, sock2);
+       err = READ_ONCE(sock1->ops)->socketpair(sock1, sock2);
        if (unlikely(err < 0)) {
                sock_release(sock2);
                sock_release(sock1);
@@ -1773,7 +1844,7 @@ int __sys_bind(int fd, struct sockaddr __user *umyaddr, int addrlen)
                                                   (struct sockaddr *)&address,
                                                   addrlen);
                        if (!err)
-                               err = sock->ops->bind(sock,
+                               err = READ_ONCE(sock->ops)->bind(sock,
                                                      (struct sockaddr *)
                                                      &address, addrlen);
                }
@@ -1801,13 +1872,13 @@ int __sys_listen(int fd, int backlog)
 
        sock = sockfd_lookup_light(fd, &err, &fput_needed);
        if (sock) {
-               somaxconn = sock_net(sock->sk)->core.sysctl_somaxconn;
+               somaxconn = READ_ONCE(sock_net(sock->sk)->core.sysctl_somaxconn);
                if ((unsigned int)backlog > somaxconn)
                        backlog = somaxconn;
 
                err = security_socket_listen(sock, backlog);
                if (!err)
-                       err = sock->ops->listen(sock, backlog);
+                       err = READ_ONCE(sock->ops)->listen(sock, backlog);
 
                fput_light(sock->file, fput_needed);
        }
@@ -1827,6 +1898,7 @@ struct file *do_accept(struct file *file, unsigned file_flags,
        struct file *newfile;
        int err, len;
        struct sockaddr_storage address;
+       const struct proto_ops *ops;
 
        sock = sock_from_file(file);
        if (!sock)
@@ -1835,15 +1907,16 @@ struct file *do_accept(struct file *file, unsigned file_flags,
        newsock = sock_alloc();
        if (!newsock)
                return ERR_PTR(-ENFILE);
+       ops = READ_ONCE(sock->ops);
 
        newsock->type = sock->type;
-       newsock->ops = sock->ops;
+       newsock->ops = ops;
 
        /*
         * We don't need try_module_get here, as the listening socket (sock)
         * has the protocol module (sock->ops->owner) held.
         */
-       __module_get(newsock->ops->owner);
+       __module_get(ops->owner);
 
        newfile = sock_alloc_file(newsock, flags, sock->sk->sk_prot_creator->name);
        if (IS_ERR(newfile))
@@ -1853,14 +1926,13 @@ struct file *do_accept(struct file *file, unsigned file_flags,
        if (err)
                goto out_fd;
 
-       err = sock->ops->accept(sock, newsock, sock->file->f_flags | file_flags,
+       err = ops->accept(sock, newsock, sock->file->f_flags | file_flags,
                                        false);
        if (err < 0)
                goto out_fd;
 
        if (upeer_sockaddr) {
-               len = newsock->ops->getname(newsock,
-                                       (struct sockaddr *)&address, 2);
+               len = ops->getname(newsock, (struct sockaddr *)&address, 2);
                if (len < 0) {
                        err = -ECONNABORTED;
                        goto out_fd;
@@ -1973,8 +2045,8 @@ int __sys_connect_file(struct file *file, struct sockaddr_storage *address,
        if (err)
                goto out;
 
-       err = sock->ops->connect(sock, (struct sockaddr *)address, addrlen,
-                                sock->file->f_flags | file_flags);
+       err = READ_ONCE(sock->ops)->connect(sock, (struct sockaddr *)address,
+                               addrlen, sock->file->f_flags | file_flags);
 out:
        return err;
 }
@@ -2023,7 +2095,7 @@ int __sys_getsockname(int fd, struct sockaddr __user *usockaddr,
        if (err)
                goto out_put;
 
-       err = sock->ops->getname(sock, (struct sockaddr *)&address, 0);
+       err = READ_ONCE(sock->ops)->getname(sock, (struct sockaddr *)&address, 0);
        if (err < 0)
                goto out_put;
        /* "err" is actually length in this case */
@@ -2055,13 +2127,15 @@ int __sys_getpeername(int fd, struct sockaddr __user *usockaddr,
 
        sock = sockfd_lookup_light(fd, &err, &fput_needed);
        if (sock != NULL) {
+               const struct proto_ops *ops = READ_ONCE(sock->ops);
+
                err = security_socket_getpeername(sock);
                if (err) {
                        fput_light(sock->file, fput_needed);
                        return err;
                }
 
-               err = sock->ops->getname(sock, (struct sockaddr *)&address, 1);
+               err = ops->getname(sock, (struct sockaddr *)&address, 1);
                if (err >= 0)
                        /* "err" is actually length in this case */
                        err = move_addr_to_user(&address, err, usockaddr,
@@ -2089,10 +2163,9 @@ int __sys_sendto(int fd, void __user *buff, size_t len, unsigned int flags,
        struct sockaddr_storage address;
        int err;
        struct msghdr msg;
-       struct iovec iov;
        int fput_needed;
 
-       err = import_single_range(WRITE, buff, len, &iov, &msg.msg_iter);
+       err = import_ubuf(ITER_SOURCE, buff, len, &msg.msg_iter);
        if (unlikely(err))
                return err;
        sock = sockfd_lookup_light(fd, &err, &fput_needed);
@@ -2111,10 +2184,11 @@ int __sys_sendto(int fd, void __user *buff, size_t len, unsigned int flags,
                msg.msg_name = (struct sockaddr *)&address;
                msg.msg_namelen = addr_len;
        }
+       flags &= ~MSG_INTERNAL_SENDMSG_FLAGS;
        if (sock->file->f_flags & O_NONBLOCK)
                flags |= MSG_DONTWAIT;
        msg.msg_flags = flags;
-       err = sock_sendmsg(sock, &msg);
+       err = __sock_sendmsg(sock, &msg);
 
 out_put:
        fput_light(sock->file, fput_needed);
@@ -2153,11 +2227,10 @@ int __sys_recvfrom(int fd, void __user *ubuf, size_t size, unsigned int flags,
                .msg_name = addr ? (struct sockaddr *)&address : NULL,
        };
        struct socket *sock;
-       struct iovec iov;
        int err, err2;
        int fput_needed;
 
-       err = import_single_range(READ, ubuf, size, &iov, &msg.msg_iter);
+       err = import_ubuf(ITER_DEST, ubuf, size, &msg.msg_iter);
        if (unlikely(err))
                return err;
        sock = sockfd_lookup_light(fd, &err, &fput_needed);
@@ -2199,41 +2272,26 @@ SYSCALL_DEFINE4(recv, int, fd, void __user *, ubuf, size_t, size,
 
 static bool sock_use_custom_sol_socket(const struct socket *sock)
 {
-       const struct sock *sk = sock->sk;
-
-       /* Use sock->ops->setsockopt() for MPTCP */
-       return IS_ENABLED(CONFIG_MPTCP) &&
-              sk->sk_protocol == IPPROTO_MPTCP &&
-              sk->sk_type == SOCK_STREAM &&
-              (sk->sk_family == AF_INET || sk->sk_family == AF_INET6);
+       return test_bit(SOCK_CUSTOM_SOCKOPT, &sock->flags);
 }
 
-/*
- *     Set a socket option. Because we don't know the option lengths we have
- *     to pass the user mode parameter for the protocols to sort out.
- */
-int __sys_setsockopt(int fd, int level, int optname, char __user *user_optval,
-               int optlen)
+int do_sock_setsockopt(struct socket *sock, bool compat, int level,
+                      int optname, sockptr_t optval, int optlen)
 {
-       sockptr_t optval = USER_SOCKPTR(user_optval);
+       const struct proto_ops *ops;
        char *kernel_optval = NULL;
-       int err, fput_needed;
-       struct socket *sock;
+       int err;
 
        if (optlen < 0)
                return -EINVAL;
 
-       sock = sockfd_lookup_light(fd, &err, &fput_needed);
-       if (!sock)
-               return err;
-
        err = security_socket_setsockopt(sock, level, optname);
        if (err)
                goto out_put;
 
-       if (!in_compat_syscall())
+       if (!compat)
                err = BPF_CGROUP_RUN_PROG_SETSOCKOPT(sock->sk, &level, &optname,
-                                                    user_optval, &optlen,
+                                                    optval, &optlen,
                                                     &kernel_optval);
        if (err < 0)
                goto out_put;
@@ -2244,15 +2302,37 @@ int __sys_setsockopt(int fd, int level, int optname, char __user *user_optval,
 
        if (kernel_optval)
                optval = KERNEL_SOCKPTR(kernel_optval);
+       ops = READ_ONCE(sock->ops);
        if (level == SOL_SOCKET && !sock_use_custom_sol_socket(sock))
                err = sock_setsockopt(sock, level, optname, optval, optlen);
-       else if (unlikely(!sock->ops->setsockopt))
+       else if (unlikely(!ops->setsockopt))
                err = -EOPNOTSUPP;
        else
-               err = sock->ops->setsockopt(sock, level, optname, optval,
+               err = ops->setsockopt(sock, level, optname, optval,
                                            optlen);
        kfree(kernel_optval);
 out_put:
+       return err;
+}
+EXPORT_SYMBOL(do_sock_setsockopt);
+
+/* Set a socket option. Because we don't know the option lengths we have
+ * to pass the user mode parameter for the protocols to sort out.
+ */
+int __sys_setsockopt(int fd, int level, int optname, char __user *user_optval,
+                    int optlen)
+{
+       sockptr_t optval = USER_SOCKPTR(user_optval);
+       bool compat = in_compat_syscall();
+       int err, fput_needed;
+       struct socket *sock;
+
+       sock = sockfd_lookup_light(fd, &err, &fput_needed);
+       if (!sock)
+               return err;
+
+       err = do_sock_setsockopt(sock, compat, level, optname, optval, optlen);
+
        fput_light(sock->file, fput_needed);
        return err;
 }
@@ -2266,6 +2346,43 @@ SYSCALL_DEFINE5(setsockopt, int, fd, int, level, int, optname,
 INDIRECT_CALLABLE_DECLARE(bool tcp_bpf_bypass_getsockopt(int level,
                                                         int optname));
 
+int do_sock_getsockopt(struct socket *sock, bool compat, int level,
+                      int optname, sockptr_t optval, sockptr_t optlen)
+{
+       int max_optlen __maybe_unused;
+       const struct proto_ops *ops;
+       int err;
+
+       err = security_socket_getsockopt(sock, level, optname);
+       if (err)
+               return err;
+
+       if (!compat)
+               max_optlen = BPF_CGROUP_GETSOCKOPT_MAX_OPTLEN(optlen);
+
+       ops = READ_ONCE(sock->ops);
+       if (level == SOL_SOCKET) {
+               err = sk_getsockopt(sock->sk, level, optname, optval, optlen);
+       } else if (unlikely(!ops->getsockopt)) {
+               err = -EOPNOTSUPP;
+       } else {
+               if (WARN_ONCE(optval.is_kernel || optlen.is_kernel,
+                             "Invalid argument type"))
+                       return -EOPNOTSUPP;
+
+               err = ops->getsockopt(sock, level, optname, optval.user,
+                                     optlen.user);
+       }
+
+       if (!compat)
+               err = BPF_CGROUP_RUN_PROG_GETSOCKOPT(sock->sk, level, optname,
+                                                    optval, optlen, max_optlen,
+                                                    err);
+
+       return err;
+}
+EXPORT_SYMBOL(do_sock_getsockopt);
+
 /*
  *     Get a socket option. Because we don't know the option lengths we have
  *     to pass a user mode parameter for the protocols to sort out.
@@ -2275,32 +2392,16 @@ int __sys_getsockopt(int fd, int level, int optname, char __user *optval,
 {
        int err, fput_needed;
        struct socket *sock;
-       int max_optlen;
+       bool compat;
 
        sock = sockfd_lookup_light(fd, &err, &fput_needed);
        if (!sock)
                return err;
 
-       err = security_socket_getsockopt(sock, level, optname);
-       if (err)
-               goto out_put;
-
-       if (!in_compat_syscall())
-               max_optlen = BPF_CGROUP_GETSOCKOPT_MAX_OPTLEN(optlen);
-
-       if (level == SOL_SOCKET)
-               err = sock_getsockopt(sock, level, optname, optval, optlen);
-       else if (unlikely(!sock->ops->getsockopt))
-               err = -EOPNOTSUPP;
-       else
-               err = sock->ops->getsockopt(sock, level, optname, optval,
-                                           optlen);
+       compat = in_compat_syscall();
+       err = do_sock_getsockopt(sock, compat, level, optname,
+                                USER_SOCKPTR(optval), USER_SOCKPTR(optlen));
 
-       if (!in_compat_syscall())
-               err = BPF_CGROUP_RUN_PROG_GETSOCKOPT(sock->sk, level, optname,
-                                                    optval, optlen, max_optlen,
-                                                    err);
-out_put:
        fput_light(sock->file, fput_needed);
        return err;
 }
@@ -2321,7 +2422,7 @@ int __sys_shutdown_sock(struct socket *sock, int how)
 
        err = security_socket_shutdown(sock, how);
        if (!err)
-               err = sock->ops->shutdown(sock, how);
+               err = READ_ONCE(sock->ops)->shutdown(sock, how);
 
        return err;
 }
@@ -2417,7 +2518,7 @@ static int copy_msghdr_from_user(struct msghdr *kmsg,
        if (err)
                return err;
 
-       err = import_iovec(save_addr ? READ : WRITE,
+       err = import_iovec(save_addr ? ITER_DEST : ITER_SOURCE,
                            msg.msg_iov, msg.msg_iovlen,
                            UIO_FASTIOV, iov, &kmsg->msg_iter);
        return err < 0 ? err : 0;
@@ -2462,6 +2563,7 @@ static int ____sys_sendmsg(struct socket *sock, struct msghdr *msg_sys,
                msg_sys->msg_control = ctl_buf;
                msg_sys->msg_control_is_user = false;
        }
+       flags &= ~MSG_INTERNAL_SENDMSG_FLAGS;
        msg_sys->msg_flags = flags;
 
        if (sock->file->f_flags & O_NONBLOCK)
@@ -2479,7 +2581,7 @@ static int ____sys_sendmsg(struct socket *sock, struct msghdr *msg_sys,
                err = sock_sendmsg_nosec(sock, msg_sys);
                goto out_freectl;
        }
-       err = sock_sendmsg(sock, msg_sys);
+       err = __sock_sendmsg(sock, msg_sys);
        /*
         * If this is sendmmsg() and sending to current destination address was
         * successful, remember it.
@@ -2498,9 +2600,9 @@ out:
        return err;
 }
 
-int sendmsg_copy_msghdr(struct msghdr *msg,
-                       struct user_msghdr __user *umsg, unsigned flags,
-                       struct iovec **iov)
+static int sendmsg_copy_msghdr(struct msghdr *msg,
+                              struct user_msghdr __user *umsg, unsigned flags,
+                              struct iovec **iov)
 {
        int err;
 
@@ -2651,10 +2753,10 @@ SYSCALL_DEFINE4(sendmmsg, int, fd, struct mmsghdr __user *, mmsg,
        return __sys_sendmmsg(fd, mmsg, vlen, flags, true);
 }
 
-int recvmsg_copy_msghdr(struct msghdr *msg,
-                       struct user_msghdr __user *umsg, unsigned flags,
-                       struct sockaddr __user **uaddr,
-                       struct iovec **iov)
+static int recvmsg_copy_msghdr(struct msghdr *msg,
+                              struct user_msghdr __user *umsg, unsigned flags,
+                              struct sockaddr __user **uaddr,
+                              struct iovec **iov)
 {
        ssize_t err;
 
@@ -2890,7 +2992,7 @@ static int do_recvmmsg(int fd, struct mmsghdr __user *mmsg,
                 * error to return on the next call or if the
                 * app asks about it using getsockopt(SO_ERROR).
                 */
-               sock->sk->sk_err = -err;
+               WRITE_ONCE(sock->sk->sk_err, -err);
        }
 out_put:
        fput_light(sock->file, fput_needed);
@@ -3312,6 +3414,7 @@ static int compat_sock_ioctl_trans(struct file *file, struct socket *sock,
        void __user *argp = compat_ptr(arg);
        struct sock *sk = sock->sk;
        struct net *net = sock_net(sk);
+       const struct proto_ops *ops;
 
        if (cmd >= SIOCDEVPRIVATE && cmd <= (SIOCDEVPRIVATE + 15))
                return sock_ioctl(file, cmd, (unsigned long)argp);
@@ -3321,10 +3424,11 @@ static int compat_sock_ioctl_trans(struct file *file, struct socket *sock,
                return compat_siocwandev(net, argp);
        case SIOCGSTAMP_OLD:
        case SIOCGSTAMPNS_OLD:
-               if (!sock->ops->gettstamp)
+               ops = READ_ONCE(sock->ops);
+               if (!ops->gettstamp)
                        return -ENOIOCTLCMD;
-               return sock->ops->gettstamp(sock, argp, cmd == SIOCGSTAMP_OLD,
-                                           !COMPAT_USE_64BIT_TIME);
+               return ops->gettstamp(sock, argp, cmd == SIOCGSTAMP_OLD,
+                                     !COMPAT_USE_64BIT_TIME);
 
        case SIOCETHTOOL:
        case SIOCBONDSLAVEINFOQUERY:
@@ -3405,6 +3509,7 @@ static long compat_sock_ioctl(struct file *file, unsigned int cmd,
                              unsigned long arg)
 {
        struct socket *sock = file->private_data;
+       const struct proto_ops *ops = READ_ONCE(sock->ops);
        int ret = -ENOIOCTLCMD;
        struct sock *sk;
        struct net *net;
@@ -3412,8 +3517,8 @@ static long compat_sock_ioctl(struct file *file, unsigned int cmd,
        sk = sock->sk;
        net = sock_net(sk);
 
-       if (sock->ops->compat_ioctl)
-               ret = sock->ops->compat_ioctl(sock, cmd, arg);
+       if (ops->compat_ioctl)
+               ret = ops->compat_ioctl(sock, cmd, arg);
 
        if (ret == -ENOIOCTLCMD &&
            (cmd >= SIOCIWFIRST && cmd <= SIOCIWLAST))
@@ -3437,7 +3542,12 @@ static long compat_sock_ioctl(struct file *file, unsigned int cmd,
 
 int kernel_bind(struct socket *sock, struct sockaddr *addr, int addrlen)
 {
-       return sock->ops->bind(sock, addr, addrlen);
+       struct sockaddr_storage address;
+
+       memcpy(&address, addr, addrlen);
+
+       return READ_ONCE(sock->ops)->bind(sock, (struct sockaddr *)&address,
+                                         addrlen);
 }
 EXPORT_SYMBOL(kernel_bind);
 
@@ -3451,7 +3561,7 @@ EXPORT_SYMBOL(kernel_bind);
 
 int kernel_listen(struct socket *sock, int backlog)
 {
-       return sock->ops->listen(sock, backlog);
+       return READ_ONCE(sock->ops)->listen(sock, backlog);
 }
 EXPORT_SYMBOL(kernel_listen);
 
@@ -3469,6 +3579,7 @@ EXPORT_SYMBOL(kernel_listen);
 int kernel_accept(struct socket *sock, struct socket **newsock, int flags)
 {
        struct sock *sk = sock->sk;
+       const struct proto_ops *ops = READ_ONCE(sock->ops);
        int err;
 
        err = sock_create_lite(sk->sk_family, sk->sk_type, sk->sk_protocol,
@@ -3476,15 +3587,15 @@ int kernel_accept(struct socket *sock, struct socket **newsock, int flags)
        if (err < 0)
                goto done;
 
-       err = sock->ops->accept(sock, *newsock, flags, true);
+       err = ops->accept(sock, *newsock, flags, true);
        if (err < 0) {
                sock_release(*newsock);
                *newsock = NULL;
                goto done;
        }
 
-       (*newsock)->ops = sock->ops;
-       __module_get((*newsock)->ops->owner);
+       (*newsock)->ops = ops;
+       __module_get(ops->owner);
 
 done:
        return err;
@@ -3507,7 +3618,12 @@ EXPORT_SYMBOL(kernel_accept);
 int kernel_connect(struct socket *sock, struct sockaddr *addr, int addrlen,
                   int flags)
 {
-       return sock->ops->connect(sock, addr, addrlen, flags);
+       struct sockaddr_storage address;
+
+       memcpy(&address, addr, addrlen);
+
+       return READ_ONCE(sock->ops)->connect(sock, (struct sockaddr *)&address,
+                                            addrlen, flags);
 }
 EXPORT_SYMBOL(kernel_connect);
 
@@ -3522,7 +3638,7 @@ EXPORT_SYMBOL(kernel_connect);
 
 int kernel_getsockname(struct socket *sock, struct sockaddr *addr)
 {
-       return sock->ops->getname(sock, addr, 0);
+       return READ_ONCE(sock->ops)->getname(sock, addr, 0);
 }
 EXPORT_SYMBOL(kernel_getsockname);
 
@@ -3537,58 +3653,10 @@ EXPORT_SYMBOL(kernel_getsockname);
 
 int kernel_getpeername(struct socket *sock, struct sockaddr *addr)
 {
-       return sock->ops->getname(sock, addr, 1);
+       return READ_ONCE(sock->ops)->getname(sock, addr, 1);
 }
 EXPORT_SYMBOL(kernel_getpeername);
 
-/**
- *     kernel_sendpage - send a &page through a socket (kernel space)
- *     @sock: socket
- *     @page: page
- *     @offset: page offset
- *     @size: total size in bytes
- *     @flags: flags (MSG_DONTWAIT, ...)
- *
- *     Returns the total amount sent in bytes or an error.
- */
-
-int kernel_sendpage(struct socket *sock, struct page *page, int offset,
-                   size_t size, int flags)
-{
-       if (sock->ops->sendpage) {
-               /* Warn in case the improper page to zero-copy send */
-               WARN_ONCE(!sendpage_ok(page), "improper page for zero-copy send");
-               return sock->ops->sendpage(sock, page, offset, size, flags);
-       }
-       return sock_no_sendpage(sock, page, offset, size, flags);
-}
-EXPORT_SYMBOL(kernel_sendpage);
-
-/**
- *     kernel_sendpage_locked - send a &page through the locked sock (kernel space)
- *     @sk: sock
- *     @page: page
- *     @offset: page offset
- *     @size: total size in bytes
- *     @flags: flags (MSG_DONTWAIT, ...)
- *
- *     Returns the total amount sent in bytes or an error.
- *     Caller must hold @sk.
- */
-
-int kernel_sendpage_locked(struct sock *sk, struct page *page, int offset,
-                          size_t size, int flags)
-{
-       struct socket *sock = sk->sk_socket;
-
-       if (sock->ops->sendpage_locked)
-               return sock->ops->sendpage_locked(sk, page, offset, size,
-                                                 flags);
-
-       return sock_no_sendpage_locked(sk, page, offset, size, flags);
-}
-EXPORT_SYMBOL(kernel_sendpage_locked);
-
 /**
  *     kernel_sock_shutdown - shut down part of a full-duplex connection (kernel space)
  *     @sock: socket
@@ -3599,7 +3667,7 @@ EXPORT_SYMBOL(kernel_sendpage_locked);
 
 int kernel_sock_shutdown(struct socket *sock, enum sock_shutdown_cmd how)
 {
-       return sock->ops->shutdown(sock, how);
+       return READ_ONCE(sock->ops)->shutdown(sock, how);
 }
 EXPORT_SYMBOL(kernel_sock_shutdown);