io_uring: keep table of pointers to ubufs
authorPavel Begunkov <asml.silence@gmail.com>
Sun, 25 Apr 2021 13:32:23 +0000 (14:32 +0100)
committerJens Axboe <axboe@kernel.dk>
Sun, 25 Apr 2021 16:14:04 +0000 (10:14 -0600)
Instead of keeping a table of ubufs convert them into pointers to ubuf,
so we can atomically read one pointer and be sure that the content of
ubuf won't change.

Because it was already dynamically allocating imu->bvec, throw both
imu and bvec into a single structure so they can be allocated together.

Signed-off-by: Pavel Begunkov <asml.silence@gmail.com>
Link: https://lore.kernel.org/r/b96efa4c5febadeccf41d0e849ac099f4c83b0d3.1619356238.git.asml.silence@gmail.com
Signed-off-by: Jens Axboe <axboe@kernel.dk>
fs/io_uring.c

index a4c37a2..b6ec14d 100644 (file)
@@ -195,9 +195,9 @@ enum io_uring_cmd_flags {
 struct io_mapped_ubuf {
        u64             ubuf;
        u64             ubuf_end;
-       struct          bio_vec *bvec;
        unsigned int    nr_bvecs;
        unsigned long   acct_pages;
+       struct bio_vec  bvec[];
 };
 
 struct io_ring_ctx;
@@ -405,7 +405,7 @@ struct io_ring_ctx {
 
        /* if used, fixed mapped user buffers */
        unsigned                nr_user_bufs;
-       struct io_mapped_ubuf   *user_bufs;
+       struct io_mapped_ubuf   **user_bufs;
 
        struct user_struct      *user;
 
@@ -2760,7 +2760,7 @@ static int io_import_fixed(struct io_kiocb *req, int rw, struct iov_iter *iter)
        if (unlikely(buf_index >= ctx->nr_user_bufs))
                return -EFAULT;
        index = array_index_nospec(buf_index, ctx->nr_user_bufs);
-       imu = &ctx->user_bufs[index];
+       imu = ctx->user_bufs[index];
        buf_addr = req->rw.addr;
 
        if (unlikely(check_add_overflow(buf_addr, (u64)len, &buf_end)))
@@ -8081,16 +8081,17 @@ static unsigned long rings_size(unsigned sq_entries, unsigned cq_entries,
        return off;
 }
 
-static void io_buffer_unmap(struct io_ring_ctx *ctx, struct io_mapped_ubuf *imu)
+static void io_buffer_unmap(struct io_ring_ctx *ctx, struct io_mapped_ubuf **slot)
 {
+       struct io_mapped_ubuf *imu = *slot;
        unsigned int i;
 
        for (i = 0; i < imu->nr_bvecs; i++)
                unpin_user_page(imu->bvec[i].bv_page);
        if (imu->acct_pages)
                io_unaccount_mem(ctx, imu->acct_pages);
-       kvfree(imu->bvec);
-       imu->nr_bvecs = 0;
+       kvfree(imu);
+       *slot = NULL;
 }
 
 static int io_sqe_buffers_unregister(struct io_ring_ctx *ctx)
@@ -8157,7 +8158,7 @@ static bool headpage_already_acct(struct io_ring_ctx *ctx, struct page **pages,
 
        /* check previously registered pages */
        for (i = 0; i < ctx->nr_user_bufs; i++) {
-               struct io_mapped_ubuf *imu = &ctx->user_bufs[i];
+               struct io_mapped_ubuf *imu = ctx->user_bufs[i];
 
                for (j = 0; j < imu->nr_bvecs; j++) {
                        if (!PageCompound(imu->bvec[j].bv_page))
@@ -8202,9 +8203,10 @@ static int io_buffer_account_pin(struct io_ring_ctx *ctx, struct page **pages,
 }
 
 static int io_sqe_buffer_register(struct io_ring_ctx *ctx, struct iovec *iov,
-                                 struct io_mapped_ubuf *imu,
+                                 struct io_mapped_ubuf **pimu,
                                  struct page **last_hpage)
 {
+       struct io_mapped_ubuf *imu = NULL;
        struct vm_area_struct **vmas = NULL;
        struct page **pages = NULL;
        unsigned long off, start, end, ubuf;
@@ -8216,6 +8218,7 @@ static int io_sqe_buffer_register(struct io_ring_ctx *ctx, struct iovec *iov,
        start = ubuf >> PAGE_SHIFT;
        nr_pages = end - start;
 
+       *pimu = NULL;
        ret = -ENOMEM;
 
        pages = kvmalloc_array(nr_pages, sizeof(struct page *), GFP_KERNEL);
@@ -8227,8 +8230,7 @@ static int io_sqe_buffer_register(struct io_ring_ctx *ctx, struct iovec *iov,
        if (!vmas)
                goto done;
 
-       imu->bvec = kvmalloc_array(nr_pages, sizeof(struct bio_vec),
-                                  GFP_KERNEL);
+       imu = kvmalloc(struct_size(imu, bvec, nr_pages), GFP_KERNEL);
        if (!imu->bvec)
                goto done;
 
@@ -8258,14 +8260,12 @@ static int io_sqe_buffer_register(struct io_ring_ctx *ctx, struct iovec *iov,
                 */
                if (pret > 0)
                        unpin_user_pages(pages, pret);
-               kvfree(imu->bvec);
                goto done;
        }
 
        ret = io_buffer_account_pin(ctx, pages, pret, imu, last_hpage);
        if (ret) {
                unpin_user_pages(pages, pret);
-               kvfree(imu->bvec);
                goto done;
        }
 
@@ -8285,8 +8285,11 @@ static int io_sqe_buffer_register(struct io_ring_ctx *ctx, struct iovec *iov,
        imu->ubuf = ubuf;
        imu->ubuf_end = ubuf + iov->iov_len;
        imu->nr_bvecs = nr_pages;
+       *pimu = imu;
        ret = 0;
 done:
+       if (ret)
+               kvfree(imu);
        kvfree(pages);
        kvfree(vmas);
        return ret;
@@ -8336,15 +8339,15 @@ static int io_sqe_buffers_register(struct io_ring_ctx *ctx, void __user *arg,
                return ret;
 
        for (i = 0; i < nr_args; i++, ctx->nr_user_bufs++) {
-               struct io_mapped_ubuf *imu = &ctx->user_bufs[i];
-
                ret = io_copy_iov(ctx, &iov, arg, i);
                if (ret)
                        break;
                ret = io_buffer_validate(&iov);
                if (ret)
                        break;
-               ret = io_sqe_buffer_register(ctx, &iov, imu, &last_hpage);
+
+               ret = io_sqe_buffer_register(ctx, &iov, &ctx->user_bufs[i],
+                                            &last_hpage);
                if (ret)
                        break;
        }
@@ -9291,7 +9294,7 @@ static void __io_uring_show_fdinfo(struct io_ring_ctx *ctx, struct seq_file *m)
        }
        seq_printf(m, "UserBufs:\t%u\n", ctx->nr_user_bufs);
        for (i = 0; has_lock && i < ctx->nr_user_bufs; i++) {
-               struct io_mapped_ubuf *buf = &ctx->user_bufs[i];
+               struct io_mapped_ubuf *buf = ctx->user_bufs[i];
                unsigned int len = buf->ubuf_end - buf->ubuf;
 
                seq_printf(m, "%5u: 0x%llx/%u\n", i, buf->ubuf, len);