crypto: skcipher - pass instance to crypto_grab_skcipher()
authorEric Biggers <ebiggers@google.com>
Fri, 3 Jan 2020 03:58:45 +0000 (19:58 -0800)
committerHerbert Xu <herbert@gondor.apana.org.au>
Thu, 9 Jan 2020 03:30:54 +0000 (11:30 +0800)
Initializing a crypto_skcipher_spawn currently requires:

1. Set spawn->base.inst to point to the instance.
2. Call crypto_grab_skcipher().

But there's no reason for these steps to be separate, and in fact this
unneeded complication has caused at least one bug, the one fixed by
commit 6db43410179b ("crypto: adiantum - initialize crypto_spawn::inst")

So just make crypto_grab_skcipher() take the instance as an argument.

To keep the function calls from getting too unwieldy due to this extra
argument, also introduce a 'mask' variable into the affected places
which weren't already using one.

Signed-off-by: Eric Biggers <ebiggers@google.com>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
14 files changed:
crypto/adiantum.c
crypto/authenc.c
crypto/authencesn.c
crypto/ccm.c
crypto/chacha20poly1305.c
crypto/cryptd.c
crypto/ctr.c
crypto/cts.c
crypto/essiv.c
crypto/gcm.c
crypto/lrw.c
crypto/skcipher.c
crypto/xts.c
include/crypto/internal/skcipher.h

