Linux 6.9-rc1
[linux-2.6-microblaze.git] / crypto / rsa.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* RSA asymmetric public-key algorithm [RFC3447]
3  *
4  * Copyright (c) 2015, Intel Corporation
5  * Authors: Tadeusz Struk <tadeusz.struk@intel.com>
6  */
7
8 #include <linux/fips.h>
9 #include <linux/module.h>
10 #include <linux/mpi.h>
11 #include <crypto/internal/rsa.h>
12 #include <crypto/internal/akcipher.h>
13 #include <crypto/akcipher.h>
14 #include <crypto/algapi.h>
15
16 struct rsa_mpi_key {
17         MPI n;
18         MPI e;
19         MPI d;
20         MPI p;
21         MPI q;
22         MPI dp;
23         MPI dq;
24         MPI qinv;
25 };
26
27 static int rsa_check_payload(MPI x, MPI n)
28 {
29         MPI n1;
30
31         if (mpi_cmp_ui(x, 1) <= 0)
32                 return -EINVAL;
33
34         n1 = mpi_alloc(0);
35         if (!n1)
36                 return -ENOMEM;
37
38         if (mpi_sub_ui(n1, n, 1) || mpi_cmp(x, n1) >= 0) {
39                 mpi_free(n1);
40                 return -EINVAL;
41         }
42
43         mpi_free(n1);
44         return 0;
45 }
46
47 /*
48  * RSAEP function [RFC3447 sec 5.1.1]
49  * c = m^e mod n;
50  */
51 static int _rsa_enc(const struct rsa_mpi_key *key, MPI c, MPI m)
52 {
53         /*
54          * Even though (1) in RFC3447 only requires 0 <= m <= n - 1, we are
55          * slightly more conservative and require 1 < m < n - 1. This is in line
56          * with SP 800-56Br2, Section 7.1.1.
57          */
58         if (rsa_check_payload(m, key->n))
59                 return -EINVAL;
60
61         /* (2) c = m^e mod n */
62         return mpi_powm(c, m, key->e, key->n);
63 }
64
65 /*
66  * RSADP function [RFC3447 sec 5.1.2]
67  * m_1 = c^dP mod p;
68  * m_2 = c^dQ mod q;
69  * h = (m_1 - m_2) * qInv mod p;
70  * m = m_2 + q * h;
71  */
72 static int _rsa_dec_crt(const struct rsa_mpi_key *key, MPI m_or_m1_or_h, MPI c)
73 {
74         MPI m2, m12_or_qh;
75         int ret = -ENOMEM;
76
77         /*
78          * Even though (1) in RFC3447 only requires 0 <= c <= n - 1, we are
79          * slightly more conservative and require 1 < c < n - 1. This is in line
80          * with SP 800-56Br2, Section 7.1.2.
81          */
82         if (rsa_check_payload(c, key->n))
83                 return -EINVAL;
84
85         m2 = mpi_alloc(0);
86         m12_or_qh = mpi_alloc(0);
87         if (!m2 || !m12_or_qh)
88                 goto err_free_mpi;
89
90         /* (2i) m_1 = c^dP mod p */
91         ret = mpi_powm(m_or_m1_or_h, c, key->dp, key->p);
92         if (ret)
93                 goto err_free_mpi;
94
95         /* (2i) m_2 = c^dQ mod q */
96         ret = mpi_powm(m2, c, key->dq, key->q);
97         if (ret)
98                 goto err_free_mpi;
99
100         /* (2iii) h = (m_1 - m_2) * qInv mod p */
101         mpi_sub(m12_or_qh, m_or_m1_or_h, m2);
102         mpi_mulm(m_or_m1_or_h, m12_or_qh, key->qinv, key->p);
103
104         /* (2iv) m = m_2 + q * h */
105         mpi_mul(m12_or_qh, key->q, m_or_m1_or_h);
106         mpi_addm(m_or_m1_or_h, m2, m12_or_qh, key->n);
107
108         ret = 0;
109
110 err_free_mpi:
111         mpi_free(m12_or_qh);
112         mpi_free(m2);
113         return ret;
114 }
115
116 static inline struct rsa_mpi_key *rsa_get_key(struct crypto_akcipher *tfm)
117 {
118         return akcipher_tfm_ctx(tfm);
119 }
120
121 static int rsa_enc(struct akcipher_request *req)
122 {
123         struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
124         const struct rsa_mpi_key *pkey = rsa_get_key(tfm);
125         MPI m, c = mpi_alloc(0);
126         int ret = 0;
127         int sign;
128
129         if (!c)
130                 return -ENOMEM;
131
132         if (unlikely(!pkey->n || !pkey->e)) {
133                 ret = -EINVAL;
134                 goto err_free_c;
135         }
136
137         ret = -ENOMEM;
138         m = mpi_read_raw_from_sgl(req->src, req->src_len);
139         if (!m)
140                 goto err_free_c;
141
142         ret = _rsa_enc(pkey, c, m);
143         if (ret)
144                 goto err_free_m;
145
146         ret = mpi_write_to_sgl(c, req->dst, req->dst_len, &sign);
147         if (ret)
148                 goto err_free_m;
149
150         if (sign < 0)
151                 ret = -EBADMSG;
152
153 err_free_m:
154         mpi_free(m);
155 err_free_c:
156         mpi_free(c);
157         return ret;
158 }
159
160 static int rsa_dec(struct akcipher_request *req)
161 {
162         struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
163         const struct rsa_mpi_key *pkey = rsa_get_key(tfm);
164         MPI c, m = mpi_alloc(0);
165         int ret = 0;
166         int sign;
167
168         if (!m)
169                 return -ENOMEM;
170
171         if (unlikely(!pkey->n || !pkey->d)) {
172                 ret = -EINVAL;
173                 goto err_free_m;
174         }
175
176         ret = -ENOMEM;
177         c = mpi_read_raw_from_sgl(req->src, req->src_len);
178         if (!c)
179                 goto err_free_m;
180
181         ret = _rsa_dec_crt(pkey, m, c);
182         if (ret)
183                 goto err_free_c;
184
185         ret = mpi_write_to_sgl(m, req->dst, req->dst_len, &sign);
186         if (ret)
187                 goto err_free_c;
188
189         if (sign < 0)
190                 ret = -EBADMSG;
191 err_free_c:
192         mpi_free(c);
193 err_free_m:
194         mpi_free(m);
195         return ret;
196 }
197
198 static void rsa_free_mpi_key(struct rsa_mpi_key *key)
199 {
200         mpi_free(key->d);
201         mpi_free(key->e);
202         mpi_free(key->n);
203         mpi_free(key->p);
204         mpi_free(key->q);
205         mpi_free(key->dp);
206         mpi_free(key->dq);
207         mpi_free(key->qinv);
208         key->d = NULL;
209         key->e = NULL;
210         key->n = NULL;
211         key->p = NULL;
212         key->q = NULL;
213         key->dp = NULL;
214         key->dq = NULL;
215         key->qinv = NULL;
216 }
217
218 static int rsa_check_key_length(unsigned int len)
219 {
220         switch (len) {
221         case 512:
222         case 1024:
223         case 1536:
224                 if (fips_enabled)
225                         return -EINVAL;
226                 fallthrough;
227         case 2048:
228         case 3072:
229         case 4096:
230                 return 0;
231         }
232
233         return -EINVAL;
234 }
235
236 static int rsa_check_exponent_fips(MPI e)
237 {
238         MPI e_max = NULL;
239
240         /* check if odd */
241         if (!mpi_test_bit(e, 0)) {
242                 return -EINVAL;
243         }
244
245         /* check if 2^16 < e < 2^256. */
246         if (mpi_cmp_ui(e, 65536) <= 0) {
247                 return -EINVAL;
248         }
249
250         e_max = mpi_alloc(0);
251         if (!e_max)
252                 return -ENOMEM;
253         mpi_set_bit(e_max, 256);
254
255         if (mpi_cmp(e, e_max) >= 0) {
256                 mpi_free(e_max);
257                 return -EINVAL;
258         }
259
260         mpi_free(e_max);
261         return 0;
262 }
263
264 static int rsa_set_pub_key(struct crypto_akcipher *tfm, const void *key,
265                            unsigned int keylen)
266 {
267         struct rsa_mpi_key *mpi_key = akcipher_tfm_ctx(tfm);
268         struct rsa_key raw_key = {0};
269         int ret;
270
271         /* Free the old MPI key if any */
272         rsa_free_mpi_key(mpi_key);
273
274         ret = rsa_parse_pub_key(&raw_key, key, keylen);
275         if (ret)
276                 return ret;
277
278         mpi_key->e = mpi_read_raw_data(raw_key.e, raw_key.e_sz);
279         if (!mpi_key->e)
280                 goto err;
281
282         mpi_key->n = mpi_read_raw_data(raw_key.n, raw_key.n_sz);
283         if (!mpi_key->n)
284                 goto err;
285
286         if (rsa_check_key_length(mpi_get_size(mpi_key->n) << 3)) {
287                 rsa_free_mpi_key(mpi_key);
288                 return -EINVAL;
289         }
290
291         if (fips_enabled && rsa_check_exponent_fips(mpi_key->e)) {
292                 rsa_free_mpi_key(mpi_key);
293                 return -EINVAL;
294         }
295
296         return 0;
297
298 err:
299         rsa_free_mpi_key(mpi_key);
300         return -ENOMEM;
301 }
302
303 static int rsa_set_priv_key(struct crypto_akcipher *tfm, const void *key,
304                             unsigned int keylen)
305 {
306         struct rsa_mpi_key *mpi_key = akcipher_tfm_ctx(tfm);
307         struct rsa_key raw_key = {0};
308         int ret;
309
310         /* Free the old MPI key if any */
311         rsa_free_mpi_key(mpi_key);
312
313         ret = rsa_parse_priv_key(&raw_key, key, keylen);
314         if (ret)
315                 return ret;
316
317         mpi_key->d = mpi_read_raw_data(raw_key.d, raw_key.d_sz);
318         if (!mpi_key->d)
319                 goto err;
320
321         mpi_key->e = mpi_read_raw_data(raw_key.e, raw_key.e_sz);
322         if (!mpi_key->e)
323                 goto err;
324
325         mpi_key->n = mpi_read_raw_data(raw_key.n, raw_key.n_sz);
326         if (!mpi_key->n)
327                 goto err;
328
329         mpi_key->p = mpi_read_raw_data(raw_key.p, raw_key.p_sz);
330         if (!mpi_key->p)
331                 goto err;
332
333         mpi_key->q = mpi_read_raw_data(raw_key.q, raw_key.q_sz);
334         if (!mpi_key->q)
335                 goto err;
336
337         mpi_key->dp = mpi_read_raw_data(raw_key.dp, raw_key.dp_sz);
338         if (!mpi_key->dp)
339                 goto err;
340
341         mpi_key->dq = mpi_read_raw_data(raw_key.dq, raw_key.dq_sz);
342         if (!mpi_key->dq)
343                 goto err;
344
345         mpi_key->qinv = mpi_read_raw_data(raw_key.qinv, raw_key.qinv_sz);
346         if (!mpi_key->qinv)
347                 goto err;
348
349         if (rsa_check_key_length(mpi_get_size(mpi_key->n) << 3)) {
350                 rsa_free_mpi_key(mpi_key);
351                 return -EINVAL;
352         }
353
354         if (fips_enabled && rsa_check_exponent_fips(mpi_key->e)) {
355                 rsa_free_mpi_key(mpi_key);
356                 return -EINVAL;
357         }
358
359         return 0;
360
361 err:
362         rsa_free_mpi_key(mpi_key);
363         return -ENOMEM;
364 }
365
366 static unsigned int rsa_max_size(struct crypto_akcipher *tfm)
367 {
368         struct rsa_mpi_key *pkey = akcipher_tfm_ctx(tfm);
369
370         return mpi_get_size(pkey->n);
371 }
372
373 static void rsa_exit_tfm(struct crypto_akcipher *tfm)
374 {
375         struct rsa_mpi_key *pkey = akcipher_tfm_ctx(tfm);
376
377         rsa_free_mpi_key(pkey);
378 }
379
380 static struct akcipher_alg rsa = {
381         .encrypt = rsa_enc,
382         .decrypt = rsa_dec,
383         .set_priv_key = rsa_set_priv_key,
384         .set_pub_key = rsa_set_pub_key,
385         .max_size = rsa_max_size,
386         .exit = rsa_exit_tfm,
387         .base = {
388                 .cra_name = "rsa",
389                 .cra_driver_name = "rsa-generic",
390                 .cra_priority = 100,
391                 .cra_module = THIS_MODULE,
392                 .cra_ctxsize = sizeof(struct rsa_mpi_key),
393         },
394 };
395
396 static int __init rsa_init(void)
397 {
398         int err;
399
400         err = crypto_register_akcipher(&rsa);
401         if (err)
402                 return err;
403
404         err = crypto_register_template(&rsa_pkcs1pad_tmpl);
405         if (err) {
406                 crypto_unregister_akcipher(&rsa);
407                 return err;
408         }
409
410         return 0;
411 }
412
413 static void __exit rsa_exit(void)
414 {
415         crypto_unregister_template(&rsa_pkcs1pad_tmpl);
416         crypto_unregister_akcipher(&rsa);
417 }
418
419 subsys_initcall(rsa_init);
420 module_exit(rsa_exit);
421 MODULE_ALIAS_CRYPTO("rsa");
422 MODULE_LICENSE("GPL");
423 MODULE_DESCRIPTION("RSA generic algorithm");