Merge tag 'for-5.16-rc2-tag' of git://git.kernel.org/pub/scm/linux/kernel/git/kdave...
[linux-2.6-microblaze.git] / fs / io-wq.c
index 422a7ed..88202de 100644 (file)
@@ -14,6 +14,7 @@
 #include <linux/rculist_nulls.h>
 #include <linux/cpu.h>
 #include <linux/tracehook.h>
+#include <linux/audit.h>
 #include <uapi/linux/io_uring.h>
 
 #include "io-wq.h"
@@ -140,6 +141,7 @@ static void io_wqe_dec_running(struct io_worker *worker);
 static bool io_acct_cancel_pending_work(struct io_wqe *wqe,
                                        struct io_wqe_acct *acct,
                                        struct io_cb_cancel_data *match);
+static void create_worker_cb(struct callback_head *cb);
 
 static bool io_worker_get(struct io_worker *worker)
 {
@@ -174,12 +176,46 @@ static void io_worker_ref_put(struct io_wq *wq)
                complete(&wq->worker_done);
 }
 
+static void io_worker_cancel_cb(struct io_worker *worker)
+{
+       struct io_wqe_acct *acct = io_wqe_get_acct(worker);
+       struct io_wqe *wqe = worker->wqe;
+       struct io_wq *wq = wqe->wq;
+
+       atomic_dec(&acct->nr_running);
+       raw_spin_lock(&worker->wqe->lock);
+       acct->nr_workers--;
+       raw_spin_unlock(&worker->wqe->lock);
+       io_worker_ref_put(wq);
+       clear_bit_unlock(0, &worker->create_state);
+       io_worker_release(worker);
+}
+
+static bool io_task_worker_match(struct callback_head *cb, void *data)
+{
+       struct io_worker *worker;
+
+       if (cb->func != create_worker_cb)
+               return false;
+       worker = container_of(cb, struct io_worker, create_work);
+       return worker == data;
+}
+
 static void io_worker_exit(struct io_worker *worker)
 {
        struct io_wqe *wqe = worker->wqe;
+       struct io_wq *wq = wqe->wq;
 
-       if (refcount_dec_and_test(&worker->ref))
-               complete(&worker->ref_done);
+       while (1) {
+               struct callback_head *cb = task_work_cancel_match(wq->task,
+                                               io_task_worker_match, worker);
+
+               if (!cb)
+                       break;
+               io_worker_cancel_cb(worker);
+       }
+
+       io_worker_release(worker);
        wait_for_completion(&worker->ref_done);
 
        raw_spin_lock(&wqe->lock);
@@ -323,8 +359,10 @@ static bool io_queue_worker_create(struct io_worker *worker,
 
        init_task_work(&worker->create_work, func);
        worker->create_index = acct->index;
-       if (!task_work_add(wq->task, &worker->create_work, TWA_SIGNAL))
+       if (!task_work_add(wq->task, &worker->create_work, TWA_SIGNAL)) {
+               clear_bit_unlock(0, &worker->create_state);
                return true;
+       }
        clear_bit_unlock(0, &worker->create_state);
 fail_release:
        io_worker_release(worker);
@@ -385,9 +423,10 @@ static inline unsigned int io_get_work_hash(struct io_wq_work *work)
        return work->flags >> IO_WQ_HASH_SHIFT;
 }
 
-static void io_wait_on_hash(struct io_wqe *wqe, unsigned int hash)
+static bool io_wait_on_hash(struct io_wqe *wqe, unsigned int hash)
 {
        struct io_wq *wq = wqe->wq;
+       bool ret = false;
 
        spin_lock_irq(&wq->hash->wait.lock);
        if (list_empty(&wqe->wait.entry)) {
@@ -395,9 +434,11 @@ static void io_wait_on_hash(struct io_wqe *wqe, unsigned int hash)
                if (!test_bit(hash, &wq->hash->map)) {
                        __set_current_state(TASK_RUNNING);
                        list_del_init(&wqe->wait.entry);
+                       ret = true;
                }
        }
        spin_unlock_irq(&wq->hash->wait.lock);
+       return ret;
 }
 
 static struct io_wq_work *io_get_next_work(struct io_wqe_acct *acct,
@@ -437,14 +478,21 @@ static struct io_wq_work *io_get_next_work(struct io_wqe_acct *acct,
        }
 
        if (stall_hash != -1U) {
+               bool unstalled;
+
                /*
                 * Set this before dropping the lock to avoid racing with new
                 * work being added and clearing the stalled bit.
                 */
                set_bit(IO_ACCT_STALLED_BIT, &acct->flags);
                raw_spin_unlock(&wqe->lock);
-               io_wait_on_hash(wqe, stall_hash);
+               unstalled = io_wait_on_hash(wqe, stall_hash);
                raw_spin_lock(&wqe->lock);
+               if (unstalled) {
+                       clear_bit(IO_ACCT_STALLED_BIT, &acct->flags);
+                       if (wq_has_sleeper(&wqe->wq->hash->wait))
+                               wake_up(&wqe->wq->hash->wait);
+               }
        }
 
        return NULL;
@@ -526,8 +574,11 @@ get_next:
                                io_wqe_enqueue(wqe, linked);
 
                        if (hash != -1U && !next_hashed) {
+                               /* serialize hash clear with wake_up() */
+                               spin_lock_irq(&wq->hash->wait.lock);
                                clear_bit(hash, &wq->hash->map);
                                clear_bit(IO_ACCT_STALLED_BIT, &acct->flags);
+                               spin_unlock_irq(&wq->hash->wait.lock);
                                if (wq_has_sleeper(&wq->hash->wait))
                                        wake_up(&wq->hash->wait);
                                raw_spin_lock(&wqe->lock);
@@ -556,6 +607,8 @@ static int io_wqe_worker(void *data)
        snprintf(buf, sizeof(buf), "iou-wrk-%d", wq->task->pid);
        set_task_comm(current, buf);
 
+       audit_alloc_kernel(current);
+
        while (!test_bit(IO_WQ_BIT_EXIT, &wq->state)) {
                long ret;
 
@@ -594,6 +647,7 @@ loop:
                io_worker_handle_work(worker);
        }
 
+       audit_free(current);
        io_worker_exit(worker);
        return 0;
 }
@@ -716,11 +770,8 @@ static void io_workqueue_create(struct work_struct *work)
        struct io_worker *worker = container_of(work, struct io_worker, work);
        struct io_wqe_acct *acct = io_wqe_get_acct(worker);
 
-       if (!io_queue_worker_create(worker, acct, create_worker_cont)) {
-               clear_bit_unlock(0, &worker->create_state);
-               io_worker_release(worker);
+       if (!io_queue_worker_create(worker, acct, create_worker_cont))
                kfree(worker);
-       }
 }
 
 static bool create_io_worker(struct io_wq *wq, struct io_wqe *wqe, int index)
@@ -1150,17 +1201,9 @@ static void io_wq_exit_workers(struct io_wq *wq)
 
        while ((cb = task_work_cancel_match(wq->task, io_task_work_match, wq)) != NULL) {
                struct io_worker *worker;
-               struct io_wqe_acct *acct;
 
                worker = container_of(cb, struct io_worker, create_work);
-               acct = io_wqe_get_acct(worker);
-               atomic_dec(&acct->nr_running);
-               raw_spin_lock(&worker->wqe->lock);
-               acct->nr_workers--;
-               raw_spin_unlock(&worker->wqe->lock);
-               io_worker_ref_put(wq);
-               clear_bit_unlock(0, &worker->create_state);
-               io_worker_release(worker);
+               io_worker_cancel_cb(worker);
        }
 
        rcu_read_lock();
@@ -1278,7 +1321,9 @@ int io_wq_cpu_affinity(struct io_wq *wq, cpumask_var_t mask)
  */
 int io_wq_max_workers(struct io_wq *wq, int *new_count)
 {
-       int i, node, prev = 0;
+       int prev[IO_WQ_ACCT_NR];
+       bool first_node = true;
+       int i, node;
 
        BUILD_BUG_ON((int) IO_WQ_ACCT_BOUND   != (int) IO_WQ_BOUND);
        BUILD_BUG_ON((int) IO_WQ_ACCT_UNBOUND != (int) IO_WQ_UNBOUND);
@@ -1289,6 +1334,9 @@ int io_wq_max_workers(struct io_wq *wq, int *new_count)
                        new_count[i] = task_rlimit(current, RLIMIT_NPROC);
        }
 
+       for (i = 0; i < IO_WQ_ACCT_NR; i++)
+               prev[i] = 0;
+
        rcu_read_lock();
        for_each_node(node) {
                struct io_wqe *wqe = wq->wqes[node];
@@ -1297,14 +1345,19 @@ int io_wq_max_workers(struct io_wq *wq, int *new_count)
                raw_spin_lock(&wqe->lock);
                for (i = 0; i < IO_WQ_ACCT_NR; i++) {
                        acct = &wqe->acct[i];
-                       prev = max_t(int, acct->max_workers, prev);
+                       if (first_node)
+                               prev[i] = max_t(int, acct->max_workers, prev[i]);
                        if (new_count[i])
                                acct->max_workers = new_count[i];
-                       new_count[i] = prev;
                }
                raw_spin_unlock(&wqe->lock);
+               first_node = false;
        }
        rcu_read_unlock();
+
+       for (i = 0; i < IO_WQ_ACCT_NR; i++)
+               new_count[i] = prev[i];
+
        return 0;
 }