drm/i915: Refactor setting dma info to a common helper
[linux-2.6-microblaze.git] / fs / io_uring.c
index 358f97b..5190bfb 100644 (file)
@@ -186,14 +186,23 @@ struct fixed_file_table {
        struct file             **files;
 };
 
+struct fixed_file_ref_node {
+       struct percpu_ref               refs;
+       struct list_head                node;
+       struct list_head                file_list;
+       struct fixed_file_data          *file_data;
+       struct work_struct              work;
+};
+
 struct fixed_file_data {
        struct fixed_file_table         *table;
        struct io_ring_ctx              *ctx;
 
+       struct percpu_ref               *cur_refs;
        struct percpu_ref               refs;
-       struct llist_head               put_llist;
-       struct work_struct              ref_work;
        struct completion               done;
+       struct list_head                ref_list;
+       spinlock_t                      lock;
 };
 
 struct io_buffer {
@@ -317,6 +326,8 @@ struct io_ring_ctx {
                spinlock_t              inflight_lock;
                struct list_head        inflight_list;
        } ____cacheline_aligned_in_smp;
+
+       struct work_struct              exit_work;
 };
 
 /*
@@ -599,6 +610,7 @@ struct io_kiocb {
        };
 
        struct io_async_ctx             *io;
+       int                             cflags;
        bool                            needs_fixed_file;
        u8                              opcode;
 
@@ -606,10 +618,8 @@ struct io_kiocb {
        struct list_head        list;
        unsigned int            flags;
        refcount_t              refs;
-       union {
-               struct task_struct      *task;
-               unsigned long           fsize;
-       };
+       struct task_struct      *task;
+       unsigned long           fsize;
        u64                     user_data;
        u32                     result;
        u32                     sequence;
@@ -618,6 +628,8 @@ struct io_kiocb {
 
        struct list_head        inflight_entry;
 
+       struct percpu_ref       *fixed_file_refs;
+
        union {
                /*
                 * Only commands that never go async can use the below fields,
@@ -629,7 +641,6 @@ struct io_kiocb {
                        struct callback_head    task_work;
                        struct hlist_node       hash_node;
                        struct async_poll       *apoll;
-                       int                     cflags;
                };
                struct io_wq_work       work;
        };
@@ -848,7 +859,6 @@ static int __io_sqe_files_update(struct io_ring_ctx *ctx,
                                 struct io_uring_files_update *ip,
                                 unsigned nr_args);
 static int io_grab_files(struct io_kiocb *req);
-static void io_ring_file_ref_flush(struct fixed_file_data *data);
 static void io_cleanup_req(struct io_kiocb *req);
 static int io_file_get(struct io_submit_state *state, struct io_kiocb *req,
                       int fd, struct file **out_file, bool fixed);
@@ -1285,8 +1295,8 @@ static struct io_kiocb *io_get_fallback_req(struct io_ring_ctx *ctx)
        return NULL;
 }
 
-static struct io_kiocb *io_get_req(struct io_ring_ctx *ctx,
-                                  struct io_submit_state *state)
+static struct io_kiocb *io_alloc_req(struct io_ring_ctx *ctx,
+                                    struct io_submit_state *state)
 {
        gfp_t gfp = GFP_KERNEL | __GFP_NOWARN;
        struct io_kiocb *req;
@@ -1319,41 +1329,20 @@ static struct io_kiocb *io_get_req(struct io_ring_ctx *ctx,
                req = state->reqs[state->free_reqs];
        }
 
-got_it:
-       req->io = NULL;
-       req->file = NULL;
-       req->ctx = ctx;
-       req->flags = 0;
-       /* one is dropped after submission, the other at completion */
-       refcount_set(&req->refs, 2);
-       req->result = 0;
-       INIT_IO_WORK(&req->work, io_wq_submit_work);
        return req;
 fallback:
-       req = io_get_fallback_req(ctx);
-       if (req)
-               goto got_it;
-       percpu_ref_put(&ctx->refs);
-       return NULL;
+       return io_get_fallback_req(ctx);
 }
 
 static inline void io_put_file(struct io_kiocb *req, struct file *file,
                          bool fixed)
 {
        if (fixed)
-               percpu_ref_put(&req->ctx->file_data->refs);
+               percpu_ref_put(req->fixed_file_refs);
        else
                fput(file);
 }
 
