Linux 6.11-rc1
[linux-2.6-microblaze.git] / fs / namespace.c
index 5a51315..328087a 100644 (file)
@@ -70,7 +70,8 @@ static DEFINE_IDA(mnt_id_ida);
 static DEFINE_IDA(mnt_group_ida);
 
 /* Don't allow confusion with old 32bit mount ID */
-static atomic64_t mnt_id_ctr = ATOMIC64_INIT(1ULL << 32);
+#define MNT_UNIQUE_ID_OFFSET (1ULL << 31)
+static atomic64_t mnt_id_ctr = ATOMIC64_INIT(MNT_UNIQUE_ID_OFFSET);
 
 static struct hlist_head *mount_hashtable __ro_after_init;
 static struct hlist_head *mountpoint_hashtable __ro_after_init;
@@ -78,6 +79,8 @@ static struct kmem_cache *mnt_cache __ro_after_init;
 static DECLARE_RWSEM(namespace_sem);
 static HLIST_HEAD(unmounted);  /* protected by namespace_sem */
 static LIST_HEAD(ex_mountpoints); /* protected by namespace_sem */
+static DEFINE_RWLOCK(mnt_ns_tree_lock);
+static struct rb_root mnt_ns_tree = RB_ROOT; /* protected by mnt_ns_tree_lock */
 
 struct mount_kattr {
        unsigned int attr_set;
@@ -103,6 +106,109 @@ EXPORT_SYMBOL_GPL(fs_kobj);
  */
 __cacheline_aligned_in_smp DEFINE_SEQLOCK(mount_lock);
 
+static int mnt_ns_cmp(u64 seq, const struct mnt_namespace *ns)
+{
+       u64 seq_b = ns->seq;
+
+       if (seq < seq_b)
+               return -1;
+       if (seq > seq_b)
+               return 1;
+       return 0;
+}
+
+static inline struct mnt_namespace *node_to_mnt_ns(const struct rb_node *node)
+{
+       if (!node)
+               return NULL;
+       return rb_entry(node, struct mnt_namespace, mnt_ns_tree_node);
+}
+
+static bool mnt_ns_less(struct rb_node *a, const struct rb_node *b)
+{
+       struct mnt_namespace *ns_a = node_to_mnt_ns(a);
+       struct mnt_namespace *ns_b = node_to_mnt_ns(b);
+       u64 seq_a = ns_a->seq;
+
+       return mnt_ns_cmp(seq_a, ns_b) < 0;
+}
+
+static void mnt_ns_tree_add(struct mnt_namespace *ns)
+{
+       guard(write_lock)(&mnt_ns_tree_lock);
+       rb_add(&ns->mnt_ns_tree_node, &mnt_ns_tree, mnt_ns_less);
+}
+
+static void mnt_ns_release(struct mnt_namespace *ns)
+{
+       lockdep_assert_not_held(&mnt_ns_tree_lock);
+
+       /* keep alive for {list,stat}mount() */
+       if (refcount_dec_and_test(&ns->passive)) {
+               put_user_ns(ns->user_ns);
+               kfree(ns);
+       }
+}
+DEFINE_FREE(mnt_ns_release, struct mnt_namespace *, if (_T) mnt_ns_release(_T))
+
+static void mnt_ns_tree_remove(struct mnt_namespace *ns)
+{
+       /* remove from global mount namespace list */
+       if (!is_anon_ns(ns)) {
+               guard(write_lock)(&mnt_ns_tree_lock);
+               rb_erase(&ns->mnt_ns_tree_node, &mnt_ns_tree);
+       }
+
+       mnt_ns_release(ns);
+}
+
+/*
+ * Returns the mount namespace which either has the specified id, or has the
+ * next smallest id afer the specified one.
+ */
+static struct mnt_namespace *mnt_ns_find_id_at(u64 mnt_ns_id)
+{
+       struct rb_node *node = mnt_ns_tree.rb_node;
+       struct mnt_namespace *ret = NULL;
+
+       lockdep_assert_held(&mnt_ns_tree_lock);
+
+       while (node) {
+               struct mnt_namespace *n = node_to_mnt_ns(node);
+
+               if (mnt_ns_id <= n->seq) {
+                       ret = node_to_mnt_ns(node);
+                       if (mnt_ns_id == n->seq)
+                               break;
+                       node = node->rb_left;
+               } else {
+                       node = node->rb_right;
+               }
+       }
+       return ret;
+}
+
+/*
+ * Lookup a mount namespace by id and take a passive reference count. Taking a
+ * passive reference means the mount namespace can be emptied if e.g., the last
+ * task holding an active reference exits. To access the mounts of the
+ * namespace the @namespace_sem must first be acquired. If the namespace has
+ * already shut down before acquiring @namespace_sem, {list,stat}mount() will
+ * see that the mount rbtree of the namespace is empty.
+ */
+static struct mnt_namespace *lookup_mnt_ns(u64 mnt_ns_id)
+{
+       struct mnt_namespace *ns;
+
+       guard(read_lock)(&mnt_ns_tree_lock);
+       ns = mnt_ns_find_id_at(mnt_ns_id);
+       if (!ns || ns->seq != mnt_ns_id)
+               return NULL;
+
+       refcount_inc(&ns->passive);
+       return ns;
+}
+
 static inline void lock_mount_hash(void)
 {
        write_seqlock(&mount_lock);
@@ -1448,6 +1554,30 @@ static struct mount *mnt_find_id_at(struct mnt_namespace *ns, u64 mnt_id)
        return ret;
 }
 
+/*
+ * Returns the mount which either has the specified mnt_id, or has the next
+ * greater id before the specified one.
+ */
+static struct mount *mnt_find_id_at_reverse(struct mnt_namespace *ns, u64 mnt_id)
+{
+       struct rb_node *node = ns->mounts.rb_node;
+       struct mount *ret = NULL;
+
+       while (node) {
+               struct mount *m = node_to_mount(node);
+
+               if (mnt_id >= m->mnt_id_unique) {
+                       ret = node_to_mount(node);
+                       if (mnt_id == m->mnt_id_unique)
+                               break;
+                       node = node->rb_right;
+               } else {
+                       node = node->rb_left;
+               }
+       }
+       return ret;
+}
+
 #ifdef CONFIG_PROC_FS
 
 /* iterator; we want it to have access to namespace_sem, thus here... */
@@ -1846,19 +1976,6 @@ bool may_mount(void)
        return ns_capable(current->nsproxy->mnt_ns->user_ns, CAP_SYS_ADMIN);
 }
 
