io_uring: COW io_identity on mismatch
[linux-2.6-microblaze.git] / fs / io_uring.c
index fc6de6b..ab30834 100644 (file)
@@ -574,7 +574,6 @@ enum {
        REQ_F_NOWAIT_BIT,
        REQ_F_LINK_TIMEOUT_BIT,
        REQ_F_ISREG_BIT,
-       REQ_F_COMP_LOCKED_BIT,
        REQ_F_NEED_CLEANUP_BIT,
        REQ_F_POLLED_BIT,
        REQ_F_BUFFER_SELECTED_BIT,
@@ -613,8 +612,6 @@ enum {
        REQ_F_LINK_TIMEOUT      = BIT(REQ_F_LINK_TIMEOUT_BIT),
        /* regular file */
        REQ_F_ISREG             = BIT(REQ_F_ISREG_BIT),
-       /* completion under lock */
-       REQ_F_COMP_LOCKED       = BIT(REQ_F_COMP_LOCKED_BIT),
        /* needs cleanup */
        REQ_F_NEED_CLEANUP      = BIT(REQ_F_NEED_CLEANUP_BIT),
        /* already went through poll handler */
@@ -692,6 +689,7 @@ struct io_kiocb {
        struct hlist_node               hash_node;
        struct async_poll               *apoll;
        struct io_wq_work               work;
+       struct io_identity              identity;
 };
 
 struct io_defer_entry {
@@ -732,8 +730,6 @@ struct io_submit_state {
 };
 
 struct io_op_def {
-       /* needs current->mm setup, does mm access */
-       unsigned                needs_mm : 1;
        /* needs req->file assigned */
        unsigned                needs_file : 1;
        /* don't fail if file grab fails */
@@ -744,10 +740,6 @@ struct io_op_def {
        unsigned                unbound_nonreg_file : 1;
        /* opcode is not supported by this kernel */
        unsigned                not_supported : 1;
-       /* needs file table */
-       unsigned                file_table : 1;
-       /* needs ->fs */
-       unsigned                needs_fs : 1;
        /* set if opcode supports polled "wait" */
        unsigned                pollin : 1;
        unsigned                pollout : 1;
@@ -757,45 +749,42 @@ struct io_op_def {
        unsigned                needs_fsize : 1;
        /* must always have async data allocated */
        unsigned                needs_async_data : 1;
-       /* needs blkcg context, issues async io potentially */
-       unsigned                needs_blkcg : 1;
        /* size of async data needed, if any */
        unsigned short          async_size;
+       unsigned                work_flags;
 };
 
-static const struct io_op_def io_op_defs[] __read_mostly = {
+static const struct io_op_def io_op_defs[] = {
        [IORING_OP_NOP] = {},
        [IORING_OP_READV] = {
-               .needs_mm               = 1,
                .needs_file             = 1,
                .unbound_nonreg_file    = 1,
                .pollin                 = 1,
                .buffer_select          = 1,
                .needs_async_data       = 1,
-               .needs_blkcg            = 1,
                .async_size             = sizeof(struct io_async_rw),
+               .work_flags             = IO_WQ_WORK_MM | IO_WQ_WORK_BLKCG,
        },
        [IORING_OP_WRITEV] = {
-               .needs_mm               = 1,
                .needs_file             = 1,
                .hash_reg_file          = 1,
                .unbound_nonreg_file    = 1,
                .pollout                = 1,
                .needs_fsize            = 1,
                .needs_async_data       = 1,
-               .needs_blkcg            = 1,
                .async_size             = sizeof(struct io_async_rw),
+               .work_flags             = IO_WQ_WORK_MM | IO_WQ_WORK_BLKCG,
        },
        [IORING_OP_FSYNC] = {
                .needs_file             = 1,
-               .needs_blkcg            = 1,
+               .work_flags             = IO_WQ_WORK_BLKCG,
        },
        [IORING_OP_READ_FIXED] = {
                .needs_file             = 1,
                .unbound_nonreg_file    = 1,
                .pollin                 = 1,
-               .needs_blkcg            = 1,
                .async_size             = sizeof(struct io_async_rw),
+               .work_flags             = IO_WQ_WORK_BLKCG,
        },
        [IORING_OP_WRITE_FIXED] = {
                .needs_file             = 1,
@@ -803,8 +792,8 @@ static const struct io_op_def io_op_defs[] __read_mostly = {
                .unbound_nonreg_file    = 1,
                .pollout                = 1,
                .needs_fsize            = 1,
-               .needs_blkcg            = 1,
                .async_size             = sizeof(struct io_async_rw),
+               .work_flags             = IO_WQ_WORK_BLKCG,
        },
        [IORING_OP_POLL_ADD] = {
                .needs_file             = 1,
@@ -813,137 +802,123 @@ static const struct io_op_def io_op_defs[] __read_mostly = {
        [IORING_OP_POLL_REMOVE] = {},
        [IORING_OP_SYNC_FILE_RANGE] = {
                .needs_file             = 1,
-               .needs_blkcg            = 1,
+               .work_flags             = IO_WQ_WORK_BLKCG,
        },
        [IORING_OP_SENDMSG] = {
-               .needs_mm               = 1,
                .needs_file             = 1,
                .unbound_nonreg_file    = 1,
-               .needs_fs               = 1,
                .pollout                = 1,
                .needs_async_data       = 1,
-               .needs_blkcg            = 1,
                .async_size             = sizeof(struct io_async_msghdr),
+               .work_flags             = IO_WQ_WORK_MM | IO_WQ_WORK_BLKCG |
+                                               IO_WQ_WORK_FS,
        },
        [IORING_OP_RECVMSG] = {
-               .needs_mm               = 1,
                .needs_file             = 1,
                .unbound_nonreg_file    = 1,
-               .needs_fs               = 1,
                .pollin                 = 1,
                .buffer_select          = 1,
                .needs_async_data       = 1,
-               .needs_blkcg            = 1,
                .async_size             = sizeof(struct io_async_msghdr),
+               .work_flags             = IO_WQ_WORK_MM | IO_WQ_WORK_BLKCG |
+                                               IO_WQ_WORK_FS,
        },
        [IORING_OP_TIMEOUT] = {
-               .needs_mm               = 1,
                .needs_async_data       = 1,
                .async_size             = sizeof(struct io_timeout_data),
+               .work_flags             = IO_WQ_WORK_MM,
        },
        [IORING_OP_TIMEOUT_REMOVE] = {},
        [IORING_OP_ACCEPT] = {
-               .needs_mm               = 1,
                .needs_file             = 1,
                .unbound_nonreg_file    = 1,
-               .file_table             = 1,
                .pollin                 = 1,
+               .work_flags             = IO_WQ_WORK_MM | IO_WQ_WORK_FILES,
        },
        [IORING_OP_ASYNC_CANCEL] = {},
        [IORING_OP_LINK_TIMEOUT] = {
-               .needs_mm               = 1,
                .needs_async_data       = 1,
                .async_size             = sizeof(struct io_timeout_data),
+               .work_flags             = IO_WQ_WORK_MM,
        },
        [IORING_OP_CONNECT] = {
-               .needs_mm               = 1,
                .needs_file             = 1,
                .unbound_nonreg_file    = 1,
                .pollout                = 1,
                .needs_async_data       = 1,
                .async_size             = sizeof(struct io_async_connect),
+               .work_flags             = IO_WQ_WORK_MM,
        },
        [IORING_OP_FALLOCATE] = {
                .needs_file             = 1,
                .needs_fsize            = 1,
-               .needs_blkcg            = 1,
+               .work_flags             = IO_WQ_WORK_BLKCG,
        },
        [IORING_OP_OPENAT] = {
-               .file_table             = 1,
-               .needs_fs               = 1,
-               .needs_blkcg            = 1,
+               .work_flags             = IO_WQ_WORK_FILES | IO_WQ_WORK_BLKCG |
+                                               IO_WQ_WORK_FS,
        },
        [IORING_OP_CLOSE] = {
                .needs_file             = 1,
                .needs_file_no_error    = 1,
-               .file_table             = 1,
-               .needs_blkcg            = 1,
+               .work_flags             = IO_WQ_WORK_FILES | IO_WQ_WORK_BLKCG,
        },
        [IORING_OP_FILES_UPDATE] = {
-               .needs_mm               = 1,
-               .file_table             = 1,
+               .work_flags             = IO_WQ_WORK_FILES | IO_WQ_WORK_MM,
        },
        [IORING_OP_STATX] = {
-               .needs_mm               = 1,
-               .needs_fs               = 1,
-               .file_table             = 1,
-               .needs_blkcg            = 1,
+               .work_flags             = IO_WQ_WORK_FILES | IO_WQ_WORK_MM |
+                                               IO_WQ_WORK_FS | IO_WQ_WORK_BLKCG,
        },
        [IORING_OP_READ] = {
-               .needs_mm               = 1,
                .needs_file             = 1,
                .unbound_nonreg_file    = 1,
                .pollin                 = 1,
                .buffer_select          = 1,
-               .needs_blkcg            = 1,
                .async_size             = sizeof(struct io_async_rw),
+               .work_flags             = IO_WQ_WORK_MM | IO_WQ_WORK_BLKCG,
        },
        [IORING_OP_WRITE] = {
-               .needs_mm               = 1,
                .needs_file             = 1,
                .unbound_nonreg_file    = 1,
                .pollout                = 1,
                .needs_fsize            = 1,
-               .needs_blkcg            = 1,
                .async_size             = sizeof(struct io_async_rw),
+               .work_flags             = IO_WQ_WORK_MM | IO_WQ_WORK_BLKCG,
        },
        [IORING_OP_FADVISE] = {
                .needs_file             = 1,
-               .needs_blkcg            = 1,
+               .work_flags             = IO_WQ_WORK_BLKCG,
        },
        [IORING_OP_MADVISE] = {
-               .needs_mm               = 1,
-               .needs_blkcg            = 1,
+               .work_flags             = IO_WQ_WORK_MM | IO_WQ_WORK_BLKCG,
        },
        [IORING_OP_SEND] = {
-               .needs_mm               = 1,
                .needs_file             = 1,
                .unbound_nonreg_file    = 1,
                .pollout                = 1,
-               .needs_blkcg            = 1,
+               .work_flags             = IO_WQ_WORK_MM | IO_WQ_WORK_BLKCG,
        },
        [IORING_OP_RECV] = {
-               .needs_mm               = 1,
                .needs_file             = 1,
                .unbound_nonreg_file    = 1,
                .pollin                 = 1,
                .buffer_select          = 1,
-               .needs_blkcg            = 1,
+               .work_flags             = IO_WQ_WORK_MM | IO_WQ_WORK_BLKCG,
        },
        [IORING_OP_OPENAT2] = {
-               .file_table             = 1,
-               .needs_fs               = 1,
-               .needs_blkcg            = 1,
+               .work_flags             = IO_WQ_WORK_FILES | IO_WQ_WORK_FS |
+                                               IO_WQ_WORK_BLKCG,
        },
        [IORING_OP_EPOLL_CTL] = {
                .unbound_nonreg_file    = 1,
-               .file_table             = 1,
+               .work_flags             = IO_WQ_WORK_FILES,
        },
        [IORING_OP_SPLICE] = {
                .needs_file             = 1,
                .hash_reg_file          = 1,
                .unbound_nonreg_file    = 1,
-               .needs_blkcg            = 1,
+               .work_flags             = IO_WQ_WORK_BLKCG,
        },
        [IORING_OP_PROVIDE_BUFFERS] = {},
        [IORING_OP_REMOVE_BUFFERS] = {},
@@ -963,8 +938,8 @@ static void __io_complete_rw(struct io_kiocb *req, long res, long res2,
                             struct io_comp_state *cs);
 static void io_cqring_fill_event(struct io_kiocb *req, long res);
 static void io_put_req(struct io_kiocb *req);
+static void io_put_req_deferred(struct io_kiocb *req, int nr);
 static void io_double_put_req(struct io_kiocb *req);
-static void __io_double_put_req(struct io_kiocb *req);
 static struct io_kiocb *io_prep_linked_timeout(struct io_kiocb *req);
 static void __io_queue_linked_timeout(struct io_kiocb *req);
 static void io_queue_linked_timeout(struct io_kiocb *req);
@@ -986,7 +961,7 @@ static int io_setup_async_rw(struct io_kiocb *req, const struct iovec *iovec,
 
 static struct kmem_cache *req_cachep;
 
-static const struct file_operations io_uring_fops __read_mostly;
+static const struct file_operations io_uring_fops;
 
 struct sock *io_uring_get_socket(struct file *file)
 {
@@ -1034,7 +1009,7 @@ static int __io_sq_thread_acquire_mm(struct io_ring_ctx *ctx)
 static int io_sq_thread_acquire_mm(struct io_ring_ctx *ctx,
                                   struct io_kiocb *req)
 {
-       if (!io_op_defs[req->opcode].needs_mm)
+       if (!(io_op_defs[req->opcode].work_flags & IO_WQ_WORK_MM))
                return 0;
        return __io_sq_thread_acquire_mm(ctx);
 }
@@ -1065,6 +1040,27 @@ static inline void req_set_fail_links(struct io_kiocb *req)
                req->flags |= REQ_F_FAIL_LINK;
 }
 
+/*
+ * None of these are dereferenced, they are simply used to check if any of
+ * them have changed. If we're under current and check they are still the
+ * same, we're fine to grab references to them for actual out-of-line use.
+ */
+static void io_init_identity(struct io_identity *id)
+{
+       id->files = current->files;
+       id->mm = current->mm;
+#ifdef CONFIG_BLK_CGROUP
+       rcu_read_lock();
+       id->blkcg_css = blkcg_css();
+       rcu_read_unlock();
+#endif
+       id->creds = current_cred();
+       id->nsproxy = current->nsproxy;
+       id->fs = current->fs;
+       id->fsize = rlimit(RLIMIT_FSIZE);
+       refcount_set(&id->count, 1);
+}
+
 /*
  * Note: must call io_req_init_async() for the first time you
  * touch any members of io_wq_work.
@@ -1076,6 +1072,8 @@ static inline void io_req_init_async(struct io_kiocb *req)
 
        memset(&req->work, 0, sizeof(req->work));
        req->flags |= REQ_F_WORK_INITIALIZED;
+       io_init_identity(&req->identity);
+       req->work.identity = &req->identity;
 }
 
 static inline bool io_async_submit(struct io_ring_ctx *ctx)
@@ -1181,105 +1179,187 @@ static void __io_commit_cqring(struct io_ring_ctx *ctx)
        }
 }
 
-/*
- * Returns true if we need to defer file table putting. This can only happen
- * from the error path with REQ_F_COMP_LOCKED set.
- */
-static bool io_req_clean_work(struct io_kiocb *req)
+static void io_put_identity(struct io_kiocb *req)
+{
+       if (req->work.identity == &req->identity)
+               return;
+       if (refcount_dec_and_test(&req->work.identity->count))
+               kfree(req->work.identity);
+}
+
+static void io_req_clean_work(struct io_kiocb *req)
 {
        if (!(req->flags & REQ_F_WORK_INITIALIZED))
-               return false;
+               return;
 
        req->flags &= ~REQ_F_WORK_INITIALIZED;
 
-       if (req->work.mm) {
-               mmdrop(req->work.mm);
-               req->work.mm = NULL;
+       if (req->work.flags & IO_WQ_WORK_MM) {
+               mmdrop(req->work.identity->mm);
+               req->work.flags &= ~IO_WQ_WORK_MM;
        }
 #ifdef CONFIG_BLK_CGROUP
-       if (req->work.blkcg_css)
-               css_put(req->work.blkcg_css);
+       if (req->work.flags & IO_WQ_WORK_BLKCG) {
+               css_put(req->work.identity->blkcg_css);
+               req->work.flags &= ~IO_WQ_WORK_BLKCG;
+       }
 #endif
-       if (req->work.creds) {
-               put_cred(req->work.creds);
-               req->work.creds = NULL;
+       if (req->work.flags & IO_WQ_WORK_CREDS) {
+               put_cred(req->work.identity->creds);
+               req->work.flags &= ~IO_WQ_WORK_CREDS;
        }
-       if (req->work.fs) {
-               struct fs_struct *fs = req->work.fs;
+       if (req->work.flags & IO_WQ_WORK_FS) {
+               struct fs_struct *fs = req->work.identity->fs;
 
-               if (req->flags & REQ_F_COMP_LOCKED)
-                       return true;
-
-               spin_lock(&req->work.fs->lock);
+               spin_lock(&req->work.identity->fs->lock);
                if (--fs->users)
                        fs = NULL;
-               spin_unlock(&req->work.fs->lock);
+               spin_unlock(&req->work.identity->fs->lock);
                if (fs)
                        free_fs_struct(fs);
-               req->work.fs = NULL;
+               req->work.flags &= ~IO_WQ_WORK_FS;
        }
 
-       return false;
+       io_put_identity(req);
 }
 
-static void io_prep_async_work(struct io_kiocb *req)
+/*
+ * Create a private copy of io_identity, since some fields don't match
+ * the current context.
+ */
+static bool io_identity_cow(struct io_kiocb *req)
+{
+       const struct cred *creds = NULL;
+       struct io_identity *id;
+
+       if (req->work.flags & IO_WQ_WORK_CREDS)
+               creds = req->work.identity->creds;
+
+       id = kmemdup(req->work.identity, sizeof(*id), GFP_KERNEL);
+       if (unlikely(!id)) {
+               req->work.flags |= IO_WQ_WORK_CANCEL;
+               return false;
+       }
+
+       /*
+        * We can safely just re-init the creds we copied  Either the field
+        * matches the current one, or we haven't grabbed it yet. The only
+        * exception is ->creds, through registered personalities, so handle
+        * that one separately.
+        */
+       io_init_identity(id);
+       if (creds)
+               req->work.identity->creds = creds;
+
+       /* add one for this request */
+       refcount_inc(&id->count);
+
+       /* drop old identity, assign new one. one ref for req, one for tctx */
+       if (req->work.identity != &req->identity &&
+           refcount_sub_and_test(2, &req->work.identity->count))
+               kfree(req->work.identity);
+
+       req->work.identity = id;
+       return true;
+}
+
+static bool io_grab_identity(struct io_kiocb *req)
 {
        const struct io_op_def *def = &io_op_defs[req->opcode];
+       struct io_identity *id = &req->identity;
        struct io_ring_ctx *ctx = req->ctx;
 
-       io_req_init_async(req);
+       if (def->needs_fsize && id->fsize != rlimit(RLIMIT_FSIZE))
+               return false;
 
-       if (req->flags & REQ_F_ISREG) {
-               if (def->hash_reg_file || (ctx->flags & IORING_SETUP_IOPOLL))
-                       io_wq_hash_work(&req->work, file_inode(req->file));
-       } else {
-               if (def->unbound_nonreg_file)
-                       req->work.flags |= IO_WQ_WORK_UNBOUND;
-       }
-       if (!req->work.files && io_op_defs[req->opcode].file_table &&
+       if (!(req->work.flags & IO_WQ_WORK_FILES) &&
+           (def->work_flags & IO_WQ_WORK_FILES) &&
            !(req->flags & REQ_F_NO_FILE_TABLE)) {
-               req->work.files = get_files_struct(current);
-               get_nsproxy(current->nsproxy);
-               req->work.nsproxy = current->nsproxy;
+               if (id->files != current->files ||
+                   id->nsproxy != current->nsproxy)
+                       return false;
+               atomic_inc(&id->files->count);
+               get_nsproxy(id->nsproxy);
                req->flags |= REQ_F_INFLIGHT;
 
                spin_lock_irq(&ctx->inflight_lock);
                list_add(&req->inflight_entry, &ctx->inflight_list);
                spin_unlock_irq(&ctx->inflight_lock);
-       }
-       if (!req->work.mm && def->needs_mm) {
-               mmgrab(current->mm);
-               req->work.mm = current->mm;
+               req->work.flags |= IO_WQ_WORK_FILES;
        }
 #ifdef CONFIG_BLK_CGROUP
-       if (!req->work.blkcg_css && def->needs_blkcg) {
+       if (!(req->work.flags & IO_WQ_WORK_BLKCG) &&
+           (def->work_flags & IO_WQ_WORK_BLKCG)) {
                rcu_read_lock();
-               req->work.blkcg_css = blkcg_css();
+               if (id->blkcg_css != blkcg_css()) {
+                       rcu_read_unlock();
+                       return false;
+               }
                /*
                 * This should be rare, either the cgroup is dying or the task
                 * is moving cgroups. Just punt to root for the handful of ios.
                 */
-               if (!css_tryget_online(req->work.blkcg_css))
-                       req->work.blkcg_css = NULL;
+               if (css_tryget_online(id->blkcg_css))
+                       req->work.flags |= IO_WQ_WORK_BLKCG;
                rcu_read_unlock();
        }
 #endif
-       if (!req->work.creds)
-               req->work.creds = get_current_cred();
-       if (!req->work.fs && def->needs_fs) {
-               spin_lock(&current->fs->lock);
-               if (!current->fs->in_exec) {
-                       req->work.fs = current->fs;
-                       req->work.fs->users++;
+       if (!(req->work.flags & IO_WQ_WORK_CREDS)) {
+               if (id->creds != current_cred())
+                       return false;
+               get_cred(id->creds);
+               req->work.flags |= IO_WQ_WORK_CREDS;
+       }
+       if (!(req->work.flags & IO_WQ_WORK_FS) &&
+           (def->work_flags & IO_WQ_WORK_FS)) {
+               if (current->fs != id->fs)
+                       return false;
+               spin_lock(&id->fs->lock);
+               if (!id->fs->in_exec) {
+                       id->fs->users++;
+                       req->work.flags |= IO_WQ_WORK_FS;
                } else {
                        req->work.flags |= IO_WQ_WORK_CANCEL;
                }
                spin_unlock(&current->fs->lock);
        }
-       if (def->needs_fsize)
-               req->work.fsize = rlimit(RLIMIT_FSIZE);
-       else
-               req->work.fsize = RLIM_INFINITY;
+
+       return true;
+}
+
+static void io_prep_async_work(struct io_kiocb *req)
+{
+       const struct io_op_def *def = &io_op_defs[req->opcode];
+       struct io_identity *id = &req->identity;
+       struct io_ring_ctx *ctx = req->ctx;
+
+       io_req_init_async(req);
+
+       if (req->flags & REQ_F_ISREG) {
+               if (def->hash_reg_file || (ctx->flags & IORING_SETUP_IOPOLL))
+                       io_wq_hash_work(&req->work, file_inode(req->file));
+       } else {
+               if (def->unbound_nonreg_file)
+                       req->work.flags |= IO_WQ_WORK_UNBOUND;
+       }
+
+       /* ->mm can never change on us */
+       if (!(req->work.flags & IO_WQ_WORK_MM) &&
+           (def->work_flags & IO_WQ_WORK_MM)) {
+               mmgrab(id->mm);
+               req->work.flags |= IO_WQ_WORK_MM;
+       }
+
+       /* if we fail grabbing identity, we must COW, regrab, and retry */
+       if (io_grab_identity(req))
+               return;
+
+       if (!io_identity_cow(req))
+               return;
+
+       /* can't fail at this point */
+       if (!io_grab_identity(req))
+               WARN_ON(1);
 }
 
 static void io_prep_async_link(struct io_kiocb *req)
@@ -1325,9 +1405,8 @@ static void io_kill_timeout(struct io_kiocb *req)
                atomic_set(&req->ctx->cq_timeouts,
                        atomic_read(&req->ctx->cq_timeouts) + 1);
                list_del_init(&req->timeout.list);
-               req->flags |= REQ_F_COMP_LOCKED;
                io_cqring_fill_event(req, 0);
-               io_put_req(req);
+               io_put_req_deferred(req, 1);
        }
 }
 
@@ -1378,8 +1457,7 @@ static void __io_queue_deferred(struct io_ring_ctx *ctx)
                if (link) {
                        __io_queue_linked_timeout(link);
                        /* drop submission reference */
-                       link->flags |= REQ_F_COMP_LOCKED;
-                       io_put_req(link);
+                       io_put_req_deferred(link, 1);
                }
                kfree(de);
        } while (!list_empty(&ctx->defer_list));
@@ -1471,8 +1549,9 @@ static inline bool io_match_files(struct io_kiocb *req,
 {
        if (!files)
                return true;
-       if (req->flags & REQ_F_WORK_INITIALIZED)
-               return req->work.files == files;
+       if ((req->flags & REQ_F_WORK_INITIALIZED) &&
+           (req->work.flags & IO_WQ_WORK_FILES))
+               return req->work.identity->files == files;
        return false;
 }
 
@@ -1606,13 +1685,19 @@ static void io_submit_flush_completions(struct io_comp_state *cs)
                req = list_first_entry(&cs->list, struct io_kiocb, compl.list);
                list_del(&req->compl.list);
                __io_cqring_fill_event(req, req->result, req->compl.cflags);
-               if (!(req->flags & REQ_F_LINK_HEAD)) {
-                       req->flags |= REQ_F_COMP_LOCKED;
-                       io_put_req(req);
-               } else {
+
+               /*
+                * io_free_req() doesn't care about completion_lock unless one
+                * of these flags is set. REQ_F_WORK_INITIALIZED is in the list
+                * because of a potential deadlock with req->work.fs->lock
+                */
+               if (req->flags & (REQ_F_FAIL_LINK|REQ_F_LINK_TIMEOUT
+                                |REQ_F_WORK_INITIALIZED)) {
                        spin_unlock_irq(&ctx->completion_lock);
                        io_put_req(req);
                        spin_lock_irq(&ctx->completion_lock);
+               } else {
+                       io_put_req(req);
                }
        }
        io_commit_cqring(ctx);
@@ -1699,7 +1784,7 @@ static inline void io_put_file(struct io_kiocb *req, struct file *file,
                fput(file);
 }
 
-static bool io_dismantle_req(struct io_kiocb *req)
+static void io_dismantle_req(struct io_kiocb *req)
 {
        io_clean_op(req);
 
@@ -1708,14 +1793,16 @@ static bool io_dismantle_req(struct io_kiocb *req)
        if (req->file)
                io_put_file(req, req->file, (req->flags & REQ_F_FIXED_FILE));
 
-       return io_req_clean_work(req);
+       io_req_clean_work(req);
 }
 
-static void __io_free_req_finish(struct io_kiocb *req)
+static void __io_free_req(struct io_kiocb *req)
 {
        struct io_uring_task *tctx = req->task->io_uring;
        struct io_ring_ctx *ctx = req->ctx;
 
+       io_dismantle_req(req);
+
        atomic_long_inc(&tctx->req_complete);
        if (tctx->in_idle)
                wake_up(&tctx->wait);
@@ -1728,39 +1815,6 @@ static void __io_free_req_finish(struct io_kiocb *req)
        percpu_ref_put(&ctx->refs);
 }
 
-static void io_req_task_file_table_put(struct callback_head *cb)
-{
-       struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work);
-       struct fs_struct *fs = req->work.fs;
-
-       spin_lock(&req->work.fs->lock);
-       if (--fs->users)
-               fs = NULL;
-       spin_unlock(&req->work.fs->lock);
-       if (fs)
-               free_fs_struct(fs);
-       req->work.fs = NULL;
-       __io_free_req_finish(req);
-}
-
-static void __io_free_req(struct io_kiocb *req)
-{
-       if (!io_dismantle_req(req)) {
-               __io_free_req_finish(req);
-       } else {
-               int ret;
-
-               init_task_work(&req->task_work, io_req_task_file_table_put);
-               ret = task_work_add(req->task, &req->task_work, TWA_RESUME);
-               if (unlikely(ret)) {
-                       struct task_struct *tsk;
-
-                       tsk = io_wq_get_task(req->ctx->io_wq);
-                       task_work_add(tsk, &req->task_work, 0);
-               }
-       }
-}
-
 static bool io_link_cancel_timeout(struct io_kiocb *req)
 {
        struct io_timeout_data *io = req->async_data;
@@ -1772,7 +1826,7 @@ static bool io_link_cancel_timeout(struct io_kiocb *req)
                io_cqring_fill_event(req, -ECANCELED);
                io_commit_cqring(ctx);
                req->flags &= ~REQ_F_LINK_HEAD;
-               io_put_req(req);
+               io_put_req_deferred(req, 1);
                return true;
        }
 
@@ -1791,7 +1845,6 @@ static bool __io_kill_linked_timeout(struct io_kiocb *req)
                return false;
 
        list_del_init(&link->link_list);
-       link->flags |= REQ_F_COMP_LOCKED;
        wake_ev = io_link_cancel_timeout(link);
        req->flags &= ~REQ_F_LINK_TIMEOUT;
        return wake_ev;
@@ -1800,17 +1853,12 @@ static bool __io_kill_linked_timeout(struct io_kiocb *req)
 static void io_kill_linked_timeout(struct io_kiocb *req)
 {
        struct io_ring_ctx *ctx = req->ctx;
+       unsigned long flags;
        bool wake_ev;
 
-       if (!(req->flags & REQ_F_COMP_LOCKED)) {
-               unsigned long flags;
-
-               spin_lock_irqsave(&ctx->completion_lock, flags);
-               wake_ev = __io_kill_linked_timeout(req);
-               spin_unlock_irqrestore(&ctx->completion_lock, flags);
-       } else {
-               wake_ev = __io_kill_linked_timeout(req);
-       }
+       spin_lock_irqsave(&ctx->completion_lock, flags);
+       wake_ev = __io_kill_linked_timeout(req);
+       spin_unlock_irqrestore(&ctx->completion_lock, flags);
 
        if (wake_ev)
                io_cqring_ev_posted(ctx);
@@ -1850,28 +1898,29 @@ static void __io_fail_links(struct io_kiocb *req)
                trace_io_uring_fail_link(req, link);
 
                io_cqring_fill_event(link, -ECANCELED);
-               link->flags |= REQ_F_COMP_LOCKED;
-               __io_double_put_req(link);
-               req->flags &= ~REQ_F_LINK_TIMEOUT;
+
+               /*
+                * It's ok to free under spinlock as they're not linked anymore,
+                * but avoid REQ_F_WORK_INITIALIZED because it may deadlock on
+                * work.fs->lock.
+                */
+               if (link->flags & REQ_F_WORK_INITIALIZED)
+                       io_put_req_deferred(link, 2);
+               else
+                       io_double_put_req(link);
        }
 
        io_commit_cqring(ctx);
-       io_cqring_ev_posted(ctx);
 }
 
 static void io_fail_links(struct io_kiocb *req)
 {
        struct io_ring_ctx *ctx = req->ctx;
+       unsigned long flags;
 
-       if (!(req->flags & REQ_F_COMP_LOCKED)) {
-               unsigned long flags;
-
-               spin_lock_irqsave(&ctx->completion_lock, flags);
-               __io_fail_links(req);
-               spin_unlock_irqrestore(&ctx->completion_lock, flags);
-       } else {
-               __io_fail_links(req);
-       }
+       spin_lock_irqsave(&ctx->completion_lock, flags);
+       __io_fail_links(req);
+       spin_unlock_irqrestore(&ctx->completion_lock, flags);
 
        io_cqring_ev_posted(ctx);
 }
@@ -2058,7 +2107,7 @@ static void io_req_free_batch(struct req_batch *rb, struct io_kiocb *req)
        }
        rb->task_refs++;
 
-       WARN_ON_ONCE(io_dismantle_req(req));
+       io_dismantle_req(req);
        rb->reqs[rb->to_free++] = req;
        if (unlikely(rb->to_free == ARRAY_SIZE(rb->reqs)))
                __io_req_free_batch_flush(req->ctx, rb);
@@ -2085,6 +2134,34 @@ static void io_put_req(struct io_kiocb *req)
                io_free_req(req);
 }
 
+static void io_put_req_deferred_cb(struct callback_head *cb)
+{
+       struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work);
+
+       io_free_req(req);
+}
+
+static void io_free_req_deferred(struct io_kiocb *req)
+{
+       int ret;
+
+       init_task_work(&req->task_work, io_put_req_deferred_cb);
+       ret = io_req_task_work_add(req, true);
+       if (unlikely(ret)) {
+               struct task_struct *tsk;
+
+               tsk = io_wq_get_task(req->ctx->io_wq);
+               task_work_add(tsk, &req->task_work, 0);
+               wake_up_process(tsk);
+       }
+}
+
+static inline void io_put_req_deferred(struct io_kiocb *req, int refs)
+{
+       if (refcount_sub_and_test(refs, &req->refs))
+               io_free_req_deferred(req);
+}
+
 static struct io_wq_work *io_steal_work(struct io_kiocb *req)
 {
        struct io_kiocb *nxt;
@@ -2101,17 +2178,6 @@ static struct io_wq_work *io_steal_work(struct io_kiocb *req)
        return nxt ? &nxt->work : NULL;
 }
 
-/*
- * Must only be used if we don't need to care about links, usually from
- * within the completion handling itself.
- */
-static void __io_double_put_req(struct io_kiocb *req)
-{
-       /* drop both submit and complete references */
-       if (refcount_sub_and_test(2, &req->refs))
-               __io_free_req(req);
-}
-
 static void io_double_put_req(struct io_kiocb *req)
 {
        /* drop both submit and complete references */
@@ -4123,7 +4189,7 @@ static int io_close(struct io_kiocb *req, bool force_nonblock,
        }
 
        /* No ->flush() or already async, safely close from here */
-       ret = filp_close(close->put_file, req->work.files);
+       ret = filp_close(close->put_file, req->work.identity->files);
        if (ret < 0)
                req_set_fail_links(req);
        fput(close->put_file);
@@ -4845,10 +4911,9 @@ static void io_poll_task_handler(struct io_kiocb *req, struct io_kiocb **nxt)
 
        hash_del(&req->hash_node);
        io_poll_complete(req, req->result, 0);
-       req->flags |= REQ_F_COMP_LOCKED;
-       *nxt = io_put_req_find_next(req);
        spin_unlock_irq(&ctx->completion_lock);
 
+       *nxt = io_put_req_find_next(req);
        io_cqring_ev_posted(ctx);
 }
 
@@ -5080,6 +5145,12 @@ static bool io_arm_poll_handler(struct io_kiocb *req)
                mask |= POLLIN | POLLRDNORM;
        if (def->pollout)
                mask |= POLLOUT | POLLWRNORM;
+
+       /* If reading from MSG_ERRQUEUE using recvmsg, ignore POLLIN */
+       if ((req->opcode == IORING_OP_RECVMSG) &&
+           (req->sr_msg.msg_flags & MSG_ERRQUEUE))
+               mask &= ~POLLIN;
+
        mask |= POLLERR | POLLPRI;
 
        ipt.pt._qproc = io_async_queue_proc;
@@ -5138,9 +5209,8 @@ static bool io_poll_remove_one(struct io_kiocb *req)
        if (do_complete) {
                io_cqring_fill_event(req, -ECANCELED);
                io_commit_cqring(req->ctx);
-               req->flags |= REQ_F_COMP_LOCKED;
                req_set_fail_links(req);
-               io_put_req(req);
+               io_put_req_deferred(req, 1);
        }
 
        return do_complete;
@@ -5322,9 +5392,8 @@ static int __io_timeout_cancel(struct io_kiocb *req)
        list_del_init(&req->timeout.list);
 
        req_set_fail_links(req);
-       req->flags |= REQ_F_COMP_LOCKED;
        io_cqring_fill_event(req, -ECANCELED);
-       io_put_req(req);
+       io_put_req_deferred(req, 1);
        return 0;
 }
 
@@ -5734,9 +5803,9 @@ static void io_req_drop_files(struct io_kiocb *req)
                wake_up(&ctx->inflight_wait);
        spin_unlock_irqrestore(&ctx->inflight_lock, flags);
        req->flags &= ~REQ_F_INFLIGHT;
-       put_files_struct(req->work.files);
-       put_nsproxy(req->work.nsproxy);
-       req->work.files = NULL;
+       put_files_struct(req->work.identity->files);
+       put_nsproxy(req->work.identity->nsproxy);
+       req->work.flags &= ~IO_WQ_WORK_FILES;
 }
 
 static void __io_clean_op(struct io_kiocb *req)
@@ -6094,14 +6163,15 @@ static void __io_queue_sqe(struct io_kiocb *req, struct io_comp_state *cs)
 again:
        linked_timeout = io_prep_linked_timeout(req);
 
-       if ((req->flags & REQ_F_WORK_INITIALIZED) && req->work.creds &&
-           req->work.creds != current_cred()) {
+       if ((req->flags & REQ_F_WORK_INITIALIZED) && req->work.identity->creds &&
+           req->work.identity->creds != current_cred()) {
                if (old_creds)
                        revert_creds(old_creds);
-               if (old_creds == req->work.creds)
+               if (old_creds == req->work.identity->creds)
                        old_creds = NULL; /* restored original creds */
                else
-                       old_creds = override_creds(req->work.creds);
+                       old_creds = override_creds(req->work.identity->creds);
+               req->work.flags |= IO_WQ_WORK_CREDS;
        }
 
        ret = io_issue_sqe(req, true, cs);
@@ -6404,11 +6474,17 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
 
        id = READ_ONCE(sqe->personality);
        if (id) {
+               struct io_identity *iod;
+
                io_req_init_async(req);
-               req->work.creds = idr_find(&ctx->personality_idr, id);
-               if (unlikely(!req->work.creds))
+               iod = idr_find(&ctx->personality_idr, id);
+               if (unlikely(!iod))
                        return -EINVAL;
-               get_cred(req->work.creds);
+               refcount_inc(&iod->count);
+               io_put_identity(req);
+               get_cred(iod->creds);
+               req->work.identity = iod;
+               req->work.flags |= IO_WQ_WORK_CREDS;
        }
 
        /* same numerical values with corresponding REQ_F_*, safe to copy */
@@ -7300,7 +7376,7 @@ static int io_sqe_files_register(struct io_ring_ctx *ctx, void __user *arg,
        spin_lock_init(&file_data->lock);
 
        nr_tables = DIV_ROUND_UP(nr_args, IORING_MAX_FILES_TABLE);
-       file_data->table = kcalloc(nr_tables, sizeof(file_data->table),
+       file_data->table = kcalloc(nr_tables, sizeof(*file_data->table),
                                   GFP_KERNEL);
        if (!file_data->table)
                goto out_free;
@@ -7311,6 +7387,7 @@ static int io_sqe_files_register(struct io_ring_ctx *ctx, void __user *arg,
 
        if (io_sqe_alloc_file_tables(file_data, nr_tables, nr_args))
                goto out_ref;
+       ctx->file_data = file_data;
 
        for (i = 0; i < nr_args; i++, ctx->nr_user_files++) {
                struct fixed_file_table *table;
@@ -7345,7 +7422,6 @@ static int io_sqe_files_register(struct io_ring_ctx *ctx, void __user *arg,
                table->files[index] = file;
        }
 
-       ctx->file_data = file_data;
        ret = io_sqe_files_scm(ctx);
        if (ret) {
                io_sqe_files_unregister(ctx);
@@ -7378,6 +7454,7 @@ out_ref:
 out_free:
        kfree(file_data->table);
        kfree(file_data);
+       ctx->file_data = NULL;
        return ret;
 }
 
@@ -8199,11 +8276,14 @@ static int io_uring_fasync(int fd, struct file *file, int on)
 static int io_remove_personalities(int id, void *p, void *data)
 {
        struct io_ring_ctx *ctx = data;
-       const struct cred *cred;
+       struct io_identity *iod;
 
-       cred = idr_remove(&ctx->personality_idr, id);
-       if (cred)
-               put_cred(cred);
+       iod = idr_remove(&ctx->personality_idr, id);
+       if (iod) {
+               put_cred(iod->creds);
+               if (refcount_dec_and_test(&iod->count))
+                       kfree(iod);
+       }
        return 0;
 }
 
@@ -8275,7 +8355,8 @@ static bool io_wq_files_match(struct io_wq_work *work, void *data)
 {
        struct files_struct *files = data;
 
-       return !files || work->files == files;
+       return !files || ((work->flags & IO_WQ_WORK_FILES) &&
+                               work->identity->files == files);
 }
 
 /*
@@ -8430,7 +8511,8 @@ static bool io_uring_cancel_files(struct io_ring_ctx *ctx,
 
                spin_lock_irq(&ctx->inflight_lock);
                list_for_each_entry(req, &ctx->inflight_list, inflight_entry) {
-                       if (files && req->work.files != files)
+                       if (files && (req->work.flags & IO_WQ_WORK_FILES) &&
+                           req->work.identity->files != files)
                                continue;
                        /* req is being completed, ignore */
                        if (!refcount_inc_not_zero(&req->refs))
@@ -9271,23 +9353,33 @@ out:
 
 static int io_register_personality(struct io_ring_ctx *ctx)
 {
-       const struct cred *creds = get_current_cred();
-       int id;
+       struct io_identity *id;
+       int ret;
 
-       id = idr_alloc_cyclic(&ctx->personality_idr, (void *) creds, 1,
-                               USHRT_MAX, GFP_KERNEL);
-       if (id < 0)
-               put_cred(creds);
-       return id;
+       id = kmalloc(sizeof(*id), GFP_KERNEL);
+       if (unlikely(!id))
+               return -ENOMEM;
+
+       io_init_identity(id);
+       id->creds = get_current_cred();
+
+       ret = idr_alloc_cyclic(&ctx->personality_idr, id, 1, USHRT_MAX, GFP_KERNEL);
+       if (ret < 0) {
+               put_cred(id->creds);
+               kfree(id);
+       }
+       return ret;
 }
 
 static int io_unregister_personality(struct io_ring_ctx *ctx, unsigned id)
 {
-       const struct cred *old_creds;
+       struct io_identity *iod;
 
-       old_creds = idr_remove(&ctx->personality_idr, id);
-       if (old_creds) {
-               put_cred(old_creds);
+       iod = idr_remove(&ctx->personality_idr, id);
+       if (iod) {
+               put_cred(iod->creds);
+               if (refcount_dec_and_test(&iod->count))
+                       kfree(iod);
                return 0;
        }