io_uring: fix linked command file table usage
[linux-2.6-microblaze.git] / fs / io-wq.c
index 541c8a3..cb60a42 100644 (file)
@@ -56,7 +56,8 @@ struct io_worker {
 
        struct rcu_head rcu;
        struct mm_struct *mm;
-       const struct cred *creds;
+       const struct cred *cur_creds;
+       const struct cred *saved_creds;
        struct files_struct *restore_files;
 };
 
@@ -109,10 +110,10 @@ struct io_wq {
 
        struct task_struct *manager;
        struct user_struct *user;
-       const struct cred *creds;
-       struct mm_struct *mm;
        refcount_t refs;
        struct completion done;
+
+       refcount_t use_refs;
 };
 
 static bool io_worker_get(struct io_worker *worker)
@@ -135,9 +136,9 @@ static bool __io_worker_unuse(struct io_wqe *wqe, struct io_worker *worker)
 {
        bool dropped_lock = false;
 
-       if (worker->creds) {
-               revert_creds(worker->creds);
-               worker->creds = NULL;
+       if (worker->saved_creds) {
+               revert_creds(worker->saved_creds);
+               worker->cur_creds = worker->saved_creds = NULL;
        }
 
        if (current->files != worker->restore_files) {
@@ -396,6 +397,43 @@ static struct io_wq_work *io_get_next_work(struct io_wqe *wqe, unsigned *hash)
        return NULL;
 }
 
+static void io_wq_switch_mm(struct io_worker *worker, struct io_wq_work *work)
+{
+       if (worker->mm) {
+               unuse_mm(worker->mm);
+               mmput(worker->mm);
+               worker->mm = NULL;
+       }
+       if (!work->mm) {
+               set_fs(KERNEL_DS);
+               return;
+       }
+       if (mmget_not_zero(work->mm)) {
+               use_mm(work->mm);
+               if (!worker->mm)
+                       set_fs(USER_DS);
+               worker->mm = work->mm;
+               /* hang on to this mm */
+               work->mm = NULL;
+               return;
+       }
+
+       /* failed grabbing mm, ensure work gets cancelled */
+       work->flags |= IO_WQ_WORK_CANCEL;
+}
+
+static void io_wq_switch_creds(struct io_worker *worker,
+                              struct io_wq_work *work)
+{
+       const struct cred *old_creds = override_creds(work->creds);
+
+       worker->cur_creds = work->creds;
+       if (worker->saved_creds)
+               put_cred(old_creds); /* creds set by previous switch */
+       else
+               worker->saved_creds = old_creds;
+}
+
 static void io_worker_handle_work(struct io_worker *worker)
        __releases(wqe->lock)
 {
@@ -438,20 +476,19 @@ next:
                if (work->flags & IO_WQ_WORK_CB)
                        work->func(&work);
 
-               if ((work->flags & IO_WQ_WORK_NEEDS_FILES) &&
-                   current->files != work->files) {
+               if (work->files && current->files != work->files) {
                        task_lock(current);
                        current->files = work->files;
                        task_unlock(current);
                }
-               if ((work->flags & IO_WQ_WORK_NEEDS_USER) && !worker->mm &&
-                   wq->mm && mmget_not_zero(wq->mm)) {
-                       use_mm(wq->mm);
-                       set_fs(USER_DS);
-                       worker->mm = wq->mm;
-               }
-               if (!worker->creds)
-                       worker->creds = override_creds(wq->creds);
+               if (work->mm != worker->mm)
+                       io_wq_switch_mm(worker, work);
+               if (worker->cur_creds != work->creds)
+                       io_wq_switch_creds(worker, work);
+               /*
+                * OK to set IO_WQ_WORK_CANCEL even for uncancellable work,
+                * the worker function will do the right thing.
+                */
                if (test_bit(IO_WQ_BIT_CANCEL, &wq->state))
                        work->flags |= IO_WQ_WORK_CANCEL;
                if (worker->mm)
@@ -716,6 +753,7 @@ static bool io_wq_can_queue(struct io_wqe *wqe, struct io_wqe_acct *acct,
 static void io_wqe_enqueue(struct io_wqe *wqe, struct io_wq_work *work)
 {
        struct io_wqe_acct *acct = io_work_get_acct(wqe, work);
+       int work_flags;
        unsigned long flags;
 
        /*
@@ -730,12 +768,14 @@ static void io_wqe_enqueue(struct io_wqe *wqe, struct io_wq_work *work)
                return;
        }
 
+       work_flags = work->flags;
        spin_lock_irqsave(&wqe->lock, flags);
        wq_list_add_tail(&work->list, &wqe->work_list);
        wqe->flags &= ~IO_WQE_FLAG_STALLED;
        spin_unlock_irqrestore(&wqe->lock, flags);
 
-       if (!atomic_read(&acct->nr_running))
+       if ((work_flags & IO_WQ_WORK_CONCURRENT) ||
+           !atomic_read(&acct->nr_running))
                io_wqe_wake_worker(wqe, acct);
 }
 
@@ -824,6 +864,7 @@ static bool io_work_cancel(struct io_worker *worker, void *cancel_data)
         */
        spin_lock_irqsave(&worker->lock, flags);
        if (worker->cur_work &&
+           !(worker->cur_work->flags & IO_WQ_WORK_NO_CANCEL) &&
            data->cancel(worker->cur_work, data->caller_data)) {
                send_sig(SIGINT, worker->task, 1);
                ret = true;
@@ -898,7 +939,8 @@ static bool io_wq_worker_cancel(struct io_worker *worker, void *data)
                return false;
 
        spin_lock_irqsave(&worker->lock, flags);
-       if (worker->cur_work == work) {
+       if (worker->cur_work == work &&
+           !(worker->cur_work->flags & IO_WQ_WORK_NO_CANCEL)) {
                send_sig(SIGINT, worker->task, 1);
                ret = true;
        }
@@ -1022,7 +1064,6 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
 
        /* caller must already hold a reference to this */
        wq->user = data->user;
-       wq->creds = data->creds;
 
        for_each_node(node) {
                struct io_wqe *wqe;
@@ -1049,9 +1090,6 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
 
        init_completion(&wq->done);
 
-       /* caller must have already done mmgrab() on this mm */
-       wq->mm = data->mm;
-
        wq->manager = kthread_create(io_wq_manager, wq, "io_wq_manager");
        if (!IS_ERR(wq->manager)) {
                wake_up_process(wq->manager);
@@ -1060,6 +1098,7 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
                        ret = -ENOMEM;
                        goto err;
                }
+               refcount_set(&wq->use_refs, 1);
                reinit_completion(&wq->done);
                return wq;
        }
@@ -1074,13 +1113,21 @@ err:
        return ERR_PTR(ret);
 }
 
+bool io_wq_get(struct io_wq *wq, struct io_wq_data *data)
+{
+       if (data->get_work != wq->get_work || data->put_work != wq->put_work)
+               return false;
+
+       return refcount_inc_not_zero(&wq->use_refs);
+}
+
 static bool io_wq_worker_wake(struct io_worker *worker, void *data)
 {
        wake_up_process(worker->task);
        return false;
 }
 
-void io_wq_destroy(struct io_wq *wq)
+static void __io_wq_destroy(struct io_wq *wq)
 {
        int node;
 
@@ -1100,3 +1147,9 @@ void io_wq_destroy(struct io_wq *wq)
        kfree(wq->wqes);
        kfree(wq);
 }
+
+void io_wq_destroy(struct io_wq *wq)
+{
+       if (refcount_dec_and_test(&wq->use_refs))
+               __io_wq_destroy(wq);
+}