io_uring: COW io_identity on mismatch
authorJens Axboe <axboe@kernel.dk>
Thu, 15 Oct 2020 14:46:24 +0000 (08:46 -0600)
committerJens Axboe <axboe@kernel.dk>
Sat, 17 Oct 2020 15:25:46 +0000 (09:25 -0600)
If the io_identity doesn't completely match the task, then create a
copy of it and use that. The existing copy remains valid until the last
user of it has gone away.

This also changes the personality lookup to be indexed by io_identity,
instead of creds directly.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
fs/io_uring.c
include/linux/io_uring.h

index bd6fd51..ab30834 100644 (file)
@@ -1040,6 +1040,27 @@ static inline void req_set_fail_links(struct io_kiocb *req)
                req->flags |= REQ_F_FAIL_LINK;
 }
 
+/*
+ * None of these are dereferenced, they are simply used to check if any of
+ * them have changed. If we're under current and check they are still the
+ * same, we're fine to grab references to them for actual out-of-line use.
+ */
+static void io_init_identity(struct io_identity *id)
+{
+       id->files = current->files;
+       id->mm = current->mm;
+#ifdef CONFIG_BLK_CGROUP
+       rcu_read_lock();
+       id->blkcg_css = blkcg_css();
+       rcu_read_unlock();
+#endif
+       id->creds = current_cred();
+       id->nsproxy = current->nsproxy;
+       id->fs = current->fs;
+       id->fsize = rlimit(RLIMIT_FSIZE);
+       refcount_set(&id->count, 1);
+}
+
 /*
  * Note: must call io_req_init_async() for the first time you
  * touch any members of io_wq_work.
@@ -1051,6 +1072,7 @@ static inline void io_req_init_async(struct io_kiocb *req)
 
        memset(&req->work, 0, sizeof(req->work));
        req->flags |= REQ_F_WORK_INITIALIZED;
+       io_init_identity(&req->identity);
        req->work.identity = &req->identity;
 }
 
@@ -1157,6 +1179,14 @@ static void __io_commit_cqring(struct io_ring_ctx *ctx)
        }
 }
 
+static void io_put_identity(struct io_kiocb *req)
+{
+       if (req->work.identity == &req->identity)
+               return;
+       if (refcount_dec_and_test(&req->work.identity->count))
+               kfree(req->work.identity);
+}
+
 static void io_req_clean_work(struct io_kiocb *req)
 {
        if (!(req->flags & REQ_F_WORK_INITIALIZED))
@@ -1189,28 +1219,67 @@ static void io_req_clean_work(struct io_kiocb *req)
                        free_fs_struct(fs);
                req->work.flags &= ~IO_WQ_WORK_FS;
        }
+
+       io_put_identity(req);
 }
 
-static void io_prep_async_work(struct io_kiocb *req)
+/*
+ * Create a private copy of io_identity, since some fields don't match
+ * the current context.
+ */
+static bool io_identity_cow(struct io_kiocb *req)
+{
+       const struct cred *creds = NULL;
+       struct io_identity *id;
+
+       if (req->work.flags & IO_WQ_WORK_CREDS)
+               creds = req->work.identity->creds;
+
+       id = kmemdup(req->work.identity, sizeof(*id), GFP_KERNEL);
+       if (unlikely(!id)) {
+               req->work.flags |= IO_WQ_WORK_CANCEL;
+               return false;
+       }
+
+       /*
+        * We can safely just re-init the creds we copied  Either the field
+        * matches the current one, or we haven't grabbed it yet. The only
+        * exception is ->creds, through registered personalities, so handle
+        * that one separately.
+        */
+       io_init_identity(id);
+       if (creds)
+               req->work.identity->creds = creds;
+
+       /* add one for this request */
+       refcount_inc(&id->count);
+
+       /* drop old identity, assign new one. one ref for req, one for tctx */
+       if (req->work.identity != &req->identity &&
+           refcount_sub_and_test(2, &req->work.identity->count))
+               kfree(req->work.identity);
+
+       req->work.identity = id;
+       return true;
+}
+
+static bool io_grab_identity(struct io_kiocb *req)
 {
        const struct io_op_def *def = &io_op_defs[req->opcode];
+       struct io_identity *id = &req->identity;
        struct io_ring_ctx *ctx = req->ctx;
 
-       io_req_init_async(req);
+       if (def->needs_fsize && id->fsize != rlimit(RLIMIT_FSIZE))
+               return false;
 
-       if (req->flags & REQ_F_ISREG) {
-               if (def->hash_reg_file || (ctx->flags & IORING_SETUP_IOPOLL))
-                       io_wq_hash_work(&req->work, file_inode(req->file));
-       } else {
-               if (def->unbound_nonreg_file)
-                       req->work.flags |= IO_WQ_WORK_UNBOUND;
-       }
        if (!(req->work.flags & IO_WQ_WORK_FILES) &&
-           (io_op_defs[req->opcode].work_flags & IO_WQ_WORK_FILES) &&
+           (def->work_flags & IO_WQ_WORK_FILES) &&
            !(req->flags & REQ_F_NO_FILE_TABLE)) {
-               req->work.identity->files = get_files_struct(current);
-               get_nsproxy(current->nsproxy);
-               req->work.identity->nsproxy = current->nsproxy;
+               if (id->files != current->files ||
+                   id->nsproxy != current->nsproxy)
+                       return false;
+               atomic_inc(&id->files->count);
+               get_nsproxy(id->nsproxy);
                req->flags |= REQ_F_INFLIGHT;
 
                spin_lock_irq(&ctx->inflight_lock);
@@ -1218,46 +1287,79 @@ static void io_prep_async_work(struct io_kiocb *req)
                spin_unlock_irq(&ctx->inflight_lock);
                req->work.flags |= IO_WQ_WORK_FILES;
        }
-       if (!(req->work.flags & IO_WQ_WORK_MM) &&
-           (def->work_flags & IO_WQ_WORK_MM)) {
-               mmgrab(current->mm);
-               req->work.identity->mm = current->mm;
-               req->work.flags |= IO_WQ_WORK_MM;
-       }
 #ifdef CONFIG_BLK_CGROUP
        if (!(req->work.flags & IO_WQ_WORK_BLKCG) &&
            (def->work_flags & IO_WQ_WORK_BLKCG)) {
                rcu_read_lock();
-               req->work.identity->blkcg_css = blkcg_css();
+               if (id->blkcg_css != blkcg_css()) {
+                       rcu_read_unlock();
+                       return false;
+               }
                /*
                 * This should be rare, either the cgroup is dying or the task
                 * is moving cgroups. Just punt to root for the handful of ios.
                 */
-               if (css_tryget_online(req->work.identity->blkcg_css))
+               if (css_tryget_online(id->blkcg_css))
                        req->work.flags |= IO_WQ_WORK_BLKCG;
                rcu_read_unlock();
        }
 #endif
        if (!(req->work.flags & IO_WQ_WORK_CREDS)) {
-               req->work.identity->creds = get_current_cred();
+               if (id->creds != current_cred())
+                       return false;
+               get_cred(id->creds);
                req->work.flags |= IO_WQ_WORK_CREDS;
        }
        if (!(req->work.flags & IO_WQ_WORK_FS) &&
            (def->work_flags & IO_WQ_WORK_FS)) {
-               spin_lock(&current->fs->lock);
-               if (!current->fs->in_exec) {
-                       req->work.identity->fs = current->fs;
-                       req->work.identity->fs->users++;
+               if (current->fs != id->fs)
+                       return false;
+               spin_lock(&id->fs->lock);
+               if (!id->fs->in_exec) {
+                       id->fs->users++;
                        req->work.flags |= IO_WQ_WORK_FS;
                } else {
                        req->work.flags |= IO_WQ_WORK_CANCEL;
                }
                spin_unlock(&current->fs->lock);
        }
-       if (def->needs_fsize)
-               req->work.identity->fsize = rlimit(RLIMIT_FSIZE);
-       else
-               req->work.identity->fsize = RLIM_INFINITY;
+
+       return true;
+}
+
+static void io_prep_async_work(struct io_kiocb *req)
+{
+       const struct io_op_def *def = &io_op_defs[req->opcode];
+       struct io_identity *id = &req->identity;
+       struct io_ring_ctx *ctx = req->ctx;
+
+       io_req_init_async(req);
+
+       if (req->flags & REQ_F_ISREG) {
+               if (def->hash_reg_file || (ctx->flags & IORING_SETUP_IOPOLL))
+                       io_wq_hash_work(&req->work, file_inode(req->file));
+       } else {
+               if (def->unbound_nonreg_file)
+                       req->work.flags |= IO_WQ_WORK_UNBOUND;
+       }
+
+       /* ->mm can never change on us */
+       if (!(req->work.flags & IO_WQ_WORK_MM) &&
+           (def->work_flags & IO_WQ_WORK_MM)) {
+               mmgrab(id->mm);
+               req->work.flags |= IO_WQ_WORK_MM;
+       }
+
+       /* if we fail grabbing identity, we must COW, regrab, and retry */
+       if (io_grab_identity(req))
+               return;
+
+       if (!io_identity_cow(req))
+               return;
+
+       /* can't fail at this point */
+       if (!io_grab_identity(req))
+               WARN_ON(1);
 }
 
 static void io_prep_async_link(struct io_kiocb *req)
@@ -1696,12 +1798,10 @@ static void io_dismantle_req(struct io_kiocb *req)
 
 static void __io_free_req(struct io_kiocb *req)
 {
-       struct io_uring_task *tctx;
-       struct io_ring_ctx *ctx;
+       struct io_uring_task *tctx = req->task->io_uring;
+       struct io_ring_ctx *ctx = req->ctx;
 
        io_dismantle_req(req);
-       tctx = req->task->io_uring;
-       ctx = req->ctx;
 
        atomic_long_inc(&tctx->req_complete);
        if (tctx->in_idle)
@@ -6374,11 +6474,16 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
 
        id = READ_ONCE(sqe->personality);
        if (id) {
+               struct io_identity *iod;
+
                io_req_init_async(req);
-               req->work.identity->creds = idr_find(&ctx->personality_idr, id);
-               if (unlikely(!req->work.identity->creds))
+               iod = idr_find(&ctx->personality_idr, id);
+               if (unlikely(!iod))
                        return -EINVAL;
-               get_cred(req->work.identity->creds);
+               refcount_inc(&iod->count);
+               io_put_identity(req);
+               get_cred(iod->creds);
+               req->work.identity = iod;
                req->work.flags |= IO_WQ_WORK_CREDS;
        }
 
@@ -8171,11 +8276,14 @@ static int io_uring_fasync(int fd, struct file *file, int on)
 static int io_remove_personalities(int id, void *p, void *data)
 {
        struct io_ring_ctx *ctx = data;
-       const struct cred *cred;
+       struct io_identity *iod;
 
-       cred = idr_remove(&ctx->personality_idr, id);
-       if (cred)
-               put_cred(cred);
+       iod = idr_remove(&ctx->personality_idr, id);
+       if (iod) {
+               put_cred(iod->creds);
+               if (refcount_dec_and_test(&iod->count))
+                       kfree(iod);
+       }
        return 0;
 }
 
@@ -9245,23 +9353,33 @@ out:
 
 static int io_register_personality(struct io_ring_ctx *ctx)
 {
-       const struct cred *creds = get_current_cred();
-       int id;
+       struct io_identity *id;
+       int ret;
 
-       id = idr_alloc_cyclic(&ctx->personality_idr, (void *) creds, 1,
-                               USHRT_MAX, GFP_KERNEL);
-       if (id < 0)
-               put_cred(creds);
-       return id;
+       id = kmalloc(sizeof(*id), GFP_KERNEL);
+       if (unlikely(!id))
+               return -ENOMEM;
+
+       io_init_identity(id);
+       id->creds = get_current_cred();
+
+       ret = idr_alloc_cyclic(&ctx->personality_idr, id, 1, USHRT_MAX, GFP_KERNEL);
+       if (ret < 0) {
+               put_cred(id->creds);
+               kfree(id);
+       }
+       return ret;
 }
 
 static int io_unregister_personality(struct io_ring_ctx *ctx, unsigned id)
 {
-       const struct cred *old_creds;
+       struct io_identity *iod;
 
-       old_creds = idr_remove(&ctx->personality_idr, id);
-       if (old_creds) {
-               put_cred(old_creds);
+       iod = idr_remove(&ctx->personality_idr, id);
+       if (iod) {
+               put_cred(iod->creds);
+               if (refcount_dec_and_test(&iod->count))
+                       kfree(iod);
                return 0;
        }
 
index 352aa6b..342cc57 100644 (file)
@@ -15,6 +15,7 @@ struct io_identity {
        struct nsproxy                  *nsproxy;
        struct fs_struct                *fs;
        unsigned long                   fsize;
+       refcount_t                      count;
 };
 
 struct io_uring_task {