io_uring: COW io_identity on mismatch
[linux-2.6-microblaze.git] / fs / io_uring.c
index bd6fd51..ab30834 100644 (file)
@@ -1040,6 +1040,27 @@ static inline void req_set_fail_links(struct io_kiocb *req)
                req->flags |= REQ_F_FAIL_LINK;
 }
 
+/*
+ * None of these are dereferenced, they are simply used to check if any of
+ * them have changed. If we're under current and check they are still the
+ * same, we're fine to grab references to them for actual out-of-line use.
+ */
+static void io_init_identity(struct io_identity *id)
+{
+       id->files = current->files;
+       id->mm = current->mm;
+#ifdef CONFIG_BLK_CGROUP
+       rcu_read_lock();
+       id->blkcg_css = blkcg_css();
+       rcu_read_unlock();
+#endif
+       id->creds = current_cred();
+       id->nsproxy = current->nsproxy;
+       id->fs = current->fs;
+       id->fsize = rlimit(RLIMIT_FSIZE);
+       refcount_set(&id->count, 1);
+}
+
 /*
  * Note: must call io_req_init_async() for the first time you
  * touch any members of io_wq_work.
@@ -1051,6 +1072,7 @@ static inline void io_req_init_async(struct io_kiocb *req)
 
        memset(&req->work, 0, sizeof(req->work));
        req->flags |= REQ_F_WORK_INITIALIZED;
+       io_init_identity(&req->identity);
        req->work.identity = &req->identity;
 }
 
@@ -1157,6 +1179,14 @@ static void __io_commit_cqring(struct io_ring_ctx *ctx)
        }
 }
 
+static void io_put_identity(struct io_kiocb *req)
+{
+       if (req->work.identity == &req->identity)
+               return;
+       if (refcount_dec_and_test(&req->work.identity->count))
+               kfree(req->work.identity);
+}
+
 static void io_req_clean_work(struct io_kiocb *req)
 {
        if (!(req->flags & REQ_F_WORK_INITIALIZED))
@@ -1189,28 +1219,67 @@ static void io_req_clean_work(struct io_kiocb *req)
                        free_fs_struct(fs);
                req->work.flags &= ~IO_WQ_WORK_FS;
        }
+
+       io_put_identity(req);
 }
 
-static void io_prep_async_work(struct io_kiocb *req)
+/*
+ * Create a private copy of io_identity, since some fields don't match
+ * the current context.
+ */
+static bool io_identity_cow(struct io_kiocb *req)
+{
+       const struct cred *creds = NULL;
+       struct io_identity *id;
+
+       if (req->work.flags & IO_WQ_WORK_CREDS)
+               creds = req->work.identity->creds;
+
+       id = kmemdup(req->work.identity, sizeof(*id), GFP_KERNEL);
+       if (unlikely(!id)) {
+               req->work.flags |= IO_WQ_WORK_CANCEL;
+               return false;
+       }
+
+       /*
+        * We can safely just re-init the creds we copied  Either the field
+        * matches the current one, or we haven't grabbed it yet. The only
+        * exception is ->creds, through registered personalities, so handle
+        * that one separately.
+        */
+       io_init_identity(id);
+       if (creds)
+               req->work.identity->creds = creds;
+
+       /* add one for this request */
+       refcount_inc(&id->count);
+
+       /* drop old identity, assign new one. one ref for req, one for tctx */
+       if (req->work.identity != &req->identity &&
+           refcount_sub_and_test(2, &req->work.identity->count))
+               kfree(req->work.identity);
+
+       req->work.identity = id;
+       return true;
+}
+
+static bool io_grab_identity(struct io_kiocb *req)
 {
        const struct io_op_def *def = &io_op_defs[req->opcode];
+       struct io_identity *id = &req->identity;
        struct io_ring_ctx *ctx = req->ctx;
 
-       io_req_init_async(req);
+       if (def->needs_fsize && id->fsize != rlimit(RLIMIT_FSIZE))
+               return false;
 
-       if (req->flags & REQ_F_ISREG) {
-               if (def->hash_reg_file || (ctx->flags & IORING_SETUP_IOPOLL))
-                       io_wq_hash_work(&req->work, file_inode(req->file));
-       } else {
-               if (def->unbound_nonreg_file)
-                       req->work.flags |= IO_WQ_WORK_UNBOUND;
-       }
        if (!(req->work.flags & IO_WQ_WORK_FILES) &&
-           (io_op_defs[req->opcode].work_flags & IO_WQ_WORK_FILES) &&
+           (def->work_flags & IO_WQ_WORK_FILES) &&
            !(req->flags & REQ_F_NO_FILE_TABLE)) {
-               req->work.identity->files = get_files_struct(current);
-               get_nsproxy(current->nsproxy);
-               req->work.identity->nsproxy = current->nsproxy;
+               if (id->files != current->files ||
+                   id->nsproxy != current->nsproxy)
+                       return false;
+               atomic_inc(&id->files->count);
+               get_nsproxy(id->nsproxy);
                req->flags |= REQ_F_INFLIGHT;
 
                spin_lock_irq(&ctx->inflight_lock);
@@ -1218,46 +1287,79 @@ static void io_prep_async_work(struct io_kiocb *req)
                spin_unlock_irq(&ctx->inflight_lock);
                req->work.flags |= IO_WQ_WORK_FILES;
        }
-       if (!(req->work.flags & IO_WQ_WORK_MM) &&
-           (def->work_flags & IO_WQ_WORK_MM)) {
-               mmgrab(current->mm);
-               req->work.identity->mm = current->mm;
-               req->work.flags |= IO_WQ_WORK_MM;
-       }
 #ifdef CONFIG_BLK_CGROUP
        if (!(req->work.flags & IO_WQ_WORK_BLKCG) &&
            (def->work_flags & IO_WQ_WORK_BLKCG)) {
                rcu_read_lock();
-               req->work.identity->blkcg_css = blkcg_css();
+               if (id->blkcg_css != blkcg_css()) {
+                       rcu_read_unlock();
+                       return false;
+               }
                /*
                 * This should be rare, either the cgroup is dying or the task
                 * is moving cgroups. Just punt to root for the handful of ios.
                 */
-               if (css_tryget_online(req->work.identity->blkcg_css))
+               if (css_tryget_online(id->blkcg_css))
                        req->work.flags |= IO_WQ_WORK_BLKCG;
                rcu_read_unlock();
        }
 #endif
        if (!(req->work.flags & IO_WQ_WORK_CREDS)) {
-               req->work.identity->creds = get_current_cred();
+               if (id->creds != current_cred())
+                       return false;
+               get_cred(id->creds);
                req->work.flags |= IO_WQ_WORK_CREDS;
        }
        if (!(req->work.flags & IO_WQ_WORK_FS) &&
            (def->work_flags & IO_WQ_WORK_FS)) {
-               spin_lock(&current->fs->lock);
-               if (!current->fs->in_exec) {
-                       req->work.identity->fs = current->fs;
-                       req->work.identity->fs->users++;
+               if (current->fs != id->fs)
+                       return false;
+               spin_lock(&id->fs->lock);
+               if (!id->fs->in_exec) {
+                       id->fs->users++;
                        req->work.flags |= IO_WQ_WORK_FS;
                } else {
                        req->work.flags |= IO_WQ_WORK_CANCEL;
                }
                spin_unlock(&current->fs->lock);
        }
-       if (def->needs_fsize)
-               req->work.identity->fsize = rlimit(RLIMIT_FSIZE);
-       else
-               req->work.identity->fsize = RLIM_INFINITY;
+
+       return true;
+}
+
+static void io_prep_async_work(struct io_kiocb *req)
+{
+       const struct io_op_def *def = &io_op_defs[req->opcode];
+       struct io_identity *id = &req->identity;
+       struct io_ring_ctx *ctx = req->ctx;
+
+       io_req_init_async(req);
+
+       if (req->flags & REQ_F_ISREG) {
+               if (def->hash_reg_file || (ctx->flags & IORING_SETUP_IOPOLL))
+                       io_wq_hash_work(&req->work, file_inode(req->file));
+       } else {
+               if (def->unbound_nonreg_file)
+                       req->work.flags |= IO_WQ_WORK_UNBOUND;
+       }
+
+       /* ->mm can never change on us */
+       if (!(req->work.flags & IO_WQ_WORK_MM) &&
+           (def->work_flags & IO_WQ_WORK_MM)) {
+               mmgrab(id->mm);
+               req->work.flags |= IO_WQ_WORK_MM;
+       }
+
+       /* if we fail grabbing identity, we must COW, regrab, and retry */
+       if (io_grab_identity(req))
+               return;
+
+       if (!io_identity_cow(req))
+               return;
+
+       /* can't fail at this point */
+       if (!io_grab_identity(req))
+               WARN_ON(1);
 }
 
 static void io_prep_async_link(struct io_kiocb *req)
@@ -1696,12 +1798,10 @@ static void io_dismantle_req(struct io_kiocb *req)
 
 static void __io_free_req(struct io_kiocb *req)
 {
-       struct io_uring_task *tctx;
-       struct io_ring_ctx *ctx;
+       struct io_uring_task *tctx = req->task->io_uring;
+       struct io_ring_ctx *ctx = req->ctx;
 
        io_dismantle_req(req);
-       tctx = req->task->io_uring;
-       ctx = req->ctx;
 
        atomic_long_inc(&tctx->req_complete);
        if (tctx->in_idle)
@@ -6374,11 +6474,16 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
 
        id = READ_ONCE(sqe->personality);
        if (id) {
+               struct io_identity *iod;
+
                io_req_init_async(req);
-               req->work.identity->creds = idr_find(&ctx->personality_idr, id);
-               if (unlikely(!req->work.identity->creds))
+               iod = idr_find(&ctx->personality_idr, id);
+               if (unlikely(!iod))
                        return -EINVAL;
-               get_cred(req->work.identity->creds);
+               refcount_inc(&iod->count);
+               io_put_identity(req);
+               get_cred(iod->creds);
+               req->work.identity = iod;
                req->work.flags |= IO_WQ_WORK_CREDS;
        }
 
@@ -8171,11 +8276,14 @@ static int io_uring_fasync(int fd, struct file *file, int on)
 static int io_remove_personalities(int id, void *p, void *data)
 {
        struct io_ring_ctx *ctx = data;
-       const struct cred *cred;
+       struct io_identity *iod;
 
-       cred = idr_remove(&ctx->personality_idr, id);
-       if (cred)
-               put_cred(cred);
+       iod = idr_remove(&ctx->personality_idr, id);
+       if (iod) {
+               put_cred(iod->creds);
+               if (refcount_dec_and_test(&iod->count))
+                       kfree(iod);
+       }
        return 0;
 }
 
@@ -9245,23 +9353,33 @@ out:
 
 static int io_register_personality(struct io_ring_ctx *ctx)
 {
-       const struct cred *creds = get_current_cred();
-       int id;
+       struct io_identity *id;
+       int ret;
 
-       id = idr_alloc_cyclic(&ctx->personality_idr, (void *) creds, 1,
-                               USHRT_MAX, GFP_KERNEL);
-       if (id < 0)
-               put_cred(creds);
-       return id;
+       id = kmalloc(sizeof(*id), GFP_KERNEL);
+       if (unlikely(!id))
+               return -ENOMEM;
+
+       io_init_identity(id);
+       id->creds = get_current_cred();
+
+       ret = idr_alloc_cyclic(&ctx->personality_idr, id, 1, USHRT_MAX, GFP_KERNEL);
+       if (ret < 0) {
+               put_cred(id->creds);
+               kfree(id);
+       }
+       return ret;
 }
 
 static int io_unregister_personality(struct io_ring_ctx *ctx, unsigned id)
 {
-       const struct cred *old_creds;
+       struct io_identity *iod;
 
-       old_creds = idr_remove(&ctx->personality_idr, id);
-       if (old_creds) {
-               put_cred(old_creds);
+       iod = idr_remove(&ctx->personality_idr, id);
+       if (iod) {
+               put_cred(iod->creds);
+               if (refcount_dec_and_test(&iod->count))
+                       kfree(iod);
                return 0;
        }