Linux 6.9-rc1
[linux-2.6-microblaze.git] / crypto / algif_hash.c
index 1d017ec..e24c829 100644 (file)
@@ -63,122 +63,117 @@ static void hash_free_result(struct sock *sk, struct hash_ctx *ctx)
 static int hash_sendmsg(struct socket *sock, struct msghdr *msg,
                        size_t ignored)
 {
-       int limit = ALG_MAX_PAGES * PAGE_SIZE;
        struct sock *sk = sock->sk;
        struct alg_sock *ask = alg_sk(sk);
        struct hash_ctx *ctx = ask->private;
-       long copied = 0;
+       ssize_t copied = 0;
+       size_t len, max_pages, npages;
+       bool continuing, need_init = false;
        int err;
 
-       if (limit > sk->sk_sndbuf)
-               limit = sk->sk_sndbuf;
+       max_pages = min_t(size_t, ALG_MAX_PAGES,
+                         DIV_ROUND_UP(sk->sk_sndbuf, PAGE_SIZE));
 
        lock_sock(sk);
-       if (!ctx->more) {
-               if ((msg->msg_flags & MSG_MORE))
-                       hash_free_result(sk, ctx);
+       continuing = ctx->more;
 
-               err = crypto_wait_req(crypto_ahash_init(&ctx->req), &ctx->wait);
-               if (err)
-                       goto unlock;
-       }
-
-       ctx->more = false;
-
-       while (msg_data_left(msg)) {
-               int len = msg_data_left(msg);
-
-               if (len > limit)
-                       len = limit;
-
-               len = af_alg_make_sg(&ctx->sgl, &msg->msg_iter, len);
-               if (len < 0) {
-                       err = copied ? 0 : len;
-                       goto unlock;
-               }
-
-               ahash_request_set_crypt(&ctx->req, ctx->sgl.sg, NULL, len);
-
-               err = crypto_wait_req(crypto_ahash_update(&ctx->req),
-                                     &ctx->wait);
-               af_alg_free_sg(&ctx->sgl);
-               if (err) {
-                       iov_iter_revert(&msg->msg_iter, len);
-                       goto unlock;
+       if (!continuing) {
+               /* Discard a previous request that wasn't marked MSG_MORE. */
+               hash_free_result(sk, ctx);
+               if (!msg_data_left(msg))
+                       goto done; /* Zero-length; don't start new req */
+               need_init = true;
+       } else if (!msg_data_left(msg)) {
+               /*
+                * No data - finalise the prev req if MSG_MORE so any error
+                * comes out here.
+                */
+               if (!(msg->msg_flags & MSG_MORE)) {
+                       err = hash_alloc_result(sk, ctx);
+                       if (err)
+                               goto unlock_free_result;
+                       ahash_request_set_crypt(&ctx->req, NULL,
+                                               ctx->result, 0);
+                       err = crypto_wait_req(crypto_ahash_final(&ctx->req),
+                                             &ctx->wait);
+                       if (err)
+                               goto unlock_free_result;
                }
-
-               copied += len;
-       }
-
-       err = 0;
-
-       ctx->more = msg->msg_flags & MSG_MORE;
-       if (!ctx->more) {
-               err = hash_alloc_result(sk, ctx);
-               if (err)
-                       goto unlock;
-
-               ahash_request_set_crypt(&ctx->req, NULL, ctx->result, 0);
-               err = crypto_wait_req(crypto_ahash_final(&ctx->req),
-                                     &ctx->wait);
+               goto done_more;
        }
 
-unlock:
-       release_sock(sk);
+       while (msg_data_left(msg)) {
+               ctx->sgl.sgt.sgl = ctx->sgl.sgl;
+               ctx->sgl.sgt.nents = 0;
+               ctx->sgl.sgt.orig_nents = 0;
 
-       return err ?: copied;
-}
+               err = -EIO;
+               npages = iov_iter_npages(&msg->msg_iter, max_pages);
+               if (npages == 0)
+                       goto unlock_free;
 
-static ssize_t hash_sendpage(struct socket *sock, struct page *page,
-                            int offset, size_t size, int flags)
-{
-       struct sock *sk = sock->sk;
-       struct alg_sock *ask = alg_sk(sk);
-       struct hash_ctx *ctx = ask->private;
-       int err;
+               sg_init_table(ctx->sgl.sgl, npages);
 
-       if (flags & MSG_SENDPAGE_NOTLAST)
-               flags |= MSG_MORE;
+               ctx->sgl.need_unpin = iov_iter_extract_will_pin(&msg->msg_iter);
 
-       lock_sock(sk);
-       sg_init_table(ctx->sgl.sg, 1);
-       sg_set_page(ctx->sgl.sg, page, size, offset);
+               err = extract_iter_to_sg(&msg->msg_iter, LONG_MAX,
+                                        &ctx->sgl.sgt, npages, 0);
+               if (err < 0)
+                       goto unlock_free;
+               len = err;
+               sg_mark_end(ctx->sgl.sgt.sgl + ctx->sgl.sgt.nents - 1);
 
-       if (!(flags & MSG_MORE)) {
-               err = hash_alloc_result(sk, ctx);
-               if (err)
-                       goto unlock;
-       } else if (!ctx->more)
-               hash_free_result(sk, ctx);
+               if (!msg_data_left(msg)) {
+                       err = hash_alloc_result(sk, ctx);
+                       if (err)
+                               goto unlock_free;
+               }
 
-       ahash_request_set_crypt(&ctx->req, ctx->sgl.sg, ctx->result, size);
+               ahash_request_set_crypt(&ctx->req, ctx->sgl.sgt.sgl,
+                                       ctx->result, len);
 
-       if (!(flags & MSG_MORE)) {
-               if (ctx->more)
-                       err = crypto_ahash_finup(&ctx->req);
-               else
+               if (!msg_data_left(msg) && !continuing &&
+                   !(msg->msg_flags & MSG_MORE)) {
                        err = crypto_ahash_digest(&ctx->req);
-       } else {
-               if (!ctx->more) {
-                       err = crypto_ahash_init(&ctx->req);
-                       err = crypto_wait_req(err, &ctx->wait);
-                       if (err)
-                               goto unlock;
+               } else {
+                       if (need_init) {
+                               err = crypto_wait_req(
+                                       crypto_ahash_init(&ctx->req),
+                                       &ctx->wait);
+                               if (err)
+                                       goto unlock_free;
+                               need_init = false;
+                       }
+
+                       if (msg_data_left(msg) || (msg->msg_flags & MSG_MORE))
+                               err = crypto_ahash_update(&ctx->req);
+                       else
+                               err = crypto_ahash_finup(&ctx->req);
+                       continuing = true;
                }
 
-               err = crypto_ahash_update(&ctx->req);
-       }
-
-       err = crypto_wait_req(err, &ctx->wait);
-       if (err)
-               goto unlock;
+               err = crypto_wait_req(err, &ctx->wait);
+               if (err)
+                       goto unlock_free;
 
-       ctx->more = flags & MSG_MORE;
+               copied += len;
+               af_alg_free_sg(&ctx->sgl);
+       }
 
+done_more:
+       ctx->more = msg->msg_flags & MSG_MORE;
+done:
+       err = 0;
 unlock:
        release_sock(sk);
+       return copied ?: err;
 
-       return err ?: size;
+unlock_free:
+       af_alg_free_sg(&ctx->sgl);
+unlock_free_result:
+       hash_free_result(sk, ctx);
+       ctx->more = false;
+       goto unlock;
 }
 
 static int hash_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
@@ -235,24 +230,31 @@ static int hash_accept(struct socket *sock, struct socket *newsock, int flags,
        struct alg_sock *ask = alg_sk(sk);
        struct hash_ctx *ctx = ask->private;
        struct ahash_request *req = &ctx->req;
-       char state[HASH_MAX_STATESIZE];
+       struct crypto_ahash *tfm;
        struct sock *sk2;
        struct alg_sock *ask2;
        struct hash_ctx *ctx2;
+       char *state;
        bool more;
        int err;
 
+       tfm = crypto_ahash_reqtfm(req);
+       state = kmalloc(crypto_ahash_statesize(tfm), GFP_KERNEL);
+       err = -ENOMEM;
+       if (!state)
+               goto out;
+
        lock_sock(sk);
        more = ctx->more;
        err = more ? crypto_ahash_export(req, state) : 0;
        release_sock(sk);
 
        if (err)
-               return err;
+               goto out_free_state;
 
        err = af_alg_accept(ask->parent, newsock, kern);
        if (err)
-               return err;
+               goto out_free_state;
 
        sk2 = newsock->sk;
        ask2 = alg_sk(sk2);
@@ -260,7 +262,7 @@ static int hash_accept(struct socket *sock, struct socket *newsock, int flags,
        ctx2->more = more;
 
        if (!more)
-               return err;
+               goto out_free_state;
 
        err = crypto_ahash_import(&ctx2->req, state);
        if (err) {
@@ -268,6 +270,10 @@ static int hash_accept(struct socket *sock, struct socket *newsock, int flags,
                sock_put(sk2);
        }
 
+out_free_state:
+       kfree_sensitive(state);
+
+out:
        return err;
 }
 
@@ -285,7 +291,6 @@ static struct proto_ops algif_hash_ops = {
 
        .release        =       af_alg_release,
        .sendmsg        =       hash_sendmsg,
-       .sendpage       =       hash_sendpage,
        .recvmsg        =       hash_recvmsg,
        .accept         =       hash_accept,
 };
@@ -337,18 +342,6 @@ static int hash_sendmsg_nokey(struct socket *sock, struct msghdr *msg,
        return hash_sendmsg(sock, msg, size);
 }
 
-static ssize_t hash_sendpage_nokey(struct socket *sock, struct page *page,
-                                  int offset, size_t size, int flags)
-{
-       int err;
-
-       err = hash_check_key(sock);
-       if (err)
-               return err;
-
-       return hash_sendpage(sock, page, offset, size, flags);
-}
-
 static int hash_recvmsg_nokey(struct socket *sock, struct msghdr *msg,
                              size_t ignored, int flags)
 {
@@ -387,7 +380,6 @@ static struct proto_ops algif_hash_ops_nokey = {
 
        .release        =       af_alg_release,
        .sendmsg        =       hash_sendmsg_nokey,
-       .sendpage       =       hash_sendpage_nokey,
        .recvmsg        =       hash_recvmsg_nokey,
        .accept         =       hash_accept_nokey,
 };