io_uring: allow non-fixed files with SQPOLL
authorJens Axboe <axboe@kernel.dk>
Mon, 14 Sep 2020 16:51:17 +0000 (10:51 -0600)
committerJens Axboe <axboe@kernel.dk>
Wed, 9 Dec 2020 19:03:54 +0000 (12:03 -0700)
The restriction of needing fixed files for SQPOLL is problematic, and
prevents/inhibits several valid uses cases. With the referenced
files_struct that we have now, it's trivially supportable.

Treat ->files like we do the mm for the SQPOLL thread - grab a reference
to it (and assign it), and drop it when we're done.

This feature is exposed as IORING_FEAT_SQPOLL_NONFIXED.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
fs/io_uring.c
include/uapi/linux/io_uring.h

index d171987..c1f3980 100644 (file)
@@ -999,8 +999,9 @@ static inline void io_clean_op(struct io_kiocb *req)
                __io_clean_op(req);
 }
 
-static void io_sq_thread_drop_mm(void)
+static void io_sq_thread_drop_mm_files(void)
 {
+       struct files_struct *files = current->files;
        struct mm_struct *mm = current->mm;
 
        if (mm) {
@@ -1008,6 +1009,40 @@ static void io_sq_thread_drop_mm(void)
                mmput(mm);
                current->mm = NULL;
        }
+       if (files) {
+               struct nsproxy *nsproxy = current->nsproxy;
+
+               task_lock(current);
+               current->files = NULL;
+               current->nsproxy = NULL;
+               task_unlock(current);
+               put_files_struct(files);
+               put_nsproxy(nsproxy);
+       }
+}
+
+static void __io_sq_thread_acquire_files(struct io_ring_ctx *ctx)
+{
+       if (!current->files) {
+               struct files_struct *files;
+               struct nsproxy *nsproxy;
+
+               task_lock(ctx->sqo_task);
+               files = ctx->sqo_task->files;
+               if (!files) {
+                       task_unlock(ctx->sqo_task);
+                       return;
+               }
+               atomic_inc(&files->count);
+               get_nsproxy(ctx->sqo_task->nsproxy);
+               nsproxy = ctx->sqo_task->nsproxy;
+               task_unlock(ctx->sqo_task);
+
+               task_lock(current);
+               current->files = files;
+               current->nsproxy = nsproxy;
+               task_unlock(current);
+       }
 }
 
 static int __io_sq_thread_acquire_mm(struct io_ring_ctx *ctx)
@@ -1035,12 +1070,21 @@ static int __io_sq_thread_acquire_mm(struct io_ring_ctx *ctx)
        return -EFAULT;
 }
 
