Merge tag 'defconfig-5.15' of git://git.kernel.org/pub/scm/linux/kernel/git/soc/soc
[linux-2.6-microblaze.git] / arch / x86 / crypto / sm4_aesni_avx_glue.c
1 /* SPDX-License-Identifier: GPL-2.0-or-later */
2 /*
3  * SM4 Cipher Algorithm, AES-NI/AVX optimized.
4  * as specified in
5  * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
6  *
7  * Copyright (c) 2021, Alibaba Group.
8  * Copyright (c) 2021 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9  */
10
11 #include <linux/module.h>
12 #include <linux/crypto.h>
13 #include <linux/kernel.h>
14 #include <asm/simd.h>
15 #include <crypto/internal/simd.h>
16 #include <crypto/internal/skcipher.h>
17 #include <crypto/sm4.h>
18 #include "sm4-avx.h"
19
20 #define SM4_CRYPT8_BLOCK_SIZE   (SM4_BLOCK_SIZE * 8)
21
22 asmlinkage void sm4_aesni_avx_crypt4(const u32 *rk, u8 *dst,
23                                 const u8 *src, int nblocks);
24 asmlinkage void sm4_aesni_avx_crypt8(const u32 *rk, u8 *dst,
25                                 const u8 *src, int nblocks);
26 asmlinkage void sm4_aesni_avx_ctr_enc_blk8(const u32 *rk, u8 *dst,
27                                 const u8 *src, u8 *iv);
28 asmlinkage void sm4_aesni_avx_cbc_dec_blk8(const u32 *rk, u8 *dst,
29                                 const u8 *src, u8 *iv);
30 asmlinkage void sm4_aesni_avx_cfb_dec_blk8(const u32 *rk, u8 *dst,
31                                 const u8 *src, u8 *iv);
32
33 static int sm4_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
34                         unsigned int key_len)
35 {
36         struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
37
38         return sm4_expandkey(ctx, key, key_len);
39 }
40
41 static int ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
42 {
43         struct skcipher_walk walk;
44         unsigned int nbytes;
45         int err;
46
47         err = skcipher_walk_virt(&walk, req, false);
48
49         while ((nbytes = walk.nbytes) > 0) {
50                 const u8 *src = walk.src.virt.addr;
51                 u8 *dst = walk.dst.virt.addr;
52
53                 kernel_fpu_begin();
54                 while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) {
55                         sm4_aesni_avx_crypt8(rkey, dst, src, 8);
56                         dst += SM4_CRYPT8_BLOCK_SIZE;
57                         src += SM4_CRYPT8_BLOCK_SIZE;
58                         nbytes -= SM4_CRYPT8_BLOCK_SIZE;
59                 }
60                 while (nbytes >= SM4_BLOCK_SIZE) {
61                         unsigned int nblocks = min(nbytes >> 4, 4u);
62                         sm4_aesni_avx_crypt4(rkey, dst, src, nblocks);
63                         dst += nblocks * SM4_BLOCK_SIZE;
64                         src += nblocks * SM4_BLOCK_SIZE;
65                         nbytes -= nblocks * SM4_BLOCK_SIZE;
66                 }
67                 kernel_fpu_end();
68
69                 err = skcipher_walk_done(&walk, nbytes);
70         }
71
72         return err;
73 }
74
75 int sm4_avx_ecb_encrypt(struct skcipher_request *req)
76 {
77         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
78         struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
79
80         return ecb_do_crypt(req, ctx->rkey_enc);
81 }
82 EXPORT_SYMBOL_GPL(sm4_avx_ecb_encrypt);
83
84 int sm4_avx_ecb_decrypt(struct skcipher_request *req)
85 {
86         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
87         struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
88
89         return ecb_do_crypt(req, ctx->rkey_dec);
90 }
91 EXPORT_SYMBOL_GPL(sm4_avx_ecb_decrypt);
92
93 int sm4_cbc_encrypt(struct skcipher_request *req)
94 {
95         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
96         struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
97         struct skcipher_walk walk;
98         unsigned int nbytes;
99         int err;
100
101         err = skcipher_walk_virt(&walk, req, false);
102
103         while ((nbytes = walk.nbytes) > 0) {
104                 const u8 *iv = walk.iv;
105                 const u8 *src = walk.src.virt.addr;
106                 u8 *dst = walk.dst.virt.addr;
107
108                 while (nbytes >= SM4_BLOCK_SIZE) {
109                         crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
110                         sm4_crypt_block(ctx->rkey_enc, dst, dst);
111                         iv = dst;
112                         src += SM4_BLOCK_SIZE;
113                         dst += SM4_BLOCK_SIZE;
114                         nbytes -= SM4_BLOCK_SIZE;
115                 }
116                 if (iv != walk.iv)
117                         memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
118
119                 err = skcipher_walk_done(&walk, nbytes);
120         }
121
122         return err;
123 }
124 EXPORT_SYMBOL_GPL(sm4_cbc_encrypt);
125
126 int sm4_avx_cbc_decrypt(struct skcipher_request *req,
127                         unsigned int bsize, sm4_crypt_func func)
128 {
129         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
130         struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
131         struct skcipher_walk walk;
132         unsigned int nbytes;
133         int err;
134
135         err = skcipher_walk_virt(&walk, req, false);
136
137         while ((nbytes = walk.nbytes) > 0) {
138                 const u8 *src = walk.src.virt.addr;
139                 u8 *dst = walk.dst.virt.addr;
140
141                 kernel_fpu_begin();
142
143                 while (nbytes >= bsize) {
144                         func(ctx->rkey_dec, dst, src, walk.iv);
145                         dst += bsize;
146                         src += bsize;
147                         nbytes -= bsize;
148                 }
149
150                 while (nbytes >= SM4_BLOCK_SIZE) {
151                         u8 keystream[SM4_BLOCK_SIZE * 8];
152                         u8 iv[SM4_BLOCK_SIZE];
153                         unsigned int nblocks = min(nbytes >> 4, 8u);
154                         int i;
155
156                         sm4_aesni_avx_crypt8(ctx->rkey_dec, keystream,
157                                                 src, nblocks);
158
159                         src += ((int)nblocks - 2) * SM4_BLOCK_SIZE;
160                         dst += (nblocks - 1) * SM4_BLOCK_SIZE;
161                         memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
162
163                         for (i = nblocks - 1; i > 0; i--) {
164                                 crypto_xor_cpy(dst, src,
165                                         &keystream[i * SM4_BLOCK_SIZE],
166                                         SM4_BLOCK_SIZE);
167                                 src -= SM4_BLOCK_SIZE;
168                                 dst -= SM4_BLOCK_SIZE;
169                         }
170                         crypto_xor_cpy(dst, walk.iv, keystream, SM4_BLOCK_SIZE);
171                         memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
172                         dst += nblocks * SM4_BLOCK_SIZE;
173                         src += (nblocks + 1) * SM4_BLOCK_SIZE;
174                         nbytes -= nblocks * SM4_BLOCK_SIZE;
175                 }
176
177                 kernel_fpu_end();
178                 err = skcipher_walk_done(&walk, nbytes);
179         }
180
181         return err;
182 }
183 EXPORT_SYMBOL_GPL(sm4_avx_cbc_decrypt);
184
185 static int cbc_decrypt(struct skcipher_request *req)
186 {
187         return sm4_avx_cbc_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
188                                 sm4_aesni_avx_cbc_dec_blk8);
189 }
190
191 int sm4_cfb_encrypt(struct skcipher_request *req)
192 {
193         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
194         struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
195         struct skcipher_walk walk;
196         unsigned int nbytes;
197         int err;
198
199         err = skcipher_walk_virt(&walk, req, false);
200
201         while ((nbytes = walk.nbytes) > 0) {
202                 u8 keystream[SM4_BLOCK_SIZE];
203                 const u8 *iv = walk.iv;
204                 const u8 *src = walk.src.virt.addr;
205                 u8 *dst = walk.dst.virt.addr;
206
207                 while (nbytes >= SM4_BLOCK_SIZE) {
208                         sm4_crypt_block(ctx->rkey_enc, keystream, iv);
209                         crypto_xor_cpy(dst, src, keystream, SM4_BLOCK_SIZE);
210                         iv = dst;
211                         src += SM4_BLOCK_SIZE;
212                         dst += SM4_BLOCK_SIZE;
213                         nbytes -= SM4_BLOCK_SIZE;
214                 }
215                 if (iv != walk.iv)
216                         memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
217
218                 /* tail */
219                 if (walk.nbytes == walk.total && nbytes > 0) {
220                         sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
221                         crypto_xor_cpy(dst, src, keystream, nbytes);
222                         nbytes = 0;
223                 }
224
225                 err = skcipher_walk_done(&walk, nbytes);
226         }
227
228         return err;
229 }
230 EXPORT_SYMBOL_GPL(sm4_cfb_encrypt);
231
232 int sm4_avx_cfb_decrypt(struct skcipher_request *req,
233                         unsigned int bsize, sm4_crypt_func func)
234 {
235         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
236         struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
237         struct skcipher_walk walk;
238         unsigned int nbytes;
239         int err;
240
241         err = skcipher_walk_virt(&walk, req, false);
242
243         while ((nbytes = walk.nbytes) > 0) {
244                 const u8 *src = walk.src.virt.addr;
245                 u8 *dst = walk.dst.virt.addr;
246
247                 kernel_fpu_begin();
248
249                 while (nbytes >= bsize) {
250                         func(ctx->rkey_enc, dst, src, walk.iv);
251                         dst += bsize;
252                         src += bsize;
253                         nbytes -= bsize;
254                 }
255
256                 while (nbytes >= SM4_BLOCK_SIZE) {
257                         u8 keystream[SM4_BLOCK_SIZE * 8];
258                         unsigned int nblocks = min(nbytes >> 4, 8u);
259
260                         memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
261                         if (nblocks > 1)
262                                 memcpy(&keystream[SM4_BLOCK_SIZE], src,
263                                         (nblocks - 1) * SM4_BLOCK_SIZE);
264                         memcpy(walk.iv, src + (nblocks - 1) * SM4_BLOCK_SIZE,
265                                 SM4_BLOCK_SIZE);
266
267                         sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
268                                                 keystream, nblocks);
269
270                         crypto_xor_cpy(dst, src, keystream,
271                                         nblocks * SM4_BLOCK_SIZE);
272                         dst += nblocks * SM4_BLOCK_SIZE;
273                         src += nblocks * SM4_BLOCK_SIZE;
274                         nbytes -= nblocks * SM4_BLOCK_SIZE;
275                 }
276
277                 kernel_fpu_end();
278
279                 /* tail */
280                 if (walk.nbytes == walk.total && nbytes > 0) {
281                         u8 keystream[SM4_BLOCK_SIZE];
282
283                         sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
284                         crypto_xor_cpy(dst, src, keystream, nbytes);
285                         nbytes = 0;
286                 }
287
288                 err = skcipher_walk_done(&walk, nbytes);
289         }
290
291         return err;
292 }
293 EXPORT_SYMBOL_GPL(sm4_avx_cfb_decrypt);
294
295 static int cfb_decrypt(struct skcipher_request *req)
296 {
297         return sm4_avx_cfb_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
298                                 sm4_aesni_avx_cfb_dec_blk8);
299 }
300
301 int sm4_avx_ctr_crypt(struct skcipher_request *req,
302                         unsigned int bsize, sm4_crypt_func func)
303 {
304         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
305         struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
306         struct skcipher_walk walk;
307         unsigned int nbytes;
308         int err;
309
310         err = skcipher_walk_virt(&walk, req, false);
311
312         while ((nbytes = walk.nbytes) > 0) {
313                 const u8 *src = walk.src.virt.addr;
314                 u8 *dst = walk.dst.virt.addr;
315
316                 kernel_fpu_begin();
317
318                 while (nbytes >= bsize) {
319                         func(ctx->rkey_enc, dst, src, walk.iv);
320                         dst += bsize;
321                         src += bsize;
322                         nbytes -= bsize;
323                 }
324
325                 while (nbytes >= SM4_BLOCK_SIZE) {
326                         u8 keystream[SM4_BLOCK_SIZE * 8];
327                         unsigned int nblocks = min(nbytes >> 4, 8u);
328                         int i;
329
330                         for (i = 0; i < nblocks; i++) {
331                                 memcpy(&keystream[i * SM4_BLOCK_SIZE],
332                                         walk.iv, SM4_BLOCK_SIZE);
333                                 crypto_inc(walk.iv, SM4_BLOCK_SIZE);
334                         }
335                         sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
336                                         keystream, nblocks);
337
338                         crypto_xor_cpy(dst, src, keystream,
339                                         nblocks * SM4_BLOCK_SIZE);
340                         dst += nblocks * SM4_BLOCK_SIZE;
341                         src += nblocks * SM4_BLOCK_SIZE;
342                         nbytes -= nblocks * SM4_BLOCK_SIZE;
343                 }
344
345                 kernel_fpu_end();
346
347                 /* tail */
348                 if (walk.nbytes == walk.total && nbytes > 0) {
349                         u8 keystream[SM4_BLOCK_SIZE];
350
351                         memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
352                         crypto_inc(walk.iv, SM4_BLOCK_SIZE);
353
354                         sm4_crypt_block(ctx->rkey_enc, keystream, keystream);
355
356                         crypto_xor_cpy(dst, src, keystream, nbytes);
357                         dst += nbytes;
358                         src += nbytes;
359                         nbytes = 0;
360                 }
361
362                 err = skcipher_walk_done(&walk, nbytes);
363         }
364
365         return err;
366 }
367 EXPORT_SYMBOL_GPL(sm4_avx_ctr_crypt);
368
369 static int ctr_crypt(struct skcipher_request *req)
370 {
371         return sm4_avx_ctr_crypt(req, SM4_CRYPT8_BLOCK_SIZE,
372                                 sm4_aesni_avx_ctr_enc_blk8);
373 }
374
375 static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
376         {
377                 .base = {
378                         .cra_name               = "__ecb(sm4)",
379                         .cra_driver_name        = "__ecb-sm4-aesni-avx",
380                         .cra_priority           = 400,
381                         .cra_flags              = CRYPTO_ALG_INTERNAL,
382                         .cra_blocksize          = SM4_BLOCK_SIZE,
383                         .cra_ctxsize            = sizeof(struct sm4_ctx),
384                         .cra_module             = THIS_MODULE,
385                 },
386                 .min_keysize    = SM4_KEY_SIZE,
387                 .max_keysize    = SM4_KEY_SIZE,
388                 .walksize       = 8 * SM4_BLOCK_SIZE,
389                 .setkey         = sm4_skcipher_setkey,
390                 .encrypt        = sm4_avx_ecb_encrypt,
391                 .decrypt        = sm4_avx_ecb_decrypt,
392         }, {
393                 .base = {
394                         .cra_name               = "__cbc(sm4)",
395                         .cra_driver_name        = "__cbc-sm4-aesni-avx",
396                         .cra_priority           = 400,
397                         .cra_flags              = CRYPTO_ALG_INTERNAL,
398                         .cra_blocksize          = SM4_BLOCK_SIZE,
399                         .cra_ctxsize            = sizeof(struct sm4_ctx),
400                         .cra_module             = THIS_MODULE,
401                 },
402                 .min_keysize    = SM4_KEY_SIZE,
403                 .max_keysize    = SM4_KEY_SIZE,
404                 .ivsize         = SM4_BLOCK_SIZE,
405                 .walksize       = 8 * SM4_BLOCK_SIZE,
406                 .setkey         = sm4_skcipher_setkey,
407                 .encrypt        = sm4_cbc_encrypt,
408                 .decrypt        = cbc_decrypt,
409         }, {
410                 .base = {
411                         .cra_name               = "__cfb(sm4)",
412                         .cra_driver_name        = "__cfb-sm4-aesni-avx",
413                         .cra_priority           = 400,
414                         .cra_flags              = CRYPTO_ALG_INTERNAL,
415                         .cra_blocksize          = 1,
416                         .cra_ctxsize            = sizeof(struct sm4_ctx),
417                         .cra_module             = THIS_MODULE,
418                 },
419                 .min_keysize    = SM4_KEY_SIZE,
420                 .max_keysize    = SM4_KEY_SIZE,
421                 .ivsize         = SM4_BLOCK_SIZE,
422                 .chunksize      = SM4_BLOCK_SIZE,
423                 .walksize       = 8 * SM4_BLOCK_SIZE,
424                 .setkey         = sm4_skcipher_setkey,
425                 .encrypt        = sm4_cfb_encrypt,
426                 .decrypt        = cfb_decrypt,
427         }, {
428                 .base = {
429                         .cra_name               = "__ctr(sm4)",
430                         .cra_driver_name        = "__ctr-sm4-aesni-avx",
431                         .cra_priority           = 400,
432                         .cra_flags              = CRYPTO_ALG_INTERNAL,
433                         .cra_blocksize          = 1,
434                         .cra_ctxsize            = sizeof(struct sm4_ctx),
435                         .cra_module             = THIS_MODULE,
436                 },
437                 .min_keysize    = SM4_KEY_SIZE,
438                 .max_keysize    = SM4_KEY_SIZE,
439                 .ivsize         = SM4_BLOCK_SIZE,
440                 .chunksize      = SM4_BLOCK_SIZE,
441                 .walksize       = 8 * SM4_BLOCK_SIZE,
442                 .setkey         = sm4_skcipher_setkey,
443                 .encrypt        = ctr_crypt,
444                 .decrypt        = ctr_crypt,
445         }
446 };
447
448 static struct simd_skcipher_alg *
449 simd_sm4_aesni_avx_skciphers[ARRAY_SIZE(sm4_aesni_avx_skciphers)];
450
451 static int __init sm4_init(void)
452 {
453         const char *feature_name;
454
455         if (!boot_cpu_has(X86_FEATURE_AVX) ||
456             !boot_cpu_has(X86_FEATURE_AES) ||
457             !boot_cpu_has(X86_FEATURE_OSXSAVE)) {
458                 pr_info("AVX or AES-NI instructions are not detected.\n");
459                 return -ENODEV;
460         }
461
462         if (!cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM,
463                                 &feature_name)) {
464                 pr_info("CPU feature '%s' is not supported.\n", feature_name);
465                 return -ENODEV;
466         }
467
468         return simd_register_skciphers_compat(sm4_aesni_avx_skciphers,
469                                         ARRAY_SIZE(sm4_aesni_avx_skciphers),
470                                         simd_sm4_aesni_avx_skciphers);
471 }
472
473 static void __exit sm4_exit(void)
474 {
475         simd_unregister_skciphers(sm4_aesni_avx_skciphers,
476                                         ARRAY_SIZE(sm4_aesni_avx_skciphers),
477                                         simd_sm4_aesni_avx_skciphers);
478 }
479
480 module_init(sm4_init);
481 module_exit(sm4_exit);
482
483 MODULE_LICENSE("GPL v2");
484 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
485 MODULE_DESCRIPTION("SM4 Cipher Algorithm, AES-NI/AVX optimized");
486 MODULE_ALIAS_CRYPTO("sm4");
487 MODULE_ALIAS_CRYPTO("sm4-aesni-avx");