io_uring: don't assume mm is constant across submits
[linux-2.6-microblaze.git] / crypto / sm2.c
1 /* SPDX-License-Identifier: GPL-2.0-or-later */
2 /*
3  * SM2 asymmetric public-key algorithm
4  * as specified by OSCCA GM/T 0003.1-2012 -- 0003.5-2012 SM2 and
5  * described at https://tools.ietf.org/html/draft-shen-sm2-ecdsa-02
6  *
7  * Copyright (c) 2020, Alibaba Group.
8  * Authors: Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9  */
10
11 #include <linux/module.h>
12 #include <linux/mpi.h>
13 #include <crypto/internal/akcipher.h>
14 #include <crypto/akcipher.h>
15 #include <crypto/hash.h>
16 #include <crypto/sm3_base.h>
17 #include <crypto/rng.h>
18 #include <crypto/sm2.h>
19 #include "sm2signature.asn1.h"
20
21 #define MPI_NBYTES(m)   ((mpi_get_nbits(m) + 7) / 8)
22
23 struct ecc_domain_parms {
24         const char *desc;           /* Description of the curve.  */
25         unsigned int nbits;         /* Number of bits.  */
26         unsigned int fips:1; /* True if this is a FIPS140-2 approved curve */
27
28         /* The model describing this curve.  This is mainly used to select
29          * the group equation.
30          */
31         enum gcry_mpi_ec_models model;
32
33         /* The actual ECC dialect used.  This is used for curve specific
34          * optimizations and to select encodings etc.
35          */
36         enum ecc_dialects dialect;
37
38         const char *p;              /* The prime defining the field.  */
39         const char *a, *b;          /* The coefficients.  For Twisted Edwards
40                                      * Curves b is used for d.  For Montgomery
41                                      * Curves (a,b) has ((A-2)/4,B^-1).
42                                      */
43         const char *n;              /* The order of the base point.  */
44         const char *g_x, *g_y;      /* Base point.  */
45         unsigned int h;             /* Cofactor.  */
46 };
47
48 static const struct ecc_domain_parms sm2_ecp = {
49         .desc = "sm2p256v1",
50         .nbits = 256,
51         .fips = 0,
52         .model = MPI_EC_WEIERSTRASS,
53         .dialect = ECC_DIALECT_STANDARD,
54         .p   = "0xfffffffeffffffffffffffffffffffffffffffff00000000ffffffffffffffff",
55         .a   = "0xfffffffeffffffffffffffffffffffffffffffff00000000fffffffffffffffc",
56         .b   = "0x28e9fa9e9d9f5e344d5a9e4bcf6509a7f39789f515ab8f92ddbcbd414d940e93",
57         .n   = "0xfffffffeffffffffffffffffffffffff7203df6b21c6052b53bbf40939d54123",
58         .g_x = "0x32c4ae2c1f1981195f9904466a39c9948fe30bbff2660be1715a4589334c74c7",
59         .g_y = "0xbc3736a2f4f6779c59bdcee36b692153d0a9877cc62a474002df32e52139f0a0",
60         .h = 1
61 };
62
63 static int sm2_ec_ctx_init(struct mpi_ec_ctx *ec)
64 {
65         const struct ecc_domain_parms *ecp = &sm2_ecp;
66         MPI p, a, b;
67         MPI x, y;
68         int rc = -EINVAL;
69
70         p = mpi_scanval(ecp->p);
71         a = mpi_scanval(ecp->a);
72         b = mpi_scanval(ecp->b);
73         if (!p || !a || !b)
74                 goto free_p;
75
76         x = mpi_scanval(ecp->g_x);
77         y = mpi_scanval(ecp->g_y);
78         if (!x || !y)
79                 goto free;
80
81         rc = -ENOMEM;
82         /* mpi_ec_setup_elliptic_curve */
83         ec->G = mpi_point_new(0);
84         if (!ec->G)
85                 goto free;
86
87         mpi_set(ec->G->x, x);
88         mpi_set(ec->G->y, y);
89         mpi_set_ui(ec->G->z, 1);
90
91         rc = -EINVAL;
92         ec->n = mpi_scanval(ecp->n);
93         if (!ec->n) {
94                 mpi_point_release(ec->G);
95                 goto free;
96         }
97
98         ec->h = ecp->h;
99         ec->name = ecp->desc;
100         mpi_ec_init(ec, ecp->model, ecp->dialect, 0, p, a, b);
101
102         rc = 0;
103
104 free:
105         mpi_free(x);
106         mpi_free(y);
107 free_p:
108         mpi_free(p);
109         mpi_free(a);
110         mpi_free(b);
111
112         return rc;
113 }
114
115 static void sm2_ec_ctx_deinit(struct mpi_ec_ctx *ec)
116 {
117         mpi_ec_deinit(ec);
118
119         memset(ec, 0, sizeof(*ec));
120 }
121
122 /* RESULT must have been initialized and is set on success to the
123  * point given by VALUE.
124  */
125 static int sm2_ecc_os2ec(MPI_POINT result, MPI value)
126 {
127         int rc;
128         size_t n;
129         unsigned char *buf;
130         MPI x, y;
131
132         n = MPI_NBYTES(value);
133         buf = kmalloc(n, GFP_KERNEL);
134         if (!buf)
135                 return -ENOMEM;
136
137         rc = mpi_print(GCRYMPI_FMT_USG, buf, n, &n, value);
138         if (rc)
139                 goto err_freebuf;
140
141         rc = -EINVAL;
142         if (n < 1 || ((n - 1) % 2))
143                 goto err_freebuf;
144         /* No support for point compression */
145         if (*buf != 0x4)
146                 goto err_freebuf;
147
148         rc = -ENOMEM;
149         n = (n - 1) / 2;
150         x = mpi_read_raw_data(buf + 1, n);
151         if (!x)
152                 goto err_freebuf;
153         y = mpi_read_raw_data(buf + 1 + n, n);
154         if (!y)
155                 goto err_freex;
156
157         mpi_normalize(x);
158         mpi_normalize(y);
159         mpi_set(result->x, x);
160         mpi_set(result->y, y);
161         mpi_set_ui(result->z, 1);
162
163         rc = 0;
164
165         mpi_free(y);
166 err_freex:
167         mpi_free(x);
168 err_freebuf:
169         kfree(buf);
170         return rc;
171 }
172
173 struct sm2_signature_ctx {
174         MPI sig_r;
175         MPI sig_s;
176 };
177
178 int sm2_get_signature_r(void *context, size_t hdrlen, unsigned char tag,
179                                 const void *value, size_t vlen)
180 {
181         struct sm2_signature_ctx *sig = context;
182
183         if (!value || !vlen)
184                 return -EINVAL;
185
186         sig->sig_r = mpi_read_raw_data(value, vlen);
187         if (!sig->sig_r)
188                 return -ENOMEM;
189
190         return 0;
191 }
192
193 int sm2_get_signature_s(void *context, size_t hdrlen, unsigned char tag,
194                                 const void *value, size_t vlen)
195 {
196         struct sm2_signature_ctx *sig = context;
197
198         if (!value || !vlen)
199                 return -EINVAL;
200
201         sig->sig_s = mpi_read_raw_data(value, vlen);
202         if (!sig->sig_s)
203                 return -ENOMEM;
204
205         return 0;
206 }
207
208 static int sm2_z_digest_update(struct shash_desc *desc,
209                         MPI m, unsigned int pbytes)
210 {
211         static const unsigned char zero[32];
212         unsigned char *in;
213         unsigned int inlen;
214
215         in = mpi_get_buffer(m, &inlen, NULL);
216         if (!in)
217                 return -EINVAL;
218
219         if (inlen < pbytes) {
220                 /* padding with zero */
221                 crypto_sm3_update(desc, zero, pbytes - inlen);
222                 crypto_sm3_update(desc, in, inlen);
223         } else if (inlen > pbytes) {
224                 /* skip the starting zero */
225                 crypto_sm3_update(desc, in + inlen - pbytes, pbytes);
226         } else {
227                 crypto_sm3_update(desc, in, inlen);
228         }
229
230         kfree(in);
231         return 0;
232 }
233
234 static int sm2_z_digest_update_point(struct shash_desc *desc,
235                 MPI_POINT point, struct mpi_ec_ctx *ec, unsigned int pbytes)
236 {
237         MPI x, y;
238         int ret = -EINVAL;
239
240         x = mpi_new(0);
241         y = mpi_new(0);
242
243         if (!mpi_ec_get_affine(x, y, point, ec) &&
244                 !sm2_z_digest_update(desc, x, pbytes) &&
245                 !sm2_z_digest_update(desc, y, pbytes))
246                 ret = 0;
247
248         mpi_free(x);
249         mpi_free(y);
250         return ret;
251 }
252
253 int sm2_compute_z_digest(struct crypto_akcipher *tfm,
254                         const unsigned char *id, size_t id_len,
255                         unsigned char dgst[SM3_DIGEST_SIZE])
256 {
257         struct mpi_ec_ctx *ec = akcipher_tfm_ctx(tfm);
258         uint16_t bits_len;
259         unsigned char entl[2];
260         SHASH_DESC_ON_STACK(desc, NULL);
261         unsigned int pbytes;
262
263         if (id_len > (USHRT_MAX / 8) || !ec->Q)
264                 return -EINVAL;
265
266         bits_len = (uint16_t)(id_len * 8);
267         entl[0] = bits_len >> 8;
268         entl[1] = bits_len & 0xff;
269
270         pbytes = MPI_NBYTES(ec->p);
271
272         /* ZA = H256(ENTLA | IDA | a | b | xG | yG | xA | yA) */
273         sm3_base_init(desc);
274         crypto_sm3_update(desc, entl, 2);
275         crypto_sm3_update(desc, id, id_len);
276
277         if (sm2_z_digest_update(desc, ec->a, pbytes) ||
278                 sm2_z_digest_update(desc, ec->b, pbytes) ||
279                 sm2_z_digest_update_point(desc, ec->G, ec, pbytes) ||
280                 sm2_z_digest_update_point(desc, ec->Q, ec, pbytes))
281                 return -EINVAL;
282
283         crypto_sm3_final(desc, dgst);
284         return 0;
285 }
286 EXPORT_SYMBOL(sm2_compute_z_digest);
287
288 static int _sm2_verify(struct mpi_ec_ctx *ec, MPI hash, MPI sig_r, MPI sig_s)
289 {
290         int rc = -EINVAL;
291         struct gcry_mpi_point sG, tP;
292         MPI t = NULL;
293         MPI x1 = NULL, y1 = NULL;
294
295         mpi_point_init(&sG);
296         mpi_point_init(&tP);
297         x1 = mpi_new(0);
298         y1 = mpi_new(0);
299         t = mpi_new(0);
300
301         /* r, s in [1, n-1] */
302         if (mpi_cmp_ui(sig_r, 1) < 0 || mpi_cmp(sig_r, ec->n) > 0 ||
303                 mpi_cmp_ui(sig_s, 1) < 0 || mpi_cmp(sig_s, ec->n) > 0) {
304                 goto leave;
305         }
306
307         /* t = (r + s) % n, t == 0 */
308         mpi_addm(t, sig_r, sig_s, ec->n);
309         if (mpi_cmp_ui(t, 0) == 0)
310                 goto leave;
311
312         /* sG + tP = (x1, y1) */
313         rc = -EBADMSG;
314         mpi_ec_mul_point(&sG, sig_s, ec->G, ec);
315         mpi_ec_mul_point(&tP, t, ec->Q, ec);
316         mpi_ec_add_points(&sG, &sG, &tP, ec);
317         if (mpi_ec_get_affine(x1, y1, &sG, ec))
318                 goto leave;
319
320         /* R = (e + x1) % n */
321         mpi_addm(t, hash, x1, ec->n);
322
323         /* check R == r */
324         rc = -EKEYREJECTED;
325         if (mpi_cmp(t, sig_r))
326                 goto leave;
327
328         rc = 0;
329
330 leave:
331         mpi_point_free_parts(&sG);
332         mpi_point_free_parts(&tP);
333         mpi_free(x1);
334         mpi_free(y1);
335         mpi_free(t);
336
337         return rc;
338 }
339
340 static int sm2_verify(struct akcipher_request *req)
341 {
342         struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
343         struct mpi_ec_ctx *ec = akcipher_tfm_ctx(tfm);
344         unsigned char *buffer;
345         struct sm2_signature_ctx sig;
346         MPI hash;
347         int ret;
348
349         if (unlikely(!ec->Q))
350                 return -EINVAL;
351
352         buffer = kmalloc(req->src_len + req->dst_len, GFP_KERNEL);
353         if (!buffer)
354                 return -ENOMEM;
355
356         sg_pcopy_to_buffer(req->src,
357                 sg_nents_for_len(req->src, req->src_len + req->dst_len),
358                 buffer, req->src_len + req->dst_len, 0);
359
360         sig.sig_r = NULL;
361         sig.sig_s = NULL;
362         ret = asn1_ber_decoder(&sm2signature_decoder, &sig,
363                                 buffer, req->src_len);
364         if (ret)
365                 goto error;
366
367         ret = -ENOMEM;
368         hash = mpi_read_raw_data(buffer + req->src_len, req->dst_len);
369         if (!hash)
370                 goto error;
371
372         ret = _sm2_verify(ec, hash, sig.sig_r, sig.sig_s);
373
374         mpi_free(hash);
375 error:
376         mpi_free(sig.sig_r);
377         mpi_free(sig.sig_s);
378         kfree(buffer);
379         return ret;
380 }
381
382 static int sm2_set_pub_key(struct crypto_akcipher *tfm,
383                         const void *key, unsigned int keylen)
384 {
385         struct mpi_ec_ctx *ec = akcipher_tfm_ctx(tfm);
386         MPI a;
387         int rc;
388
389         ec->Q = mpi_point_new(0);
390         if (!ec->Q)
391                 return -ENOMEM;
392
393         /* include the uncompressed flag '0x04' */
394         rc = -ENOMEM;
395         a = mpi_read_raw_data(key, keylen);
396         if (!a)
397                 goto error;
398
399         mpi_normalize(a);
400         rc = sm2_ecc_os2ec(ec->Q, a);
401         mpi_free(a);
402         if (rc)
403                 goto error;
404
405         return 0;
406
407 error:
408         mpi_point_release(ec->Q);
409         ec->Q = NULL;
410         return rc;
411 }
412
413 static unsigned int sm2_max_size(struct crypto_akcipher *tfm)
414 {
415         /* Unlimited max size */
416         return PAGE_SIZE;
417 }
418
419 static int sm2_init_tfm(struct crypto_akcipher *tfm)
420 {
421         struct mpi_ec_ctx *ec = akcipher_tfm_ctx(tfm);
422
423         return sm2_ec_ctx_init(ec);
424 }
425
426 static void sm2_exit_tfm(struct crypto_akcipher *tfm)
427 {
428         struct mpi_ec_ctx *ec = akcipher_tfm_ctx(tfm);
429
430         sm2_ec_ctx_deinit(ec);
431 }
432
433 static struct akcipher_alg sm2 = {
434         .verify = sm2_verify,
435         .set_pub_key = sm2_set_pub_key,
436         .max_size = sm2_max_size,
437         .init = sm2_init_tfm,
438         .exit = sm2_exit_tfm,
439         .base = {
440                 .cra_name = "sm2",
441                 .cra_driver_name = "sm2-generic",
442                 .cra_priority = 100,
443                 .cra_module = THIS_MODULE,
444                 .cra_ctxsize = sizeof(struct mpi_ec_ctx),
445         },
446 };
447
448 static int sm2_init(void)
449 {
450         return crypto_register_akcipher(&sm2);
451 }
452
453 static void sm2_exit(void)
454 {
455         crypto_unregister_akcipher(&sm2);
456 }
457
458 subsys_initcall(sm2_init);
459 module_exit(sm2_exit);
460
461 MODULE_LICENSE("GPL");
462 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
463 MODULE_DESCRIPTION("SM2 generic algorithm");
464 MODULE_ALIAS_CRYPTO("sm2-generic");