-static int io_sq_thread_acquire_mm(struct io_ring_ctx *ctx,
-                                  struct io_kiocb *req)
+static int io_sq_thread_acquire_mm_files(struct io_ring_ctx *ctx,
+                                        struct io_kiocb *req)
 {
-       if (!(io_op_defs[req->opcode].work_flags & IO_WQ_WORK_MM))
-               return 0;
-       return __io_sq_thread_acquire_mm(ctx);
+       const struct io_op_def *def = &io_op_defs[req->opcode];
+
+       if (def->work_flags & IO_WQ_WORK_MM) {
+               int ret = __io_sq_thread_acquire_mm(ctx);
+               if (unlikely(ret))
+                       return ret;
+       }
+
+       if (def->needs_file || (def->work_flags & IO_WQ_WORK_FILES))
+               __io_sq_thread_acquire_files(ctx);
+
+       return 0;
 }
 
 static void io_sq_thread_associate_blkcg(struct io_ring_ctx *ctx,
@@ -2061,6 +2105,7 @@ static void __io_req_task_submit(struct io_kiocb *req)
        struct io_ring_ctx *ctx = req->ctx;
 
        if (!__io_sq_thread_acquire_mm(ctx)) {
+               __io_sq_thread_acquire_files(ctx);
                mutex_lock(&ctx->uring_lock);
                __io_queue_sqe(req, NULL);
                mutex_unlock(&ctx->uring_lock);
@@ -2603,7 +2648,7 @@ static bool io_rw_reissue(struct io_kiocb *req, long res)
        if ((res != -EAGAIN && res != -EOPNOTSUPP) || io_wq_current_is_worker())
                return false;
 
-       ret = io_sq_thread_acquire_mm(req->ctx, req);
+       ret = io_sq_thread_acquire_mm_files(req->ctx, req);
 
        if (io_resubmit_prep(req, ret)) {
                refcount_inc(&req->refs);
@@ -6168,13 +6213,7 @@ static struct file *io_file_get(struct io_submit_state *state,
 static int io_req_set_file(struct io_submit_state *state, struct io_kiocb *req,
                           int fd)
 {
-       bool fixed;
-
-       fixed = (req->flags & REQ_F_FIXED_FILE) != 0;
-       if (unlikely(!fixed && io_async_submit(req->ctx)))
-               return -EBADF;
-
-       req->file = io_file_get(state, req, fd, fixed);
+       req->file = io_file_get(state, req, fd, req->flags & REQ_F_FIXED_FILE);
        if (req->file || io_op_defs[req->opcode].needs_file_no_error)
                return 0;
        return -EBADF;
@@ -6551,7 +6590,7 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
        if (unlikely(req->opcode >= IORING_OP_LAST))
                return -EINVAL;
 
-       if (unlikely(io_sq_thread_acquire_mm(ctx, req)))
+       if (unlikely(io_sq_thread_acquire_mm_files(ctx, req)))
                return -EFAULT;
 
        sqe_flags = READ_ONCE(sqe->flags);
@@ -6739,7 +6778,7 @@ again:
                 * adding ourselves to the waitqueue, as the unuse/drop
                 * may sleep.
                 */
-               io_sq_thread_drop_mm();
+               io_sq_thread_drop_mm_files();
 
                /*
                 * We're polling. If we're within the defined idle
@@ -6808,11 +6847,18 @@ static void io_sqd_init_new(struct io_sq_data *sqd)
 static int io_sq_thread(void *data)
 {
        struct cgroup_subsys_state *cur_css = NULL;
+       struct files_struct *old_files = current->files;
+       struct nsproxy *old_nsproxy = current->nsproxy;
        const struct cred *old_cred = NULL;
        struct io_sq_data *sqd = data;
        struct io_ring_ctx *ctx;
        unsigned long start_jiffies;
 
+       task_lock(current);
+       current->files = NULL;
+       current->nsproxy = NULL;
+       task_unlock(current);
+
        start_jiffies = jiffies;
        while (!kthread_should_stop()) {
                enum sq_ret ret = 0;
@@ -6845,7 +6891,7 @@ static int io_sq_thread(void *data)
 
                        ret |= __io_sq_thread(ctx, start_jiffies, cap_entries);
 
-                       io_sq_thread_drop_mm();
+                       io_sq_thread_drop_mm_files();
                }
 
                if (ret & SQT_SPIN) {
@@ -6870,6 +6916,11 @@ static int io_sq_thread(void *data)
        if (old_cred)
                revert_creds(old_cred);
 
+       task_lock(current);
+       current->files = old_files;
+       current->nsproxy = old_nsproxy;
+       task_unlock(current);
+
        kthread_parkme();
 
        return 0;
@@ -9415,7 +9466,7 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p,
        p->features = IORING_FEAT_SINGLE_MMAP | IORING_FEAT_NODROP |
                        IORING_FEAT_SUBMIT_STABLE | IORING_FEAT_RW_CUR_POS |
                        IORING_FEAT_CUR_PERSONALITY | IORING_FEAT_FAST_POLL |
-                       IORING_FEAT_POLL_32BITS;
+                       IORING_FEAT_POLL_32BITS | IORING_FEAT_SQPOLL_NONFIXED;
 
        if (copy_to_user(params, p, sizeof(*p))) {
                ret = -EFAULT;
index e943bf0..2301c37 100644 (file)
@@ -254,6 +254,7 @@ struct io_uring_params {
 #define IORING_FEAT_CUR_PERSONALITY    (1U << 4)
 #define IORING_FEAT_FAST_POLL          (1U << 5)
 #define IORING_FEAT_POLL_32BITS        (1U << 6)
+#define IORING_FEAT_SQPOLL_NONFIXED    (1U << 7)
 
 /*
  * io_uring_register(2) opcodes and arguments