-/**
- * path_mounted - check whether path is mounted
- * @path: path to check
- *
- * Determine whether @path refers to the root of a mount.
- *
- * Return: true if @path is the root of a mount, false if not.
- */
-static inline bool path_mounted(const struct path *path)
-{
-       return path->mnt->mnt_root == path->dentry;
-}
-
 static void warn_mandlock(void)
 {
        pr_warn_once("=======================================================\n"
@@ -1966,69 +2083,72 @@ static bool mnt_ns_loop(struct dentry *dentry)
        return current->nsproxy->mnt_ns->seq >= mnt_ns->seq;
 }
 
-struct mount *copy_tree(struct mount *mnt, struct dentry *dentry,
+struct mount *copy_tree(struct mount *src_root, struct dentry *dentry,
                                        int flag)
 {
-       struct mount *res, *p, *q, *r, *parent;
+       struct mount *res, *src_parent, *src_root_child, *src_mnt,
+               *dst_parent, *dst_mnt;
 
-       if (!(flag & CL_COPY_UNBINDABLE) && IS_MNT_UNBINDABLE(mnt))
+       if (!(flag & CL_COPY_UNBINDABLE) && IS_MNT_UNBINDABLE(src_root))
                return ERR_PTR(-EINVAL);
 
        if (!(flag & CL_COPY_MNT_NS_FILE) && is_mnt_ns_file(dentry))
                return ERR_PTR(-EINVAL);
 
-       res = q = clone_mnt(mnt, dentry, flag);
-       if (IS_ERR(q))
-               return q;
+       res = dst_mnt = clone_mnt(src_root, dentry, flag);
+       if (IS_ERR(dst_mnt))
+               return dst_mnt;
 
-       q->mnt_mountpoint = mnt->mnt_mountpoint;
+       src_parent = src_root;
+       dst_mnt->mnt_mountpoint = src_root->mnt_mountpoint;
 
-       p = mnt;
-       list_for_each_entry(r, &mnt->mnt_mounts, mnt_child) {
-               struct mount *s;
-               if (!is_subdir(r->mnt_mountpoint, dentry))
+       list_for_each_entry(src_root_child, &src_root->mnt_mounts, mnt_child) {
+               if (!is_subdir(src_root_child->mnt_mountpoint, dentry))
                        continue;
 
-               for (s = r; s; s = next_mnt(s, r)) {
+               for (src_mnt = src_root_child; src_mnt;
+                   src_mnt = next_mnt(src_mnt, src_root_child)) {
                        if (!(flag & CL_COPY_UNBINDABLE) &&
-                           IS_MNT_UNBINDABLE(s)) {
-                               if (s->mnt.mnt_flags & MNT_LOCKED) {
+                           IS_MNT_UNBINDABLE(src_mnt)) {
+                               if (src_mnt->mnt.mnt_flags & MNT_LOCKED) {
                                        /* Both unbindable and locked. */
-                                       q = ERR_PTR(-EPERM);
+                                       dst_mnt = ERR_PTR(-EPERM);
                                        goto out;
                                } else {
-                                       s = skip_mnt_tree(s);
+                                       src_mnt = skip_mnt_tree(src_mnt);
                                        continue;
                                }
                        }
                        if (!(flag & CL_COPY_MNT_NS_FILE) &&
-                           is_mnt_ns_file(s->mnt.mnt_root)) {
-                               s = skip_mnt_tree(s);
+                           is_mnt_ns_file(src_mnt->mnt.mnt_root)) {
+                               src_mnt = skip_mnt_tree(src_mnt);
                                continue;
                        }
-                       while (p != s->mnt_parent) {
-                               p = p->mnt_parent;
-                               q = q->mnt_parent;
+                       while (src_parent != src_mnt->mnt_parent) {
+                               src_parent = src_parent->mnt_parent;
+                               dst_mnt = dst_mnt->mnt_parent;
                        }
-                       p = s;
-                       parent = q;
-                       q = clone_mnt(p, p->mnt.mnt_root, flag);
-                       if (IS_ERR(q))
+
+                       src_parent = src_mnt;
+                       dst_parent = dst_mnt;
+                       dst_mnt = clone_mnt(src_mnt, src_mnt->mnt.mnt_root, flag);
+                       if (IS_ERR(dst_mnt))
                                goto out;
                        lock_mount_hash();
-                       list_add_tail(&q->mnt_list, &res->mnt_list);
-                       attach_mnt(q, parent, p->mnt_mp, false);
+                       list_add_tail(&dst_mnt->mnt_list, &res->mnt_list);
+                       attach_mnt(dst_mnt, dst_parent, src_parent->mnt_mp, false);
                        unlock_mount_hash();
                }
        }
        return res;
+
 out:
        if (res) {
                lock_mount_hash();
                umount_tree(res, UMOUNT_SYNC);
                unlock_mount_hash();
        }
-       return q;
+       return dst_mnt;
 }
 
 /* Caller should check returned pointer for errors */
@@ -2078,7 +2198,7 @@ void drop_collected_mounts(struct vfsmount *mnt)
        namespace_unlock();
 }
 
-static bool has_locked_children(struct mount *mnt, struct dentry *dentry)
+bool has_locked_children(struct mount *mnt, struct dentry *dentry)
 {
        struct mount *child;
 
@@ -3709,8 +3829,7 @@ static void free_mnt_ns(struct mnt_namespace *ns)
        if (!is_anon_ns(ns))
                ns_free_inum(&ns->ns);
        dec_mnt_namespaces(ns->ucounts);
-       put_user_ns(ns->user_ns);
-       kfree(ns);
+       mnt_ns_tree_remove(ns);
 }
 
 /*
@@ -3749,7 +3868,9 @@ static struct mnt_namespace *alloc_mnt_ns(struct user_namespace *user_ns, bool a
        if (!anon)
                new_ns->seq = atomic64_add_return(1, &mnt_ns_seq);
        refcount_set(&new_ns->ns.count, 1);
+       refcount_set(&new_ns->passive, 1);
        new_ns->mounts = RB_ROOT;
+       RB_CLEAR_NODE(&new_ns->mnt_ns_tree_node);
        init_waitqueue_head(&new_ns->poll);
        new_ns->user_ns = get_user_ns(user_ns);
        new_ns->ucounts = ucounts;
@@ -3826,6 +3947,7 @@ struct mnt_namespace *copy_mnt_ns(unsigned long flags, struct mnt_namespace *ns,
                while (p->mnt.mnt_root != q->mnt.mnt_root)
                        p = next_mnt(skip_mnt_tree(p), old);
        }
+       mnt_ns_tree_add(new_ns);
        namespace_unlock();
 
        if (rootmnt)
@@ -4843,6 +4965,40 @@ static int statmount_fs_type(struct kstatmount *s, struct seq_file *seq)
        return 0;
 }
 
+static void statmount_mnt_ns_id(struct kstatmount *s, struct mnt_namespace *ns)
+{
+       s->sm.mask |= STATMOUNT_MNT_NS_ID;
+       s->sm.mnt_ns_id = ns->seq;
+}
+
+static int statmount_mnt_opts(struct kstatmount *s, struct seq_file *seq)
+{
+       struct vfsmount *mnt = s->mnt;
+       struct super_block *sb = mnt->mnt_sb;
+       int err;
+
+       if (sb->s_op->show_options) {
+               size_t start = seq->count;
+
+               err = sb->s_op->show_options(seq, mnt->mnt_root);
+               if (err)
+                       return err;
+
+               if (unlikely(seq_has_overflowed(seq)))
+                       return -EAGAIN;
+
+               if (seq->count == start)
+                       return 0;
+
+               /* skip leading comma */
+               memmove(seq->buf + start, seq->buf + start + 1,
+                       seq->count - start - 1);
+               seq->count--;
+       }
+
+       return 0;
+}
+
 static int statmount_string(struct kstatmount *s, u64 flag)
 {
        int ret;
@@ -4863,6 +5019,10 @@ static int statmount_string(struct kstatmount *s, u64 flag)
                sm->mnt_point = seq->count;
                ret = statmount_mnt_point(s, seq);
                break;
+       case STATMOUNT_MNT_OPTS:
+               sm->mnt_opts = seq->count;
+               ret = statmount_mnt_opts(s, seq);
+               break;
        default:
                WARN_ON_ONCE(true);
                return -EINVAL;
@@ -4903,23 +5063,84 @@ static int copy_statmount_to_user(struct kstatmount *s)
        return 0;
 }
 
-static int do_statmount(struct kstatmount *s)
+static struct mount *listmnt_next(struct mount *curr, bool reverse)
 {
-       struct mount *m = real_mount(s->mnt);
+       struct rb_node *node;
+
+       if (reverse)
+               node = rb_prev(&curr->mnt_node);
+       else
+               node = rb_next(&curr->mnt_node);
+
+       return node_to_mount(node);
+}
+
+static int grab_requested_root(struct mnt_namespace *ns, struct path *root)
+{
+       struct mount *first, *child;
+
+       rwsem_assert_held(&namespace_sem);
+
+       /* We're looking at our own ns, just use get_fs_root. */
+       if (ns == current->nsproxy->mnt_ns) {
+               get_fs_root(current->fs, root);
+               return 0;
+       }
+
+       /*
+        * We have to find the first mount in our ns and use that, however it
+        * may not exist, so handle that properly.
+        */
+       if (RB_EMPTY_ROOT(&ns->mounts))
+               return -ENOENT;
+
+       first = child = ns->root;
+       for (;;) {
+               child = listmnt_next(child, false);
+               if (!child)
+                       return -ENOENT;
+               if (child->mnt_parent == first)
+                       break;
+       }
+
+       root->mnt = mntget(&child->mnt);
+       root->dentry = dget(root->mnt->mnt_root);
+       return 0;
+}
+
+static int do_statmount(struct kstatmount *s, u64 mnt_id, u64 mnt_ns_id,
+                       struct mnt_namespace *ns)
+{
+       struct path root __free(path_put) = {};
+       struct mount *m;
        int err;
 
+       /* Has the namespace already been emptied? */
+       if (mnt_ns_id && RB_EMPTY_ROOT(&ns->mounts))
+               return -ENOENT;
+
+       s->mnt = lookup_mnt_in_ns(mnt_id, ns);
+       if (!s->mnt)
+               return -ENOENT;
+
+       err = grab_requested_root(ns, &root);
+       if (err)
+               return err;
+
        /*
         * Don't trigger audit denials. We just want to determine what
         * mounts to show users.
         */
-       if (!is_path_reachable(m, m->mnt.mnt_root, &s->root) &&
-           !ns_capable_noaudit(&init_user_ns, CAP_SYS_ADMIN))
+       m = real_mount(s->mnt);
+       if (!is_path_reachable(m, m->mnt.mnt_root, &root) &&
+           !ns_capable_noaudit(ns->user_ns, CAP_SYS_ADMIN))
                return -EPERM;
 
        err = security_sb_statfs(s->mnt->mnt_root);
        if (err)
                return err;
 
+       s->root = root;
        if (s->mask & STATMOUNT_SB_BASIC)
                statmount_sb_basic(s);
 
@@ -4938,6 +5159,12 @@ static int do_statmount(struct kstatmount *s)
        if (!err && s->mask & STATMOUNT_MNT_POINT)
                err = statmount_string(s, STATMOUNT_MNT_POINT);
 
+       if (!err && s->mask & STATMOUNT_MNT_OPTS)
+               err = statmount_string(s, STATMOUNT_MNT_OPTS);
+
+       if (!err && s->mask & STATMOUNT_MNT_NS_ID)
+               statmount_mnt_ns_id(s, ns);
+
        if (err)
                return err;
 
@@ -4955,6 +5182,9 @@ static inline bool retry_statmount(const long ret, size_t *seq_size)
        return true;
 }
 
+#define STATMOUNT_STRING_REQ (STATMOUNT_MNT_ROOT | STATMOUNT_MNT_POINT | \
+                             STATMOUNT_FS_TYPE | STATMOUNT_MNT_OPTS)
+
 static int prepare_kstatmount(struct kstatmount *ks, struct mnt_id_req *kreq,
                              struct statmount __user *buf, size_t bufsize,
                              size_t seq_size)
@@ -4966,10 +5196,18 @@ static int prepare_kstatmount(struct kstatmount *ks, struct mnt_id_req *kreq,
        ks->mask = kreq->param;
        ks->buf = buf;
        ks->bufsize = bufsize;
-       ks->seq.size = seq_size;
-       ks->seq.buf = kvmalloc(seq_size, GFP_KERNEL_ACCOUNT);
-       if (!ks->seq.buf)
-               return -ENOMEM;
+
+       if (ks->mask & STATMOUNT_STRING_REQ) {
+               if (bufsize == sizeof(ks->sm))
+                       return -EOVERFLOW;
+
+               ks->seq.buf = kvmalloc(seq_size, GFP_KERNEL_ACCOUNT);
+               if (!ks->seq.buf)
+                       return -ENOMEM;
+
+               ks->seq.size = seq_size;
+       }
+
        return 0;
 }
 
@@ -4979,7 +5217,7 @@ static int copy_mnt_id_req(const struct mnt_id_req __user *req,
        int ret;
        size_t usize;
 
-       BUILD_BUG_ON(sizeof(struct mnt_id_req) != MNT_ID_REQ_SIZE_VER0);
+       BUILD_BUG_ON(sizeof(struct mnt_id_req) != MNT_ID_REQ_SIZE_VER1);
 
        ret = get_user(usize, &req->size);
        if (ret)
@@ -4994,16 +5232,32 @@ static int copy_mnt_id_req(const struct mnt_id_req __user *req,
                return ret;
        if (kreq->spare != 0)
                return -EINVAL;
+       /* The first valid unique mount id is MNT_UNIQUE_ID_OFFSET + 1. */
+       if (kreq->mnt_id <= MNT_UNIQUE_ID_OFFSET)
+               return -EINVAL;
        return 0;
 }
 
+/*
+ * If the user requested a specific mount namespace id, look that up and return
+ * that, or if not simply grab a passive reference on our mount namespace and
+ * return that.
+ */
+static struct mnt_namespace *grab_requested_mnt_ns(u64 mnt_ns_id)
+{
+       if (mnt_ns_id)
+               return lookup_mnt_ns(mnt_ns_id);
+       refcount_inc(&current->nsproxy->mnt_ns->passive);
+       return current->nsproxy->mnt_ns;
+}
+
 SYSCALL_DEFINE4(statmount, const struct mnt_id_req __user *, req,
                struct statmount __user *, buf, size_t, bufsize,
                unsigned int, flags)
 {
-       struct vfsmount *mnt;
+       struct mnt_namespace *ns __free(mnt_ns_release) = NULL;
+       struct kstatmount *ks __free(kfree) = NULL;
        struct mnt_id_req kreq;
-       struct kstatmount ks;
        /* We currently support retrieval of 3 strings. */
        size_t seq_size = 3 * PATH_MAX;
        int ret;
@@ -5015,64 +5269,88 @@ SYSCALL_DEFINE4(statmount, const struct mnt_id_req __user *, req,
        if (ret)
                return ret;
 
+       ns = grab_requested_mnt_ns(kreq.mnt_ns_id);
+       if (!ns)
+               return -ENOENT;
+
+       if (kreq.mnt_ns_id && (ns != current->nsproxy->mnt_ns) &&
+           !ns_capable_noaudit(ns->user_ns, CAP_SYS_ADMIN))
+               return -ENOENT;
+
+       ks = kmalloc(sizeof(*ks), GFP_KERNEL_ACCOUNT);
+       if (!ks)
+               return -ENOMEM;
+
 retry:
-       ret = prepare_kstatmount(&ks, &kreq, buf, bufsize, seq_size);
+       ret = prepare_kstatmount(ks, &kreq, buf, bufsize, seq_size);
        if (ret)
                return ret;
 
-       down_read(&namespace_sem);
-       mnt = lookup_mnt_in_ns(kreq.mnt_id, current->nsproxy->mnt_ns);
-       if (!mnt) {
-               up_read(&namespace_sem);
-               kvfree(ks.seq.buf);
-               return -ENOENT;
-       }
-
-       ks.mnt = mnt;
-       get_fs_root(current->fs, &ks.root);
-       ret = do_statmount(&ks);
-       path_put(&ks.root);
-       up_read(&namespace_sem);
+       scoped_guard(rwsem_read, &namespace_sem)
+               ret = do_statmount(ks, kreq.mnt_id, kreq.mnt_ns_id, ns);
 
        if (!ret)
-               ret = copy_statmount_to_user(&ks);
-       kvfree(ks.seq.buf);
+               ret = copy_statmount_to_user(ks);
+       kvfree(ks->seq.buf);
        if (retry_statmount(ret, &seq_size))
                goto retry;
        return ret;
 }
 
-static struct mount *listmnt_next(struct mount *curr)
+static ssize_t do_listmount(struct mnt_namespace *ns, u64 mnt_parent_id,
+                           u64 last_mnt_id, u64 *mnt_ids, size_t nr_mnt_ids,
+                           bool reverse)
 {
-       return node_to_mount(rb_next(&curr->mnt_node));
-}
-
-static ssize_t do_listmount(struct mount *first, struct path *orig,
-                           u64 mnt_parent_id, u64 __user *mnt_ids,
-                           size_t nr_mnt_ids, const struct path *root)
-{
-       struct mount *r;
+       struct path root __free(path_put) = {};
+       struct path orig;
+       struct mount *r, *first;
        ssize_t ret;
 
+       rwsem_assert_held(&namespace_sem);
+
+       ret = grab_requested_root(ns, &root);
+       if (ret)
+               return ret;
+
+       if (mnt_parent_id == LSMT_ROOT) {
+               orig = root;
+       } else {
+               orig.mnt = lookup_mnt_in_ns(mnt_parent_id, ns);
+               if (!orig.mnt)
+                       return -ENOENT;
+               orig.dentry = orig.mnt->mnt_root;
+       }
+
        /*
         * Don't trigger audit denials. We just want to determine what
         * mounts to show users.
         */
-       if (!is_path_reachable(real_mount(orig->mnt), orig->dentry, root) &&
-           !ns_capable_noaudit(&init_user_ns, CAP_SYS_ADMIN))
+       if (!is_path_reachable(real_mount(orig.mnt), orig.dentry, &root) &&
+           !ns_capable_noaudit(ns->user_ns, CAP_SYS_ADMIN))
                return -EPERM;
 
-       ret = security_sb_statfs(orig->dentry);
+       ret = security_sb_statfs(orig.dentry);
        if (ret)
                return ret;
 
-       for (ret = 0, r = first; r && nr_mnt_ids; r = listmnt_next(r)) {
+       if (!last_mnt_id) {
+               if (reverse)
+                       first = node_to_mount(rb_last(&ns->mounts));
+               else
+                       first = node_to_mount(rb_first(&ns->mounts));
+       } else {
+               if (reverse)
+                       first = mnt_find_id_at_reverse(ns, last_mnt_id - 1);
+               else
+                       first = mnt_find_id_at(ns, last_mnt_id + 1);
+       }
+
+       for (ret = 0, r = first; r && nr_mnt_ids; r = listmnt_next(r, reverse)) {
                if (r->mnt_id_unique == mnt_parent_id)
                        continue;
-               if (!is_path_reachable(r, r->mnt.mnt_root, orig))
+               if (!is_path_reachable(r, r->mnt.mnt_root, &orig))
                        continue;
-               if (put_user(r->mnt_id_unique, mnt_ids))
-                       return -EFAULT;
+               *mnt_ids = r->mnt_id_unique;
                mnt_ids++;
                nr_mnt_ids--;
                ret++;
@@ -5080,22 +5358,26 @@ static ssize_t do_listmount(struct mount *first, struct path *orig,
        return ret;
 }
 
-SYSCALL_DEFINE4(listmount, const struct mnt_id_req __user *, req, u64 __user *,
-               mnt_ids, size_t, nr_mnt_ids, unsigned int, flags)
+SYSCALL_DEFINE4(listmount, const struct mnt_id_req __user *, req,
+               u64 __user *, mnt_ids, size_t, nr_mnt_ids, unsigned int, flags)
 {
-       struct mnt_namespace *ns = current->nsproxy->mnt_ns;
+       u64 *kmnt_ids __free(kvfree) = NULL;
+       const size_t maxcount = 1000000;
+       struct mnt_namespace *ns __free(mnt_ns_release) = NULL;
        struct mnt_id_req kreq;
-       struct mount *first;
-       struct path root, orig;
-       u64 mnt_parent_id, last_mnt_id;
-       const size_t maxcount = (size_t)-1 >> 3;
+       u64 last_mnt_id;
        ssize_t ret;
 
-       if (flags)
+       if (flags & ~LISTMOUNT_REVERSE)
                return -EINVAL;
 
+       /*
+        * If the mount namespace really has more than 1 million mounts the
+        * caller must iterate over the mount namespace (and reconsider their
+        * system design...).
+        */
        if (unlikely(nr_mnt_ids > maxcount))
-               return -EFAULT;
+               return -EOVERFLOW;
 
        if (!access_ok(mnt_ids, nr_mnt_ids * sizeof(*mnt_ids)))
                return -EFAULT;
@@ -5103,33 +5385,37 @@ SYSCALL_DEFINE4(listmount, const struct mnt_id_req __user *, req, u64 __user *,
        ret = copy_mnt_id_req(req, &kreq);
        if (ret)
                return ret;
-       mnt_parent_id = kreq.mnt_id;
+
        last_mnt_id = kreq.param;
+       /* The first valid unique mount id is MNT_UNIQUE_ID_OFFSET + 1. */
+       if (last_mnt_id != 0 && last_mnt_id <= MNT_UNIQUE_ID_OFFSET)
+               return -EINVAL;
 
-       down_read(&namespace_sem);
-       get_fs_root(current->fs, &root);
-       if (mnt_parent_id == LSMT_ROOT) {
-               orig = root;
-       } else {
-               ret = -ENOENT;
-               orig.mnt = lookup_mnt_in_ns(mnt_parent_id, ns);
-               if (!orig.mnt)
-                       goto err;
-               orig.dentry = orig.mnt->mnt_root;
-       }
-       if (!last_mnt_id)
-               first = node_to_mount(rb_first(&ns->mounts));
-       else
-               first = mnt_find_id_at(ns, last_mnt_id + 1);
+       kmnt_ids = kvmalloc_array(nr_mnt_ids, sizeof(*kmnt_ids),
+                                 GFP_KERNEL_ACCOUNT);
+       if (!kmnt_ids)
+               return -ENOMEM;
+
+       ns = grab_requested_mnt_ns(kreq.mnt_ns_id);
+       if (!ns)
+               return -ENOENT;
+
+       if (kreq.mnt_ns_id && (ns != current->nsproxy->mnt_ns) &&
+           !ns_capable_noaudit(ns->user_ns, CAP_SYS_ADMIN))
+               return -ENOENT;
+
+       scoped_guard(rwsem_read, &namespace_sem)
+               ret = do_listmount(ns, kreq.mnt_id, last_mnt_id, kmnt_ids,
+                                  nr_mnt_ids, (flags & LISTMOUNT_REVERSE));
+       if (ret <= 0)
+               return ret;
+
+       if (copy_to_user(mnt_ids, kmnt_ids, ret * sizeof(*mnt_ids)))
+               return -EFAULT;
 
-       ret = do_listmount(first, &orig, mnt_parent_id, mnt_ids, nr_mnt_ids, &root);
-err:
-       path_put(&root);
-       up_read(&namespace_sem);
        return ret;
 }
 
-
 static void __init init_mount_tree(void)
 {
        struct vfsmount *mnt;
@@ -5157,6 +5443,8 @@ static void __init init_mount_tree(void)
 
        set_fs_pwd(current->fs, &root);
        set_fs_root(current->fs, &root);
+
+       mnt_ns_tree_add(ns);
 }
 
 void __init mnt_init(void)