index 30cffb4..5670714 100644 (file)
@@ -493,6 +493,7 @@ static bool adiantum_supported_algorithms(struct skcipher_alg *streamcipher_alg,
 static int adiantum_create(struct crypto_template *tmpl, struct rtattr **tb)
 {
        struct crypto_attr_type *algt;
+       u32 mask;
        const char *streamcipher_name;
        const char *blockcipher_name;
        const char *nhpoly1305_name;
@@ -511,6 +512,8 @@ static int adiantum_create(struct crypto_template *tmpl, struct rtattr **tb)
        if ((algt->type ^ CRYPTO_ALG_TYPE_SKCIPHER) & algt->mask)
                return -EINVAL;
 
+       mask = crypto_requires_sync(algt->type, algt->mask);
+
        streamcipher_name = crypto_attr_alg_name(tb[1]);
        if (IS_ERR(streamcipher_name))
                return PTR_ERR(streamcipher_name);
@@ -531,11 +534,9 @@ static int adiantum_create(struct crypto_template *tmpl, struct rtattr **tb)
        ictx = skcipher_instance_ctx(inst);
 
        /* Stream cipher, e.g. "xchacha12" */
-       crypto_set_skcipher_spawn(&ictx->streamcipher_spawn,
-                                 skcipher_crypto_instance(inst));
-       err = crypto_grab_skcipher(&ictx->streamcipher_spawn, streamcipher_name,
-                                  0, crypto_requires_sync(algt->type,
-                                                          algt->mask));
+       err = crypto_grab_skcipher(&ictx->streamcipher_spawn,
+                                  skcipher_crypto_instance(inst),
+                                  streamcipher_name, 0, mask);
        if (err)
                goto out_free_inst;
        streamcipher_alg = crypto_spawn_skcipher_alg(&ictx->streamcipher_spawn);
index 15aaddd..e31bcec 100644 (file)
@@ -373,6 +373,7 @@ static int crypto_authenc_create(struct crypto_template *tmpl,
                                 struct rtattr **tb)
 {
        struct crypto_attr_type *algt;
+       u32 mask;
        struct aead_instance *inst;
        struct hash_alg_common *auth;
        struct crypto_alg *auth_base;
@@ -388,9 +389,10 @@ static int crypto_authenc_create(struct crypto_template *tmpl,
        if ((algt->type ^ CRYPTO_ALG_TYPE_AEAD) & algt->mask)
                return -EINVAL;
 
+       mask = crypto_requires_sync(algt->type, algt->mask);
+
        auth = ahash_attr_alg(tb[1], CRYPTO_ALG_TYPE_HASH,
-                             CRYPTO_ALG_TYPE_AHASH_MASK |
-                             crypto_requires_sync(algt->type, algt->mask));
+                             CRYPTO_ALG_TYPE_AHASH_MASK | mask);
        if (IS_ERR(auth))
                return PTR_ERR(auth);
 
@@ -413,10 +415,8 @@ static int crypto_authenc_create(struct crypto_template *tmpl,
        if (err)
                goto err_free_inst;
 
-       crypto_set_skcipher_spawn(&ctx->enc, aead_crypto_instance(inst));
-       err = crypto_grab_skcipher(&ctx->enc, enc_name, 0,
-                                  crypto_requires_sync(algt->type,
-                                                       algt->mask));
+       err = crypto_grab_skcipher(&ctx->enc, aead_crypto_instance(inst),
+                                  enc_name, 0, mask);
        if (err)
                goto err_drop_auth;
 
index fc81324..83bda7f 100644 (file)
@@ -391,6 +391,7 @@ static int crypto_authenc_esn_create(struct crypto_template *tmpl,
                                     struct rtattr **tb)
 {
        struct crypto_attr_type *algt;
+       u32 mask;
        struct aead_instance *inst;
        struct hash_alg_common *auth;
        struct crypto_alg *auth_base;
@@ -406,9 +407,10 @@ static int crypto_authenc_esn_create(struct crypto_template *tmpl,
        if ((algt->type ^ CRYPTO_ALG_TYPE_AEAD) & algt->mask)
                return -EINVAL;
 
+       mask = crypto_requires_sync(algt->type, algt->mask);
+
        auth = ahash_attr_alg(tb[1], CRYPTO_ALG_TYPE_HASH,
-                             CRYPTO_ALG_TYPE_AHASH_MASK |
-                             crypto_requires_sync(algt->type, algt->mask));
+                             CRYPTO_ALG_TYPE_AHASH_MASK | mask);
        if (IS_ERR(auth))
                return PTR_ERR(auth);
 
@@ -431,10 +433,8 @@ static int crypto_authenc_esn_create(struct crypto_template *tmpl,
        if (err)
                goto err_free_inst;
 
-       crypto_set_skcipher_spawn(&ctx->enc, aead_crypto_instance(inst));
-       err = crypto_grab_skcipher(&ctx->enc, enc_name, 0,
-                                  crypto_requires_sync(algt->type,
-                                                       algt->mask));
+       err = crypto_grab_skcipher(&ctx->enc, aead_crypto_instance(inst),
+                                  enc_name, 0, mask);
        if (err)
                goto err_drop_auth;
 
index 4410452..4414f0d 100644 (file)
@@ -450,6 +450,7 @@ static int crypto_ccm_create_common(struct crypto_template *tmpl,
                                    const char *mac_name)
 {
        struct crypto_attr_type *algt;
+       u32 mask;
        struct aead_instance *inst;
        struct skcipher_alg *ctr;
        struct crypto_alg *mac_alg;
@@ -464,6 +465,8 @@ static int crypto_ccm_create_common(struct crypto_template *tmpl,
        if ((algt->type ^ CRYPTO_ALG_TYPE_AEAD) & algt->mask)
                return -EINVAL;
 
+       mask = crypto_requires_sync(algt->type, algt->mask);
+
        mac_alg = crypto_find_alg(mac_name, &crypto_ahash_type,
                                  CRYPTO_ALG_TYPE_HASH,
                                  CRYPTO_ALG_TYPE_AHASH_MASK |
@@ -488,10 +491,8 @@ static int crypto_ccm_create_common(struct crypto_template *tmpl,
        if (err)
                goto err_free_inst;
 
-       crypto_set_skcipher_spawn(&ictx->ctr, aead_crypto_instance(inst));
-       err = crypto_grab_skcipher(&ictx->ctr, ctr_name, 0,
-                                  crypto_requires_sync(algt->type,
-                                                       algt->mask));
+       err = crypto_grab_skcipher(&ictx->ctr, aead_crypto_instance(inst),
+                                  ctr_name, 0, mask);
        if (err)
                goto err_drop_mac;
 
index 88cbdab..09d5a34 100644 (file)
@@ -558,6 +558,7 @@ static int chachapoly_create(struct crypto_template *tmpl, struct rtattr **tb,
                             const char *name, unsigned int ivsize)
 {
        struct crypto_attr_type *algt;
+       u32 mask;
        struct aead_instance *inst;
        struct skcipher_alg *chacha;
        struct crypto_alg *poly;
@@ -576,6 +577,8 @@ static int chachapoly_create(struct crypto_template *tmpl, struct rtattr **tb,
        if ((algt->type ^ CRYPTO_ALG_TYPE_AEAD) & algt->mask)
                return -EINVAL;
 
+       mask = crypto_requires_sync(algt->type, algt->mask);
+
        chacha_name = crypto_attr_alg_name(tb[1]);
        if (IS_ERR(chacha_name))
                return PTR_ERR(chacha_name);
@@ -585,9 +588,7 @@ static int chachapoly_create(struct crypto_template *tmpl, struct rtattr **tb,
 
        poly = crypto_find_alg(poly_name, &crypto_ahash_type,
                               CRYPTO_ALG_TYPE_HASH,
-                              CRYPTO_ALG_TYPE_AHASH_MASK |
-                              crypto_requires_sync(algt->type,
-                                                   algt->mask));
+                              CRYPTO_ALG_TYPE_AHASH_MASK | mask);
        if (IS_ERR(poly))
                return PTR_ERR(poly);
        poly_hash = __crypto_hash_alg_common(poly);
@@ -608,10 +609,8 @@ static int chachapoly_create(struct crypto_template *tmpl, struct rtattr **tb,
        if (err)
                goto err_free_inst;
 
-       crypto_set_skcipher_spawn(&ctx->chacha, aead_crypto_instance(inst));
-       err = crypto_grab_skcipher(&ctx->chacha, chacha_name, 0,
-                                  crypto_requires_sync(algt->type,
-                                                       algt->mask));
+       err = crypto_grab_skcipher(&ctx->chacha, aead_crypto_instance(inst),
+                                  chacha_name, 0, mask);
        if (err)
                goto err_drop_poly;
 
index cd94243..a0fe106 100644 (file)
@@ -416,8 +416,8 @@ static int cryptd_create_skcipher(struct crypto_template *tmpl,
        ctx = skcipher_instance_ctx(inst);
        ctx->queue = queue;
 
-       crypto_set_skcipher_spawn(&ctx->spawn, skcipher_crypto_instance(inst));
-       err = crypto_grab_skcipher(&ctx->spawn, name, type, mask);
+       err = crypto_grab_skcipher(&ctx->spawn, skcipher_crypto_instance(inst),
+                                  name, type, mask);
        if (err)
                goto out_free_inst;
 
index b63b19d..a8feab6 100644 (file)
@@ -286,8 +286,8 @@ static int crypto_rfc3686_create(struct crypto_template *tmpl,
 
        spawn = skcipher_instance_ctx(inst);
 
-       crypto_set_skcipher_spawn(spawn, skcipher_crypto_instance(inst));
-       err = crypto_grab_skcipher(spawn, cipher_name, 0, mask);
+       err = crypto_grab_skcipher(spawn, skcipher_crypto_instance(inst),
+                                  cipher_name, 0, mask);
        if (err)
                goto err_free_inst;
 
index a0bb994..48188ad 100644 (file)
@@ -328,6 +328,7 @@ static int crypto_cts_create(struct crypto_template *tmpl, struct rtattr **tb)
        struct crypto_attr_type *algt;
        struct skcipher_alg *alg;
        const char *cipher_name;
+       u32 mask;
        int err;
 
        algt = crypto_get_attr_type(tb);
@@ -337,6 +338,8 @@ static int crypto_cts_create(struct crypto_template *tmpl, struct rtattr **tb)
        if ((algt->type ^ CRYPTO_ALG_TYPE_SKCIPHER) & algt->mask)
                return -EINVAL;
 
+       mask = crypto_requires_sync(algt->type, algt->mask);
+
        cipher_name = crypto_attr_alg_name(tb[1]);
        if (IS_ERR(cipher_name))
                return PTR_ERR(cipher_name);
@@ -347,10 +350,8 @@ static int crypto_cts_create(struct crypto_template *tmpl, struct rtattr **tb)
 
        spawn = skcipher_instance_ctx(inst);
 
-       crypto_set_skcipher_spawn(spawn, skcipher_crypto_instance(inst));
-       err = crypto_grab_skcipher(spawn, cipher_name, 0,
-                                  crypto_requires_sync(algt->type,
-                                                       algt->mask));
+       err = crypto_grab_skcipher(spawn, skcipher_crypto_instance(inst),
+                                  cipher_name, 0, mask);
        if (err)
                goto err_free_inst;
 
index 61d9000..0e45f5b 100644 (file)
@@ -452,6 +452,7 @@ static int essiv_create(struct crypto_template *tmpl, struct rtattr **tb)
        struct shash_alg *hash_alg;
        int ivsize;
        u32 type;
+       u32 mask;
        int err;
 
        algt = crypto_get_attr_type(tb);
@@ -467,6 +468,7 @@ static int essiv_create(struct crypto_template *tmpl, struct rtattr **tb)
                return PTR_ERR(shash_name);
 
        type = algt->type & algt->mask;
+       mask = crypto_requires_sync(algt->type, algt->mask);
 
        switch (type) {
        case CRYPTO_ALG_TYPE_SKCIPHER:
@@ -479,11 +481,8 @@ static int essiv_create(struct crypto_template *tmpl, struct rtattr **tb)
                ictx = crypto_instance_ctx(inst);
 
                /* Symmetric cipher, e.g., "cbc(aes)" */
-               crypto_set_skcipher_spawn(&ictx->u.skcipher_spawn, inst);
-               err = crypto_grab_skcipher(&ictx->u.skcipher_spawn,
-                                          inner_cipher_name, 0,
-                                          crypto_requires_sync(algt->type,
-                                                               algt->mask));
+               err = crypto_grab_skcipher(&ictx->u.skcipher_spawn, inst,
+                                          inner_cipher_name, 0, mask);
                if (err)
                        goto out_free_inst;
                skcipher_alg = crypto_spawn_skcipher_alg(&ictx->u.skcipher_spawn);
@@ -503,9 +502,7 @@ static int essiv_create(struct crypto_template *tmpl, struct rtattr **tb)
                /* AEAD cipher, e.g., "authenc(hmac(sha256),cbc(aes))" */
                crypto_set_aead_spawn(&ictx->u.aead_spawn, inst);
                err = crypto_grab_aead(&ictx->u.aead_spawn,
-                                      inner_cipher_name, 0,
-                                      crypto_requires_sync(algt->type,
-                                                           algt->mask));
+                                      inner_cipher_name, 0, mask);
                if (err)
                        goto out_free_inst;
                aead_alg = crypto_spawn_aead_alg(&ictx->u.aead_spawn);
index 7041cb1..887f472 100644 (file)
@@ -580,6 +580,7 @@ static int crypto_gcm_create_common(struct crypto_template *tmpl,
                                    const char *ghash_name)
 {
        struct crypto_attr_type *algt;
+       u32 mask;
        struct aead_instance *inst;
        struct skcipher_alg *ctr;
        struct crypto_alg *ghash_alg;
@@ -594,11 +595,11 @@ static int crypto_gcm_create_common(struct crypto_template *tmpl,
        if ((algt->type ^ CRYPTO_ALG_TYPE_AEAD) & algt->mask)
                return -EINVAL;
 
+       mask = crypto_requires_sync(algt->type, algt->mask);
+
        ghash_alg = crypto_find_alg(ghash_name, &crypto_ahash_type,
                                    CRYPTO_ALG_TYPE_HASH,
-                                   CRYPTO_ALG_TYPE_AHASH_MASK |
-                                   crypto_requires_sync(algt->type,
-                                                        algt->mask));
+                                   CRYPTO_ALG_TYPE_AHASH_MASK | mask);
        if (IS_ERR(ghash_alg))
                return PTR_ERR(ghash_alg);
 
@@ -620,10 +621,8 @@ static int crypto_gcm_create_common(struct crypto_template *tmpl,
            ghash->digestsize != 16)
                goto err_drop_ghash;
 
-       crypto_set_skcipher_spawn(&ctx->ctr, aead_crypto_instance(inst));
-       err = crypto_grab_skcipher(&ctx->ctr, ctr_name, 0,
-                                  crypto_requires_sync(algt->type,
-                                                       algt->mask));
+       err = crypto_grab_skcipher(&ctx->ctr, aead_crypto_instance(inst),
+                                  ctr_name, 0, mask);
        if (err)
                goto err_drop_ghash;
 
index 8ebd792..63c485c 100644 (file)
@@ -301,6 +301,7 @@ static int create(struct crypto_template *tmpl, struct rtattr **tb)
        struct skcipher_alg *alg;
        const char *cipher_name;
        char ecb_name[CRYPTO_MAX_ALG_NAME];
+       u32 mask;
        int err;
 
        algt = crypto_get_attr_type(tb);
@@ -310,6 +311,8 @@ static int create(struct crypto_template *tmpl, struct rtattr **tb)
        if ((algt->type ^ CRYPTO_ALG_TYPE_SKCIPHER) & algt->mask)
                return -EINVAL;
 
+       mask = crypto_requires_sync(algt->type, algt->mask);
+
        cipher_name = crypto_attr_alg_name(tb[1]);
        if (IS_ERR(cipher_name))
                return PTR_ERR(cipher_name);
@@ -320,19 +323,17 @@ static int create(struct crypto_template *tmpl, struct rtattr **tb)
 
        spawn = skcipher_instance_ctx(inst);
 
-       crypto_set_skcipher_spawn(spawn, skcipher_crypto_instance(inst));
-       err = crypto_grab_skcipher(spawn, cipher_name, 0,
-                                  crypto_requires_sync(algt->type,
-                                                       algt->mask));
+       err = crypto_grab_skcipher(spawn, skcipher_crypto_instance(inst),
+                                  cipher_name, 0, mask);
        if (err == -ENOENT) {
                err = -ENAMETOOLONG;
                if (snprintf(ecb_name, CRYPTO_MAX_ALG_NAME, "ecb(%s)",
                             cipher_name) >= CRYPTO_MAX_ALG_NAME)
                        goto err_free_inst;
 
-               err = crypto_grab_skcipher(spawn, ecb_name, 0,
-                                          crypto_requires_sync(algt->type,
-                                                               algt->mask));
+               err = crypto_grab_skcipher(spawn,
+                                          skcipher_crypto_instance(inst),
+                                          ecb_name, 0, mask);
        }
 
        if (err)
index 89137a1..8759d47 100644 (file)
@@ -747,8 +747,10 @@ static const struct crypto_type crypto_skcipher_type = {
 };
 
 int crypto_grab_skcipher(struct crypto_skcipher_spawn *spawn,
-                         const char *name, u32 type, u32 mask)
+                        struct crypto_instance *inst,
+                        const char *name, u32 type, u32 mask)
 {
+       spawn->base.inst = inst;
        spawn->base.frontend = &crypto_skcipher_type;
        return crypto_grab_spawn(&spawn->base, name, type, mask);
 }
index 19d5548..29efa15 100644 (file)
@@ -355,20 +355,21 @@ static int create(struct crypto_template *tmpl, struct rtattr **tb)
 
        ctx = skcipher_instance_ctx(inst);
 
-       crypto_set_skcipher_spawn(&ctx->spawn, skcipher_crypto_instance(inst));
-
        mask = crypto_requires_off(algt->type, algt->mask,
                                   CRYPTO_ALG_NEED_FALLBACK |
                                   CRYPTO_ALG_ASYNC);
 
-       err = crypto_grab_skcipher(&ctx->spawn, cipher_name, 0, mask);
+       err = crypto_grab_skcipher(&ctx->spawn, skcipher_crypto_instance(inst),
+                                  cipher_name, 0, mask);
        if (err == -ENOENT) {
                err = -ENAMETOOLONG;
                if (snprintf(ctx->name, CRYPTO_MAX_ALG_NAME, "ecb(%s)",
                             cipher_name) >= CRYPTO_MAX_ALG_NAME)
                        goto err_free_inst;
 
-               err = crypto_grab_skcipher(&ctx->spawn, ctx->name, 0, mask);
+               err = crypto_grab_skcipher(&ctx->spawn,
+                                          skcipher_crypto_instance(inst),
+                                          ctx->name, 0, mask);
        }
 
        if (err)
index df4fdea..e387424 100644 (file)
@@ -88,14 +88,9 @@ static inline void skcipher_request_complete(struct skcipher_request *req, int e
        req->base.complete(&req->base, err);
 }
 
-static inline void crypto_set_skcipher_spawn(
-       struct crypto_skcipher_spawn *spawn, struct crypto_instance *inst)
-{
-       crypto_set_spawn(&spawn->base, inst);
-}
-
-int crypto_grab_skcipher(struct crypto_skcipher_spawn *spawn, const char *name,
-                        u32 type, u32 mask);
+int crypto_grab_skcipher(struct crypto_skcipher_spawn *spawn,
+                        struct crypto_instance *inst,
+                        const char *name, u32 type, u32 mask);
 
 static inline void crypto_drop_skcipher(struct crypto_skcipher_spawn *spawn)
 {