-static void __io_req_do_free(struct io_kiocb *req)
-{
-       if (likely(!io_is_fallback_req(req)))
-               kmem_cache_free(req_cachep, req);
-       else
-               clear_bit_unlock(0, (unsigned long *) req->ctx->fallback_req);
-}
-
 static void __io_req_aux_free(struct io_kiocb *req)
 {
        if (req->flags & REQ_F_NEED_CLEANUP)
@@ -1362,6 +1351,8 @@ static void __io_req_aux_free(struct io_kiocb *req)
        kfree(req->io);
        if (req->file)
                io_put_file(req, req->file, (req->flags & REQ_F_FIXED_FILE));
+       if (req->task)
+               put_task_struct(req->task);
 
        io_req_work_drop_env(req);
 }
@@ -1382,7 +1373,10 @@ static void __io_free_req(struct io_kiocb *req)
        }
 
        percpu_ref_put(&req->ctx->refs);
-       __io_req_do_free(req);
+       if (likely(!io_is_fallback_req(req)))
+               kmem_cache_free(req_cachep, req);
+       else
+               clear_bit_unlock(0, (unsigned long *) req->ctx->fallback_req);
 }
 
 struct req_batch {
@@ -1393,21 +1387,18 @@ struct req_batch {
 
 static void io_free_req_many(struct io_ring_ctx *ctx, struct req_batch *rb)
 {
-       int fixed_refs = rb->to_free;
-
        if (!rb->to_free)
                return;
        if (rb->need_iter) {
                int i, inflight = 0;
                unsigned long flags;
 
-               fixed_refs = 0;
                for (i = 0; i < rb->to_free; i++) {
                        struct io_kiocb *req = rb->reqs[i];
 
                        if (req->flags & REQ_F_FIXED_FILE) {
                                req->file = NULL;
-                               fixed_refs++;
+                               percpu_ref_put(req->fixed_file_refs);
                        }
                        if (req->flags & REQ_F_INFLIGHT)
                                inflight++;
@@ -1433,8 +1424,6 @@ static void io_free_req_many(struct io_ring_ctx *ctx, struct req_batch *rb)
        }
 do_free:
        kmem_cache_free_bulk(req_cachep, rb->to_free, rb->reqs);
-       if (fixed_refs)
-               percpu_ref_put_many(&ctx->file_data->refs, fixed_refs);
        percpu_ref_put_many(&ctx->refs, rb->to_free);
        rb->to_free = rb->need_iter = 0;
 }
@@ -1738,11 +1727,24 @@ static void io_iopoll_complete(struct io_ring_ctx *ctx, unsigned int *nr_events,
        io_free_req_many(ctx, &rb);
 }
 
+static void io_iopoll_queue(struct list_head *again)
+{
+       struct io_kiocb *req;
+
+       do {
+               req = list_first_entry(again, struct io_kiocb, list);
+               list_del(&req->list);
+               refcount_inc(&req->refs);
+               io_queue_async_work(req);
+       } while (!list_empty(again));
+}
+
 static int io_do_iopoll(struct io_ring_ctx *ctx, unsigned int *nr_events,
                        long min)
 {
        struct io_kiocb *req, *tmp;
        LIST_HEAD(done);
+       LIST_HEAD(again);
        bool spin;
        int ret;
 
@@ -1757,9 +1759,9 @@ static int io_do_iopoll(struct io_ring_ctx *ctx, unsigned int *nr_events,
                struct kiocb *kiocb = &req->rw.kiocb;
 
                /*
-                * Move completed entries to our local list. If we find a
-                * request that requires polling, break out and complete
-                * the done list first, if we have entries there.
+                * Move completed and retryable entries to our local lists.
+                * If we find a request that requires polling, break out
+                * and complete those lists first, if we have entries there.
                 */
                if (req->flags & REQ_F_IOPOLL_COMPLETED) {
                        list_move_tail(&req->list, &done);
@@ -1768,6 +1770,13 @@ static int io_do_iopoll(struct io_ring_ctx *ctx, unsigned int *nr_events,
                if (!list_empty(&done))
                        break;
 
+               if (req->result == -EAGAIN) {
+                       list_move_tail(&req->list, &again);
+                       continue;
+               }
+               if (!list_empty(&again))
+                       break;
+
                ret = kiocb->ki_filp->f_op->iopoll(kiocb, spin);
                if (ret < 0)
                        break;
@@ -1780,6 +1789,9 @@ static int io_do_iopoll(struct io_ring_ctx *ctx, unsigned int *nr_events,
        if (!list_empty(&done))
                io_iopoll_complete(ctx, nr_events, &done);
 
+       if (!list_empty(&again))
+               io_iopoll_queue(&again);
+
        return ret;
 }
 
@@ -2465,8 +2477,9 @@ static void io_req_map_rw(struct io_kiocb *req, ssize_t io_size,
        req->io->rw.iov = iovec;
        if (!req->io->rw.iov) {
                req->io->rw.iov = req->io->rw.fast_iov;
-               memcpy(req->io->rw.iov, fast_iov,
-                       sizeof(struct iovec) * iter->nr_segs);
+               if (req->io->rw.iov != fast_iov)
+                       memcpy(req->io->rw.iov, fast_iov,
+                              sizeof(struct iovec) * iter->nr_segs);
        } else {
                req->flags |= REQ_F_NEED_CLEANUP;
        }
@@ -2920,7 +2933,7 @@ static int io_openat_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 
        if (sqe->ioprio || sqe->buf_index)
                return -EINVAL;
-       if (sqe->flags & IOSQE_FIXED_FILE)
+       if (req->flags & REQ_F_FIXED_FILE)
                return -EBADF;
        if (req->flags & REQ_F_NEED_CLEANUP)
                return 0;
@@ -2929,6 +2942,8 @@ static int io_openat_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
        req->open.how.mode = READ_ONCE(sqe->len);
        fname = u64_to_user_ptr(READ_ONCE(sqe->addr));
        req->open.how.flags = READ_ONCE(sqe->open_flags);
+       if (force_o_largefile())
+               req->open.how.flags |= O_LARGEFILE;
 
        req->open.filename = getname(fname);
        if (IS_ERR(req->open.filename)) {
@@ -2951,7 +2966,7 @@ static int io_openat2_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 
        if (sqe->ioprio || sqe->buf_index)
                return -EINVAL;
-       if (sqe->flags & IOSQE_FIXED_FILE)
+       if (req->flags & REQ_F_FIXED_FILE)
                return -EBADF;
        if (req->flags & REQ_F_NEED_CLEANUP)
                return 0;
@@ -3305,7 +3320,7 @@ static int io_statx_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 
        if (sqe->ioprio || sqe->buf_index)
                return -EINVAL;
-       if (sqe->flags & IOSQE_FIXED_FILE)
+       if (req->flags & REQ_F_FIXED_FILE)
                return -EBADF;
        if (req->flags & REQ_F_NEED_CLEANUP)
                return 0;
@@ -3382,7 +3397,7 @@ static int io_close_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
        if (sqe->ioprio || sqe->off || sqe->addr || sqe->len ||
            sqe->rw_flags || sqe->buf_index)
                return -EINVAL;
-       if (sqe->flags & IOSQE_FIXED_FILE)
+       if (req->flags & REQ_F_FIXED_FILE)
                return -EBADF;
 
        req->close.fd = READ_ONCE(sqe->fd);
@@ -3481,14 +3496,11 @@ static void __io_sync_file_range(struct io_kiocb *req)
 static void io_sync_file_range_finish(struct io_wq_work **workptr)
 {
        struct io_kiocb *req = container_of(*workptr, struct io_kiocb, work);
-       struct io_kiocb *nxt = NULL;
 
        if (io_req_cancelled(req))
                return;
        __io_sync_file_range(req);
        io_put_req(req); /* put submission ref */
-       if (nxt)
-               io_wq_assign_next(workptr, nxt);
 }
 
 static int io_sync_file_range(struct io_kiocb *req, bool force_nonblock)
@@ -4114,6 +4126,7 @@ static int __io_async_wake(struct io_kiocb *req, struct io_poll_iocb *poll,
                           __poll_t mask, task_work_func_t func)
 {
        struct task_struct *tsk;
+       int ret;
 
        /* for instances that support it check for an event match first: */
        if (mask && !(mask & poll->events))
@@ -4127,11 +4140,15 @@ static int __io_async_wake(struct io_kiocb *req, struct io_poll_iocb *poll,
        req->result = mask;
        init_task_work(&req->task_work, func);
        /*
-        * If this fails, then the task is exiting. If that is the case, then
-        * the exit check will ultimately cancel these work items. Hence we
-        * don't need to check here and handle it specifically.
+        * If this fails, then the task is exiting. Punt to one of the io-wq
+        * threads to ensure the work gets run, we can't always rely on exit
+        * cancelation taking care of this.
         */
-       task_work_add(tsk, &req->task_work, true);
+       ret = task_work_add(tsk, &req->task_work, true);
+       if (unlikely(ret)) {
+               tsk = io_wq_get_task(req->ctx->io_wq);
+               task_work_add(tsk, &req->task_work, true);
+       }
        wake_up_process(tsk);
        return 1;
 }
@@ -4251,10 +4268,7 @@ static bool io_arm_poll_handler(struct io_kiocb *req)
        req->flags |= REQ_F_POLLED;
        memcpy(&apoll->work, &req->work, sizeof(req->work));
 
-       /*
-        * Don't need a reference here, as we're adding it to the task
-        * task_works list. If the task exits, the list is pruned.
-        */
+       get_task_struct(current);
        req->task = current;
        req->apoll = apoll;
        INIT_HLIST_NODE(&req->hash_node);
@@ -4407,8 +4421,20 @@ static void io_poll_complete(struct io_kiocb *req, __poll_t mask, int error)
 static void io_poll_task_handler(struct io_kiocb *req, struct io_kiocb **nxt)
 {
        struct io_ring_ctx *ctx = req->ctx;
+       struct io_poll_iocb *poll = &req->poll;
+
+       if (!req->result && !READ_ONCE(poll->canceled)) {
+               struct poll_table_struct pt = { ._key = poll->events };
+
+               req->result = vfs_poll(req->file, &pt) & poll->events;
+       }
 
        spin_lock_irq(&ctx->completion_lock);
+       if (!req->result && !READ_ONCE(poll->canceled)) {
+               add_wait_queue(poll->head, &poll->wait);
+               spin_unlock_irq(&ctx->completion_lock);
+               return;
+       }
        hash_del(&req->hash_node);
        io_poll_complete(req, req->result, 0);
        req->flags |= REQ_F_COMP_LOCKED;
@@ -4465,10 +4491,7 @@ static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe
        events = READ_ONCE(sqe->poll_events);
        poll->events = demangle_poll(events) | EPOLLERR | EPOLLHUP;
 
-       /*
-        * Don't need a reference here, as we're adding it to the task
-        * task_works list. If the task exits, the list is pruned.
-        */
+       get_task_struct(current);
        req->task = current;
        return 0;
 }
@@ -5331,7 +5354,8 @@ static int io_file_get(struct io_submit_state *state, struct io_kiocb *req,
                file = io_file_from_index(ctx, fd);
                if (!file)
                        return -EBADF;
-               percpu_ref_get(&ctx->file_data->refs);
+               req->fixed_file_refs = ctx->file_data->cur_refs;
+               percpu_ref_get(req->fixed_file_refs);
        } else {
                trace_io_uring_file_get(ctx, fd);
                file = __io_file_get(state, fd);
@@ -5344,15 +5368,10 @@ static int io_file_get(struct io_submit_state *state, struct io_kiocb *req,
 }
 
 static int io_req_set_file(struct io_submit_state *state, struct io_kiocb *req,
-                          const struct io_uring_sqe *sqe)
+                          int fd, unsigned int flags)
 {
-       unsigned flags;
-       int fd;
        bool fixed;
 
-       flags = READ_ONCE(sqe->flags);
-       fd = READ_ONCE(sqe->fd);
-
        if (!io_req_needs_file(req, fd))
                return 0;
 
@@ -5594,7 +5613,7 @@ static bool io_submit_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe,
 {
        struct io_ring_ctx *ctx = req->ctx;
        unsigned int sqe_flags;
-       int ret, id;
+       int ret, id, fd;
 
        sqe_flags = READ_ONCE(sqe->flags);
 
@@ -5625,7 +5644,8 @@ static bool io_submit_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe,
                                        IOSQE_ASYNC | IOSQE_FIXED_FILE |
                                        IOSQE_BUFFER_SELECT);
 
-       ret = io_req_set_file(state, req, sqe);
+       fd = READ_ONCE(sqe->fd);
+       ret = io_req_set_file(state, req, fd, sqe_flags);
        if (unlikely(ret)) {
 err_req:
                io_cqring_add_event(req, ret);
@@ -5741,8 +5761,7 @@ static void io_commit_sqring(struct io_ring_ctx *ctx)
  * used, it's important that those reads are done through READ_ONCE() to
  * prevent a re-load down the line.
  */
-static bool io_get_sqring(struct io_ring_ctx *ctx, struct io_kiocb *req,
-                         const struct io_uring_sqe **sqe_ptr)
+static const struct io_uring_sqe *io_get_sqe(struct io_ring_ctx *ctx)
 {
        u32 *sq_array = ctx->sq_array;
        unsigned head;
@@ -5756,25 +5775,40 @@ static bool io_get_sqring(struct io_ring_ctx *ctx, struct io_kiocb *req,
         *    though the application is the one updating it.
         */
        head = READ_ONCE(sq_array[ctx->cached_sq_head & ctx->sq_mask]);
-       if (likely(head < ctx->sq_entries)) {
-               /*
-                * All io need record the previous position, if LINK vs DARIN,
-                * it can be used to mark the position of the first IO in the
-                * link list.
-                */
-               req->sequence = ctx->cached_sq_head;
-               *sqe_ptr = &ctx->sq_sqes[head];
-               req->opcode = READ_ONCE((*sqe_ptr)->opcode);
-               req->user_data = READ_ONCE((*sqe_ptr)->user_data);
-               ctx->cached_sq_head++;
-               return true;
-       }
+       if (likely(head < ctx->sq_entries))
+               return &ctx->sq_sqes[head];
 
        /* drop invalid entries */
-       ctx->cached_sq_head++;
        ctx->cached_sq_dropped++;
        WRITE_ONCE(ctx->rings->sq_dropped, ctx->cached_sq_dropped);
-       return false;
+       return NULL;
+}
+
+static inline void io_consume_sqe(struct io_ring_ctx *ctx)
+{
+       ctx->cached_sq_head++;
+}
+
+static void io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
+                       const struct io_uring_sqe *sqe)
+{
+       /*
+        * All io need record the previous position, if LINK vs DARIN,
+        * it can be used to mark the position of the first IO in the
+        * link list.
+        */
+       req->sequence = ctx->cached_sq_head;
+       req->opcode = READ_ONCE(sqe->opcode);
+       req->user_data = READ_ONCE(sqe->user_data);
+       req->io = NULL;
+       req->file = NULL;
+       req->ctx = ctx;
+       req->flags = 0;
+       /* one is dropped after submission, the other at completion */
+       refcount_set(&req->refs, 2);
+       req->task = NULL;
+       req->result = 0;
+       INIT_IO_WORK(&req->work, io_wq_submit_work);
 }
 
 static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr,
@@ -5812,17 +5846,20 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr,
                struct io_kiocb *req;
                int err;
 
-               req = io_get_req(ctx, statep);
+               sqe = io_get_sqe(ctx);
+               if (unlikely(!sqe)) {
+                       io_consume_sqe(ctx);
+                       break;
+               }
+               req = io_alloc_req(ctx, statep);
                if (unlikely(!req)) {
                        if (!submitted)
                                submitted = -EAGAIN;
                        break;
                }
-               if (!io_get_sqring(ctx, req, &sqe)) {
-                       __io_req_do_free(req);
-                       break;
-               }
 
+               io_init_req(ctx, req, sqe);
+               io_consume_sqe(ctx);
                /* will complete beyond this point, count as submitted */
                submitted++;
 
@@ -5962,6 +5999,7 @@ static int io_sq_thread(void *data)
                                }
                                if (current->task_works) {
                                        task_work_run();
+                                       finish_wait(&ctx->sqo_wait, &wait);
                                        continue;
                                }
                                if (signal_pending(current))
@@ -6124,43 +6162,36 @@ static void io_file_ref_kill(struct percpu_ref *ref)
        complete(&data->done);
 }
 
-static void io_file_ref_exit_and_free(struct work_struct *work)
-{
-       struct fixed_file_data *data;
-
-       data = container_of(work, struct fixed_file_data, ref_work);
-
-       /*
-        * Ensure any percpu-ref atomic switch callback has run, it could have
-        * been in progress when the files were being unregistered. Once
-        * that's done, we can safely exit and free the ref and containing
-        * data structure.
-        */
-       rcu_barrier();
-       percpu_ref_exit(&data->refs);
-       kfree(data);
-}
-
 static int io_sqe_files_unregister(struct io_ring_ctx *ctx)
 {
        struct fixed_file_data *data = ctx->file_data;
+       struct fixed_file_ref_node *ref_node = NULL;
        unsigned nr_tables, i;
+       unsigned long flags;
 
        if (!data)
                return -ENXIO;
 
-       percpu_ref_kill_and_confirm(&data->refs, io_file_ref_kill);
-       flush_work(&data->ref_work);
+       spin_lock_irqsave(&data->lock, flags);
+       if (!list_empty(&data->ref_list))
+               ref_node = list_first_entry(&data->ref_list,
+                               struct fixed_file_ref_node, node);
+       spin_unlock_irqrestore(&data->lock, flags);
+       if (ref_node)
+               percpu_ref_kill(&ref_node->refs);
+
+       percpu_ref_kill(&data->refs);
+
+       /* wait for all refs nodes to complete */
        wait_for_completion(&data->done);
-       io_ring_file_ref_flush(data);
 
        __io_sqe_files_unregister(ctx);
        nr_tables = DIV_ROUND_UP(ctx->nr_user_files, IORING_MAX_FILES_TABLE);
        for (i = 0; i < nr_tables; i++)
                kfree(data->table[i].files);
        kfree(data->table);
-       INIT_WORK(&data->ref_work, io_file_ref_exit_and_free);
-       queue_work(system_wq, &data->ref_work);
+       percpu_ref_exit(&data->refs);
+       kfree(data);
        ctx->file_data = NULL;
        ctx->nr_user_files = 0;
        return 0;
@@ -6204,13 +6235,6 @@ static int __io_sqe_files_scm(struct io_ring_ctx *ctx, int nr, int offset)
        struct sk_buff *skb;
        int i, nr_files;
 
-       if (!capable(CAP_SYS_RESOURCE) && !capable(CAP_SYS_ADMIN)) {
-               unsigned long inflight = ctx->user->unix_inflight + nr;
-
-               if (inflight > task_rlimit(current, RLIMIT_NOFILE))
-                       return -EMFILE;
-       }
-
        fpl = kzalloc(sizeof(*fpl), GFP_KERNEL);
        if (!fpl)
                return -ENOMEM;
@@ -6385,46 +6409,72 @@ static void io_ring_file_put(struct io_ring_ctx *ctx, struct file *file)
 }
 
 struct io_file_put {
-       struct llist_node llist;
+       struct list_head list;
        struct file *file;
 };
 
-static void io_ring_file_ref_flush(struct fixed_file_data *data)
+static void io_file_put_work(struct work_struct *work)
 {
+       struct fixed_file_ref_node *ref_node;
+       struct fixed_file_data *file_data;
+       struct io_ring_ctx *ctx;
        struct io_file_put *pfile, *tmp;
-       struct llist_node *node;
+       unsigned long flags;
 
-       while ((node = llist_del_all(&data->put_llist)) != NULL) {
-               llist_for_each_entry_safe(pfile, tmp, node, llist) {
-                       io_ring_file_put(data->ctx, pfile->file);
-                       kfree(pfile);
-               }
+       ref_node = container_of(work, struct fixed_file_ref_node, work);
+       file_data = ref_node->file_data;
+       ctx = file_data->ctx;
+
+       list_for_each_entry_safe(pfile, tmp, &ref_node->file_list, list) {
+               list_del_init(&pfile->list);
+               io_ring_file_put(ctx, pfile->file);
+               kfree(pfile);
        }
+
+       spin_lock_irqsave(&file_data->lock, flags);
+       list_del_init(&ref_node->node);
+       spin_unlock_irqrestore(&file_data->lock, flags);
+
+       percpu_ref_exit(&ref_node->refs);
+       kfree(ref_node);
+       percpu_ref_put(&file_data->refs);
 }
 
-static void io_ring_file_ref_switch(struct work_struct *work)
+static void io_file_data_ref_zero(struct percpu_ref *ref)
 {
-       struct fixed_file_data *data;
+       struct fixed_file_ref_node *ref_node;
+
+       ref_node = container_of(ref, struct fixed_file_ref_node, refs);
 
-       data = container_of(work, struct fixed_file_data, ref_work);
-       io_ring_file_ref_flush(data);
-       percpu_ref_switch_to_percpu(&data->refs);
+       queue_work(system_wq, &ref_node->work);
 }
 
-static void io_file_data_ref_zero(struct percpu_ref *ref)
+static struct fixed_file_ref_node *alloc_fixed_file_ref_node(
+                       struct io_ring_ctx *ctx)
 {
-       struct fixed_file_data *data;
+       struct fixed_file_ref_node *ref_node;
 
-       data = container_of(ref, struct fixed_file_data, refs);
+       ref_node = kzalloc(sizeof(*ref_node), GFP_KERNEL);
+       if (!ref_node)
+               return ERR_PTR(-ENOMEM);
+
+       if (percpu_ref_init(&ref_node->refs, io_file_data_ref_zero,
+                           0, GFP_KERNEL)) {
+               kfree(ref_node);
+               return ERR_PTR(-ENOMEM);
+       }
+       INIT_LIST_HEAD(&ref_node->node);
+       INIT_LIST_HEAD(&ref_node->file_list);
+       INIT_WORK(&ref_node->work, io_file_put_work);
+       ref_node->file_data = ctx->file_data;
+       return ref_node;
 
-       /*
-        * We can't safely switch from inside this context, punt to wq. If
-        * the table ref is going away, the table is being unregistered.
-        * Don't queue up the async work for that case, the caller will
-        * handle it.
-        */
-       if (!percpu_ref_is_dying(&data->refs))
-               queue_work(system_wq, &data->ref_work);
+}
+
+static void destroy_fixed_file_ref_node(struct fixed_file_ref_node *ref_node)
+{
+       percpu_ref_exit(&ref_node->refs);
+       kfree(ref_node);
 }
 
 static int io_sqe_files_register(struct io_ring_ctx *ctx, void __user *arg,
@@ -6435,6 +6485,8 @@ static int io_sqe_files_register(struct io_ring_ctx *ctx, void __user *arg,
        struct file *file;
        int fd, ret = 0;
        unsigned i;
+       struct fixed_file_ref_node *ref_node;
+       unsigned long flags;
 
        if (ctx->file_data)
                return -EBUSY;
@@ -6448,6 +6500,8 @@ static int io_sqe_files_register(struct io_ring_ctx *ctx, void __user *arg,
                return -ENOMEM;
        ctx->file_data->ctx = ctx;
        init_completion(&ctx->file_data->done);
+       INIT_LIST_HEAD(&ctx->file_data->ref_list);
+       spin_lock_init(&ctx->file_data->lock);
 
        nr_tables = DIV_ROUND_UP(nr_args, IORING_MAX_FILES_TABLE);
        ctx->file_data->table = kcalloc(nr_tables,
@@ -6459,15 +6513,13 @@ static int io_sqe_files_register(struct io_ring_ctx *ctx, void __user *arg,
                return -ENOMEM;
        }
 
-       if (percpu_ref_init(&ctx->file_data->refs, io_file_data_ref_zero,
+       if (percpu_ref_init(&ctx->file_data->refs, io_file_ref_kill,
                                PERCPU_REF_ALLOW_REINIT, GFP_KERNEL)) {
                kfree(ctx->file_data->table);
                kfree(ctx->file_data);
                ctx->file_data = NULL;
                return -ENOMEM;
        }
-       ctx->file_data->put_llist.first = NULL;
-       INIT_WORK(&ctx->file_data->ref_work, io_ring_file_ref_switch);
 
        if (io_sqe_alloc_file_tables(ctx, nr_tables, nr_args)) {
                percpu_ref_exit(&ctx->file_data->refs);
@@ -6530,9 +6582,22 @@ static int io_sqe_files_register(struct io_ring_ctx *ctx, void __user *arg,
        }
 
        ret = io_sqe_files_scm(ctx);
-       if (ret)
+       if (ret) {
                io_sqe_files_unregister(ctx);
+               return ret;
+       }
+
+       ref_node = alloc_fixed_file_ref_node(ctx);
+       if (IS_ERR(ref_node)) {
+               io_sqe_files_unregister(ctx);
+               return PTR_ERR(ref_node);
+       }
 
+       ctx->file_data->cur_refs = &ref_node->refs;
+       spin_lock_irqsave(&ctx->file_data->lock, flags);
+       list_add(&ref_node->node, &ctx->file_data->ref_list);
+       spin_unlock_irqrestore(&ctx->file_data->lock, flags);
+       percpu_ref_get(&ctx->file_data->refs);
        return ret;
 }
 
@@ -6579,30 +6644,21 @@ static int io_sqe_file_register(struct io_ring_ctx *ctx, struct file *file,
 #endif
 }
 
-static void io_atomic_switch(struct percpu_ref *ref)
-{
-       struct fixed_file_data *data;
-
-       /*
-        * Juggle reference to ensure we hit zero, if needed, so we can
-        * switch back to percpu mode
-        */
-       data = container_of(ref, struct fixed_file_data, refs);
-       percpu_ref_put(&data->refs);
-       percpu_ref_get(&data->refs);
-}
-
 static int io_queue_file_removal(struct fixed_file_data *data,
-                                 struct file *file)
+                                struct file *file)
 {
        struct io_file_put *pfile;
+       struct percpu_ref *refs = data->cur_refs;
+       struct fixed_file_ref_node *ref_node;
 
        pfile = kzalloc(sizeof(*pfile), GFP_KERNEL);
        if (!pfile)
                return -ENOMEM;
 
+       ref_node = container_of(refs, struct fixed_file_ref_node, refs);
        pfile->file = file;
-       llist_add(&pfile->llist, &data->put_llist);
+       list_add(&pfile->list, &ref_node->file_list);
+
        return 0;
 }
 
@@ -6611,17 +6667,23 @@ static int __io_sqe_files_update(struct io_ring_ctx *ctx,
                                 unsigned nr_args)
 {
        struct fixed_file_data *data = ctx->file_data;
-       bool ref_switch = false;
+       struct fixed_file_ref_node *ref_node;
        struct file *file;
        __s32 __user *fds;
        int fd, i, err;
        __u32 done;
+       unsigned long flags;
+       bool needs_switch = false;
 
        if (check_add_overflow(up->offset, nr_args, &done))
                return -EOVERFLOW;
        if (done > ctx->nr_user_files)
                return -EINVAL;
 
+       ref_node = alloc_fixed_file_ref_node(ctx);
+       if (IS_ERR(ref_node))
+               return PTR_ERR(ref_node);
+
        done = 0;
        fds = u64_to_user_ptr(up->fds);
        while (nr_args) {
@@ -6642,7 +6704,7 @@ static int __io_sqe_files_update(struct io_ring_ctx *ctx,
                        if (err)
                                break;
                        table->files[index] = NULL;
-                       ref_switch = true;
+                       needs_switch = true;
                }
                if (fd != -1) {
                        file = fget(fd);
@@ -6673,11 +6735,19 @@ static int __io_sqe_files_update(struct io_ring_ctx *ctx,
                up->offset++;
        }
 
-       if (ref_switch)
-               percpu_ref_switch_to_atomic(&data->refs, io_atomic_switch);
+       if (needs_switch) {
+               percpu_ref_kill(data->cur_refs);
+               spin_lock_irqsave(&data->lock, flags);
+               list_add(&ref_node->node, &data->ref_list);
+               data->cur_refs = &ref_node->refs;
+               spin_unlock_irqrestore(&data->lock, flags);
+               percpu_ref_get(&ctx->file_data->refs);
+       } else
+               destroy_fixed_file_ref_node(ref_node);
 
        return done ? done : err;
 }
+
 static int io_sqe_files_update(struct io_ring_ctx *ctx, void __user *arg,
                               unsigned nr_args)
 {
@@ -7203,6 +7273,18 @@ static int io_remove_personalities(int id, void *p, void *data)
        return 0;
 }
 
+static void io_ring_exit_work(struct work_struct *work)
+{
+       struct io_ring_ctx *ctx;
+
+       ctx = container_of(work, struct io_ring_ctx, exit_work);
+       if (ctx->rings)
+               io_cqring_overflow_flush(ctx, true);
+
+       wait_for_completion(&ctx->completions[0]);
+       io_ring_ctx_free(ctx);
+}
+
 static void io_ring_ctx_wait_and_kill(struct io_ring_ctx *ctx)
 {
        mutex_lock(&ctx->uring_lock);
@@ -7230,8 +7312,8 @@ static void io_ring_ctx_wait_and_kill(struct io_ring_ctx *ctx)
        if (ctx->rings)
                io_cqring_overflow_flush(ctx, true);
        idr_for_each(&ctx->personality_idr, io_remove_personalities, ctx);
-       wait_for_completion(&ctx->completions[0]);
-       io_ring_ctx_free(ctx);
+       INIT_WORK(&ctx->exit_work, io_ring_exit_work);
+       queue_work(system_wq, &ctx->exit_work);
 }
 
 static int io_uring_release(struct inode *inode, struct file *file)