Merge tag 'nfsd-5.13-1' of git://git.kernel.org/pub/scm/linux/kernel/git/cel/linux
[linux-2.6-microblaze.git] / arch / arm64 / crypto / aes-glue.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
4  *
5  * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7
8 #include <asm/neon.h>
9 #include <asm/hwcap.h>
10 #include <asm/simd.h>
11 #include <crypto/aes.h>
12 #include <crypto/ctr.h>
13 #include <crypto/sha2.h>
14 #include <crypto/internal/hash.h>
15 #include <crypto/internal/simd.h>
16 #include <crypto/internal/skcipher.h>
17 #include <crypto/scatterwalk.h>
18 #include <linux/module.h>
19 #include <linux/cpufeature.h>
20 #include <crypto/xts.h>
21
22 #include "aes-ce-setkey.h"
23
24 #ifdef USE_V8_CRYPTO_EXTENSIONS
25 #define MODE                    "ce"
26 #define PRIO                    300
27 #define STRIDE                  5
28 #define aes_expandkey           ce_aes_expandkey
29 #define aes_ecb_encrypt         ce_aes_ecb_encrypt
30 #define aes_ecb_decrypt         ce_aes_ecb_decrypt
31 #define aes_cbc_encrypt         ce_aes_cbc_encrypt
32 #define aes_cbc_decrypt         ce_aes_cbc_decrypt
33 #define aes_cbc_cts_encrypt     ce_aes_cbc_cts_encrypt
34 #define aes_cbc_cts_decrypt     ce_aes_cbc_cts_decrypt
35 #define aes_essiv_cbc_encrypt   ce_aes_essiv_cbc_encrypt
36 #define aes_essiv_cbc_decrypt   ce_aes_essiv_cbc_decrypt
37 #define aes_ctr_encrypt         ce_aes_ctr_encrypt
38 #define aes_xts_encrypt         ce_aes_xts_encrypt
39 #define aes_xts_decrypt         ce_aes_xts_decrypt
40 #define aes_mac_update          ce_aes_mac_update
41 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
42 #else
43 #define MODE                    "neon"
44 #define PRIO                    200
45 #define STRIDE                  4
46 #define aes_ecb_encrypt         neon_aes_ecb_encrypt
47 #define aes_ecb_decrypt         neon_aes_ecb_decrypt
48 #define aes_cbc_encrypt         neon_aes_cbc_encrypt
49 #define aes_cbc_decrypt         neon_aes_cbc_decrypt
50 #define aes_cbc_cts_encrypt     neon_aes_cbc_cts_encrypt
51 #define aes_cbc_cts_decrypt     neon_aes_cbc_cts_decrypt
52 #define aes_essiv_cbc_encrypt   neon_aes_essiv_cbc_encrypt
53 #define aes_essiv_cbc_decrypt   neon_aes_essiv_cbc_decrypt
54 #define aes_ctr_encrypt         neon_aes_ctr_encrypt
55 #define aes_xts_encrypt         neon_aes_xts_encrypt
56 #define aes_xts_decrypt         neon_aes_xts_decrypt
57 #define aes_mac_update          neon_aes_mac_update
58 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON");
59 #endif
60 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
61 MODULE_ALIAS_CRYPTO("ecb(aes)");
62 MODULE_ALIAS_CRYPTO("cbc(aes)");
63 MODULE_ALIAS_CRYPTO("ctr(aes)");
64 MODULE_ALIAS_CRYPTO("xts(aes)");
65 #endif
66 MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
67 MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
68 MODULE_ALIAS_CRYPTO("cmac(aes)");
69 MODULE_ALIAS_CRYPTO("xcbc(aes)");
70 MODULE_ALIAS_CRYPTO("cbcmac(aes)");
71
72 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
73 MODULE_LICENSE("GPL v2");
74
75 /* defined in aes-modes.S */
76 asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
77                                 int rounds, int blocks);
78 asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
79                                 int rounds, int blocks);
80
81 asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
82                                 int rounds, int blocks, u8 iv[]);
83 asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
84                                 int rounds, int blocks, u8 iv[]);
85
86 asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
87                                 int rounds, int bytes, u8 const iv[]);
88 asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
89                                 int rounds, int bytes, u8 const iv[]);
90
91 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
92                                 int rounds, int bytes, u8 ctr[], u8 finalbuf[]);
93
94 asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
95                                 int rounds, int bytes, u32 const rk2[], u8 iv[],
96                                 int first);
97 asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
98                                 int rounds, int bytes, u32 const rk2[], u8 iv[],
99                                 int first);
100
101 asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
102                                       int rounds, int blocks, u8 iv[],
103                                       u32 const rk2[]);
104 asmlinkage void aes_essiv_cbc_decrypt(u8 out[], u8 const in[], u32 const rk1[],
105                                       int rounds, int blocks, u8 iv[],
106                                       u32 const rk2[]);
107
108 asmlinkage int aes_mac_update(u8 const in[], u32 const rk[], int rounds,
109                               int blocks, u8 dg[], int enc_before,
110                               int enc_after);
111
112 struct crypto_aes_xts_ctx {
113         struct crypto_aes_ctx key1;
114         struct crypto_aes_ctx __aligned(8) key2;
115 };
116
117 struct crypto_aes_essiv_cbc_ctx {
118         struct crypto_aes_ctx key1;
119         struct crypto_aes_ctx __aligned(8) key2;
120         struct crypto_shash *hash;
121 };
122
123 struct mac_tfm_ctx {
124         struct crypto_aes_ctx key;
125         u8 __aligned(8) consts[];
126 };
127
128 struct mac_desc_ctx {
129         unsigned int len;
130         u8 dg[AES_BLOCK_SIZE];
131 };
132
133 static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
134                                unsigned int key_len)
135 {
136         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
137
138         return aes_expandkey(ctx, in_key, key_len);
139 }
140
141 static int __maybe_unused xts_set_key(struct crypto_skcipher *tfm,
142                                       const u8 *in_key, unsigned int key_len)
143 {
144         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
145         int ret;
146
147         ret = xts_verify_key(tfm, in_key, key_len);
148         if (ret)
149                 return ret;
150
151         ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
152         if (!ret)
153                 ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
154                                     key_len / 2);
155         return ret;
156 }
157
158 static int __maybe_unused essiv_cbc_set_key(struct crypto_skcipher *tfm,
159                                             const u8 *in_key,
160                                             unsigned int key_len)
161 {
162         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
163         u8 digest[SHA256_DIGEST_SIZE];
164         int ret;
165
166         ret = aes_expandkey(&ctx->key1, in_key, key_len);
167         if (ret)
168                 return ret;
169
170         crypto_shash_tfm_digest(ctx->hash, in_key, key_len, digest);
171
172         return aes_expandkey(&ctx->key2, digest, sizeof(digest));
173 }
174
175 static int __maybe_unused ecb_encrypt(struct skcipher_request *req)
176 {
177         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
178         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
179         int err, rounds = 6 + ctx->key_length / 4;
180         struct skcipher_walk walk;
181         unsigned int blocks;
182
183         err = skcipher_walk_virt(&walk, req, false);
184
185         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
186                 kernel_neon_begin();
187                 aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
188                                 ctx->key_enc, rounds, blocks);
189                 kernel_neon_end();
190                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
191         }
192         return err;
193 }
194
195 static int __maybe_unused ecb_decrypt(struct skcipher_request *req)
196 {
197         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
198         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
199         int err, rounds = 6 + ctx->key_length / 4;
200         struct skcipher_walk walk;
201         unsigned int blocks;
202
203         err = skcipher_walk_virt(&walk, req, false);
204
205         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
206                 kernel_neon_begin();
207                 aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
208                                 ctx->key_dec, rounds, blocks);
209                 kernel_neon_end();
210                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
211         }
212         return err;
213 }
214
215 static int cbc_encrypt_walk(struct skcipher_request *req,
216                             struct skcipher_walk *walk)
217 {
218         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
219         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
220         int err = 0, rounds = 6 + ctx->key_length / 4;
221         unsigned int blocks;
222
223         while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
224                 kernel_neon_begin();
225                 aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
226                                 ctx->key_enc, rounds, blocks, walk->iv);
227                 kernel_neon_end();
228                 err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
229         }
230         return err;
231 }
232
233 static int __maybe_unused cbc_encrypt(struct skcipher_request *req)
234 {
235         struct skcipher_walk walk;
236         int err;
237
238         err = skcipher_walk_virt(&walk, req, false);
239         if (err)
240                 return err;
241         return cbc_encrypt_walk(req, &walk);
242 }
243
244 static int cbc_decrypt_walk(struct skcipher_request *req,
245                             struct skcipher_walk *walk)
246 {
247         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
248         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
249         int err = 0, rounds = 6 + ctx->key_length / 4;
250         unsigned int blocks;
251
252         while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
253                 kernel_neon_begin();
254                 aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
255                                 ctx->key_dec, rounds, blocks, walk->iv);
256                 kernel_neon_end();
257                 err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
258         }
259         return err;
260 }
261
262 static int __maybe_unused cbc_decrypt(struct skcipher_request *req)
263 {
264         struct skcipher_walk walk;
265         int err;
266
267         err = skcipher_walk_virt(&walk, req, false);
268         if (err)
269                 return err;
270         return cbc_decrypt_walk(req, &walk);
271 }
272
273 static int cts_cbc_encrypt(struct skcipher_request *req)
274 {
275         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
276         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
277         int err, rounds = 6 + ctx->key_length / 4;
278         int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
279         struct scatterlist *src = req->src, *dst = req->dst;
280         struct scatterlist sg_src[2], sg_dst[2];
281         struct skcipher_request subreq;
282         struct skcipher_walk walk;
283
284         skcipher_request_set_tfm(&subreq, tfm);
285         skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
286                                       NULL, NULL);
287
288         if (req->cryptlen <= AES_BLOCK_SIZE) {
289                 if (req->cryptlen < AES_BLOCK_SIZE)
290                         return -EINVAL;
291                 cbc_blocks = 1;
292         }
293
294         if (cbc_blocks > 0) {
295                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
296                                            cbc_blocks * AES_BLOCK_SIZE,
297                                            req->iv);
298
299                 err = skcipher_walk_virt(&walk, &subreq, false) ?:
300                       cbc_encrypt_walk(&subreq, &walk);
301                 if (err)
302                         return err;
303
304                 if (req->cryptlen == AES_BLOCK_SIZE)
305                         return 0;
306
307                 dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
308                 if (req->dst != req->src)
309                         dst = scatterwalk_ffwd(sg_dst, req->dst,
310                                                subreq.cryptlen);
311         }
312
313         /* handle ciphertext stealing */
314         skcipher_request_set_crypt(&subreq, src, dst,
315                                    req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
316                                    req->iv);
317
318         err = skcipher_walk_virt(&walk, &subreq, false);
319         if (err)
320                 return err;
321
322         kernel_neon_begin();
323         aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
324                             ctx->key_enc, rounds, walk.nbytes, walk.iv);
325         kernel_neon_end();
326
327         return skcipher_walk_done(&walk, 0);
328 }
329
330 static int cts_cbc_decrypt(struct skcipher_request *req)
331 {
332         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
333         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
334         int err, rounds = 6 + ctx->key_length / 4;
335         int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
336         struct scatterlist *src = req->src, *dst = req->dst;
337         struct scatterlist sg_src[2], sg_dst[2];
338         struct skcipher_request subreq;
339         struct skcipher_walk walk;
340
341         skcipher_request_set_tfm(&subreq, tfm);
342         skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
343                                       NULL, NULL);
344
345         if (req->cryptlen <= AES_BLOCK_SIZE) {
346                 if (req->cryptlen < AES_BLOCK_SIZE)
347                         return -EINVAL;
348                 cbc_blocks = 1;
349         }
350
351         if (cbc_blocks > 0) {
352                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
353                                            cbc_blocks * AES_BLOCK_SIZE,
354                                            req->iv);
355
356                 err = skcipher_walk_virt(&walk, &subreq, false) ?:
357                       cbc_decrypt_walk(&subreq, &walk);
358                 if (err)
359                         return err;
360
361                 if (req->cryptlen == AES_BLOCK_SIZE)
362                         return 0;
363
364                 dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
365                 if (req->dst != req->src)
366                         dst = scatterwalk_ffwd(sg_dst, req->dst,
367                                                subreq.cryptlen);
368         }
369
370         /* handle ciphertext stealing */
371         skcipher_request_set_crypt(&subreq, src, dst,
372                                    req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
373                                    req->iv);
374
375         err = skcipher_walk_virt(&walk, &subreq, false);
376         if (err)
377                 return err;
378
379         kernel_neon_begin();
380         aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
381                             ctx->key_dec, rounds, walk.nbytes, walk.iv);
382         kernel_neon_end();
383
384         return skcipher_walk_done(&walk, 0);
385 }
386
387 static int __maybe_unused essiv_cbc_init_tfm(struct crypto_skcipher *tfm)
388 {
389         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
390
391         ctx->hash = crypto_alloc_shash("sha256", 0, 0);
392
393         return PTR_ERR_OR_ZERO(ctx->hash);
394 }
395
396 static void __maybe_unused essiv_cbc_exit_tfm(struct crypto_skcipher *tfm)
397 {
398         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
399
400         crypto_free_shash(ctx->hash);
401 }
402
403 static int __maybe_unused essiv_cbc_encrypt(struct skcipher_request *req)
404 {
405         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
406         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
407         int err, rounds = 6 + ctx->key1.key_length / 4;
408         struct skcipher_walk walk;
409         unsigned int blocks;
410
411         err = skcipher_walk_virt(&walk, req, false);
412
413         blocks = walk.nbytes / AES_BLOCK_SIZE;
414         if (blocks) {
415                 kernel_neon_begin();
416                 aes_essiv_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
417                                       ctx->key1.key_enc, rounds, blocks,
418                                       req->iv, ctx->key2.key_enc);
419                 kernel_neon_end();
420                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
421         }
422         return err ?: cbc_encrypt_walk(req, &walk);
423 }
424
425 static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
426 {
427         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
428         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
429         int err, rounds = 6 + ctx->key1.key_length / 4;
430         struct skcipher_walk walk;
431         unsigned int blocks;
432
433         err = skcipher_walk_virt(&walk, req, false);
434
435         blocks = walk.nbytes / AES_BLOCK_SIZE;
436         if (blocks) {
437                 kernel_neon_begin();
438                 aes_essiv_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
439                                       ctx->key1.key_dec, rounds, blocks,
440                                       req->iv, ctx->key2.key_enc);
441                 kernel_neon_end();
442                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
443         }
444         return err ?: cbc_decrypt_walk(req, &walk);
445 }
446
447 static int ctr_encrypt(struct skcipher_request *req)
448 {
449         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
450         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
451         int err, rounds = 6 + ctx->key_length / 4;
452         struct skcipher_walk walk;
453
454         err = skcipher_walk_virt(&walk, req, false);
455
456         while (walk.nbytes > 0) {
457                 const u8 *src = walk.src.virt.addr;
458                 unsigned int nbytes = walk.nbytes;
459                 u8 *dst = walk.dst.virt.addr;
460                 u8 buf[AES_BLOCK_SIZE];
461                 unsigned int tail;
462
463                 if (unlikely(nbytes < AES_BLOCK_SIZE))
464                         src = memcpy(buf, src, nbytes);
465                 else if (nbytes < walk.total)
466                         nbytes &= ~(AES_BLOCK_SIZE - 1);
467
468                 kernel_neon_begin();
469                 aes_ctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
470                                 walk.iv, buf);
471                 kernel_neon_end();
472
473                 tail = nbytes % (STRIDE * AES_BLOCK_SIZE);
474                 if (tail > 0 && tail < AES_BLOCK_SIZE)
475                         /*
476                          * The final partial block could not be returned using
477                          * an overlapping store, so it was passed via buf[]
478                          * instead.
479                          */
480                         memcpy(dst + nbytes - tail, buf, tail);
481
482                 err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
483         }
484
485         return err;
486 }
487
488 static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
489 {
490         const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
491         unsigned long flags;
492
493         /*
494          * Temporarily disable interrupts to avoid races where
495          * cachelines are evicted when the CPU is interrupted
496          * to do something else.
497          */
498         local_irq_save(flags);
499         aes_encrypt(ctx, dst, src);
500         local_irq_restore(flags);
501 }
502
503 static int __maybe_unused ctr_encrypt_sync(struct skcipher_request *req)
504 {
505         if (!crypto_simd_usable())
506                 return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
507
508         return ctr_encrypt(req);
509 }
510
511 static int __maybe_unused xts_encrypt(struct skcipher_request *req)
512 {
513         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
514         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
515         int err, first, rounds = 6 + ctx->key1.key_length / 4;
516         int tail = req->cryptlen % AES_BLOCK_SIZE;
517         struct scatterlist sg_src[2], sg_dst[2];
518         struct skcipher_request subreq;
519         struct scatterlist *src, *dst;
520         struct skcipher_walk walk;
521
522         if (req->cryptlen < AES_BLOCK_SIZE)
523                 return -EINVAL;
524
525         err = skcipher_walk_virt(&walk, req, false);
526
527         if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
528                 int xts_blocks = DIV_ROUND_UP(req->cryptlen,
529                                               AES_BLOCK_SIZE) - 2;
530
531                 skcipher_walk_abort(&walk);
532
533                 skcipher_request_set_tfm(&subreq, tfm);
534                 skcipher_request_set_callback(&subreq,
535                                               skcipher_request_flags(req),
536                                               NULL, NULL);
537                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
538                                            xts_blocks * AES_BLOCK_SIZE,
539                                            req->iv);
540                 req = &subreq;
541                 err = skcipher_walk_virt(&walk, req, false);
542         } else {
543                 tail = 0;
544         }
545
546         for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
547                 int nbytes = walk.nbytes;
548
549                 if (walk.nbytes < walk.total)
550                         nbytes &= ~(AES_BLOCK_SIZE - 1);
551
552                 kernel_neon_begin();
553                 aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
554                                 ctx->key1.key_enc, rounds, nbytes,
555                                 ctx->key2.key_enc, walk.iv, first);
556                 kernel_neon_end();
557                 err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
558         }
559
560         if (err || likely(!tail))
561                 return err;
562
563         dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
564         if (req->dst != req->src)
565                 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
566
567         skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
568                                    req->iv);
569
570         err = skcipher_walk_virt(&walk, &subreq, false);
571         if (err)
572                 return err;
573
574         kernel_neon_begin();
575         aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
576                         ctx->key1.key_enc, rounds, walk.nbytes,
577                         ctx->key2.key_enc, walk.iv, first);
578         kernel_neon_end();
579
580         return skcipher_walk_done(&walk, 0);
581 }
582
583 static int __maybe_unused xts_decrypt(struct skcipher_request *req)
584 {
585         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
586         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
587         int err, first, rounds = 6 + ctx->key1.key_length / 4;
588         int tail = req->cryptlen % AES_BLOCK_SIZE;
589         struct scatterlist sg_src[2], sg_dst[2];
590         struct skcipher_request subreq;
591         struct scatterlist *src, *dst;
592         struct skcipher_walk walk;
593
594         if (req->cryptlen < AES_BLOCK_SIZE)
595                 return -EINVAL;
596
597         err = skcipher_walk_virt(&walk, req, false);
598
599         if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
600                 int xts_blocks = DIV_ROUND_UP(req->cryptlen,
601                                               AES_BLOCK_SIZE) - 2;
602
603                 skcipher_walk_abort(&walk);
604
605                 skcipher_request_set_tfm(&subreq, tfm);
606                 skcipher_request_set_callback(&subreq,
607                                               skcipher_request_flags(req),
608                                               NULL, NULL);
609                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
610                                            xts_blocks * AES_BLOCK_SIZE,
611                                            req->iv);
612                 req = &subreq;
613                 err = skcipher_walk_virt(&walk, req, false);
614         } else {
615                 tail = 0;
616         }
617
618         for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
619                 int nbytes = walk.nbytes;
620
621                 if (walk.nbytes < walk.total)
622                         nbytes &= ~(AES_BLOCK_SIZE - 1);
623
624                 kernel_neon_begin();
625                 aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
626                                 ctx->key1.key_dec, rounds, nbytes,
627                                 ctx->key2.key_enc, walk.iv, first);
628                 kernel_neon_end();
629                 err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
630         }
631
632         if (err || likely(!tail))
633                 return err;
634
635         dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
636         if (req->dst != req->src)
637                 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
638
639         skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
640                                    req->iv);
641
642         err = skcipher_walk_virt(&walk, &subreq, false);
643         if (err)
644                 return err;
645
646
647         kernel_neon_begin();
648         aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
649                         ctx->key1.key_dec, rounds, walk.nbytes,
650                         ctx->key2.key_enc, walk.iv, first);
651         kernel_neon_end();
652
653         return skcipher_walk_done(&walk, 0);
654 }
655
656 static struct skcipher_alg aes_algs[] = { {
657 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
658         .base = {
659                 .cra_name               = "__ecb(aes)",
660                 .cra_driver_name        = "__ecb-aes-" MODE,
661                 .cra_priority           = PRIO,
662                 .cra_flags              = CRYPTO_ALG_INTERNAL,
663                 .cra_blocksize          = AES_BLOCK_SIZE,
664                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
665                 .cra_module             = THIS_MODULE,
666         },
667         .min_keysize    = AES_MIN_KEY_SIZE,
668         .max_keysize    = AES_MAX_KEY_SIZE,
669         .setkey         = skcipher_aes_setkey,
670         .encrypt        = ecb_encrypt,
671         .decrypt        = ecb_decrypt,
672 }, {
673         .base = {
674                 .cra_name               = "__cbc(aes)",
675                 .cra_driver_name        = "__cbc-aes-" MODE,
676                 .cra_priority           = PRIO,
677                 .cra_flags              = CRYPTO_ALG_INTERNAL,
678                 .cra_blocksize          = AES_BLOCK_SIZE,
679                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
680                 .cra_module             = THIS_MODULE,
681         },
682         .min_keysize    = AES_MIN_KEY_SIZE,
683         .max_keysize    = AES_MAX_KEY_SIZE,
684         .ivsize         = AES_BLOCK_SIZE,
685         .setkey         = skcipher_aes_setkey,
686         .encrypt        = cbc_encrypt,
687         .decrypt        = cbc_decrypt,
688 }, {
689         .base = {
690                 .cra_name               = "__ctr(aes)",
691                 .cra_driver_name        = "__ctr-aes-" MODE,
692                 .cra_priority           = PRIO,
693                 .cra_flags              = CRYPTO_ALG_INTERNAL,
694                 .cra_blocksize          = 1,
695                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
696                 .cra_module             = THIS_MODULE,
697         },
698         .min_keysize    = AES_MIN_KEY_SIZE,
699         .max_keysize    = AES_MAX_KEY_SIZE,
700         .ivsize         = AES_BLOCK_SIZE,
701         .chunksize      = AES_BLOCK_SIZE,
702         .setkey         = skcipher_aes_setkey,
703         .encrypt        = ctr_encrypt,
704         .decrypt        = ctr_encrypt,
705 }, {
706         .base = {
707                 .cra_name               = "ctr(aes)",
708                 .cra_driver_name        = "ctr-aes-" MODE,
709                 .cra_priority           = PRIO - 1,
710                 .cra_blocksize          = 1,
711                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
712                 .cra_module             = THIS_MODULE,
713         },
714         .min_keysize    = AES_MIN_KEY_SIZE,
715         .max_keysize    = AES_MAX_KEY_SIZE,
716         .ivsize         = AES_BLOCK_SIZE,
717         .chunksize      = AES_BLOCK_SIZE,
718         .setkey         = skcipher_aes_setkey,
719         .encrypt        = ctr_encrypt_sync,
720         .decrypt        = ctr_encrypt_sync,
721 }, {
722         .base = {
723                 .cra_name               = "__xts(aes)",
724                 .cra_driver_name        = "__xts-aes-" MODE,
725                 .cra_priority           = PRIO,
726                 .cra_flags              = CRYPTO_ALG_INTERNAL,
727                 .cra_blocksize          = AES_BLOCK_SIZE,
728                 .cra_ctxsize            = sizeof(struct crypto_aes_xts_ctx),
729                 .cra_module             = THIS_MODULE,
730         },
731         .min_keysize    = 2 * AES_MIN_KEY_SIZE,
732         .max_keysize    = 2 * AES_MAX_KEY_SIZE,
733         .ivsize         = AES_BLOCK_SIZE,
734         .walksize       = 2 * AES_BLOCK_SIZE,
735         .setkey         = xts_set_key,
736         .encrypt        = xts_encrypt,
737         .decrypt        = xts_decrypt,
738 }, {
739 #endif
740         .base = {
741                 .cra_name               = "__cts(cbc(aes))",
742                 .cra_driver_name        = "__cts-cbc-aes-" MODE,
743                 .cra_priority           = PRIO,
744                 .cra_flags              = CRYPTO_ALG_INTERNAL,
745                 .cra_blocksize          = AES_BLOCK_SIZE,
746                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
747                 .cra_module             = THIS_MODULE,
748         },
749         .min_keysize    = AES_MIN_KEY_SIZE,
750         .max_keysize    = AES_MAX_KEY_SIZE,
751         .ivsize         = AES_BLOCK_SIZE,
752         .walksize       = 2 * AES_BLOCK_SIZE,
753         .setkey         = skcipher_aes_setkey,
754         .encrypt        = cts_cbc_encrypt,
755         .decrypt        = cts_cbc_decrypt,
756 }, {
757         .base = {
758                 .cra_name               = "__essiv(cbc(aes),sha256)",
759                 .cra_driver_name        = "__essiv-cbc-aes-sha256-" MODE,
760                 .cra_priority           = PRIO + 1,
761                 .cra_flags              = CRYPTO_ALG_INTERNAL,
762                 .cra_blocksize          = AES_BLOCK_SIZE,
763                 .cra_ctxsize            = sizeof(struct crypto_aes_essiv_cbc_ctx),
764                 .cra_module             = THIS_MODULE,
765         },
766         .min_keysize    = AES_MIN_KEY_SIZE,
767         .max_keysize    = AES_MAX_KEY_SIZE,
768         .ivsize         = AES_BLOCK_SIZE,
769         .setkey         = essiv_cbc_set_key,
770         .encrypt        = essiv_cbc_encrypt,
771         .decrypt        = essiv_cbc_decrypt,
772         .init           = essiv_cbc_init_tfm,
773         .exit           = essiv_cbc_exit_tfm,
774 } };
775
776 static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
777                          unsigned int key_len)
778 {
779         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
780
781         return aes_expandkey(&ctx->key, in_key, key_len);
782 }
783
784 static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
785 {
786         u64 a = be64_to_cpu(x->a);
787         u64 b = be64_to_cpu(x->b);
788
789         y->a = cpu_to_be64((a << 1) | (b >> 63));
790         y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
791 }
792
793 static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
794                        unsigned int key_len)
795 {
796         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
797         be128 *consts = (be128 *)ctx->consts;
798         int rounds = 6 + key_len / 4;
799         int err;
800
801         err = cbcmac_setkey(tfm, in_key, key_len);
802         if (err)
803                 return err;
804
805         /* encrypt the zero vector */
806         kernel_neon_begin();
807         aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
808                         rounds, 1);
809         kernel_neon_end();
810
811         cmac_gf128_mul_by_x(consts, consts);
812         cmac_gf128_mul_by_x(consts + 1, consts);
813
814         return 0;
815 }
816
817 static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
818                        unsigned int key_len)
819 {
820         static u8 const ks[3][AES_BLOCK_SIZE] = {
821                 { [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
822                 { [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
823                 { [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
824         };
825
826         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
827         int rounds = 6 + key_len / 4;
828         u8 key[AES_BLOCK_SIZE];
829         int err;
830
831         err = cbcmac_setkey(tfm, in_key, key_len);
832         if (err)
833                 return err;
834
835         kernel_neon_begin();
836         aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
837         aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
838         kernel_neon_end();
839
840         return cbcmac_setkey(tfm, key, sizeof(key));
841 }
842
843 static int mac_init(struct shash_desc *desc)
844 {
845         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
846
847         memset(ctx->dg, 0, AES_BLOCK_SIZE);
848         ctx->len = 0;
849
850         return 0;
851 }
852
853 static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
854                           u8 dg[], int enc_before, int enc_after)
855 {
856         int rounds = 6 + ctx->key_length / 4;
857
858         if (crypto_simd_usable()) {
859                 int rem;
860
861                 do {
862                         kernel_neon_begin();
863                         rem = aes_mac_update(in, ctx->key_enc, rounds, blocks,
864                                              dg, enc_before, enc_after);
865                         kernel_neon_end();
866                         in += (blocks - rem) * AES_BLOCK_SIZE;
867                         blocks = rem;
868                         enc_before = 0;
869                 } while (blocks);
870         } else {
871                 if (enc_before)
872                         aes_encrypt(ctx, dg, dg);
873
874                 while (blocks--) {
875                         crypto_xor(dg, in, AES_BLOCK_SIZE);
876                         in += AES_BLOCK_SIZE;
877
878                         if (blocks || enc_after)
879                                 aes_encrypt(ctx, dg, dg);
880                 }
881         }
882 }
883
884 static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
885 {
886         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
887         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
888
889         while (len > 0) {
890                 unsigned int l;
891
892                 if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
893                     (ctx->len + len) > AES_BLOCK_SIZE) {
894
895                         int blocks = len / AES_BLOCK_SIZE;
896
897                         len %= AES_BLOCK_SIZE;
898
899                         mac_do_update(&tctx->key, p, blocks, ctx->dg,
900                                       (ctx->len != 0), (len != 0));
901
902                         p += blocks * AES_BLOCK_SIZE;
903
904                         if (!len) {
905                                 ctx->len = AES_BLOCK_SIZE;
906                                 break;
907                         }
908                         ctx->len = 0;
909                 }
910
911                 l = min(len, AES_BLOCK_SIZE - ctx->len);
912
913                 if (l <= AES_BLOCK_SIZE) {
914                         crypto_xor(ctx->dg + ctx->len, p, l);
915                         ctx->len += l;
916                         len -= l;
917                         p += l;
918                 }
919         }
920
921         return 0;
922 }
923
924 static int cbcmac_final(struct shash_desc *desc, u8 *out)
925 {
926         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
927         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
928
929         mac_do_update(&tctx->key, NULL, 0, ctx->dg, (ctx->len != 0), 0);
930
931         memcpy(out, ctx->dg, AES_BLOCK_SIZE);
932
933         return 0;
934 }
935
936 static int cmac_final(struct shash_desc *desc, u8 *out)
937 {
938         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
939         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
940         u8 *consts = tctx->consts;
941
942         if (ctx->len != AES_BLOCK_SIZE) {
943                 ctx->dg[ctx->len] ^= 0x80;
944                 consts += AES_BLOCK_SIZE;
945         }
946
947         mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
948
949         memcpy(out, ctx->dg, AES_BLOCK_SIZE);
950
951         return 0;
952 }
953
954 static struct shash_alg mac_algs[] = { {
955         .base.cra_name          = "cmac(aes)",
956         .base.cra_driver_name   = "cmac-aes-" MODE,
957         .base.cra_priority      = PRIO,
958         .base.cra_blocksize     = AES_BLOCK_SIZE,
959         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
960                                   2 * AES_BLOCK_SIZE,
961         .base.cra_module        = THIS_MODULE,
962
963         .digestsize             = AES_BLOCK_SIZE,
964         .init                   = mac_init,
965         .update                 = mac_update,
966         .final                  = cmac_final,
967         .setkey                 = cmac_setkey,
968         .descsize               = sizeof(struct mac_desc_ctx),
969 }, {
970         .base.cra_name          = "xcbc(aes)",
971         .base.cra_driver_name   = "xcbc-aes-" MODE,
972         .base.cra_priority      = PRIO,
973         .base.cra_blocksize     = AES_BLOCK_SIZE,
974         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
975                                   2 * AES_BLOCK_SIZE,
976         .base.cra_module        = THIS_MODULE,
977
978         .digestsize             = AES_BLOCK_SIZE,
979         .init                   = mac_init,
980         .update                 = mac_update,
981         .final                  = cmac_final,
982         .setkey                 = xcbc_setkey,
983         .descsize               = sizeof(struct mac_desc_ctx),
984 }, {
985         .base.cra_name          = "cbcmac(aes)",
986         .base.cra_driver_name   = "cbcmac-aes-" MODE,
987         .base.cra_priority      = PRIO,
988         .base.cra_blocksize     = 1,
989         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx),
990         .base.cra_module        = THIS_MODULE,
991
992         .digestsize             = AES_BLOCK_SIZE,
993         .init                   = mac_init,
994         .update                 = mac_update,
995         .final                  = cbcmac_final,
996         .setkey                 = cbcmac_setkey,
997         .descsize               = sizeof(struct mac_desc_ctx),
998 } };
999
1000 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
1001
1002 static void aes_exit(void)
1003 {
1004         int i;
1005
1006         for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
1007                 if (aes_simd_algs[i])
1008                         simd_skcipher_free(aes_simd_algs[i]);
1009
1010         crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1011         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1012 }
1013
1014 static int __init aes_init(void)
1015 {
1016         struct simd_skcipher_alg *simd;
1017         const char *basename;
1018         const char *algname;
1019         const char *drvname;
1020         int err;
1021         int i;
1022
1023         err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1024         if (err)
1025                 return err;
1026
1027         err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1028         if (err)
1029                 goto unregister_ciphers;
1030
1031         for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
1032                 if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
1033                         continue;
1034
1035                 algname = aes_algs[i].base.cra_name + 2;
1036                 drvname = aes_algs[i].base.cra_driver_name + 2;
1037                 basename = aes_algs[i].base.cra_driver_name;
1038                 simd = simd_skcipher_create_compat(algname, drvname, basename);
1039                 err = PTR_ERR(simd);
1040                 if (IS_ERR(simd))
1041                         goto unregister_simds;
1042
1043                 aes_simd_algs[i] = simd;
1044         }
1045
1046         return 0;
1047
1048 unregister_simds:
1049         aes_exit();
1050         return err;
1051 unregister_ciphers:
1052         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1053         return err;
1054 }
1055
1056 #ifdef USE_V8_CRYPTO_EXTENSIONS
1057 module_cpu_feature_match(AES, aes_init);
1058 #else
1059 module_init(aes_init);
1060 EXPORT_SYMBOL(neon_aes_ecb_encrypt);
1061 EXPORT_SYMBOL(neon_aes_cbc_encrypt);
1062 EXPORT_SYMBOL(neon_aes_xts_encrypt);
1063 EXPORT_SYMBOL(neon_aes_xts_decrypt);
1064 #endif
1065 module_exit(aes_exit);