Merge tag 'block-6.5-2023-07-03' of git://git.kernel.dk/linux
[linux-2.6-microblaze.git] / crypto / sm2.c
index ed9307d..285b3cb 100644 (file)
 #include <crypto/internal/akcipher.h>
 #include <crypto/akcipher.h>
 #include <crypto/hash.h>
-#include <crypto/sm3.h>
 #include <crypto/rng.h>
 #include <crypto/sm2.h>
 #include "sm2signature.asn1.h"
 
+/* The default user id as specified in GM/T 0009-2012 */
+#define SM2_DEFAULT_USERID "1234567812345678"
+#define SM2_DEFAULT_USERID_LEN 16
+
 #define MPI_NBYTES(m)   ((mpi_get_nbits(m) + 7) / 8)
 
 struct ecc_domain_parms {
@@ -60,6 +63,9 @@ static const struct ecc_domain_parms sm2_ecp = {
        .h = 1
 };
 
+static int __sm2_set_pub_key(struct mpi_ec_ctx *ec,
+                            const void *key, unsigned int keylen);
+
 static int sm2_ec_ctx_init(struct mpi_ec_ctx *ec)
 {
        const struct ecc_domain_parms *ecp = &sm2_ecp;
@@ -213,12 +219,13 @@ int sm2_get_signature_s(void *context, size_t hdrlen, unsigned char tag,
        return 0;
 }
 
-static int sm2_z_digest_update(struct sm3_state *sctx,
-                       MPI m, unsigned int pbytes)
+static int sm2_z_digest_update(struct shash_desc *desc,
+                              MPI m, unsigned int pbytes)
 {
        static const unsigned char zero[32];
        unsigned char *in;
        unsigned int inlen;
+       int err;
 
        in = mpi_get_buffer(m, &inlen, NULL);
        if (!in)
@@ -226,21 +233,22 @@ static int sm2_z_digest_update(struct sm3_state *sctx,
 
        if (inlen < pbytes) {
                /* padding with zero */
-               sm3_update(sctx, zero, pbytes - inlen);
-               sm3_update(sctx, in, inlen);
+               err = crypto_shash_update(desc, zero, pbytes - inlen) ?:
+                     crypto_shash_update(desc, in, inlen);
        } else if (inlen > pbytes) {
                /* skip the starting zero */
-               sm3_update(sctx, in + inlen - pbytes, pbytes);
+               err = crypto_shash_update(desc, in + inlen - pbytes, pbytes);
        } else {
-               sm3_update(sctx, in, inlen);
+               err = crypto_shash_update(desc, in, inlen);
        }
 
        kfree(in);
-       return 0;
+       return err;
 }
 
-static int sm2_z_digest_update_point(struct sm3_state *sctx,
-               MPI_POINT point, struct mpi_ec_ctx *ec, unsigned int pbytes)
+static int sm2_z_digest_update_point(struct shash_desc *desc,
+                                    MPI_POINT point, struct mpi_ec_ctx *ec,
+                                    unsigned int pbytes)
 {
        MPI x, y;
        int ret = -EINVAL;
@@ -248,50 +256,68 @@ static int sm2_z_digest_update_point(struct sm3_state *sctx,
        x = mpi_new(0);
        y = mpi_new(0);
 
-       if (!mpi_ec_get_affine(x, y, point, ec) &&
-           !sm2_z_digest_update(sctx, x, pbytes) &&
-           !sm2_z_digest_update(sctx, y, pbytes))
-               ret = 0;
+       ret = mpi_ec_get_affine(x, y, point, ec) ? -EINVAL :
+             sm2_z_digest_update(desc, x, pbytes) ?:
+             sm2_z_digest_update(desc, y, pbytes);
 
        mpi_free(x);
        mpi_free(y);
        return ret;
 }
 
-int sm2_compute_z_digest(struct crypto_akcipher *tfm,
-                       const unsigned char *id, size_t id_len,
-                       unsigned char dgst[SM3_DIGEST_SIZE])
+int sm2_compute_z_digest(struct shash_desc *desc,
+                        const void *key, unsigned int keylen, void *dgst)
 {
-       struct mpi_ec_ctx *ec = akcipher_tfm_ctx(tfm);
-       uint16_t bits_len;
-       unsigned char entl[2];
-       struct sm3_state sctx;
+       struct mpi_ec_ctx *ec;
+       unsigned int bits_len;
        unsigned int pbytes;
+       u8 entl[2];
+       int err;
 
-       if (id_len > (USHRT_MAX / 8) || !ec->Q)
-               return -EINVAL;
+       ec = kmalloc(sizeof(*ec), GFP_KERNEL);
+       if (!ec)
+               return -ENOMEM;
+
+       err = __sm2_set_pub_key(ec, key, keylen);
+       if (err)
+               goto out_free_ec;
 
-       bits_len = (uint16_t)(id_len * 8);
+       bits_len = SM2_DEFAULT_USERID_LEN * 8;
        entl[0] = bits_len >> 8;
        entl[1] = bits_len & 0xff;
 
        pbytes = MPI_NBYTES(ec->p);
 
        /* ZA = H256(ENTLA | IDA | a | b | xG | yG | xA | yA) */
-       sm3_init(&sctx);
-       sm3_update(&sctx, entl, 2);
-       sm3_update(&sctx, id, id_len);
-
-       if (sm2_z_digest_update(&sctx, ec->a, pbytes) ||
-           sm2_z_digest_update(&sctx, ec->b, pbytes) ||
-           sm2_z_digest_update_point(&sctx, ec->G, ec, pbytes) ||
-           sm2_z_digest_update_point(&sctx, ec->Q, ec, pbytes))
-               return -EINVAL;
+       err = crypto_shash_init(desc);
+       if (err)
+               goto out_deinit_ec;
 
-       sm3_final(&sctx, dgst);
-       return 0;
+       err = crypto_shash_update(desc, entl, 2);
+       if (err)
+               goto out_deinit_ec;
+
+       err = crypto_shash_update(desc, SM2_DEFAULT_USERID,
+                                 SM2_DEFAULT_USERID_LEN);
+       if (err)
+               goto out_deinit_ec;
+
+       err = sm2_z_digest_update(desc, ec->a, pbytes) ?:
+             sm2_z_digest_update(desc, ec->b, pbytes) ?:
+             sm2_z_digest_update_point(desc, ec->G, ec, pbytes) ?:
+             sm2_z_digest_update_point(desc, ec->Q, ec, pbytes);
+       if (err)
+               goto out_deinit_ec;
+
+       err = crypto_shash_final(desc, dgst);
+
+out_deinit_ec:
+       sm2_ec_ctx_deinit(ec);
+out_free_ec:
+       kfree(ec);
+       return err;
 }
-EXPORT_SYMBOL(sm2_compute_z_digest);
+EXPORT_SYMBOL_GPL(sm2_compute_z_digest);
 
 static int _sm2_verify(struct mpi_ec_ctx *ec, MPI hash, MPI sig_r, MPI sig_s)
 {
@@ -391,6 +417,14 @@ static int sm2_set_pub_key(struct crypto_akcipher *tfm,
                        const void *key, unsigned int keylen)
 {
        struct mpi_ec_ctx *ec = akcipher_tfm_ctx(tfm);
+
+       return __sm2_set_pub_key(ec, key, keylen);
+
+}
+
+static int __sm2_set_pub_key(struct mpi_ec_ctx *ec,
+                            const void *key, unsigned int keylen)
+{
        MPI a;
        int rc;