crypto: sha - split sha.h into sha1.h and sha2.h
[linux-2.6-microblaze.git] / arch / arm64 / crypto / aes-glue.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
4  *
5  * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7
8 #include <asm/neon.h>
9 #include <asm/hwcap.h>
10 #include <asm/simd.h>
11 #include <crypto/aes.h>
12 #include <crypto/ctr.h>
13 #include <crypto/sha2.h>
14 #include <crypto/internal/hash.h>
15 #include <crypto/internal/simd.h>
16 #include <crypto/internal/skcipher.h>
17 #include <crypto/scatterwalk.h>
18 #include <linux/module.h>
19 #include <linux/cpufeature.h>
20 #include <crypto/xts.h>
21
22 #include "aes-ce-setkey.h"
23
24 #ifdef USE_V8_CRYPTO_EXTENSIONS
25 #define MODE                    "ce"
26 #define PRIO                    300
27 #define aes_expandkey           ce_aes_expandkey
28 #define aes_ecb_encrypt         ce_aes_ecb_encrypt
29 #define aes_ecb_decrypt         ce_aes_ecb_decrypt
30 #define aes_cbc_encrypt         ce_aes_cbc_encrypt
31 #define aes_cbc_decrypt         ce_aes_cbc_decrypt
32 #define aes_cbc_cts_encrypt     ce_aes_cbc_cts_encrypt
33 #define aes_cbc_cts_decrypt     ce_aes_cbc_cts_decrypt
34 #define aes_essiv_cbc_encrypt   ce_aes_essiv_cbc_encrypt
35 #define aes_essiv_cbc_decrypt   ce_aes_essiv_cbc_decrypt
36 #define aes_ctr_encrypt         ce_aes_ctr_encrypt
37 #define aes_xts_encrypt         ce_aes_xts_encrypt
38 #define aes_xts_decrypt         ce_aes_xts_decrypt
39 #define aes_mac_update          ce_aes_mac_update
40 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
41 #else
42 #define MODE                    "neon"
43 #define PRIO                    200
44 #define aes_ecb_encrypt         neon_aes_ecb_encrypt
45 #define aes_ecb_decrypt         neon_aes_ecb_decrypt
46 #define aes_cbc_encrypt         neon_aes_cbc_encrypt
47 #define aes_cbc_decrypt         neon_aes_cbc_decrypt
48 #define aes_cbc_cts_encrypt     neon_aes_cbc_cts_encrypt
49 #define aes_cbc_cts_decrypt     neon_aes_cbc_cts_decrypt
50 #define aes_essiv_cbc_encrypt   neon_aes_essiv_cbc_encrypt
51 #define aes_essiv_cbc_decrypt   neon_aes_essiv_cbc_decrypt
52 #define aes_ctr_encrypt         neon_aes_ctr_encrypt
53 #define aes_xts_encrypt         neon_aes_xts_encrypt
54 #define aes_xts_decrypt         neon_aes_xts_decrypt
55 #define aes_mac_update          neon_aes_mac_update
56 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON");
57 #endif
58 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !defined(CONFIG_CRYPTO_AES_ARM64_BS)
59 MODULE_ALIAS_CRYPTO("ecb(aes)");
60 MODULE_ALIAS_CRYPTO("cbc(aes)");
61 MODULE_ALIAS_CRYPTO("ctr(aes)");
62 MODULE_ALIAS_CRYPTO("xts(aes)");
63 #endif
64 MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
65 MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
66 MODULE_ALIAS_CRYPTO("cmac(aes)");
67 MODULE_ALIAS_CRYPTO("xcbc(aes)");
68 MODULE_ALIAS_CRYPTO("cbcmac(aes)");
69
70 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
71 MODULE_LICENSE("GPL v2");
72
73 /* defined in aes-modes.S */
74 asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
75                                 int rounds, int blocks);
76 asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
77                                 int rounds, int blocks);
78
79 asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
80                                 int rounds, int blocks, u8 iv[]);
81 asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
82                                 int rounds, int blocks, u8 iv[]);
83
84 asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
85                                 int rounds, int bytes, u8 const iv[]);
86 asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
87                                 int rounds, int bytes, u8 const iv[]);
88
89 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
90                                 int rounds, int blocks, u8 ctr[]);
91
92 asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
93                                 int rounds, int bytes, u32 const rk2[], u8 iv[],
94                                 int first);
95 asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
96                                 int rounds, int bytes, u32 const rk2[], u8 iv[],
97                                 int first);
98
99 asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
100                                       int rounds, int blocks, u8 iv[],
101                                       u32 const rk2[]);
102 asmlinkage void aes_essiv_cbc_decrypt(u8 out[], u8 const in[], u32 const rk1[],
103                                       int rounds, int blocks, u8 iv[],
104                                       u32 const rk2[]);
105
106 asmlinkage void aes_mac_update(u8 const in[], u32 const rk[], int rounds,
107                                int blocks, u8 dg[], int enc_before,
108                                int enc_after);
109
110 struct crypto_aes_xts_ctx {
111         struct crypto_aes_ctx key1;
112         struct crypto_aes_ctx __aligned(8) key2;
113 };
114
115 struct crypto_aes_essiv_cbc_ctx {
116         struct crypto_aes_ctx key1;
117         struct crypto_aes_ctx __aligned(8) key2;
118         struct crypto_shash *hash;
119 };
120
121 struct mac_tfm_ctx {
122         struct crypto_aes_ctx key;
123         u8 __aligned(8) consts[];
124 };
125
126 struct mac_desc_ctx {
127         unsigned int len;
128         u8 dg[AES_BLOCK_SIZE];
129 };
130
131 static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
132                                unsigned int key_len)
133 {
134         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
135
136         return aes_expandkey(ctx, in_key, key_len);
137 }
138
139 static int __maybe_unused xts_set_key(struct crypto_skcipher *tfm,
140                                       const u8 *in_key, unsigned int key_len)
141 {
142         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
143         int ret;
144
145         ret = xts_verify_key(tfm, in_key, key_len);
146         if (ret)
147                 return ret;
148
149         ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
150         if (!ret)
151                 ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
152                                     key_len / 2);
153         return ret;
154 }
155
156 static int __maybe_unused essiv_cbc_set_key(struct crypto_skcipher *tfm,
157                                             const u8 *in_key,
158                                             unsigned int key_len)
159 {
160         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
161         u8 digest[SHA256_DIGEST_SIZE];
162         int ret;
163
164         ret = aes_expandkey(&ctx->key1, in_key, key_len);
165         if (ret)
166                 return ret;
167
168         crypto_shash_tfm_digest(ctx->hash, in_key, key_len, digest);
169
170         return aes_expandkey(&ctx->key2, digest, sizeof(digest));
171 }
172
173 static int __maybe_unused ecb_encrypt(struct skcipher_request *req)
174 {
175         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
176         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
177         int err, rounds = 6 + ctx->key_length / 4;
178         struct skcipher_walk walk;
179         unsigned int blocks;
180
181         err = skcipher_walk_virt(&walk, req, false);
182
183         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
184                 kernel_neon_begin();
185                 aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
186                                 ctx->key_enc, rounds, blocks);
187                 kernel_neon_end();
188                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
189         }
190         return err;
191 }
192
193 static int __maybe_unused ecb_decrypt(struct skcipher_request *req)
194 {
195         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
196         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
197         int err, rounds = 6 + ctx->key_length / 4;
198         struct skcipher_walk walk;
199         unsigned int blocks;
200
201         err = skcipher_walk_virt(&walk, req, false);
202
203         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
204                 kernel_neon_begin();
205                 aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
206                                 ctx->key_dec, rounds, blocks);
207                 kernel_neon_end();
208                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
209         }
210         return err;
211 }
212
213 static int cbc_encrypt_walk(struct skcipher_request *req,
214                             struct skcipher_walk *walk)
215 {
216         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
217         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
218         int err = 0, rounds = 6 + ctx->key_length / 4;
219         unsigned int blocks;
220
221         while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
222                 kernel_neon_begin();
223                 aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
224                                 ctx->key_enc, rounds, blocks, walk->iv);
225                 kernel_neon_end();
226                 err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
227         }
228         return err;
229 }
230
231 static int __maybe_unused cbc_encrypt(struct skcipher_request *req)
232 {
233         struct skcipher_walk walk;
234         int err;
235
236         err = skcipher_walk_virt(&walk, req, false);
237         if (err)
238                 return err;
239         return cbc_encrypt_walk(req, &walk);
240 }
241
242 static int cbc_decrypt_walk(struct skcipher_request *req,
243                             struct skcipher_walk *walk)
244 {
245         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
246         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
247         int err = 0, rounds = 6 + ctx->key_length / 4;
248         unsigned int blocks;
249
250         while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
251                 kernel_neon_begin();
252                 aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
253                                 ctx->key_dec, rounds, blocks, walk->iv);
254                 kernel_neon_end();
255                 err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
256         }
257         return err;
258 }
259
260 static int __maybe_unused cbc_decrypt(struct skcipher_request *req)
261 {
262         struct skcipher_walk walk;
263         int err;
264
265         err = skcipher_walk_virt(&walk, req, false);
266         if (err)
267                 return err;
268         return cbc_decrypt_walk(req, &walk);
269 }
270
271 static int cts_cbc_encrypt(struct skcipher_request *req)
272 {
273         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
274         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
275         int err, rounds = 6 + ctx->key_length / 4;
276         int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
277         struct scatterlist *src = req->src, *dst = req->dst;
278         struct scatterlist sg_src[2], sg_dst[2];
279         struct skcipher_request subreq;
280         struct skcipher_walk walk;
281
282         skcipher_request_set_tfm(&subreq, tfm);
283         skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
284                                       NULL, NULL);
285
286         if (req->cryptlen <= AES_BLOCK_SIZE) {
287                 if (req->cryptlen < AES_BLOCK_SIZE)
288                         return -EINVAL;
289                 cbc_blocks = 1;
290         }
291
292         if (cbc_blocks > 0) {
293                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
294                                            cbc_blocks * AES_BLOCK_SIZE,
295                                            req->iv);
296
297                 err = skcipher_walk_virt(&walk, &subreq, false) ?:
298                       cbc_encrypt_walk(&subreq, &walk);
299                 if (err)
300                         return err;
301
302                 if (req->cryptlen == AES_BLOCK_SIZE)
303                         return 0;
304
305                 dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
306                 if (req->dst != req->src)
307                         dst = scatterwalk_ffwd(sg_dst, req->dst,
308                                                subreq.cryptlen);
309         }
310
311         /* handle ciphertext stealing */
312         skcipher_request_set_crypt(&subreq, src, dst,
313                                    req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
314                                    req->iv);
315
316         err = skcipher_walk_virt(&walk, &subreq, false);
317         if (err)
318                 return err;
319
320         kernel_neon_begin();
321         aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
322                             ctx->key_enc, rounds, walk.nbytes, walk.iv);
323         kernel_neon_end();
324
325         return skcipher_walk_done(&walk, 0);
326 }
327
328 static int cts_cbc_decrypt(struct skcipher_request *req)
329 {
330         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
331         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
332         int err, rounds = 6 + ctx->key_length / 4;
333         int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
334         struct scatterlist *src = req->src, *dst = req->dst;
335         struct scatterlist sg_src[2], sg_dst[2];
336         struct skcipher_request subreq;
337         struct skcipher_walk walk;
338
339         skcipher_request_set_tfm(&subreq, tfm);
340         skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
341                                       NULL, NULL);
342
343         if (req->cryptlen <= AES_BLOCK_SIZE) {
344                 if (req->cryptlen < AES_BLOCK_SIZE)
345                         return -EINVAL;
346                 cbc_blocks = 1;
347         }
348
349         if (cbc_blocks > 0) {
350                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
351                                            cbc_blocks * AES_BLOCK_SIZE,
352                                            req->iv);
353
354                 err = skcipher_walk_virt(&walk, &subreq, false) ?:
355                       cbc_decrypt_walk(&subreq, &walk);
356                 if (err)
357                         return err;
358
359                 if (req->cryptlen == AES_BLOCK_SIZE)
360                         return 0;
361
362                 dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
363                 if (req->dst != req->src)
364                         dst = scatterwalk_ffwd(sg_dst, req->dst,
365                                                subreq.cryptlen);
366         }
367
368         /* handle ciphertext stealing */
369         skcipher_request_set_crypt(&subreq, src, dst,
370                                    req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
371                                    req->iv);
372
373         err = skcipher_walk_virt(&walk, &subreq, false);
374         if (err)
375                 return err;
376
377         kernel_neon_begin();
378         aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
379                             ctx->key_dec, rounds, walk.nbytes, walk.iv);
380         kernel_neon_end();
381
382         return skcipher_walk_done(&walk, 0);
383 }
384
385 static int __maybe_unused essiv_cbc_init_tfm(struct crypto_skcipher *tfm)
386 {
387         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
388
389         ctx->hash = crypto_alloc_shash("sha256", 0, 0);
390
391         return PTR_ERR_OR_ZERO(ctx->hash);
392 }
393
394 static void __maybe_unused essiv_cbc_exit_tfm(struct crypto_skcipher *tfm)
395 {
396         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
397
398         crypto_free_shash(ctx->hash);
399 }
400
401 static int __maybe_unused essiv_cbc_encrypt(struct skcipher_request *req)
402 {
403         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
404         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
405         int err, rounds = 6 + ctx->key1.key_length / 4;
406         struct skcipher_walk walk;
407         unsigned int blocks;
408
409         err = skcipher_walk_virt(&walk, req, false);
410
411         blocks = walk.nbytes / AES_BLOCK_SIZE;
412         if (blocks) {
413                 kernel_neon_begin();
414                 aes_essiv_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
415                                       ctx->key1.key_enc, rounds, blocks,
416                                       req->iv, ctx->key2.key_enc);
417                 kernel_neon_end();
418                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
419         }
420         return err ?: cbc_encrypt_walk(req, &walk);
421 }
422
423 static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
424 {
425         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
426         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
427         int err, rounds = 6 + ctx->key1.key_length / 4;
428         struct skcipher_walk walk;
429         unsigned int blocks;
430
431         err = skcipher_walk_virt(&walk, req, false);
432
433         blocks = walk.nbytes / AES_BLOCK_SIZE;
434         if (blocks) {
435                 kernel_neon_begin();
436                 aes_essiv_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
437                                       ctx->key1.key_dec, rounds, blocks,
438                                       req->iv, ctx->key2.key_enc);
439                 kernel_neon_end();
440                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
441         }
442         return err ?: cbc_decrypt_walk(req, &walk);
443 }
444
445 static int ctr_encrypt(struct skcipher_request *req)
446 {
447         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
448         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
449         int err, rounds = 6 + ctx->key_length / 4;
450         struct skcipher_walk walk;
451         int blocks;
452
453         err = skcipher_walk_virt(&walk, req, false);
454
455         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
456                 kernel_neon_begin();
457                 aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
458                                 ctx->key_enc, rounds, blocks, walk.iv);
459                 kernel_neon_end();
460                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
461         }
462         if (walk.nbytes) {
463                 u8 __aligned(8) tail[AES_BLOCK_SIZE];
464                 unsigned int nbytes = walk.nbytes;
465                 u8 *tdst = walk.dst.virt.addr;
466                 u8 *tsrc = walk.src.virt.addr;
467
468                 /*
469                  * Tell aes_ctr_encrypt() to process a tail block.
470                  */
471                 blocks = -1;
472
473                 kernel_neon_begin();
474                 aes_ctr_encrypt(tail, NULL, ctx->key_enc, rounds,
475                                 blocks, walk.iv);
476                 kernel_neon_end();
477                 crypto_xor_cpy(tdst, tsrc, tail, nbytes);
478                 err = skcipher_walk_done(&walk, 0);
479         }
480
481         return err;
482 }
483
484 static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
485 {
486         const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
487         unsigned long flags;
488
489         /*
490          * Temporarily disable interrupts to avoid races where
491          * cachelines are evicted when the CPU is interrupted
492          * to do something else.
493          */
494         local_irq_save(flags);
495         aes_encrypt(ctx, dst, src);
496         local_irq_restore(flags);
497 }
498
499 static int __maybe_unused ctr_encrypt_sync(struct skcipher_request *req)
500 {
501         if (!crypto_simd_usable())
502                 return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
503
504         return ctr_encrypt(req);
505 }
506
507 static int __maybe_unused xts_encrypt(struct skcipher_request *req)
508 {
509         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
510         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
511         int err, first, rounds = 6 + ctx->key1.key_length / 4;
512         int tail = req->cryptlen % AES_BLOCK_SIZE;
513         struct scatterlist sg_src[2], sg_dst[2];
514         struct skcipher_request subreq;
515         struct scatterlist *src, *dst;
516         struct skcipher_walk walk;
517
518         if (req->cryptlen < AES_BLOCK_SIZE)
519                 return -EINVAL;
520
521         err = skcipher_walk_virt(&walk, req, false);
522
523         if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
524                 int xts_blocks = DIV_ROUND_UP(req->cryptlen,
525                                               AES_BLOCK_SIZE) - 2;
526
527                 skcipher_walk_abort(&walk);
528
529                 skcipher_request_set_tfm(&subreq, tfm);
530                 skcipher_request_set_callback(&subreq,
531                                               skcipher_request_flags(req),
532                                               NULL, NULL);
533                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
534                                            xts_blocks * AES_BLOCK_SIZE,
535                                            req->iv);
536                 req = &subreq;
537                 err = skcipher_walk_virt(&walk, req, false);
538         } else {
539                 tail = 0;
540         }
541
542         for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
543                 int nbytes = walk.nbytes;
544
545                 if (walk.nbytes < walk.total)
546                         nbytes &= ~(AES_BLOCK_SIZE - 1);
547
548                 kernel_neon_begin();
549                 aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
550                                 ctx->key1.key_enc, rounds, nbytes,
551                                 ctx->key2.key_enc, walk.iv, first);
552                 kernel_neon_end();
553                 err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
554         }
555
556         if (err || likely(!tail))
557                 return err;
558
559         dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
560         if (req->dst != req->src)
561                 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
562
563         skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
564                                    req->iv);
565
566         err = skcipher_walk_virt(&walk, &subreq, false);
567         if (err)
568                 return err;
569
570         kernel_neon_begin();
571         aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
572                         ctx->key1.key_enc, rounds, walk.nbytes,
573                         ctx->key2.key_enc, walk.iv, first);
574         kernel_neon_end();
575
576         return skcipher_walk_done(&walk, 0);
577 }
578
579 static int __maybe_unused xts_decrypt(struct skcipher_request *req)
580 {
581         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
582         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
583         int err, first, rounds = 6 + ctx->key1.key_length / 4;
584         int tail = req->cryptlen % AES_BLOCK_SIZE;
585         struct scatterlist sg_src[2], sg_dst[2];
586         struct skcipher_request subreq;
587         struct scatterlist *src, *dst;
588         struct skcipher_walk walk;
589
590         if (req->cryptlen < AES_BLOCK_SIZE)
591                 return -EINVAL;
592
593         err = skcipher_walk_virt(&walk, req, false);
594
595         if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
596                 int xts_blocks = DIV_ROUND_UP(req->cryptlen,
597                                               AES_BLOCK_SIZE) - 2;
598
599                 skcipher_walk_abort(&walk);
600
601                 skcipher_request_set_tfm(&subreq, tfm);
602                 skcipher_request_set_callback(&subreq,
603                                               skcipher_request_flags(req),
604                                               NULL, NULL);
605                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
606                                            xts_blocks * AES_BLOCK_SIZE,
607                                            req->iv);
608                 req = &subreq;
609                 err = skcipher_walk_virt(&walk, req, false);
610         } else {
611                 tail = 0;
612         }
613
614         for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
615                 int nbytes = walk.nbytes;
616
617                 if (walk.nbytes < walk.total)
618                         nbytes &= ~(AES_BLOCK_SIZE - 1);
619
620                 kernel_neon_begin();
621                 aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
622                                 ctx->key1.key_dec, rounds, nbytes,
623                                 ctx->key2.key_enc, walk.iv, first);
624                 kernel_neon_end();
625                 err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
626         }
627
628         if (err || likely(!tail))
629                 return err;
630
631         dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
632         if (req->dst != req->src)
633                 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
634
635         skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
636                                    req->iv);
637
638         err = skcipher_walk_virt(&walk, &subreq, false);
639         if (err)
640                 return err;
641
642
643         kernel_neon_begin();
644         aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
645                         ctx->key1.key_dec, rounds, walk.nbytes,
646                         ctx->key2.key_enc, walk.iv, first);
647         kernel_neon_end();
648
649         return skcipher_walk_done(&walk, 0);
650 }
651
652 static struct skcipher_alg aes_algs[] = { {
653 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !defined(CONFIG_CRYPTO_AES_ARM64_BS)
654         .base = {
655                 .cra_name               = "__ecb(aes)",
656                 .cra_driver_name        = "__ecb-aes-" MODE,
657                 .cra_priority           = PRIO,
658                 .cra_flags              = CRYPTO_ALG_INTERNAL,
659                 .cra_blocksize          = AES_BLOCK_SIZE,
660                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
661                 .cra_module             = THIS_MODULE,
662         },
663         .min_keysize    = AES_MIN_KEY_SIZE,
664         .max_keysize    = AES_MAX_KEY_SIZE,
665         .setkey         = skcipher_aes_setkey,
666         .encrypt        = ecb_encrypt,
667         .decrypt        = ecb_decrypt,
668 }, {
669         .base = {
670                 .cra_name               = "__cbc(aes)",
671                 .cra_driver_name        = "__cbc-aes-" MODE,
672                 .cra_priority           = PRIO,
673                 .cra_flags              = CRYPTO_ALG_INTERNAL,
674                 .cra_blocksize          = AES_BLOCK_SIZE,
675                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
676                 .cra_module             = THIS_MODULE,
677         },
678         .min_keysize    = AES_MIN_KEY_SIZE,
679         .max_keysize    = AES_MAX_KEY_SIZE,
680         .ivsize         = AES_BLOCK_SIZE,
681         .setkey         = skcipher_aes_setkey,
682         .encrypt        = cbc_encrypt,
683         .decrypt        = cbc_decrypt,
684 }, {
685         .base = {
686                 .cra_name               = "__ctr(aes)",
687                 .cra_driver_name        = "__ctr-aes-" MODE,
688                 .cra_priority           = PRIO,
689                 .cra_flags              = CRYPTO_ALG_INTERNAL,
690                 .cra_blocksize          = 1,
691                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
692                 .cra_module             = THIS_MODULE,
693         },
694         .min_keysize    = AES_MIN_KEY_SIZE,
695         .max_keysize    = AES_MAX_KEY_SIZE,
696         .ivsize         = AES_BLOCK_SIZE,
697         .chunksize      = AES_BLOCK_SIZE,
698         .setkey         = skcipher_aes_setkey,
699         .encrypt        = ctr_encrypt,
700         .decrypt        = ctr_encrypt,
701 }, {
702         .base = {
703                 .cra_name               = "ctr(aes)",
704                 .cra_driver_name        = "ctr-aes-" MODE,
705                 .cra_priority           = PRIO - 1,
706                 .cra_blocksize          = 1,
707                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
708                 .cra_module             = THIS_MODULE,
709         },
710         .min_keysize    = AES_MIN_KEY_SIZE,
711         .max_keysize    = AES_MAX_KEY_SIZE,
712         .ivsize         = AES_BLOCK_SIZE,
713         .chunksize      = AES_BLOCK_SIZE,
714         .setkey         = skcipher_aes_setkey,
715         .encrypt        = ctr_encrypt_sync,
716         .decrypt        = ctr_encrypt_sync,
717 }, {
718         .base = {
719                 .cra_name               = "__xts(aes)",
720                 .cra_driver_name        = "__xts-aes-" MODE,
721                 .cra_priority           = PRIO,
722                 .cra_flags              = CRYPTO_ALG_INTERNAL,
723                 .cra_blocksize          = AES_BLOCK_SIZE,
724                 .cra_ctxsize            = sizeof(struct crypto_aes_xts_ctx),
725                 .cra_module             = THIS_MODULE,
726         },
727         .min_keysize    = 2 * AES_MIN_KEY_SIZE,
728         .max_keysize    = 2 * AES_MAX_KEY_SIZE,
729         .ivsize         = AES_BLOCK_SIZE,
730         .walksize       = 2 * AES_BLOCK_SIZE,
731         .setkey         = xts_set_key,
732         .encrypt        = xts_encrypt,
733         .decrypt        = xts_decrypt,
734 }, {
735 #endif
736         .base = {
737                 .cra_name               = "__cts(cbc(aes))",
738                 .cra_driver_name        = "__cts-cbc-aes-" MODE,
739                 .cra_priority           = PRIO,
740                 .cra_flags              = CRYPTO_ALG_INTERNAL,
741                 .cra_blocksize          = AES_BLOCK_SIZE,
742                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
743                 .cra_module             = THIS_MODULE,
744         },
745         .min_keysize    = AES_MIN_KEY_SIZE,
746         .max_keysize    = AES_MAX_KEY_SIZE,
747         .ivsize         = AES_BLOCK_SIZE,
748         .walksize       = 2 * AES_BLOCK_SIZE,
749         .setkey         = skcipher_aes_setkey,
750         .encrypt        = cts_cbc_encrypt,
751         .decrypt        = cts_cbc_decrypt,
752 }, {
753         .base = {
754                 .cra_name               = "__essiv(cbc(aes),sha256)",
755                 .cra_driver_name        = "__essiv-cbc-aes-sha256-" MODE,
756                 .cra_priority           = PRIO + 1,
757                 .cra_flags              = CRYPTO_ALG_INTERNAL,
758                 .cra_blocksize          = AES_BLOCK_SIZE,
759                 .cra_ctxsize            = sizeof(struct crypto_aes_essiv_cbc_ctx),
760                 .cra_module             = THIS_MODULE,
761         },
762         .min_keysize    = AES_MIN_KEY_SIZE,
763         .max_keysize    = AES_MAX_KEY_SIZE,
764         .ivsize         = AES_BLOCK_SIZE,
765         .setkey         = essiv_cbc_set_key,
766         .encrypt        = essiv_cbc_encrypt,
767         .decrypt        = essiv_cbc_decrypt,
768         .init           = essiv_cbc_init_tfm,
769         .exit           = essiv_cbc_exit_tfm,
770 } };
771
772 static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
773                          unsigned int key_len)
774 {
775         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
776
777         return aes_expandkey(&ctx->key, in_key, key_len);
778 }
779
780 static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
781 {
782         u64 a = be64_to_cpu(x->a);
783         u64 b = be64_to_cpu(x->b);
784
785         y->a = cpu_to_be64((a << 1) | (b >> 63));
786         y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
787 }
788
789 static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
790                        unsigned int key_len)
791 {
792         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
793         be128 *consts = (be128 *)ctx->consts;
794         int rounds = 6 + key_len / 4;
795         int err;
796
797         err = cbcmac_setkey(tfm, in_key, key_len);
798         if (err)
799                 return err;
800
801         /* encrypt the zero vector */
802         kernel_neon_begin();
803         aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
804                         rounds, 1);
805         kernel_neon_end();
806
807         cmac_gf128_mul_by_x(consts, consts);
808         cmac_gf128_mul_by_x(consts + 1, consts);
809
810         return 0;
811 }
812
813 static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
814                        unsigned int key_len)
815 {
816         static u8 const ks[3][AES_BLOCK_SIZE] = {
817                 { [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
818                 { [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
819                 { [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
820         };
821
822         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
823         int rounds = 6 + key_len / 4;
824         u8 key[AES_BLOCK_SIZE];
825         int err;
826
827         err = cbcmac_setkey(tfm, in_key, key_len);
828         if (err)
829                 return err;
830
831         kernel_neon_begin();
832         aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
833         aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
834         kernel_neon_end();
835
836         return cbcmac_setkey(tfm, key, sizeof(key));
837 }
838
839 static int mac_init(struct shash_desc *desc)
840 {
841         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
842
843         memset(ctx->dg, 0, AES_BLOCK_SIZE);
844         ctx->len = 0;
845
846         return 0;
847 }
848
849 static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
850                           u8 dg[], int enc_before, int enc_after)
851 {
852         int rounds = 6 + ctx->key_length / 4;
853
854         if (crypto_simd_usable()) {
855                 kernel_neon_begin();
856                 aes_mac_update(in, ctx->key_enc, rounds, blocks, dg, enc_before,
857                                enc_after);
858                 kernel_neon_end();
859         } else {
860                 if (enc_before)
861                         aes_encrypt(ctx, dg, dg);
862
863                 while (blocks--) {
864                         crypto_xor(dg, in, AES_BLOCK_SIZE);
865                         in += AES_BLOCK_SIZE;
866
867                         if (blocks || enc_after)
868                                 aes_encrypt(ctx, dg, dg);
869                 }
870         }
871 }
872
873 static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
874 {
875         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
876         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
877
878         while (len > 0) {
879                 unsigned int l;
880
881                 if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
882                     (ctx->len + len) > AES_BLOCK_SIZE) {
883
884                         int blocks = len / AES_BLOCK_SIZE;
885
886                         len %= AES_BLOCK_SIZE;
887
888                         mac_do_update(&tctx->key, p, blocks, ctx->dg,
889                                       (ctx->len != 0), (len != 0));
890
891                         p += blocks * AES_BLOCK_SIZE;
892
893                         if (!len) {
894                                 ctx->len = AES_BLOCK_SIZE;
895                                 break;
896                         }
897                         ctx->len = 0;
898                 }
899
900                 l = min(len, AES_BLOCK_SIZE - ctx->len);
901
902                 if (l <= AES_BLOCK_SIZE) {
903                         crypto_xor(ctx->dg + ctx->len, p, l);
904                         ctx->len += l;
905                         len -= l;
906                         p += l;
907                 }
908         }
909
910         return 0;
911 }
912
913 static int cbcmac_final(struct shash_desc *desc, u8 *out)
914 {
915         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
916         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
917
918         mac_do_update(&tctx->key, NULL, 0, ctx->dg, (ctx->len != 0), 0);
919
920         memcpy(out, ctx->dg, AES_BLOCK_SIZE);
921
922         return 0;
923 }
924
925 static int cmac_final(struct shash_desc *desc, u8 *out)
926 {
927         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
928         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
929         u8 *consts = tctx->consts;
930
931         if (ctx->len != AES_BLOCK_SIZE) {
932                 ctx->dg[ctx->len] ^= 0x80;
933                 consts += AES_BLOCK_SIZE;
934         }
935
936         mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
937
938         memcpy(out, ctx->dg, AES_BLOCK_SIZE);
939
940         return 0;
941 }
942
943 static struct shash_alg mac_algs[] = { {
944         .base.cra_name          = "cmac(aes)",
945         .base.cra_driver_name   = "cmac-aes-" MODE,
946         .base.cra_priority      = PRIO,
947         .base.cra_blocksize     = AES_BLOCK_SIZE,
948         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
949                                   2 * AES_BLOCK_SIZE,
950         .base.cra_module        = THIS_MODULE,
951
952         .digestsize             = AES_BLOCK_SIZE,
953         .init                   = mac_init,
954         .update                 = mac_update,
955         .final                  = cmac_final,
956         .setkey                 = cmac_setkey,
957         .descsize               = sizeof(struct mac_desc_ctx),
958 }, {
959         .base.cra_name          = "xcbc(aes)",
960         .base.cra_driver_name   = "xcbc-aes-" MODE,
961         .base.cra_priority      = PRIO,
962         .base.cra_blocksize     = AES_BLOCK_SIZE,
963         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
964                                   2 * AES_BLOCK_SIZE,
965         .base.cra_module        = THIS_MODULE,
966
967         .digestsize             = AES_BLOCK_SIZE,
968         .init                   = mac_init,
969         .update                 = mac_update,
970         .final                  = cmac_final,
971         .setkey                 = xcbc_setkey,
972         .descsize               = sizeof(struct mac_desc_ctx),
973 }, {
974         .base.cra_name          = "cbcmac(aes)",
975         .base.cra_driver_name   = "cbcmac-aes-" MODE,
976         .base.cra_priority      = PRIO,
977         .base.cra_blocksize     = 1,
978         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx),
979         .base.cra_module        = THIS_MODULE,
980
981         .digestsize             = AES_BLOCK_SIZE,
982         .init                   = mac_init,
983         .update                 = mac_update,
984         .final                  = cbcmac_final,
985         .setkey                 = cbcmac_setkey,
986         .descsize               = sizeof(struct mac_desc_ctx),
987 } };
988
989 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
990
991 static void aes_exit(void)
992 {
993         int i;
994
995         for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
996                 if (aes_simd_algs[i])
997                         simd_skcipher_free(aes_simd_algs[i]);
998
999         crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1000         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1001 }
1002
1003 static int __init aes_init(void)
1004 {
1005         struct simd_skcipher_alg *simd;
1006         const char *basename;
1007         const char *algname;
1008         const char *drvname;
1009         int err;
1010         int i;
1011
1012         err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1013         if (err)
1014                 return err;
1015
1016         err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1017         if (err)
1018                 goto unregister_ciphers;
1019
1020         for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
1021                 if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
1022                         continue;
1023
1024                 algname = aes_algs[i].base.cra_name + 2;
1025                 drvname = aes_algs[i].base.cra_driver_name + 2;
1026                 basename = aes_algs[i].base.cra_driver_name;
1027                 simd = simd_skcipher_create_compat(algname, drvname, basename);
1028                 err = PTR_ERR(simd);
1029                 if (IS_ERR(simd))
1030                         goto unregister_simds;
1031
1032                 aes_simd_algs[i] = simd;
1033         }
1034
1035         return 0;
1036
1037 unregister_simds:
1038         aes_exit();
1039         return err;
1040 unregister_ciphers:
1041         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1042         return err;
1043 }
1044
1045 #ifdef USE_V8_CRYPTO_EXTENSIONS
1046 module_cpu_feature_match(AES, aes_init);
1047 #else
1048 module_init(aes_init);
1049 EXPORT_SYMBOL(neon_aes_ecb_encrypt);
1050 EXPORT_SYMBOL(neon_aes_cbc_encrypt);
1051 EXPORT_SYMBOL(neon_aes_xts_encrypt);
1052 EXPORT_SYMBOL(neon_aes_xts_decrypt);
1053 #endif
1054 module_exit(aes_exit);