Merge tag 'parisc-for-6.8-rc1' of git://git.kernel.org/pub/scm/linux/kernel/git/delle...
[linux-2.6-microblaze.git] / crypto / lskcipher.c
index 9edc897..0b6dd8a 100644 (file)
@@ -88,8 +88,9 @@ EXPORT_SYMBOL_GPL(crypto_lskcipher_setkey);
 static int crypto_lskcipher_crypt_unaligned(
        struct crypto_lskcipher *tfm, const u8 *src, u8 *dst, unsigned len,
        u8 *iv, int (*crypt)(struct crypto_lskcipher *tfm, const u8 *src,
-                            u8 *dst, unsigned len, u8 *iv, bool final))
+                            u8 *dst, unsigned len, u8 *iv, u32 flags))
 {
+       unsigned statesize = crypto_lskcipher_statesize(tfm);
        unsigned ivsize = crypto_lskcipher_ivsize(tfm);
        unsigned bs = crypto_lskcipher_blocksize(tfm);
        unsigned cs = crypto_lskcipher_chunksize(tfm);
@@ -104,7 +105,7 @@ static int crypto_lskcipher_crypt_unaligned(
        if (!tiv)
                return -ENOMEM;
 
-       memcpy(tiv, iv, ivsize);
+       memcpy(tiv, iv, ivsize + statesize);
 
        p = kmalloc(PAGE_SIZE, GFP_ATOMIC);
        err = -ENOMEM;
@@ -119,7 +120,7 @@ static int crypto_lskcipher_crypt_unaligned(
                        chunk &= ~(cs - 1);
 
                memcpy(p, src, chunk);
-               err = crypt(tfm, p, p, chunk, tiv, true);
+               err = crypt(tfm, p, p, chunk, tiv, CRYPTO_LSKCIPHER_FLAG_FINAL);
                if (err)
                        goto out;
 
@@ -132,7 +133,7 @@ static int crypto_lskcipher_crypt_unaligned(
        err = len ? -EINVAL : 0;
 
 out:
-       memcpy(iv, tiv, ivsize);
+       memcpy(iv, tiv, ivsize + statesize);
        kfree_sensitive(p);
        kfree_sensitive(tiv);
        return err;
@@ -143,7 +144,7 @@ static int crypto_lskcipher_crypt(struct crypto_lskcipher *tfm, const u8 *src,
                                  int (*crypt)(struct crypto_lskcipher *tfm,
                                               const u8 *src, u8 *dst,
                                               unsigned len, u8 *iv,
-                                              bool final))
+                                              u32 flags))
 {
        unsigned long alignmask = crypto_lskcipher_alignmask(tfm);
        struct lskcipher_alg *alg = crypto_lskcipher_alg(tfm);
@@ -156,7 +157,7 @@ static int crypto_lskcipher_crypt(struct crypto_lskcipher *tfm, const u8 *src,
                goto out;
        }
 
-       ret = crypt(tfm, src, dst, len, iv, true);
+       ret = crypt(tfm, src, dst, len, iv, CRYPTO_LSKCIPHER_FLAG_FINAL);
 
 out:
        return crypto_lskcipher_errstat(alg, ret);
@@ -197,23 +198,45 @@ EXPORT_SYMBOL_GPL(crypto_lskcipher_decrypt);
 static int crypto_lskcipher_crypt_sg(struct skcipher_request *req,
                                     int (*crypt)(struct crypto_lskcipher *tfm,
                                                  const u8 *src, u8 *dst,
-                                                 unsigned len, u8 *iv,
-                                                 bool final))
+                                                 unsigned len, u8 *ivs,
+                                                 u32 flags))
 {
        struct crypto_skcipher *skcipher = crypto_skcipher_reqtfm(req);
        struct crypto_lskcipher **ctx = crypto_skcipher_ctx(skcipher);
+       u8 *ivs = skcipher_request_ctx(req);
        struct crypto_lskcipher *tfm = *ctx;
        struct skcipher_walk walk;
+       unsigned ivsize;
+       u32 flags;
        int err;
 
+       ivsize = crypto_lskcipher_ivsize(tfm);
+       ivs = PTR_ALIGN(ivs, crypto_skcipher_alignmask(skcipher) + 1);
+
+       flags = req->base.flags & CRYPTO_TFM_REQ_MAY_SLEEP;
+
+       if (req->base.flags & CRYPTO_SKCIPHER_REQ_CONT)
+               flags |= CRYPTO_LSKCIPHER_FLAG_CONT;
+       else
+               memcpy(ivs, req->iv, ivsize);
+
+       if (!(req->base.flags & CRYPTO_SKCIPHER_REQ_NOTFINAL))
+               flags |= CRYPTO_LSKCIPHER_FLAG_FINAL;
+
        err = skcipher_walk_virt(&walk, req, false);
 
        while (walk.nbytes) {
                err = crypt(tfm, walk.src.virt.addr, walk.dst.virt.addr,
-                           walk.nbytes, walk.iv, walk.nbytes == walk.total);
+                           walk.nbytes, ivs,
+                           flags & ~(walk.nbytes == walk.total ?
+                           0 : CRYPTO_LSKCIPHER_FLAG_FINAL));
                err = skcipher_walk_done(&walk, err);
+               flags |= CRYPTO_LSKCIPHER_FLAG_CONT;
        }
 
+       if (flags & CRYPTO_LSKCIPHER_FLAG_FINAL)
+               memcpy(req->iv, ivs, ivsize);
+
        return err;
 }
 
@@ -276,6 +299,7 @@ static void __maybe_unused crypto_lskcipher_show(
        seq_printf(m, "max keysize  : %u\n", skcipher->co.max_keysize);
        seq_printf(m, "ivsize       : %u\n", skcipher->co.ivsize);
        seq_printf(m, "chunksize    : %u\n", skcipher->co.chunksize);
+       seq_printf(m, "statesize    : %u\n", skcipher->co.statesize);
 }
 
 static int __maybe_unused crypto_lskcipher_report(
@@ -618,6 +642,7 @@ struct lskcipher_instance *lskcipher_alloc_instance_simple(
        inst->alg.co.min_keysize = cipher_alg->co.min_keysize;
        inst->alg.co.max_keysize = cipher_alg->co.max_keysize;
        inst->alg.co.ivsize = cipher_alg->co.base.cra_blocksize;
+       inst->alg.co.statesize = cipher_alg->co.statesize;
 
        /* Use struct crypto_lskcipher * by default, can be overridden */
        inst->alg.co.base.cra_ctxsize = sizeof(struct crypto_lskcipher *);