Merge patch series "riscv: kprobes: simulate some instructions"
[linux-2.6-microblaze.git] / arch / arm64 / crypto / sm4-ce-glue.c
1 /* SPDX-License-Identifier: GPL-2.0-or-later */
2 /*
3  * SM4 Cipher Algorithm, using ARMv8 Crypto Extensions
4  * as specified in
5  * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
6  *
7  * Copyright (C) 2022, Alibaba Group.
8  * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9  */
10
11 #include <linux/module.h>
12 #include <linux/crypto.h>
13 #include <linux/kernel.h>
14 #include <linux/cpufeature.h>
15 #include <asm/neon.h>
16 #include <asm/simd.h>
17 #include <crypto/b128ops.h>
18 #include <crypto/internal/simd.h>
19 #include <crypto/internal/skcipher.h>
20 #include <crypto/internal/hash.h>
21 #include <crypto/scatterwalk.h>
22 #include <crypto/xts.h>
23 #include <crypto/sm4.h>
24
25 #define BYTES2BLKS(nbytes)      ((nbytes) >> 4)
26
27 asmlinkage void sm4_ce_expand_key(const u8 *key, u32 *rkey_enc, u32 *rkey_dec,
28                                   const u32 *fk, const u32 *ck);
29 asmlinkage void sm4_ce_crypt_block(const u32 *rkey, u8 *dst, const u8 *src);
30 asmlinkage void sm4_ce_crypt(const u32 *rkey, u8 *dst, const u8 *src,
31                              unsigned int nblks);
32 asmlinkage void sm4_ce_cbc_enc(const u32 *rkey, u8 *dst, const u8 *src,
33                                u8 *iv, unsigned int nblocks);
34 asmlinkage void sm4_ce_cbc_dec(const u32 *rkey, u8 *dst, const u8 *src,
35                                u8 *iv, unsigned int nblocks);
36 asmlinkage void sm4_ce_cbc_cts_enc(const u32 *rkey, u8 *dst, const u8 *src,
37                                    u8 *iv, unsigned int nbytes);
38 asmlinkage void sm4_ce_cbc_cts_dec(const u32 *rkey, u8 *dst, const u8 *src,
39                                    u8 *iv, unsigned int nbytes);
40 asmlinkage void sm4_ce_cfb_enc(const u32 *rkey, u8 *dst, const u8 *src,
41                                u8 *iv, unsigned int nblks);
42 asmlinkage void sm4_ce_cfb_dec(const u32 *rkey, u8 *dst, const u8 *src,
43                                u8 *iv, unsigned int nblks);
44 asmlinkage void sm4_ce_ctr_enc(const u32 *rkey, u8 *dst, const u8 *src,
45                                u8 *iv, unsigned int nblks);
46 asmlinkage void sm4_ce_xts_enc(const u32 *rkey1, u8 *dst, const u8 *src,
47                                u8 *tweak, unsigned int nbytes,
48                                const u32 *rkey2_enc);
49 asmlinkage void sm4_ce_xts_dec(const u32 *rkey1, u8 *dst, const u8 *src,
50                                u8 *tweak, unsigned int nbytes,
51                                const u32 *rkey2_enc);
52 asmlinkage void sm4_ce_mac_update(const u32 *rkey_enc, u8 *digest,
53                                   const u8 *src, unsigned int nblocks,
54                                   bool enc_before, bool enc_after);
55
56 EXPORT_SYMBOL(sm4_ce_expand_key);
57 EXPORT_SYMBOL(sm4_ce_crypt_block);
58 EXPORT_SYMBOL(sm4_ce_cbc_enc);
59 EXPORT_SYMBOL(sm4_ce_cfb_enc);
60
61 struct sm4_xts_ctx {
62         struct sm4_ctx key1;
63         struct sm4_ctx key2;
64 };
65
66 struct sm4_mac_tfm_ctx {
67         struct sm4_ctx key;
68         u8 __aligned(8) consts[];
69 };
70
71 struct sm4_mac_desc_ctx {
72         unsigned int len;
73         u8 digest[SM4_BLOCK_SIZE];
74 };
75
76 static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key,
77                       unsigned int key_len)
78 {
79         struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
80
81         if (key_len != SM4_KEY_SIZE)
82                 return -EINVAL;
83
84         kernel_neon_begin();
85         sm4_ce_expand_key(key, ctx->rkey_enc, ctx->rkey_dec,
86                           crypto_sm4_fk, crypto_sm4_ck);
87         kernel_neon_end();
88         return 0;
89 }
90
91 static int sm4_xts_setkey(struct crypto_skcipher *tfm, const u8 *key,
92                           unsigned int key_len)
93 {
94         struct sm4_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
95         int ret;
96
97         if (key_len != SM4_KEY_SIZE * 2)
98                 return -EINVAL;
99
100         ret = xts_verify_key(tfm, key, key_len);
101         if (ret)
102                 return ret;
103
104         kernel_neon_begin();
105         sm4_ce_expand_key(key, ctx->key1.rkey_enc,
106                           ctx->key1.rkey_dec, crypto_sm4_fk, crypto_sm4_ck);
107         sm4_ce_expand_key(&key[SM4_KEY_SIZE], ctx->key2.rkey_enc,
108                           ctx->key2.rkey_dec, crypto_sm4_fk, crypto_sm4_ck);
109         kernel_neon_end();
110
111         return 0;
112 }
113
114 static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
115 {
116         struct skcipher_walk walk;
117         unsigned int nbytes;
118         int err;
119
120         err = skcipher_walk_virt(&walk, req, false);
121
122         while ((nbytes = walk.nbytes) > 0) {
123                 const u8 *src = walk.src.virt.addr;
124                 u8 *dst = walk.dst.virt.addr;
125                 unsigned int nblks;
126
127                 kernel_neon_begin();
128
129                 nblks = BYTES2BLKS(nbytes);
130                 if (nblks) {
131                         sm4_ce_crypt(rkey, dst, src, nblks);
132                         nbytes -= nblks * SM4_BLOCK_SIZE;
133                 }
134
135                 kernel_neon_end();
136
137                 err = skcipher_walk_done(&walk, nbytes);
138         }
139
140         return err;
141 }
142
143 static int sm4_ecb_encrypt(struct skcipher_request *req)
144 {
145         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
146         struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
147
148         return sm4_ecb_do_crypt(req, ctx->rkey_enc);
149 }
150
151 static int sm4_ecb_decrypt(struct skcipher_request *req)
152 {
153         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
154         struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
155
156         return sm4_ecb_do_crypt(req, ctx->rkey_dec);
157 }
158
159 static int sm4_cbc_crypt(struct skcipher_request *req,
160                          struct sm4_ctx *ctx, bool encrypt)
161 {
162         struct skcipher_walk walk;
163         unsigned int nbytes;
164         int err;
165
166         err = skcipher_walk_virt(&walk, req, false);
167         if (err)
168                 return err;
169
170         while ((nbytes = walk.nbytes) > 0) {
171                 const u8 *src = walk.src.virt.addr;
172                 u8 *dst = walk.dst.virt.addr;
173                 unsigned int nblocks;
174
175                 nblocks = nbytes / SM4_BLOCK_SIZE;
176                 if (nblocks) {
177                         kernel_neon_begin();
178
179                         if (encrypt)
180                                 sm4_ce_cbc_enc(ctx->rkey_enc, dst, src,
181                                                walk.iv, nblocks);
182                         else
183                                 sm4_ce_cbc_dec(ctx->rkey_dec, dst, src,
184                                                walk.iv, nblocks);
185
186                         kernel_neon_end();
187                 }
188
189                 err = skcipher_walk_done(&walk, nbytes % SM4_BLOCK_SIZE);
190         }
191
192         return err;
193 }
194
195 static int sm4_cbc_encrypt(struct skcipher_request *req)
196 {
197         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
198         struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
199
200         return sm4_cbc_crypt(req, ctx, true);
201 }
202
203 static int sm4_cbc_decrypt(struct skcipher_request *req)
204 {
205         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
206         struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
207
208         return sm4_cbc_crypt(req, ctx, false);
209 }
210
211 static int sm4_cbc_cts_crypt(struct skcipher_request *req, bool encrypt)
212 {
213         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
214         struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
215         struct scatterlist *src = req->src;
216         struct scatterlist *dst = req->dst;
217         struct scatterlist sg_src[2], sg_dst[2];
218         struct skcipher_request subreq;
219         struct skcipher_walk walk;
220         int cbc_blocks;
221         int err;
222
223         if (req->cryptlen < SM4_BLOCK_SIZE)
224                 return -EINVAL;
225
226         if (req->cryptlen == SM4_BLOCK_SIZE)
227                 return sm4_cbc_crypt(req, ctx, encrypt);
228
229         skcipher_request_set_tfm(&subreq, tfm);
230         skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
231                                       NULL, NULL);
232
233         /* handle the CBC cryption part */
234         cbc_blocks = DIV_ROUND_UP(req->cryptlen, SM4_BLOCK_SIZE) - 2;
235         if (cbc_blocks) {
236                 skcipher_request_set_crypt(&subreq, src, dst,
237                                            cbc_blocks * SM4_BLOCK_SIZE,
238                                            req->iv);
239
240                 err = sm4_cbc_crypt(&subreq, ctx, encrypt);
241                 if (err)
242                         return err;
243
244                 dst = src = scatterwalk_ffwd(sg_src, src, subreq.cryptlen);
245                 if (req->dst != req->src)
246                         dst = scatterwalk_ffwd(sg_dst, req->dst,
247                                                subreq.cryptlen);
248         }
249
250         /* handle ciphertext stealing */
251         skcipher_request_set_crypt(&subreq, src, dst,
252                                    req->cryptlen - cbc_blocks * SM4_BLOCK_SIZE,
253                                    req->iv);
254
255         err = skcipher_walk_virt(&walk, &subreq, false);
256         if (err)
257                 return err;
258
259         kernel_neon_begin();
260
261         if (encrypt)
262                 sm4_ce_cbc_cts_enc(ctx->rkey_enc, walk.dst.virt.addr,
263                                    walk.src.virt.addr, walk.iv, walk.nbytes);
264         else
265                 sm4_ce_cbc_cts_dec(ctx->rkey_dec, walk.dst.virt.addr,
266                                    walk.src.virt.addr, walk.iv, walk.nbytes);
267
268         kernel_neon_end();
269
270         return skcipher_walk_done(&walk, 0);
271 }
272
273 static int sm4_cbc_cts_encrypt(struct skcipher_request *req)
274 {
275         return sm4_cbc_cts_crypt(req, true);
276 }
277
278 static int sm4_cbc_cts_decrypt(struct skcipher_request *req)
279 {
280         return sm4_cbc_cts_crypt(req, false);
281 }
282
283 static int sm4_cfb_encrypt(struct skcipher_request *req)
284 {
285         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
286         struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
287         struct skcipher_walk walk;
288         unsigned int nbytes;
289         int err;
290
291         err = skcipher_walk_virt(&walk, req, false);
292
293         while ((nbytes = walk.nbytes) > 0) {
294                 const u8 *src = walk.src.virt.addr;
295                 u8 *dst = walk.dst.virt.addr;
296                 unsigned int nblks;
297
298                 kernel_neon_begin();
299
300                 nblks = BYTES2BLKS(nbytes);
301                 if (nblks) {
302                         sm4_ce_cfb_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
303                         dst += nblks * SM4_BLOCK_SIZE;
304                         src += nblks * SM4_BLOCK_SIZE;
305                         nbytes -= nblks * SM4_BLOCK_SIZE;
306                 }
307
308                 /* tail */
309                 if (walk.nbytes == walk.total && nbytes > 0) {
310                         u8 keystream[SM4_BLOCK_SIZE];
311
312                         sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
313                         crypto_xor_cpy(dst, src, keystream, nbytes);
314                         nbytes = 0;
315                 }
316
317                 kernel_neon_end();
318
319                 err = skcipher_walk_done(&walk, nbytes);
320         }
321
322         return err;
323 }
324
325 static int sm4_cfb_decrypt(struct skcipher_request *req)
326 {
327         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
328         struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
329         struct skcipher_walk walk;
330         unsigned int nbytes;
331         int err;
332
333         err = skcipher_walk_virt(&walk, req, false);
334
335         while ((nbytes = walk.nbytes) > 0) {
336                 const u8 *src = walk.src.virt.addr;
337                 u8 *dst = walk.dst.virt.addr;
338                 unsigned int nblks;
339
340                 kernel_neon_begin();
341
342                 nblks = BYTES2BLKS(nbytes);
343                 if (nblks) {
344                         sm4_ce_cfb_dec(ctx->rkey_enc, dst, src, walk.iv, nblks);
345                         dst += nblks * SM4_BLOCK_SIZE;
346                         src += nblks * SM4_BLOCK_SIZE;
347                         nbytes -= nblks * SM4_BLOCK_SIZE;
348                 }
349
350                 /* tail */
351                 if (walk.nbytes == walk.total && nbytes > 0) {
352                         u8 keystream[SM4_BLOCK_SIZE];
353
354                         sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
355                         crypto_xor_cpy(dst, src, keystream, nbytes);
356                         nbytes = 0;
357                 }
358
359                 kernel_neon_end();
360
361                 err = skcipher_walk_done(&walk, nbytes);
362         }
363
364         return err;
365 }
366
367 static int sm4_ctr_crypt(struct skcipher_request *req)
368 {
369         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
370         struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
371         struct skcipher_walk walk;
372         unsigned int nbytes;
373         int err;
374
375         err = skcipher_walk_virt(&walk, req, false);
376
377         while ((nbytes = walk.nbytes) > 0) {
378                 const u8 *src = walk.src.virt.addr;
379                 u8 *dst = walk.dst.virt.addr;
380                 unsigned int nblks;
381
382                 kernel_neon_begin();
383
384                 nblks = BYTES2BLKS(nbytes);
385                 if (nblks) {
386                         sm4_ce_ctr_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
387                         dst += nblks * SM4_BLOCK_SIZE;
388                         src += nblks * SM4_BLOCK_SIZE;
389                         nbytes -= nblks * SM4_BLOCK_SIZE;
390                 }
391
392                 /* tail */
393                 if (walk.nbytes == walk.total && nbytes > 0) {
394                         u8 keystream[SM4_BLOCK_SIZE];
395
396                         sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
397                         crypto_inc(walk.iv, SM4_BLOCK_SIZE);
398                         crypto_xor_cpy(dst, src, keystream, nbytes);
399                         nbytes = 0;
400                 }
401
402                 kernel_neon_end();
403
404                 err = skcipher_walk_done(&walk, nbytes);
405         }
406
407         return err;
408 }
409
410 static int sm4_xts_crypt(struct skcipher_request *req, bool encrypt)
411 {
412         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
413         struct sm4_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
414         int tail = req->cryptlen % SM4_BLOCK_SIZE;
415         const u32 *rkey2_enc = ctx->key2.rkey_enc;
416         struct scatterlist sg_src[2], sg_dst[2];
417         struct skcipher_request subreq;
418         struct scatterlist *src, *dst;
419         struct skcipher_walk walk;
420         unsigned int nbytes;
421         int err;
422
423         if (req->cryptlen < SM4_BLOCK_SIZE)
424                 return -EINVAL;
425
426         err = skcipher_walk_virt(&walk, req, false);
427         if (err)
428                 return err;
429
430         if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
431                 int nblocks = DIV_ROUND_UP(req->cryptlen, SM4_BLOCK_SIZE) - 2;
432
433                 skcipher_walk_abort(&walk);
434
435                 skcipher_request_set_tfm(&subreq, tfm);
436                 skcipher_request_set_callback(&subreq,
437                                               skcipher_request_flags(req),
438                                               NULL, NULL);
439                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
440                                            nblocks * SM4_BLOCK_SIZE, req->iv);
441
442                 err = skcipher_walk_virt(&walk, &subreq, false);
443                 if (err)
444                         return err;
445         } else {
446                 tail = 0;
447         }
448
449         while ((nbytes = walk.nbytes) >= SM4_BLOCK_SIZE) {
450                 if (nbytes < walk.total)
451                         nbytes &= ~(SM4_BLOCK_SIZE - 1);
452
453                 kernel_neon_begin();
454
455                 if (encrypt)
456                         sm4_ce_xts_enc(ctx->key1.rkey_enc, walk.dst.virt.addr,
457                                        walk.src.virt.addr, walk.iv, nbytes,
458                                        rkey2_enc);
459                 else
460                         sm4_ce_xts_dec(ctx->key1.rkey_dec, walk.dst.virt.addr,
461                                        walk.src.virt.addr, walk.iv, nbytes,
462                                        rkey2_enc);
463
464                 kernel_neon_end();
465
466                 rkey2_enc = NULL;
467
468                 err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
469                 if (err)
470                         return err;
471         }
472
473         if (likely(tail == 0))
474                 return 0;
475
476         /* handle ciphertext stealing */
477
478         dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
479         if (req->dst != req->src)
480                 dst = scatterwalk_ffwd(sg_dst, req->dst, subreq.cryptlen);
481
482         skcipher_request_set_crypt(&subreq, src, dst, SM4_BLOCK_SIZE + tail,
483                                    req->iv);
484
485         err = skcipher_walk_virt(&walk, &subreq, false);
486         if (err)
487                 return err;
488
489         kernel_neon_begin();
490
491         if (encrypt)
492                 sm4_ce_xts_enc(ctx->key1.rkey_enc, walk.dst.virt.addr,
493                                walk.src.virt.addr, walk.iv, walk.nbytes,
494                                rkey2_enc);
495         else
496                 sm4_ce_xts_dec(ctx->key1.rkey_dec, walk.dst.virt.addr,
497                                walk.src.virt.addr, walk.iv, walk.nbytes,
498                                rkey2_enc);
499
500         kernel_neon_end();
501
502         return skcipher_walk_done(&walk, 0);
503 }
504
505 static int sm4_xts_encrypt(struct skcipher_request *req)
506 {
507         return sm4_xts_crypt(req, true);
508 }
509
510 static int sm4_xts_decrypt(struct skcipher_request *req)
511 {
512         return sm4_xts_crypt(req, false);
513 }
514
515 static struct skcipher_alg sm4_algs[] = {
516         {
517                 .base = {
518                         .cra_name               = "ecb(sm4)",
519                         .cra_driver_name        = "ecb-sm4-ce",
520                         .cra_priority           = 400,
521                         .cra_blocksize          = SM4_BLOCK_SIZE,
522                         .cra_ctxsize            = sizeof(struct sm4_ctx),
523                         .cra_module             = THIS_MODULE,
524                 },
525                 .min_keysize    = SM4_KEY_SIZE,
526                 .max_keysize    = SM4_KEY_SIZE,
527                 .setkey         = sm4_setkey,
528                 .encrypt        = sm4_ecb_encrypt,
529                 .decrypt        = sm4_ecb_decrypt,
530         }, {
531                 .base = {
532                         .cra_name               = "cbc(sm4)",
533                         .cra_driver_name        = "cbc-sm4-ce",
534                         .cra_priority           = 400,
535                         .cra_blocksize          = SM4_BLOCK_SIZE,
536                         .cra_ctxsize            = sizeof(struct sm4_ctx),
537                         .cra_module             = THIS_MODULE,
538                 },
539                 .min_keysize    = SM4_KEY_SIZE,
540                 .max_keysize    = SM4_KEY_SIZE,
541                 .ivsize         = SM4_BLOCK_SIZE,
542                 .setkey         = sm4_setkey,
543                 .encrypt        = sm4_cbc_encrypt,
544                 .decrypt        = sm4_cbc_decrypt,
545         }, {
546                 .base = {
547                         .cra_name               = "cfb(sm4)",
548                         .cra_driver_name        = "cfb-sm4-ce",
549                         .cra_priority           = 400,
550                         .cra_blocksize          = 1,
551                         .cra_ctxsize            = sizeof(struct sm4_ctx),
552                         .cra_module             = THIS_MODULE,
553                 },
554                 .min_keysize    = SM4_KEY_SIZE,
555                 .max_keysize    = SM4_KEY_SIZE,
556                 .ivsize         = SM4_BLOCK_SIZE,
557                 .chunksize      = SM4_BLOCK_SIZE,
558                 .setkey         = sm4_setkey,
559                 .encrypt        = sm4_cfb_encrypt,
560                 .decrypt        = sm4_cfb_decrypt,
561         }, {
562                 .base = {
563                         .cra_name               = "ctr(sm4)",
564                         .cra_driver_name        = "ctr-sm4-ce",
565                         .cra_priority           = 400,
566                         .cra_blocksize          = 1,
567                         .cra_ctxsize            = sizeof(struct sm4_ctx),
568                         .cra_module             = THIS_MODULE,
569                 },
570                 .min_keysize    = SM4_KEY_SIZE,
571                 .max_keysize    = SM4_KEY_SIZE,
572                 .ivsize         = SM4_BLOCK_SIZE,
573                 .chunksize      = SM4_BLOCK_SIZE,
574                 .setkey         = sm4_setkey,
575                 .encrypt        = sm4_ctr_crypt,
576                 .decrypt        = sm4_ctr_crypt,
577         }, {
578                 .base = {
579                         .cra_name               = "cts(cbc(sm4))",
580                         .cra_driver_name        = "cts-cbc-sm4-ce",
581                         .cra_priority           = 400,
582                         .cra_blocksize          = SM4_BLOCK_SIZE,
583                         .cra_ctxsize            = sizeof(struct sm4_ctx),
584                         .cra_module             = THIS_MODULE,
585                 },
586                 .min_keysize    = SM4_KEY_SIZE,
587                 .max_keysize    = SM4_KEY_SIZE,
588                 .ivsize         = SM4_BLOCK_SIZE,
589                 .walksize       = SM4_BLOCK_SIZE * 2,
590                 .setkey         = sm4_setkey,
591                 .encrypt        = sm4_cbc_cts_encrypt,
592                 .decrypt        = sm4_cbc_cts_decrypt,
593         }, {
594                 .base = {
595                         .cra_name               = "xts(sm4)",
596                         .cra_driver_name        = "xts-sm4-ce",
597                         .cra_priority           = 400,
598                         .cra_blocksize          = SM4_BLOCK_SIZE,
599                         .cra_ctxsize            = sizeof(struct sm4_xts_ctx),
600                         .cra_module             = THIS_MODULE,
601                 },
602                 .min_keysize    = SM4_KEY_SIZE * 2,
603                 .max_keysize    = SM4_KEY_SIZE * 2,
604                 .ivsize         = SM4_BLOCK_SIZE,
605                 .walksize       = SM4_BLOCK_SIZE * 2,
606                 .setkey         = sm4_xts_setkey,
607                 .encrypt        = sm4_xts_encrypt,
608                 .decrypt        = sm4_xts_decrypt,
609         }
610 };
611
612 static int sm4_cbcmac_setkey(struct crypto_shash *tfm, const u8 *key,
613                              unsigned int key_len)
614 {
615         struct sm4_mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
616
617         if (key_len != SM4_KEY_SIZE)
618                 return -EINVAL;
619
620         kernel_neon_begin();
621         sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
622                           crypto_sm4_fk, crypto_sm4_ck);
623         kernel_neon_end();
624
625         return 0;
626 }
627
628 static int sm4_cmac_setkey(struct crypto_shash *tfm, const u8 *key,
629                            unsigned int key_len)
630 {
631         struct sm4_mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
632         be128 *consts = (be128 *)ctx->consts;
633         u64 a, b;
634
635         if (key_len != SM4_KEY_SIZE)
636                 return -EINVAL;
637
638         memset(consts, 0, SM4_BLOCK_SIZE);
639
640         kernel_neon_begin();
641
642         sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
643                           crypto_sm4_fk, crypto_sm4_ck);
644
645         /* encrypt the zero block */
646         sm4_ce_crypt_block(ctx->key.rkey_enc, (u8 *)consts, (const u8 *)consts);
647
648         kernel_neon_end();
649
650         /* gf(2^128) multiply zero-ciphertext with u and u^2 */
651         a = be64_to_cpu(consts[0].a);
652         b = be64_to_cpu(consts[0].b);
653         consts[0].a = cpu_to_be64((a << 1) | (b >> 63));
654         consts[0].b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
655
656         a = be64_to_cpu(consts[0].a);
657         b = be64_to_cpu(consts[0].b);
658         consts[1].a = cpu_to_be64((a << 1) | (b >> 63));
659         consts[1].b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
660
661         return 0;
662 }
663
664 static int sm4_xcbc_setkey(struct crypto_shash *tfm, const u8 *key,
665                            unsigned int key_len)
666 {
667         struct sm4_mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
668         u8 __aligned(8) key2[SM4_BLOCK_SIZE];
669         static u8 const ks[3][SM4_BLOCK_SIZE] = {
670                 { [0 ... SM4_BLOCK_SIZE - 1] = 0x1},
671                 { [0 ... SM4_BLOCK_SIZE - 1] = 0x2},
672                 { [0 ... SM4_BLOCK_SIZE - 1] = 0x3},
673         };
674
675         if (key_len != SM4_KEY_SIZE)
676                 return -EINVAL;
677
678         kernel_neon_begin();
679
680         sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
681                           crypto_sm4_fk, crypto_sm4_ck);
682
683         sm4_ce_crypt_block(ctx->key.rkey_enc, key2, ks[0]);
684         sm4_ce_crypt(ctx->key.rkey_enc, ctx->consts, ks[1], 2);
685
686         sm4_ce_expand_key(key2, ctx->key.rkey_enc, ctx->key.rkey_dec,
687                           crypto_sm4_fk, crypto_sm4_ck);
688
689         kernel_neon_end();
690
691         return 0;
692 }
693
694 static int sm4_mac_init(struct shash_desc *desc)
695 {
696         struct sm4_mac_desc_ctx *ctx = shash_desc_ctx(desc);
697
698         memset(ctx->digest, 0, SM4_BLOCK_SIZE);
699         ctx->len = 0;
700
701         return 0;
702 }
703
704 static int sm4_mac_update(struct shash_desc *desc, const u8 *p,
705                           unsigned int len)
706 {
707         struct sm4_mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
708         struct sm4_mac_desc_ctx *ctx = shash_desc_ctx(desc);
709         unsigned int l, nblocks;
710
711         if (len == 0)
712                 return 0;
713
714         if (ctx->len || ctx->len + len < SM4_BLOCK_SIZE) {
715                 l = min(len, SM4_BLOCK_SIZE - ctx->len);
716
717                 crypto_xor(ctx->digest + ctx->len, p, l);
718                 ctx->len += l;
719                 len -= l;
720                 p += l;
721         }
722
723         if (len && (ctx->len % SM4_BLOCK_SIZE) == 0) {
724                 kernel_neon_begin();
725
726                 if (len < SM4_BLOCK_SIZE && ctx->len == SM4_BLOCK_SIZE) {
727                         sm4_ce_crypt_block(tctx->key.rkey_enc,
728                                            ctx->digest, ctx->digest);
729                         ctx->len = 0;
730                 } else {
731                         nblocks = len / SM4_BLOCK_SIZE;
732                         len %= SM4_BLOCK_SIZE;
733
734                         sm4_ce_mac_update(tctx->key.rkey_enc, ctx->digest, p,
735                                           nblocks, (ctx->len == SM4_BLOCK_SIZE),
736                                           (len != 0));
737
738                         p += nblocks * SM4_BLOCK_SIZE;
739
740                         if (len == 0)
741                                 ctx->len = SM4_BLOCK_SIZE;
742                 }
743
744                 kernel_neon_end();
745
746                 if (len) {
747                         crypto_xor(ctx->digest, p, len);
748                         ctx->len = len;
749                 }
750         }
751
752         return 0;
753 }
754
755 static int sm4_cmac_final(struct shash_desc *desc, u8 *out)
756 {
757         struct sm4_mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
758         struct sm4_mac_desc_ctx *ctx = shash_desc_ctx(desc);
759         const u8 *consts = tctx->consts;
760
761         if (ctx->len != SM4_BLOCK_SIZE) {
762                 ctx->digest[ctx->len] ^= 0x80;
763                 consts += SM4_BLOCK_SIZE;
764         }
765
766         kernel_neon_begin();
767         sm4_ce_mac_update(tctx->key.rkey_enc, ctx->digest, consts, 1,
768                           false, true);
769         kernel_neon_end();
770
771         memcpy(out, ctx->digest, SM4_BLOCK_SIZE);
772
773         return 0;
774 }
775
776 static int sm4_cbcmac_final(struct shash_desc *desc, u8 *out)
777 {
778         struct sm4_mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
779         struct sm4_mac_desc_ctx *ctx = shash_desc_ctx(desc);
780
781         if (ctx->len) {
782                 kernel_neon_begin();
783                 sm4_ce_crypt_block(tctx->key.rkey_enc, ctx->digest,
784                                    ctx->digest);
785                 kernel_neon_end();
786         }
787
788         memcpy(out, ctx->digest, SM4_BLOCK_SIZE);
789
790         return 0;
791 }
792
793 static struct shash_alg sm4_mac_algs[] = {
794         {
795                 .base = {
796                         .cra_name               = "cmac(sm4)",
797                         .cra_driver_name        = "cmac-sm4-ce",
798                         .cra_priority           = 400,
799                         .cra_blocksize          = SM4_BLOCK_SIZE,
800                         .cra_ctxsize            = sizeof(struct sm4_mac_tfm_ctx)
801                                                         + SM4_BLOCK_SIZE * 2,
802                         .cra_module             = THIS_MODULE,
803                 },
804                 .digestsize     = SM4_BLOCK_SIZE,
805                 .init           = sm4_mac_init,
806                 .update         = sm4_mac_update,
807                 .final          = sm4_cmac_final,
808                 .setkey         = sm4_cmac_setkey,
809                 .descsize       = sizeof(struct sm4_mac_desc_ctx),
810         }, {
811                 .base = {
812                         .cra_name               = "xcbc(sm4)",
813                         .cra_driver_name        = "xcbc-sm4-ce",
814                         .cra_priority           = 400,
815                         .cra_blocksize          = SM4_BLOCK_SIZE,
816                         .cra_ctxsize            = sizeof(struct sm4_mac_tfm_ctx)
817                                                         + SM4_BLOCK_SIZE * 2,
818                         .cra_module             = THIS_MODULE,
819                 },
820                 .digestsize     = SM4_BLOCK_SIZE,
821                 .init           = sm4_mac_init,
822                 .update         = sm4_mac_update,
823                 .final          = sm4_cmac_final,
824                 .setkey         = sm4_xcbc_setkey,
825                 .descsize       = sizeof(struct sm4_mac_desc_ctx),
826         }, {
827                 .base = {
828                         .cra_name               = "cbcmac(sm4)",
829                         .cra_driver_name        = "cbcmac-sm4-ce",
830                         .cra_priority           = 400,
831                         .cra_blocksize          = 1,
832                         .cra_ctxsize            = sizeof(struct sm4_mac_tfm_ctx),
833                         .cra_module             = THIS_MODULE,
834                 },
835                 .digestsize     = SM4_BLOCK_SIZE,
836                 .init           = sm4_mac_init,
837                 .update         = sm4_mac_update,
838                 .final          = sm4_cbcmac_final,
839                 .setkey         = sm4_cbcmac_setkey,
840                 .descsize       = sizeof(struct sm4_mac_desc_ctx),
841         }
842 };
843
844 static int __init sm4_init(void)
845 {
846         int err;
847
848         err = crypto_register_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
849         if (err)
850                 return err;
851
852         err = crypto_register_shashes(sm4_mac_algs, ARRAY_SIZE(sm4_mac_algs));
853         if (err)
854                 goto out_err;
855
856         return 0;
857
858 out_err:
859         crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
860         return err;
861 }
862
863 static void __exit sm4_exit(void)
864 {
865         crypto_unregister_shashes(sm4_mac_algs, ARRAY_SIZE(sm4_mac_algs));
866         crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
867 }
868
869 module_cpu_feature_match(SM4, sm4_init);
870 module_exit(sm4_exit);
871
872 MODULE_DESCRIPTION("SM4 ECB/CBC/CFB/CTR/XTS using ARMv8 Crypto Extensions");
873 MODULE_ALIAS_CRYPTO("sm4-ce");
874 MODULE_ALIAS_CRYPTO("sm4");
875 MODULE_ALIAS_CRYPTO("ecb(sm4)");
876 MODULE_ALIAS_CRYPTO("cbc(sm4)");
877 MODULE_ALIAS_CRYPTO("cfb(sm4)");
878 MODULE_ALIAS_CRYPTO("ctr(sm4)");
879 MODULE_ALIAS_CRYPTO("cts(cbc(sm4))");
880 MODULE_ALIAS_CRYPTO("xts(sm4)");
881 MODULE_ALIAS_CRYPTO("cmac(sm4)");
882 MODULE_ALIAS_CRYPTO("xcbc(sm4)");
883 MODULE_ALIAS_CRYPTO("cbcmac(sm4)");
884 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
885 MODULE_LICENSE("GPL v2");