io_uring: flags-based creds init in queue
[linux-2.6-microblaze.git] / fs / io_uring.c
index bd6fd51..e9e8006 100644 (file)
@@ -81,6 +81,7 @@
 #include <linux/pagemap.h>
 #include <linux/io_uring.h>
 #include <linux/blk-cgroup.h>
+#include <linux/audit.h>
 
 #define CREATE_TRACE_POINTS
 #include <trace/events/io_uring.h>
@@ -327,6 +328,11 @@ struct io_ring_ctx {
 
        const struct cred       *creds;
 
+#ifdef CONFIG_AUDIT
+       kuid_t                  loginuid;
+       unsigned int            sessionid;
+#endif
+
        struct completion       ref_comp;
        struct completion       sq_thread_comp;
 
@@ -689,7 +695,6 @@ struct io_kiocb {
        struct hlist_node               hash_node;
        struct async_poll               *apoll;
        struct io_wq_work               work;
-       struct io_identity              identity;
 };
 
 struct io_defer_entry {
@@ -1040,18 +1045,49 @@ static inline void req_set_fail_links(struct io_kiocb *req)
                req->flags |= REQ_F_FAIL_LINK;
 }
 
+/*
+ * None of these are dereferenced, they are simply used to check if any of
+ * them have changed. If we're under current and check they are still the
+ * same, we're fine to grab references to them for actual out-of-line use.
+ */
+static void io_init_identity(struct io_identity *id)
+{
+       id->files = current->files;
+       id->mm = current->mm;
+#ifdef CONFIG_BLK_CGROUP
+       rcu_read_lock();
+       id->blkcg_css = blkcg_css();
+       rcu_read_unlock();
+#endif
+       id->creds = current_cred();
+       id->nsproxy = current->nsproxy;
+       id->fs = current->fs;
+       id->fsize = rlimit(RLIMIT_FSIZE);
+#ifdef CONFIG_AUDIT
+       id->loginuid = current->loginuid;
+       id->sessionid = current->sessionid;
+#endif
+       refcount_set(&id->count, 1);
+}
+
 /*
  * Note: must call io_req_init_async() for the first time you
  * touch any members of io_wq_work.
  */
 static inline void io_req_init_async(struct io_kiocb *req)
 {
+       struct io_uring_task *tctx = current->io_uring;
+
        if (req->flags & REQ_F_WORK_INITIALIZED)
                return;
 
        memset(&req->work, 0, sizeof(req->work));
        req->flags |= REQ_F_WORK_INITIALIZED;
-       req->work.identity = &req->identity;
+
+       /* Grab a ref if this isn't our static identity */
+       req->work.identity = tctx->identity;
+       if (tctx->identity != &tctx->__identity)
+               refcount_inc(&req->work.identity->count);
 }
 
 static inline bool io_async_submit(struct io_ring_ctx *ctx)
@@ -1157,6 +1193,14 @@ static void __io_commit_cqring(struct io_ring_ctx *ctx)
        }
 }
 
+static void io_put_identity(struct io_uring_task *tctx, struct io_kiocb *req)
+{
+       if (req->work.identity == &tctx->__identity)
+               return;
+       if (refcount_dec_and_test(&req->work.identity->count))
+               kfree(req->work.identity);
+}
+
 static void io_req_clean_work(struct io_kiocb *req)
 {
        if (!(req->flags & REQ_F_WORK_INITIALIZED))
@@ -1189,28 +1233,69 @@ static void io_req_clean_work(struct io_kiocb *req)
                        free_fs_struct(fs);
                req->work.flags &= ~IO_WQ_WORK_FS;
        }
+
+       io_put_identity(req->task->io_uring, req);
 }
 
