Merge tag 'v5.7-rc7' into perf/core, to pick up fixes
[linux-2.6-microblaze.git] / arch / arm64 / crypto / aes-glue.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
4  *
5  * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7
8 #include <asm/neon.h>
9 #include <asm/hwcap.h>
10 #include <asm/simd.h>
11 #include <crypto/aes.h>
12 #include <crypto/ctr.h>
13 #include <crypto/sha.h>
14 #include <crypto/internal/hash.h>
15 #include <crypto/internal/simd.h>
16 #include <crypto/internal/skcipher.h>
17 #include <crypto/scatterwalk.h>
18 #include <linux/module.h>
19 #include <linux/cpufeature.h>
20 #include <crypto/xts.h>
21
22 #include "aes-ce-setkey.h"
23
24 #ifdef USE_V8_CRYPTO_EXTENSIONS
25 #define MODE                    "ce"
26 #define PRIO                    300
27 #define aes_expandkey           ce_aes_expandkey
28 #define aes_ecb_encrypt         ce_aes_ecb_encrypt
29 #define aes_ecb_decrypt         ce_aes_ecb_decrypt
30 #define aes_cbc_encrypt         ce_aes_cbc_encrypt
31 #define aes_cbc_decrypt         ce_aes_cbc_decrypt
32 #define aes_cbc_cts_encrypt     ce_aes_cbc_cts_encrypt
33 #define aes_cbc_cts_decrypt     ce_aes_cbc_cts_decrypt
34 #define aes_essiv_cbc_encrypt   ce_aes_essiv_cbc_encrypt
35 #define aes_essiv_cbc_decrypt   ce_aes_essiv_cbc_decrypt
36 #define aes_ctr_encrypt         ce_aes_ctr_encrypt
37 #define aes_xts_encrypt         ce_aes_xts_encrypt
38 #define aes_xts_decrypt         ce_aes_xts_decrypt
39 #define aes_mac_update          ce_aes_mac_update
40 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
41 #else
42 #define MODE                    "neon"
43 #define PRIO                    200
44 #define aes_ecb_encrypt         neon_aes_ecb_encrypt
45 #define aes_ecb_decrypt         neon_aes_ecb_decrypt
46 #define aes_cbc_encrypt         neon_aes_cbc_encrypt
47 #define aes_cbc_decrypt         neon_aes_cbc_decrypt
48 #define aes_cbc_cts_encrypt     neon_aes_cbc_cts_encrypt
49 #define aes_cbc_cts_decrypt     neon_aes_cbc_cts_decrypt
50 #define aes_essiv_cbc_encrypt   neon_aes_essiv_cbc_encrypt
51 #define aes_essiv_cbc_decrypt   neon_aes_essiv_cbc_decrypt
52 #define aes_ctr_encrypt         neon_aes_ctr_encrypt
53 #define aes_xts_encrypt         neon_aes_xts_encrypt
54 #define aes_xts_decrypt         neon_aes_xts_decrypt
55 #define aes_mac_update          neon_aes_mac_update
56 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON");
57 #endif
58 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !defined(CONFIG_CRYPTO_AES_ARM64_BS)
59 MODULE_ALIAS_CRYPTO("ecb(aes)");
60 MODULE_ALIAS_CRYPTO("cbc(aes)");
61 MODULE_ALIAS_CRYPTO("ctr(aes)");
62 MODULE_ALIAS_CRYPTO("xts(aes)");
63 #endif
64 MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
65 MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
66 MODULE_ALIAS_CRYPTO("cmac(aes)");
67 MODULE_ALIAS_CRYPTO("xcbc(aes)");
68 MODULE_ALIAS_CRYPTO("cbcmac(aes)");
69
70 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
71 MODULE_LICENSE("GPL v2");
72
73 /* defined in aes-modes.S */
74 asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
75                                 int rounds, int blocks);
76 asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
77                                 int rounds, int blocks);
78
79 asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
80                                 int rounds, int blocks, u8 iv[]);
81 asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
82                                 int rounds, int blocks, u8 iv[]);
83
84 asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
85                                 int rounds, int bytes, u8 const iv[]);
86 asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
87                                 int rounds, int bytes, u8 const iv[]);
88
89 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
90                                 int rounds, int blocks, u8 ctr[]);
91
92 asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
93                                 int rounds, int bytes, u32 const rk2[], u8 iv[],
94                                 int first);
95 asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
96                                 int rounds, int bytes, u32 const rk2[], u8 iv[],
97                                 int first);
98
99 asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
100                                       int rounds, int blocks, u8 iv[],
101                                       u32 const rk2[]);
102 asmlinkage void aes_essiv_cbc_decrypt(u8 out[], u8 const in[], u32 const rk1[],
103                                       int rounds, int blocks, u8 iv[],
104                                       u32 const rk2[]);
105
106 asmlinkage void aes_mac_update(u8 const in[], u32 const rk[], int rounds,
107                                int blocks, u8 dg[], int enc_before,
108                                int enc_after);
109
110 struct crypto_aes_xts_ctx {
111         struct crypto_aes_ctx key1;
112         struct crypto_aes_ctx __aligned(8) key2;
113 };
114
115 struct crypto_aes_essiv_cbc_ctx {
116         struct crypto_aes_ctx key1;
117         struct crypto_aes_ctx __aligned(8) key2;
118         struct crypto_shash *hash;
119 };
120
121 struct mac_tfm_ctx {
122         struct crypto_aes_ctx key;
123         u8 __aligned(8) consts[];
124 };
125
126 struct mac_desc_ctx {
127         unsigned int len;
128         u8 dg[AES_BLOCK_SIZE];
129 };
130
131 static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
132                                unsigned int key_len)
133 {
134         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
135
136         return aes_expandkey(ctx, in_key, key_len);
137 }
138
139 static int __maybe_unused xts_set_key(struct crypto_skcipher *tfm,
140                                       const u8 *in_key, unsigned int key_len)
141 {
142         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
143         int ret;
144
145         ret = xts_verify_key(tfm, in_key, key_len);
146         if (ret)
147                 return ret;
148
149         ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
150         if (!ret)
151                 ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
152                                     key_len / 2);
153         return ret;
154 }
155
156 static int __maybe_unused essiv_cbc_set_key(struct crypto_skcipher *tfm,
157                                             const u8 *in_key,
158                                             unsigned int key_len)
159 {
160         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
161         SHASH_DESC_ON_STACK(desc, ctx->hash);
162         u8 digest[SHA256_DIGEST_SIZE];
163         int ret;
164
165         ret = aes_expandkey(&ctx->key1, in_key, key_len);
166         if (ret)
167                 return ret;
168
169         desc->tfm = ctx->hash;
170         crypto_shash_digest(desc, in_key, key_len, digest);
171
172         return aes_expandkey(&ctx->key2, digest, sizeof(digest));
173 }
174
175 static int __maybe_unused ecb_encrypt(struct skcipher_request *req)
176 {
177         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
178         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
179         int err, rounds = 6 + ctx->key_length / 4;
180         struct skcipher_walk walk;
181         unsigned int blocks;
182
183         err = skcipher_walk_virt(&walk, req, false);
184
185         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
186                 kernel_neon_begin();
187                 aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
188                                 ctx->key_enc, rounds, blocks);
189                 kernel_neon_end();
190                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
191         }
192         return err;
193 }
194
195 static int __maybe_unused ecb_decrypt(struct skcipher_request *req)
196 {
197         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
198         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
199         int err, rounds = 6 + ctx->key_length / 4;
200         struct skcipher_walk walk;
201         unsigned int blocks;
202
203         err = skcipher_walk_virt(&walk, req, false);
204
205         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
206                 kernel_neon_begin();
207                 aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
208                                 ctx->key_dec, rounds, blocks);
209                 kernel_neon_end();
210                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
211         }
212         return err;
213 }
214
215 static int cbc_encrypt_walk(struct skcipher_request *req,
216                             struct skcipher_walk *walk)
217 {
218         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
219         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
220         int err = 0, rounds = 6 + ctx->key_length / 4;
221         unsigned int blocks;
222
223         while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
224                 kernel_neon_begin();
225                 aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
226                                 ctx->key_enc, rounds, blocks, walk->iv);
227                 kernel_neon_end();
228                 err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
229         }
230         return err;
231 }
232
233 static int __maybe_unused cbc_encrypt(struct skcipher_request *req)
234 {
235         struct skcipher_walk walk;
236         int err;
237
238         err = skcipher_walk_virt(&walk, req, false);
239         if (err)
240                 return err;
241         return cbc_encrypt_walk(req, &walk);
242 }
243
244 static int cbc_decrypt_walk(struct skcipher_request *req,
245                             struct skcipher_walk *walk)
246 {
247         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
248         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
249         int err = 0, rounds = 6 + ctx->key_length / 4;
250         unsigned int blocks;
251
252         while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
253                 kernel_neon_begin();
254                 aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
255                                 ctx->key_dec, rounds, blocks, walk->iv);
256                 kernel_neon_end();
257                 err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
258         }
259         return err;
260 }
261
262 static int __maybe_unused cbc_decrypt(struct skcipher_request *req)
263 {
264         struct skcipher_walk walk;
265         int err;
266
267         err = skcipher_walk_virt(&walk, req, false);
268         if (err)
269                 return err;
270         return cbc_decrypt_walk(req, &walk);
271 }
272
273 static int cts_cbc_encrypt(struct skcipher_request *req)
274 {
275         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
276         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
277         int err, rounds = 6 + ctx->key_length / 4;
278         int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
279         struct scatterlist *src = req->src, *dst = req->dst;
280         struct scatterlist sg_src[2], sg_dst[2];
281         struct skcipher_request subreq;
282         struct skcipher_walk walk;
283
284         skcipher_request_set_tfm(&subreq, tfm);
285         skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
286                                       NULL, NULL);
287
288         if (req->cryptlen <= AES_BLOCK_SIZE) {
289                 if (req->cryptlen < AES_BLOCK_SIZE)
290                         return -EINVAL;
291                 cbc_blocks = 1;
292         }
293
294         if (cbc_blocks > 0) {
295                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
296                                            cbc_blocks * AES_BLOCK_SIZE,
297                                            req->iv);
298
299                 err = skcipher_walk_virt(&walk, &subreq, false) ?:
300                       cbc_encrypt_walk(&subreq, &walk);
301                 if (err)
302                         return err;
303
304                 if (req->cryptlen == AES_BLOCK_SIZE)
305                         return 0;
306
307                 dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
308                 if (req->dst != req->src)
309                         dst = scatterwalk_ffwd(sg_dst, req->dst,
310                                                subreq.cryptlen);
311         }
312
313         /* handle ciphertext stealing */
314         skcipher_request_set_crypt(&subreq, src, dst,
315                                    req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
316                                    req->iv);
317
318         err = skcipher_walk_virt(&walk, &subreq, false);
319         if (err)
320                 return err;
321
322         kernel_neon_begin();
323         aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
324                             ctx->key_enc, rounds, walk.nbytes, walk.iv);
325         kernel_neon_end();
326
327         return skcipher_walk_done(&walk, 0);
328 }
329
330 static int cts_cbc_decrypt(struct skcipher_request *req)
331 {
332         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
333         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
334         int err, rounds = 6 + ctx->key_length / 4;
335         int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
336         struct scatterlist *src = req->src, *dst = req->dst;
337         struct scatterlist sg_src[2], sg_dst[2];
338         struct skcipher_request subreq;
339         struct skcipher_walk walk;
340
341         skcipher_request_set_tfm(&subreq, tfm);
342         skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
343                                       NULL, NULL);
344
345         if (req->cryptlen <= AES_BLOCK_SIZE) {
346                 if (req->cryptlen < AES_BLOCK_SIZE)
347                         return -EINVAL;
348                 cbc_blocks = 1;
349         }
350
351         if (cbc_blocks > 0) {
352                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
353                                            cbc_blocks * AES_BLOCK_SIZE,
354                                            req->iv);
355
356                 err = skcipher_walk_virt(&walk, &subreq, false) ?:
357                       cbc_decrypt_walk(&subreq, &walk);
358                 if (err)
359                         return err;
360
361                 if (req->cryptlen == AES_BLOCK_SIZE)
362                         return 0;
363
364                 dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
365                 if (req->dst != req->src)
366                         dst = scatterwalk_ffwd(sg_dst, req->dst,
367                                                subreq.cryptlen);
368         }
369
370         /* handle ciphertext stealing */
371         skcipher_request_set_crypt(&subreq, src, dst,
372                                    req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
373                                    req->iv);
374
375         err = skcipher_walk_virt(&walk, &subreq, false);
376         if (err)
377                 return err;
378
379         kernel_neon_begin();
380         aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
381                             ctx->key_dec, rounds, walk.nbytes, walk.iv);
382         kernel_neon_end();
383
384         return skcipher_walk_done(&walk, 0);
385 }
386
387 static int __maybe_unused essiv_cbc_init_tfm(struct crypto_skcipher *tfm)
388 {
389         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
390
391         ctx->hash = crypto_alloc_shash("sha256", 0, 0);
392
393         return PTR_ERR_OR_ZERO(ctx->hash);
394 }
395
396 static void __maybe_unused essiv_cbc_exit_tfm(struct crypto_skcipher *tfm)
397 {
398         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
399
400         crypto_free_shash(ctx->hash);
401 }
402
403 static int __maybe_unused essiv_cbc_encrypt(struct skcipher_request *req)
404 {
405         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
406         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
407         int err, rounds = 6 + ctx->key1.key_length / 4;
408         struct skcipher_walk walk;
409         unsigned int blocks;
410
411         err = skcipher_walk_virt(&walk, req, false);
412
413         blocks = walk.nbytes / AES_BLOCK_SIZE;
414         if (blocks) {
415                 kernel_neon_begin();
416                 aes_essiv_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
417                                       ctx->key1.key_enc, rounds, blocks,
418                                       req->iv, ctx->key2.key_enc);
419                 kernel_neon_end();
420                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
421         }
422         return err ?: cbc_encrypt_walk(req, &walk);
423 }
424
425 static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
426 {
427         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
428         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
429         int err, rounds = 6 + ctx->key1.key_length / 4;
430         struct skcipher_walk walk;
431         unsigned int blocks;
432
433         err = skcipher_walk_virt(&walk, req, false);
434
435         blocks = walk.nbytes / AES_BLOCK_SIZE;
436         if (blocks) {
437                 kernel_neon_begin();
438                 aes_essiv_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
439                                       ctx->key1.key_dec, rounds, blocks,
440                                       req->iv, ctx->key2.key_enc);
441                 kernel_neon_end();
442                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
443         }
444         return err ?: cbc_decrypt_walk(req, &walk);
445 }
446
447 static int ctr_encrypt(struct skcipher_request *req)
448 {
449         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
450         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
451         int err, rounds = 6 + ctx->key_length / 4;
452         struct skcipher_walk walk;
453         int blocks;
454
455         err = skcipher_walk_virt(&walk, req, false);
456
457         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
458                 kernel_neon_begin();
459                 aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
460                                 ctx->key_enc, rounds, blocks, walk.iv);
461                 kernel_neon_end();
462                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
463         }
464         if (walk.nbytes) {
465                 u8 __aligned(8) tail[AES_BLOCK_SIZE];
466                 unsigned int nbytes = walk.nbytes;
467                 u8 *tdst = walk.dst.virt.addr;
468                 u8 *tsrc = walk.src.virt.addr;
469
470                 /*
471                  * Tell aes_ctr_encrypt() to process a tail block.
472                  */
473                 blocks = -1;
474
475                 kernel_neon_begin();
476                 aes_ctr_encrypt(tail, NULL, ctx->key_enc, rounds,
477                                 blocks, walk.iv);
478                 kernel_neon_end();
479                 crypto_xor_cpy(tdst, tsrc, tail, nbytes);
480                 err = skcipher_walk_done(&walk, 0);
481         }
482
483         return err;
484 }
485
486 static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
487 {
488         const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
489         unsigned long flags;
490
491         /*
492          * Temporarily disable interrupts to avoid races where
493          * cachelines are evicted when the CPU is interrupted
494          * to do something else.
495          */
496         local_irq_save(flags);
497         aes_encrypt(ctx, dst, src);
498         local_irq_restore(flags);
499 }
500
501 static int __maybe_unused ctr_encrypt_sync(struct skcipher_request *req)
502 {
503         if (!crypto_simd_usable())
504                 return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
505
506         return ctr_encrypt(req);
507 }
508
509 static int __maybe_unused xts_encrypt(struct skcipher_request *req)
510 {
511         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
512         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
513         int err, first, rounds = 6 + ctx->key1.key_length / 4;
514         int tail = req->cryptlen % AES_BLOCK_SIZE;
515         struct scatterlist sg_src[2], sg_dst[2];
516         struct skcipher_request subreq;
517         struct scatterlist *src, *dst;
518         struct skcipher_walk walk;
519
520         if (req->cryptlen < AES_BLOCK_SIZE)
521                 return -EINVAL;
522
523         err = skcipher_walk_virt(&walk, req, false);
524
525         if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
526                 int xts_blocks = DIV_ROUND_UP(req->cryptlen,
527                                               AES_BLOCK_SIZE) - 2;
528
529                 skcipher_walk_abort(&walk);
530
531                 skcipher_request_set_tfm(&subreq, tfm);
532                 skcipher_request_set_callback(&subreq,
533                                               skcipher_request_flags(req),
534                                               NULL, NULL);
535                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
536                                            xts_blocks * AES_BLOCK_SIZE,
537                                            req->iv);
538                 req = &subreq;
539                 err = skcipher_walk_virt(&walk, req, false);
540         } else {
541                 tail = 0;
542         }
543
544         for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
545                 int nbytes = walk.nbytes;
546
547                 if (walk.nbytes < walk.total)
548                         nbytes &= ~(AES_BLOCK_SIZE - 1);
549
550                 kernel_neon_begin();
551                 aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
552                                 ctx->key1.key_enc, rounds, nbytes,
553                                 ctx->key2.key_enc, walk.iv, first);
554                 kernel_neon_end();
555                 err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
556         }
557
558         if (err || likely(!tail))
559                 return err;
560
561         dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
562         if (req->dst != req->src)
563                 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
564
565         skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
566                                    req->iv);
567
568         err = skcipher_walk_virt(&walk, &subreq, false);
569         if (err)
570                 return err;
571
572         kernel_neon_begin();
573         aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
574                         ctx->key1.key_enc, rounds, walk.nbytes,
575                         ctx->key2.key_enc, walk.iv, first);
576         kernel_neon_end();
577
578         return skcipher_walk_done(&walk, 0);
579 }
580
581 static int __maybe_unused xts_decrypt(struct skcipher_request *req)
582 {
583         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
584         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
585         int err, first, rounds = 6 + ctx->key1.key_length / 4;
586         int tail = req->cryptlen % AES_BLOCK_SIZE;
587         struct scatterlist sg_src[2], sg_dst[2];
588         struct skcipher_request subreq;
589         struct scatterlist *src, *dst;
590         struct skcipher_walk walk;
591
592         if (req->cryptlen < AES_BLOCK_SIZE)
593                 return -EINVAL;
594
595         err = skcipher_walk_virt(&walk, req, false);
596
597         if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
598                 int xts_blocks = DIV_ROUND_UP(req->cryptlen,
599                                               AES_BLOCK_SIZE) - 2;
600
601                 skcipher_walk_abort(&walk);
602
603                 skcipher_request_set_tfm(&subreq, tfm);
604                 skcipher_request_set_callback(&subreq,
605                                               skcipher_request_flags(req),
606                                               NULL, NULL);
607                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
608                                            xts_blocks * AES_BLOCK_SIZE,
609                                            req->iv);
610                 req = &subreq;
611                 err = skcipher_walk_virt(&walk, req, false);
612         } else {
613                 tail = 0;
614         }
615
616         for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
617                 int nbytes = walk.nbytes;
618
619                 if (walk.nbytes < walk.total)
620                         nbytes &= ~(AES_BLOCK_SIZE - 1);
621
622                 kernel_neon_begin();
623                 aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
624                                 ctx->key1.key_dec, rounds, nbytes,
625                                 ctx->key2.key_enc, walk.iv, first);
626                 kernel_neon_end();
627                 err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
628         }
629
630         if (err || likely(!tail))
631                 return err;
632
633         dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
634         if (req->dst != req->src)
635                 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
636
637         skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
638                                    req->iv);
639
640         err = skcipher_walk_virt(&walk, &subreq, false);
641         if (err)
642                 return err;
643
644
645         kernel_neon_begin();
646         aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
647                         ctx->key1.key_dec, rounds, walk.nbytes,
648                         ctx->key2.key_enc, walk.iv, first);
649         kernel_neon_end();
650
651         return skcipher_walk_done(&walk, 0);
652 }
653
654 static struct skcipher_alg aes_algs[] = { {
655 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !defined(CONFIG_CRYPTO_AES_ARM64_BS)
656         .base = {
657                 .cra_name               = "__ecb(aes)",
658                 .cra_driver_name        = "__ecb-aes-" MODE,
659                 .cra_priority           = PRIO,
660                 .cra_flags              = CRYPTO_ALG_INTERNAL,
661                 .cra_blocksize          = AES_BLOCK_SIZE,
662                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
663                 .cra_module             = THIS_MODULE,
664         },
665         .min_keysize    = AES_MIN_KEY_SIZE,
666         .max_keysize    = AES_MAX_KEY_SIZE,
667         .setkey         = skcipher_aes_setkey,
668         .encrypt        = ecb_encrypt,
669         .decrypt        = ecb_decrypt,
670 }, {
671         .base = {
672                 .cra_name               = "__cbc(aes)",
673                 .cra_driver_name        = "__cbc-aes-" MODE,
674                 .cra_priority           = PRIO,
675                 .cra_flags              = CRYPTO_ALG_INTERNAL,
676                 .cra_blocksize          = AES_BLOCK_SIZE,
677                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
678                 .cra_module             = THIS_MODULE,
679         },
680         .min_keysize    = AES_MIN_KEY_SIZE,
681         .max_keysize    = AES_MAX_KEY_SIZE,
682         .ivsize         = AES_BLOCK_SIZE,
683         .setkey         = skcipher_aes_setkey,
684         .encrypt        = cbc_encrypt,
685         .decrypt        = cbc_decrypt,
686 }, {
687         .base = {
688                 .cra_name               = "__ctr(aes)",
689                 .cra_driver_name        = "__ctr-aes-" MODE,
690                 .cra_priority           = PRIO,
691                 .cra_flags              = CRYPTO_ALG_INTERNAL,
692                 .cra_blocksize          = 1,
693                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
694                 .cra_module             = THIS_MODULE,
695         },
696         .min_keysize    = AES_MIN_KEY_SIZE,
697         .max_keysize    = AES_MAX_KEY_SIZE,
698         .ivsize         = AES_BLOCK_SIZE,
699         .chunksize      = AES_BLOCK_SIZE,
700         .setkey         = skcipher_aes_setkey,
701         .encrypt        = ctr_encrypt,
702         .decrypt        = ctr_encrypt,
703 }, {
704         .base = {
705                 .cra_name               = "ctr(aes)",
706                 .cra_driver_name        = "ctr-aes-" MODE,
707                 .cra_priority           = PRIO - 1,
708                 .cra_blocksize          = 1,
709                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
710                 .cra_module             = THIS_MODULE,
711         },
712         .min_keysize    = AES_MIN_KEY_SIZE,
713         .max_keysize    = AES_MAX_KEY_SIZE,
714         .ivsize         = AES_BLOCK_SIZE,
715         .chunksize      = AES_BLOCK_SIZE,
716         .setkey         = skcipher_aes_setkey,
717         .encrypt        = ctr_encrypt_sync,
718         .decrypt        = ctr_encrypt_sync,
719 }, {
720         .base = {
721                 .cra_name               = "__xts(aes)",
722                 .cra_driver_name        = "__xts-aes-" MODE,
723                 .cra_priority           = PRIO,
724                 .cra_flags              = CRYPTO_ALG_INTERNAL,
725                 .cra_blocksize          = AES_BLOCK_SIZE,
726                 .cra_ctxsize            = sizeof(struct crypto_aes_xts_ctx),
727                 .cra_module             = THIS_MODULE,
728         },
729         .min_keysize    = 2 * AES_MIN_KEY_SIZE,
730         .max_keysize    = 2 * AES_MAX_KEY_SIZE,
731         .ivsize         = AES_BLOCK_SIZE,
732         .walksize       = 2 * AES_BLOCK_SIZE,
733         .setkey         = xts_set_key,
734         .encrypt        = xts_encrypt,
735         .decrypt        = xts_decrypt,
736 }, {
737 #endif
738         .base = {
739                 .cra_name               = "__cts(cbc(aes))",
740                 .cra_driver_name        = "__cts-cbc-aes-" MODE,
741                 .cra_priority           = PRIO,
742                 .cra_flags              = CRYPTO_ALG_INTERNAL,
743                 .cra_blocksize          = AES_BLOCK_SIZE,
744                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
745                 .cra_module             = THIS_MODULE,
746         },
747         .min_keysize    = AES_MIN_KEY_SIZE,
748         .max_keysize    = AES_MAX_KEY_SIZE,
749         .ivsize         = AES_BLOCK_SIZE,
750         .walksize       = 2 * AES_BLOCK_SIZE,
751         .setkey         = skcipher_aes_setkey,
752         .encrypt        = cts_cbc_encrypt,
753         .decrypt        = cts_cbc_decrypt,
754 }, {
755         .base = {
756                 .cra_name               = "__essiv(cbc(aes),sha256)",
757                 .cra_driver_name        = "__essiv-cbc-aes-sha256-" MODE,
758                 .cra_priority           = PRIO + 1,
759                 .cra_flags              = CRYPTO_ALG_INTERNAL,
760                 .cra_blocksize          = AES_BLOCK_SIZE,
761                 .cra_ctxsize            = sizeof(struct crypto_aes_essiv_cbc_ctx),
762                 .cra_module             = THIS_MODULE,
763         },
764         .min_keysize    = AES_MIN_KEY_SIZE,
765         .max_keysize    = AES_MAX_KEY_SIZE,
766         .ivsize         = AES_BLOCK_SIZE,
767         .setkey         = essiv_cbc_set_key,
768         .encrypt        = essiv_cbc_encrypt,
769         .decrypt        = essiv_cbc_decrypt,
770         .init           = essiv_cbc_init_tfm,
771         .exit           = essiv_cbc_exit_tfm,
772 } };
773
774 static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
775                          unsigned int key_len)
776 {
777         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
778
779         return aes_expandkey(&ctx->key, in_key, key_len);
780 }
781
782 static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
783 {
784         u64 a = be64_to_cpu(x->a);
785         u64 b = be64_to_cpu(x->b);
786
787         y->a = cpu_to_be64((a << 1) | (b >> 63));
788         y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
789 }
790
791 static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
792                        unsigned int key_len)
793 {
794         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
795         be128 *consts = (be128 *)ctx->consts;
796         int rounds = 6 + key_len / 4;
797         int err;
798
799         err = cbcmac_setkey(tfm, in_key, key_len);
800         if (err)
801                 return err;
802
803         /* encrypt the zero vector */
804         kernel_neon_begin();
805         aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
806                         rounds, 1);
807         kernel_neon_end();
808
809         cmac_gf128_mul_by_x(consts, consts);
810         cmac_gf128_mul_by_x(consts + 1, consts);
811
812         return 0;
813 }
814
815 static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
816                        unsigned int key_len)
817 {
818         static u8 const ks[3][AES_BLOCK_SIZE] = {
819                 { [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
820                 { [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
821                 { [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
822         };
823
824         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
825         int rounds = 6 + key_len / 4;
826         u8 key[AES_BLOCK_SIZE];
827         int err;
828
829         err = cbcmac_setkey(tfm, in_key, key_len);
830         if (err)
831                 return err;
832
833         kernel_neon_begin();
834         aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
835         aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
836         kernel_neon_end();
837
838         return cbcmac_setkey(tfm, key, sizeof(key));
839 }
840
841 static int mac_init(struct shash_desc *desc)
842 {
843         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
844
845         memset(ctx->dg, 0, AES_BLOCK_SIZE);
846         ctx->len = 0;
847
848         return 0;
849 }
850
851 static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
852                           u8 dg[], int enc_before, int enc_after)
853 {
854         int rounds = 6 + ctx->key_length / 4;
855
856         if (crypto_simd_usable()) {
857                 kernel_neon_begin();
858                 aes_mac_update(in, ctx->key_enc, rounds, blocks, dg, enc_before,
859                                enc_after);
860                 kernel_neon_end();
861         } else {
862                 if (enc_before)
863                         aes_encrypt(ctx, dg, dg);
864
865                 while (blocks--) {
866                         crypto_xor(dg, in, AES_BLOCK_SIZE);
867                         in += AES_BLOCK_SIZE;
868
869                         if (blocks || enc_after)
870                                 aes_encrypt(ctx, dg, dg);
871                 }
872         }
873 }
874
875 static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
876 {
877         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
878         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
879
880         while (len > 0) {
881                 unsigned int l;
882
883                 if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
884                     (ctx->len + len) > AES_BLOCK_SIZE) {
885
886                         int blocks = len / AES_BLOCK_SIZE;
887
888                         len %= AES_BLOCK_SIZE;
889
890                         mac_do_update(&tctx->key, p, blocks, ctx->dg,
891                                       (ctx->len != 0), (len != 0));
892
893                         p += blocks * AES_BLOCK_SIZE;
894
895                         if (!len) {
896                                 ctx->len = AES_BLOCK_SIZE;
897                                 break;
898                         }
899                         ctx->len = 0;
900                 }
901
902                 l = min(len, AES_BLOCK_SIZE - ctx->len);
903
904                 if (l <= AES_BLOCK_SIZE) {
905                         crypto_xor(ctx->dg + ctx->len, p, l);
906                         ctx->len += l;
907                         len -= l;
908                         p += l;
909                 }
910         }
911
912         return 0;
913 }
914
915 static int cbcmac_final(struct shash_desc *desc, u8 *out)
916 {
917         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
918         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
919
920         mac_do_update(&tctx->key, NULL, 0, ctx->dg, (ctx->len != 0), 0);
921
922         memcpy(out, ctx->dg, AES_BLOCK_SIZE);
923
924         return 0;
925 }
926
927 static int cmac_final(struct shash_desc *desc, u8 *out)
928 {
929         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
930         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
931         u8 *consts = tctx->consts;
932
933         if (ctx->len != AES_BLOCK_SIZE) {
934                 ctx->dg[ctx->len] ^= 0x80;
935                 consts += AES_BLOCK_SIZE;
936         }
937
938         mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
939
940         memcpy(out, ctx->dg, AES_BLOCK_SIZE);
941
942         return 0;
943 }
944
945 static struct shash_alg mac_algs[] = { {
946         .base.cra_name          = "cmac(aes)",
947         .base.cra_driver_name   = "cmac-aes-" MODE,
948         .base.cra_priority      = PRIO,
949         .base.cra_blocksize     = AES_BLOCK_SIZE,
950         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
951                                   2 * AES_BLOCK_SIZE,
952         .base.cra_module        = THIS_MODULE,
953
954         .digestsize             = AES_BLOCK_SIZE,
955         .init                   = mac_init,
956         .update                 = mac_update,
957         .final                  = cmac_final,
958         .setkey                 = cmac_setkey,
959         .descsize               = sizeof(struct mac_desc_ctx),
960 }, {
961         .base.cra_name          = "xcbc(aes)",
962         .base.cra_driver_name   = "xcbc-aes-" MODE,
963         .base.cra_priority      = PRIO,
964         .base.cra_blocksize     = AES_BLOCK_SIZE,
965         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
966                                   2 * AES_BLOCK_SIZE,
967         .base.cra_module        = THIS_MODULE,
968
969         .digestsize             = AES_BLOCK_SIZE,
970         .init                   = mac_init,
971         .update                 = mac_update,
972         .final                  = cmac_final,
973         .setkey                 = xcbc_setkey,
974         .descsize               = sizeof(struct mac_desc_ctx),
975 }, {
976         .base.cra_name          = "cbcmac(aes)",
977         .base.cra_driver_name   = "cbcmac-aes-" MODE,
978         .base.cra_priority      = PRIO,
979         .base.cra_blocksize     = 1,
980         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx),
981         .base.cra_module        = THIS_MODULE,
982
983         .digestsize             = AES_BLOCK_SIZE,
984         .init                   = mac_init,
985         .update                 = mac_update,
986         .final                  = cbcmac_final,
987         .setkey                 = cbcmac_setkey,
988         .descsize               = sizeof(struct mac_desc_ctx),
989 } };
990
991 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
992
993 static void aes_exit(void)
994 {
995         int i;
996
997         for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
998                 if (aes_simd_algs[i])
999                         simd_skcipher_free(aes_simd_algs[i]);
1000
1001         crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1002         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1003 }
1004
1005 static int __init aes_init(void)
1006 {
1007         struct simd_skcipher_alg *simd;
1008         const char *basename;
1009         const char *algname;
1010         const char *drvname;
1011         int err;
1012         int i;
1013
1014         err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1015         if (err)
1016                 return err;
1017
1018         err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1019         if (err)
1020                 goto unregister_ciphers;
1021
1022         for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
1023                 if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
1024                         continue;
1025
1026                 algname = aes_algs[i].base.cra_name + 2;
1027                 drvname = aes_algs[i].base.cra_driver_name + 2;
1028                 basename = aes_algs[i].base.cra_driver_name;
1029                 simd = simd_skcipher_create_compat(algname, drvname, basename);
1030                 err = PTR_ERR(simd);
1031                 if (IS_ERR(simd))
1032                         goto unregister_simds;
1033
1034                 aes_simd_algs[i] = simd;
1035         }
1036
1037         return 0;
1038
1039 unregister_simds:
1040         aes_exit();
1041         return err;
1042 unregister_ciphers:
1043         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1044         return err;
1045 }
1046
1047 #ifdef USE_V8_CRYPTO_EXTENSIONS
1048 module_cpu_feature_match(AES, aes_init);
1049 #else
1050 module_init(aes_init);
1051 EXPORT_SYMBOL(neon_aes_ecb_encrypt);
1052 EXPORT_SYMBOL(neon_aes_cbc_encrypt);
1053 EXPORT_SYMBOL(neon_aes_xts_encrypt);
1054 EXPORT_SYMBOL(neon_aes_xts_decrypt);
1055 #endif
1056 module_exit(aes_exit);