-static void io_prep_async_work(struct io_kiocb *req)
+/*
+ * Create a private copy of io_identity, since some fields don't match
+ * the current context.
+ */
+static bool io_identity_cow(struct io_kiocb *req)
+{
+       struct io_uring_task *tctx = current->io_uring;
+       const struct cred *creds = NULL;
+       struct io_identity *id;
+
+       if (req->work.flags & IO_WQ_WORK_CREDS)
+               creds = req->work.identity->creds;
+
+       id = kmemdup(req->work.identity, sizeof(*id), GFP_KERNEL);
+       if (unlikely(!id)) {
+               req->work.flags |= IO_WQ_WORK_CANCEL;
+               return false;
+       }
+
+       /*
+        * We can safely just re-init the creds we copied  Either the field
+        * matches the current one, or we haven't grabbed it yet. The only
+        * exception is ->creds, through registered personalities, so handle
+        * that one separately.
+        */
+       io_init_identity(id);
+       if (creds)
+               req->work.identity->creds = creds;
+
+       /* add one for this request */
+       refcount_inc(&id->count);
+
+       /* drop old identity, assign new one. one ref for req, one for tctx */
+       if (req->work.identity != tctx->identity &&
+           refcount_sub_and_test(2, &req->work.identity->count))
+               kfree(req->work.identity);
+
+       req->work.identity = id;
+       tctx->identity = id;
+       return true;
+}
+
+static bool io_grab_identity(struct io_kiocb *req)
 {
        const struct io_op_def *def = &io_op_defs[req->opcode];
+       struct io_identity *id = req->work.identity;
        struct io_ring_ctx *ctx = req->ctx;
 
-       io_req_init_async(req);
+       if (def->needs_fsize && id->fsize != rlimit(RLIMIT_FSIZE))
+               return false;
 
-       if (req->flags & REQ_F_ISREG) {
-               if (def->hash_reg_file || (ctx->flags & IORING_SETUP_IOPOLL))
-                       io_wq_hash_work(&req->work, file_inode(req->file));
-       } else {
-               if (def->unbound_nonreg_file)
-                       req->work.flags |= IO_WQ_WORK_UNBOUND;
-       }
        if (!(req->work.flags & IO_WQ_WORK_FILES) &&
-           (io_op_defs[req->opcode].work_flags & IO_WQ_WORK_FILES) &&
+           (def->work_flags & IO_WQ_WORK_FILES) &&
            !(req->flags & REQ_F_NO_FILE_TABLE)) {
-               req->work.identity->files = get_files_struct(current);
-               get_nsproxy(current->nsproxy);
-               req->work.identity->nsproxy = current->nsproxy;
+               if (id->files != current->files ||
+                   id->nsproxy != current->nsproxy)
+                       return false;
+               atomic_inc(&id->files->count);
+               get_nsproxy(id->nsproxy);
                req->flags |= REQ_F_INFLIGHT;
 
                spin_lock_irq(&ctx->inflight_lock);
@@ -1218,46 +1303,85 @@ static void io_prep_async_work(struct io_kiocb *req)
                spin_unlock_irq(&ctx->inflight_lock);
                req->work.flags |= IO_WQ_WORK_FILES;
        }
-       if (!(req->work.flags & IO_WQ_WORK_MM) &&
-           (def->work_flags & IO_WQ_WORK_MM)) {
-               mmgrab(current->mm);
-               req->work.identity->mm = current->mm;
-               req->work.flags |= IO_WQ_WORK_MM;
-       }
 #ifdef CONFIG_BLK_CGROUP
        if (!(req->work.flags & IO_WQ_WORK_BLKCG) &&
            (def->work_flags & IO_WQ_WORK_BLKCG)) {
                rcu_read_lock();
-               req->work.identity->blkcg_css = blkcg_css();
+               if (id->blkcg_css != blkcg_css()) {
+                       rcu_read_unlock();
+                       return false;
+               }
                /*
                 * This should be rare, either the cgroup is dying or the task
                 * is moving cgroups. Just punt to root for the handful of ios.
                 */
-               if (css_tryget_online(req->work.identity->blkcg_css))
+               if (css_tryget_online(id->blkcg_css))
                        req->work.flags |= IO_WQ_WORK_BLKCG;
                rcu_read_unlock();
        }
 #endif
        if (!(req->work.flags & IO_WQ_WORK_CREDS)) {
-               req->work.identity->creds = get_current_cred();
+               if (id->creds != current_cred())
+                       return false;
+               get_cred(id->creds);
                req->work.flags |= IO_WQ_WORK_CREDS;
        }
+#ifdef CONFIG_AUDIT
+       if (!uid_eq(current->loginuid, id->loginuid) ||
+           current->sessionid != id->sessionid)
+               return false;
+#endif
        if (!(req->work.flags & IO_WQ_WORK_FS) &&
            (def->work_flags & IO_WQ_WORK_FS)) {
-               spin_lock(&current->fs->lock);
-               if (!current->fs->in_exec) {
-                       req->work.identity->fs = current->fs;
-                       req->work.identity->fs->users++;
+               if (current->fs != id->fs)
+                       return false;
+               spin_lock(&id->fs->lock);
+               if (!id->fs->in_exec) {
+                       id->fs->users++;
                        req->work.flags |= IO_WQ_WORK_FS;
                } else {
                        req->work.flags |= IO_WQ_WORK_CANCEL;
                }
                spin_unlock(&current->fs->lock);
        }
-       if (def->needs_fsize)
-               req->work.identity->fsize = rlimit(RLIMIT_FSIZE);
-       else
-               req->work.identity->fsize = RLIM_INFINITY;
+
+       return true;
+}
+
+static void io_prep_async_work(struct io_kiocb *req)
+{
+       const struct io_op_def *def = &io_op_defs[req->opcode];
+       struct io_ring_ctx *ctx = req->ctx;
+       struct io_identity *id;
+
+       io_req_init_async(req);
+       id = req->work.identity;
+
+       if (req->flags & REQ_F_ISREG) {
+               if (def->hash_reg_file || (ctx->flags & IORING_SETUP_IOPOLL))
+                       io_wq_hash_work(&req->work, file_inode(req->file));
+       } else {
+               if (def->unbound_nonreg_file)
+                       req->work.flags |= IO_WQ_WORK_UNBOUND;
+       }
+
+       /* ->mm can never change on us */
+       if (!(req->work.flags & IO_WQ_WORK_MM) &&
+           (def->work_flags & IO_WQ_WORK_MM)) {
+               mmgrab(id->mm);
+               req->work.flags |= IO_WQ_WORK_MM;
+       }
+
+       /* if we fail grabbing identity, we must COW, regrab, and retry */
+       if (io_grab_identity(req))
+               return;
+
+       if (!io_identity_cow(req))
+               return;
+
+       /* can't fail at this point */
+       if (!io_grab_identity(req))
+               WARN_ON(1);
 }
 
 static void io_prep_async_link(struct io_kiocb *req)
@@ -1696,14 +1820,12 @@ static void io_dismantle_req(struct io_kiocb *req)
 
 static void __io_free_req(struct io_kiocb *req)
 {
-       struct io_uring_task *tctx;
-       struct io_ring_ctx *ctx;
+       struct io_uring_task *tctx = req->task->io_uring;
+       struct io_ring_ctx *ctx = req->ctx;
 
        io_dismantle_req(req);
-       tctx = req->task->io_uring;
-       ctx = req->ctx;
 
-       atomic_long_inc(&tctx->req_complete);
+       percpu_counter_dec(&tctx->inflight);
        if (tctx->in_idle)
                wake_up(&tctx->wait);
        put_task_struct(req->task);
@@ -1982,7 +2104,9 @@ static void io_req_free_batch_finish(struct io_ring_ctx *ctx,
        if (rb->to_free)
                __io_req_free_batch_flush(ctx, rb);
        if (rb->task) {
-               atomic_long_add(rb->task_refs, &rb->task->io_uring->req_complete);
+               struct io_uring_task *tctx = rb->task->io_uring;
+
+               percpu_counter_sub(&tctx->inflight, rb->task_refs);
                put_task_struct_many(rb->task, rb->task_refs);
                rb->task = NULL;
        }
@@ -1999,7 +2123,9 @@ static void io_req_free_batch(struct req_batch *rb, struct io_kiocb *req)
 
        if (req->task != rb->task) {
                if (rb->task) {
-                       atomic_long_add(rb->task_refs, &rb->task->io_uring->req_complete);
+                       struct io_uring_task *tctx = rb->task->io_uring;
+
+                       percpu_counter_sub(&tctx->inflight, rb->task_refs);
                        put_task_struct_many(rb->task, rb->task_refs);
                }
                rb->task = req->task;
@@ -2567,7 +2693,7 @@ static struct file *__io_file_get(struct io_submit_state *state, int fd)
 static bool io_bdev_nowait(struct block_device *bdev)
 {
 #ifdef CONFIG_BLOCK
-       return !bdev || queue_is_mq(bdev_get_queue(bdev));
+       return !bdev || blk_queue_nowait(bdev_get_queue(bdev));
 #else
        return true;
 #endif
@@ -4882,6 +5008,8 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
         * for write). Setup a separate io_poll_iocb if this happens.
         */
        if (unlikely(poll->head)) {
+               struct io_poll_iocb *poll_one = poll;
+
                /* already have a 2nd entry, fail a third attempt */
                if (*poll_ptr) {
                        pt->error = -EINVAL;
@@ -4892,7 +5020,7 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
                        pt->error = -ENOMEM;
                        return;
                }
-               io_init_poll_iocb(poll, req->poll.events, io_poll_double_wake);
+               io_init_poll_iocb(poll, poll_one->events, io_poll_double_wake);
                refcount_inc(&req->refs);
                poll->wait.private = req;
                *poll_ptr = poll;
@@ -6063,7 +6191,8 @@ static void __io_queue_sqe(struct io_kiocb *req, struct io_comp_state *cs)
 again:
        linked_timeout = io_prep_linked_timeout(req);
 
-       if ((req->flags & REQ_F_WORK_INITIALIZED) && req->work.identity->creds &&
+       if ((req->flags & REQ_F_WORK_INITIALIZED) &&
+           (req->work.flags & IO_WQ_WORK_CREDS) &&
            req->work.identity->creds != current_cred()) {
                if (old_creds)
                        revert_creds(old_creds);
@@ -6071,7 +6200,6 @@ again:
                        old_creds = NULL; /* restored original creds */
                else
                        old_creds = override_creds(req->work.identity->creds);
-               req->work.flags |= IO_WQ_WORK_CREDS;
        }
 
        ret = io_issue_sqe(req, true, cs);
@@ -6374,11 +6502,16 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
 
        id = READ_ONCE(sqe->personality);
        if (id) {
+               struct io_identity *iod;
+
                io_req_init_async(req);
-               req->work.identity->creds = idr_find(&ctx->personality_idr, id);
-               if (unlikely(!req->work.identity->creds))
+               iod = idr_find(&ctx->personality_idr, id);
+               if (unlikely(!iod))
                        return -EINVAL;
-               get_cred(req->work.identity->creds);
+               refcount_inc(&iod->count);
+               io_put_identity(current->io_uring, req);
+               get_cred(iod->creds);
+               req->work.identity = iod;
                req->work.flags |= IO_WQ_WORK_CREDS;
        }
 
@@ -6412,7 +6545,7 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr)
        if (!percpu_ref_tryget_many(&ctx->refs, nr))
                return -EAGAIN;
 
-       atomic_long_add(nr, &current->io_uring->req_issue);
+       percpu_counter_add(&current->io_uring->inflight, nr);
        refcount_add(nr, &current->usage);
 
        io_submit_state_start(&state, ctx, nr);
@@ -6454,10 +6587,12 @@ fail_req:
 
        if (unlikely(submitted != nr)) {
                int ref_used = (submitted == -EAGAIN) ? 0 : submitted;
+               struct io_uring_task *tctx = current->io_uring;
+               int unused = nr - ref_used;
 
-               percpu_ref_put_many(&ctx->refs, nr - ref_used);
-               atomic_long_sub(nr - ref_used, &current->io_uring->req_issue);
-               put_task_struct_many(current, nr - ref_used);
+               percpu_ref_put_many(&ctx->refs, unused);
+               percpu_counter_sub(&tctx->inflight, unused);
+               put_task_struct_many(current, unused);
        }
        if (link)
                io_queue_link_head(link, &state.comp);
@@ -6637,6 +6772,10 @@ static int io_sq_thread(void *data)
                                old_cred = override_creds(ctx->creds);
                        }
                        io_sq_thread_associate_blkcg(ctx, &cur_css);
+#ifdef CONFIG_AUDIT
+                       current->loginuid = ctx->loginuid;
+                       current->sessionid = ctx->sessionid;
+#endif
 
                        ret |= __io_sq_thread(ctx, start_jiffies, cap_entries);
 
@@ -7575,17 +7714,24 @@ out_fput:
 static int io_uring_alloc_task_context(struct task_struct *task)
 {
        struct io_uring_task *tctx;
+       int ret;
 
        tctx = kmalloc(sizeof(*tctx), GFP_KERNEL);
        if (unlikely(!tctx))
                return -ENOMEM;
 
+       ret = percpu_counter_init(&tctx->inflight, 0, GFP_KERNEL);
+       if (unlikely(ret)) {
+               kfree(tctx);
+               return ret;
+       }
+
        xa_init(&tctx->xa);
        init_waitqueue_head(&tctx->wait);
        tctx->last = NULL;
        tctx->in_idle = 0;
-       atomic_long_set(&tctx->req_issue, 0);
-       atomic_long_set(&tctx->req_complete, 0);
+       io_init_identity(&tctx->__identity);
+       tctx->identity = &tctx->__identity;
        task->io_uring = tctx;
        return 0;
 }
@@ -7595,6 +7741,10 @@ void __io_uring_free(struct task_struct *tsk)
        struct io_uring_task *tctx = tsk->io_uring;
 
        WARN_ON_ONCE(!xa_empty(&tctx->xa));
+       WARN_ON_ONCE(refcount_read(&tctx->identity->count) != 1);
+       if (tctx->identity != &tctx->__identity)
+               kfree(tctx->identity);
+       percpu_counter_destroy(&tctx->inflight);
        kfree(tctx);
        tsk->io_uring = NULL;
 }
@@ -8171,11 +8321,14 @@ static int io_uring_fasync(int fd, struct file *file, int on)
 static int io_remove_personalities(int id, void *p, void *data)
 {
        struct io_ring_ctx *ctx = data;
-       const struct cred *cred;
+       struct io_identity *iod;
 
-       cred = idr_remove(&ctx->personality_idr, id);
-       if (cred)
-               put_cred(cred);
+       iod = idr_remove(&ctx->personality_idr, id);
+       if (iod) {
+               put_cred(iod->creds);
+               if (refcount_dec_and_test(&iod->count))
+                       kfree(iod);
+       }
        return 0;
 }
 
@@ -8576,12 +8729,6 @@ void __io_uring_files_cancel(struct files_struct *files)
        }
 }
 
-static inline bool io_uring_task_idle(struct io_uring_task *tctx)
-{
-       return atomic_long_read(&tctx->req_issue) ==
-               atomic_long_read(&tctx->req_complete);
-}
-
 /*
  * Find any io_uring fd that this task has registered or done IO on, and cancel
  * requests.
@@ -8590,14 +8737,16 @@ void __io_uring_task_cancel(void)
 {
        struct io_uring_task *tctx = current->io_uring;
        DEFINE_WAIT(wait);
-       long completions;
+       s64 inflight;
 
        /* make sure overflow events are dropped */
        tctx->in_idle = true;
 
-       while (!io_uring_task_idle(tctx)) {
+       do {
                /* read completions before cancelations */
-               completions = atomic_long_read(&tctx->req_complete);
+               inflight = percpu_counter_sum(&tctx->inflight);
+               if (!inflight)
+                       break;
                __io_uring_files_cancel(NULL);
 
                prepare_to_wait(&tctx->wait, &wait, TASK_UNINTERRUPTIBLE);
@@ -8606,12 +8755,10 @@ void __io_uring_task_cancel(void)
                 * If we've seen completions, retry. This avoids a race where
                 * a completion comes in before we did prepare_to_wait().
                 */
-               if (completions != atomic_long_read(&tctx->req_complete))
+               if (inflight != percpu_counter_sum(&tctx->inflight))
                        continue;
-               if (io_uring_task_idle(tctx))
-                       break;
                schedule();
-       }
+       } while (1);
 
        finish_wait(&tctx->wait, &wait);
        tctx->in_idle = false;
@@ -9077,7 +9224,10 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p,
        ctx->compat = in_compat_syscall();
        ctx->user = user;
        ctx->creds = get_current_cred();
-
+#ifdef CONFIG_AUDIT
+       ctx->loginuid = current->loginuid;
+       ctx->sessionid = current->sessionid;
+#endif
        ctx->sqo_task = get_task_struct(current);
 
        /*
@@ -9245,23 +9395,33 @@ out:
 
 static int io_register_personality(struct io_ring_ctx *ctx)
 {
-       const struct cred *creds = get_current_cred();
-       int id;
+       struct io_identity *id;
+       int ret;
+
+       id = kmalloc(sizeof(*id), GFP_KERNEL);
+       if (unlikely(!id))
+               return -ENOMEM;
+
+       io_init_identity(id);
+       id->creds = get_current_cred();
 
-       id = idr_alloc_cyclic(&ctx->personality_idr, (void *) creds, 1,
-                               USHRT_MAX, GFP_KERNEL);
-       if (id < 0)
-               put_cred(creds);
-       return id;
+       ret = idr_alloc_cyclic(&ctx->personality_idr, id, 1, USHRT_MAX, GFP_KERNEL);
+       if (ret < 0) {
+               put_cred(id->creds);
+               kfree(id);
+       }
+       return ret;
 }
 
 static int io_unregister_personality(struct io_ring_ctx *ctx, unsigned id)
 {
-       const struct cred *old_creds;
+       struct io_identity *iod;
 
-       old_creds = idr_remove(&ctx->personality_idr, id);
-       if (old_creds) {
-               put_cred(old_creds);
+       iod = idr_remove(&ctx->personality_idr, id);
+       if (iod) {
+               put_cred(iod->creds);
+               if (refcount_dec_and_test(&iod->count))
+                       kfree(iod);
                return 0;
        }