Merge tag 'pm-5.20-rc1-2' of git://git.kernel.org/pub/scm/linux/kernel/git/rafael...
[linux-2.6-microblaze.git] / arch / arm64 / crypto / aes-glue.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
4  *
5  * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7
8 #include <asm/neon.h>
9 #include <asm/hwcap.h>
10 #include <asm/simd.h>
11 #include <crypto/aes.h>
12 #include <crypto/ctr.h>
13 #include <crypto/sha2.h>
14 #include <crypto/internal/hash.h>
15 #include <crypto/internal/simd.h>
16 #include <crypto/internal/skcipher.h>
17 #include <crypto/scatterwalk.h>
18 #include <linux/module.h>
19 #include <linux/cpufeature.h>
20 #include <crypto/xts.h>
21
22 #include "aes-ce-setkey.h"
23
24 #ifdef USE_V8_CRYPTO_EXTENSIONS
25 #define MODE                    "ce"
26 #define PRIO                    300
27 #define aes_expandkey           ce_aes_expandkey
28 #define aes_ecb_encrypt         ce_aes_ecb_encrypt
29 #define aes_ecb_decrypt         ce_aes_ecb_decrypt
30 #define aes_cbc_encrypt         ce_aes_cbc_encrypt
31 #define aes_cbc_decrypt         ce_aes_cbc_decrypt
32 #define aes_cbc_cts_encrypt     ce_aes_cbc_cts_encrypt
33 #define aes_cbc_cts_decrypt     ce_aes_cbc_cts_decrypt
34 #define aes_essiv_cbc_encrypt   ce_aes_essiv_cbc_encrypt
35 #define aes_essiv_cbc_decrypt   ce_aes_essiv_cbc_decrypt
36 #define aes_ctr_encrypt         ce_aes_ctr_encrypt
37 #define aes_xctr_encrypt        ce_aes_xctr_encrypt
38 #define aes_xts_encrypt         ce_aes_xts_encrypt
39 #define aes_xts_decrypt         ce_aes_xts_decrypt
40 #define aes_mac_update          ce_aes_mac_update
41 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 Crypto Extensions");
42 #else
43 #define MODE                    "neon"
44 #define PRIO                    200
45 #define aes_ecb_encrypt         neon_aes_ecb_encrypt
46 #define aes_ecb_decrypt         neon_aes_ecb_decrypt
47 #define aes_cbc_encrypt         neon_aes_cbc_encrypt
48 #define aes_cbc_decrypt         neon_aes_cbc_decrypt
49 #define aes_cbc_cts_encrypt     neon_aes_cbc_cts_encrypt
50 #define aes_cbc_cts_decrypt     neon_aes_cbc_cts_decrypt
51 #define aes_essiv_cbc_encrypt   neon_aes_essiv_cbc_encrypt
52 #define aes_essiv_cbc_decrypt   neon_aes_essiv_cbc_decrypt
53 #define aes_ctr_encrypt         neon_aes_ctr_encrypt
54 #define aes_xctr_encrypt        neon_aes_xctr_encrypt
55 #define aes_xts_encrypt         neon_aes_xts_encrypt
56 #define aes_xts_decrypt         neon_aes_xts_decrypt
57 #define aes_mac_update          neon_aes_mac_update
58 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 NEON");
59 #endif
60 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
61 MODULE_ALIAS_CRYPTO("ecb(aes)");
62 MODULE_ALIAS_CRYPTO("cbc(aes)");
63 MODULE_ALIAS_CRYPTO("ctr(aes)");
64 MODULE_ALIAS_CRYPTO("xts(aes)");
65 MODULE_ALIAS_CRYPTO("xctr(aes)");
66 #endif
67 MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
68 MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
69 MODULE_ALIAS_CRYPTO("cmac(aes)");
70 MODULE_ALIAS_CRYPTO("xcbc(aes)");
71 MODULE_ALIAS_CRYPTO("cbcmac(aes)");
72
73 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
74 MODULE_LICENSE("GPL v2");
75
76 /* defined in aes-modes.S */
77 asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
78                                 int rounds, int blocks);
79 asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
80                                 int rounds, int blocks);
81
82 asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
83                                 int rounds, int blocks, u8 iv[]);
84 asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
85                                 int rounds, int blocks, u8 iv[]);
86
87 asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
88                                 int rounds, int bytes, u8 const iv[]);
89 asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
90                                 int rounds, int bytes, u8 const iv[]);
91
92 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
93                                 int rounds, int bytes, u8 ctr[]);
94
95 asmlinkage void aes_xctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
96                                  int rounds, int bytes, u8 ctr[], int byte_ctr);
97
98 asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
99                                 int rounds, int bytes, u32 const rk2[], u8 iv[],
100                                 int first);
101 asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
102                                 int rounds, int bytes, u32 const rk2[], u8 iv[],
103                                 int first);
104
105 asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
106                                       int rounds, int blocks, u8 iv[],
107                                       u32 const rk2[]);
108 asmlinkage void aes_essiv_cbc_decrypt(u8 out[], u8 const in[], u32 const rk1[],
109                                       int rounds, int blocks, u8 iv[],
110                                       u32 const rk2[]);
111
112 asmlinkage int aes_mac_update(u8 const in[], u32 const rk[], int rounds,
113                               int blocks, u8 dg[], int enc_before,
114                               int enc_after);
115
116 struct crypto_aes_xts_ctx {
117         struct crypto_aes_ctx key1;
118         struct crypto_aes_ctx __aligned(8) key2;
119 };
120
121 struct crypto_aes_essiv_cbc_ctx {
122         struct crypto_aes_ctx key1;
123         struct crypto_aes_ctx __aligned(8) key2;
124         struct crypto_shash *hash;
125 };
126
127 struct mac_tfm_ctx {
128         struct crypto_aes_ctx key;
129         u8 __aligned(8) consts[];
130 };
131
132 struct mac_desc_ctx {
133         unsigned int len;
134         u8 dg[AES_BLOCK_SIZE];
135 };
136
137 static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
138                                unsigned int key_len)
139 {
140         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
141
142         return aes_expandkey(ctx, in_key, key_len);
143 }
144
145 static int __maybe_unused xts_set_key(struct crypto_skcipher *tfm,
146                                       const u8 *in_key, unsigned int key_len)
147 {
148         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
149         int ret;
150
151         ret = xts_verify_key(tfm, in_key, key_len);
152         if (ret)
153                 return ret;
154
155         ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
156         if (!ret)
157                 ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
158                                     key_len / 2);
159         return ret;
160 }
161
162 static int __maybe_unused essiv_cbc_set_key(struct crypto_skcipher *tfm,
163                                             const u8 *in_key,
164                                             unsigned int key_len)
165 {
166         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
167         u8 digest[SHA256_DIGEST_SIZE];
168         int ret;
169
170         ret = aes_expandkey(&ctx->key1, in_key, key_len);
171         if (ret)
172                 return ret;
173
174         crypto_shash_tfm_digest(ctx->hash, in_key, key_len, digest);
175
176         return aes_expandkey(&ctx->key2, digest, sizeof(digest));
177 }
178
179 static int __maybe_unused ecb_encrypt(struct skcipher_request *req)
180 {
181         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
182         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
183         int err, rounds = 6 + ctx->key_length / 4;
184         struct skcipher_walk walk;
185         unsigned int blocks;
186
187         err = skcipher_walk_virt(&walk, req, false);
188
189         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
190                 kernel_neon_begin();
191                 aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
192                                 ctx->key_enc, rounds, blocks);
193                 kernel_neon_end();
194                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
195         }
196         return err;
197 }
198
199 static int __maybe_unused ecb_decrypt(struct skcipher_request *req)
200 {
201         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
202         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
203         int err, rounds = 6 + ctx->key_length / 4;
204         struct skcipher_walk walk;
205         unsigned int blocks;
206
207         err = skcipher_walk_virt(&walk, req, false);
208
209         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
210                 kernel_neon_begin();
211                 aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
212                                 ctx->key_dec, rounds, blocks);
213                 kernel_neon_end();
214                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
215         }
216         return err;
217 }
218
219 static int cbc_encrypt_walk(struct skcipher_request *req,
220                             struct skcipher_walk *walk)
221 {
222         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
223         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
224         int err = 0, rounds = 6 + ctx->key_length / 4;
225         unsigned int blocks;
226
227         while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
228                 kernel_neon_begin();
229                 aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
230                                 ctx->key_enc, rounds, blocks, walk->iv);
231                 kernel_neon_end();
232                 err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
233         }
234         return err;
235 }
236
237 static int __maybe_unused cbc_encrypt(struct skcipher_request *req)
238 {
239         struct skcipher_walk walk;
240         int err;
241
242         err = skcipher_walk_virt(&walk, req, false);
243         if (err)
244                 return err;
245         return cbc_encrypt_walk(req, &walk);
246 }
247
248 static int cbc_decrypt_walk(struct skcipher_request *req,
249                             struct skcipher_walk *walk)
250 {
251         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
252         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
253         int err = 0, rounds = 6 + ctx->key_length / 4;
254         unsigned int blocks;
255
256         while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
257                 kernel_neon_begin();
258                 aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
259                                 ctx->key_dec, rounds, blocks, walk->iv);
260                 kernel_neon_end();
261                 err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
262         }
263         return err;
264 }
265
266 static int __maybe_unused cbc_decrypt(struct skcipher_request *req)
267 {
268         struct skcipher_walk walk;
269         int err;
270
271         err = skcipher_walk_virt(&walk, req, false);
272         if (err)
273                 return err;
274         return cbc_decrypt_walk(req, &walk);
275 }
276
277 static int cts_cbc_encrypt(struct skcipher_request *req)
278 {
279         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
280         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
281         int err, rounds = 6 + ctx->key_length / 4;
282         int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
283         struct scatterlist *src = req->src, *dst = req->dst;
284         struct scatterlist sg_src[2], sg_dst[2];
285         struct skcipher_request subreq;
286         struct skcipher_walk walk;
287
288         skcipher_request_set_tfm(&subreq, tfm);
289         skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
290                                       NULL, NULL);
291
292         if (req->cryptlen <= AES_BLOCK_SIZE) {
293                 if (req->cryptlen < AES_BLOCK_SIZE)
294                         return -EINVAL;
295                 cbc_blocks = 1;
296         }
297
298         if (cbc_blocks > 0) {
299                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
300                                            cbc_blocks * AES_BLOCK_SIZE,
301                                            req->iv);
302
303                 err = skcipher_walk_virt(&walk, &subreq, false) ?:
304                       cbc_encrypt_walk(&subreq, &walk);
305                 if (err)
306                         return err;
307
308                 if (req->cryptlen == AES_BLOCK_SIZE)
309                         return 0;
310
311                 dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
312                 if (req->dst != req->src)
313                         dst = scatterwalk_ffwd(sg_dst, req->dst,
314                                                subreq.cryptlen);
315         }
316
317         /* handle ciphertext stealing */
318         skcipher_request_set_crypt(&subreq, src, dst,
319                                    req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
320                                    req->iv);
321
322         err = skcipher_walk_virt(&walk, &subreq, false);
323         if (err)
324                 return err;
325
326         kernel_neon_begin();
327         aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
328                             ctx->key_enc, rounds, walk.nbytes, walk.iv);
329         kernel_neon_end();
330
331         return skcipher_walk_done(&walk, 0);
332 }
333
334 static int cts_cbc_decrypt(struct skcipher_request *req)
335 {
336         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
337         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
338         int err, rounds = 6 + ctx->key_length / 4;
339         int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
340         struct scatterlist *src = req->src, *dst = req->dst;
341         struct scatterlist sg_src[2], sg_dst[2];
342         struct skcipher_request subreq;
343         struct skcipher_walk walk;
344
345         skcipher_request_set_tfm(&subreq, tfm);
346         skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
347                                       NULL, NULL);
348
349         if (req->cryptlen <= AES_BLOCK_SIZE) {
350                 if (req->cryptlen < AES_BLOCK_SIZE)
351                         return -EINVAL;
352                 cbc_blocks = 1;
353         }
354
355         if (cbc_blocks > 0) {
356                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
357                                            cbc_blocks * AES_BLOCK_SIZE,
358                                            req->iv);
359
360                 err = skcipher_walk_virt(&walk, &subreq, false) ?:
361                       cbc_decrypt_walk(&subreq, &walk);
362                 if (err)
363                         return err;
364
365                 if (req->cryptlen == AES_BLOCK_SIZE)
366                         return 0;
367
368                 dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
369                 if (req->dst != req->src)
370                         dst = scatterwalk_ffwd(sg_dst, req->dst,
371                                                subreq.cryptlen);
372         }
373
374         /* handle ciphertext stealing */
375         skcipher_request_set_crypt(&subreq, src, dst,
376                                    req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
377                                    req->iv);
378
379         err = skcipher_walk_virt(&walk, &subreq, false);
380         if (err)
381                 return err;
382
383         kernel_neon_begin();
384         aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
385                             ctx->key_dec, rounds, walk.nbytes, walk.iv);
386         kernel_neon_end();
387
388         return skcipher_walk_done(&walk, 0);
389 }
390
391 static int __maybe_unused essiv_cbc_init_tfm(struct crypto_skcipher *tfm)
392 {
393         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
394
395         ctx->hash = crypto_alloc_shash("sha256", 0, 0);
396
397         return PTR_ERR_OR_ZERO(ctx->hash);
398 }
399
400 static void __maybe_unused essiv_cbc_exit_tfm(struct crypto_skcipher *tfm)
401 {
402         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
403
404         crypto_free_shash(ctx->hash);
405 }
406
407 static int __maybe_unused essiv_cbc_encrypt(struct skcipher_request *req)
408 {
409         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
410         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
411         int err, rounds = 6 + ctx->key1.key_length / 4;
412         struct skcipher_walk walk;
413         unsigned int blocks;
414
415         err = skcipher_walk_virt(&walk, req, false);
416
417         blocks = walk.nbytes / AES_BLOCK_SIZE;
418         if (blocks) {
419                 kernel_neon_begin();
420                 aes_essiv_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
421                                       ctx->key1.key_enc, rounds, blocks,
422                                       req->iv, ctx->key2.key_enc);
423                 kernel_neon_end();
424                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
425         }
426         return err ?: cbc_encrypt_walk(req, &walk);
427 }
428
429 static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
430 {
431         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
432         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
433         int err, rounds = 6 + ctx->key1.key_length / 4;
434         struct skcipher_walk walk;
435         unsigned int blocks;
436
437         err = skcipher_walk_virt(&walk, req, false);
438
439         blocks = walk.nbytes / AES_BLOCK_SIZE;
440         if (blocks) {
441                 kernel_neon_begin();
442                 aes_essiv_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
443                                       ctx->key1.key_dec, rounds, blocks,
444                                       req->iv, ctx->key2.key_enc);
445                 kernel_neon_end();
446                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
447         }
448         return err ?: cbc_decrypt_walk(req, &walk);
449 }
450
451 static int __maybe_unused xctr_encrypt(struct skcipher_request *req)
452 {
453         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
454         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
455         int err, rounds = 6 + ctx->key_length / 4;
456         struct skcipher_walk walk;
457         unsigned int byte_ctr = 0;
458
459         err = skcipher_walk_virt(&walk, req, false);
460
461         while (walk.nbytes > 0) {
462                 const u8 *src = walk.src.virt.addr;
463                 unsigned int nbytes = walk.nbytes;
464                 u8 *dst = walk.dst.virt.addr;
465                 u8 buf[AES_BLOCK_SIZE];
466
467                 /*
468                  * If given less than 16 bytes, we must copy the partial block
469                  * into a temporary buffer of 16 bytes to avoid out of bounds
470                  * reads and writes.  Furthermore, this code is somewhat unusual
471                  * in that it expects the end of the data to be at the end of
472                  * the temporary buffer, rather than the start of the data at
473                  * the start of the temporary buffer.
474                  */
475                 if (unlikely(nbytes < AES_BLOCK_SIZE))
476                         src = dst = memcpy(buf + sizeof(buf) - nbytes,
477                                            src, nbytes);
478                 else if (nbytes < walk.total)
479                         nbytes &= ~(AES_BLOCK_SIZE - 1);
480
481                 kernel_neon_begin();
482                 aes_xctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
483                                                  walk.iv, byte_ctr);
484                 kernel_neon_end();
485
486                 if (unlikely(nbytes < AES_BLOCK_SIZE))
487                         memcpy(walk.dst.virt.addr,
488                                buf + sizeof(buf) - nbytes, nbytes);
489                 byte_ctr += nbytes;
490
491                 err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
492         }
493
494         return err;
495 }
496
497 static int __maybe_unused ctr_encrypt(struct skcipher_request *req)
498 {
499         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
500         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
501         int err, rounds = 6 + ctx->key_length / 4;
502         struct skcipher_walk walk;
503
504         err = skcipher_walk_virt(&walk, req, false);
505
506         while (walk.nbytes > 0) {
507                 const u8 *src = walk.src.virt.addr;
508                 unsigned int nbytes = walk.nbytes;
509                 u8 *dst = walk.dst.virt.addr;
510                 u8 buf[AES_BLOCK_SIZE];
511
512                 /*
513                  * If given less than 16 bytes, we must copy the partial block
514                  * into a temporary buffer of 16 bytes to avoid out of bounds
515                  * reads and writes.  Furthermore, this code is somewhat unusual
516                  * in that it expects the end of the data to be at the end of
517                  * the temporary buffer, rather than the start of the data at
518                  * the start of the temporary buffer.
519                  */
520                 if (unlikely(nbytes < AES_BLOCK_SIZE))
521                         src = dst = memcpy(buf + sizeof(buf) - nbytes,
522                                            src, nbytes);
523                 else if (nbytes < walk.total)
524                         nbytes &= ~(AES_BLOCK_SIZE - 1);
525
526                 kernel_neon_begin();
527                 aes_ctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
528                                 walk.iv);
529                 kernel_neon_end();
530
531                 if (unlikely(nbytes < AES_BLOCK_SIZE))
532                         memcpy(walk.dst.virt.addr,
533                                buf + sizeof(buf) - nbytes, nbytes);
534
535                 err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
536         }
537
538         return err;
539 }
540
541 static int __maybe_unused xts_encrypt(struct skcipher_request *req)
542 {
543         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
544         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
545         int err, first, rounds = 6 + ctx->key1.key_length / 4;
546         int tail = req->cryptlen % AES_BLOCK_SIZE;
547         struct scatterlist sg_src[2], sg_dst[2];
548         struct skcipher_request subreq;
549         struct scatterlist *src, *dst;
550         struct skcipher_walk walk;
551
552         if (req->cryptlen < AES_BLOCK_SIZE)
553                 return -EINVAL;
554
555         err = skcipher_walk_virt(&walk, req, false);
556
557         if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
558                 int xts_blocks = DIV_ROUND_UP(req->cryptlen,
559                                               AES_BLOCK_SIZE) - 2;
560
561                 skcipher_walk_abort(&walk);
562
563                 skcipher_request_set_tfm(&subreq, tfm);
564                 skcipher_request_set_callback(&subreq,
565                                               skcipher_request_flags(req),
566                                               NULL, NULL);
567                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
568                                            xts_blocks * AES_BLOCK_SIZE,
569                                            req->iv);
570                 req = &subreq;
571                 err = skcipher_walk_virt(&walk, req, false);
572         } else {
573                 tail = 0;
574         }
575
576         for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
577                 int nbytes = walk.nbytes;
578
579                 if (walk.nbytes < walk.total)
580                         nbytes &= ~(AES_BLOCK_SIZE - 1);
581
582                 kernel_neon_begin();
583                 aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
584                                 ctx->key1.key_enc, rounds, nbytes,
585                                 ctx->key2.key_enc, walk.iv, first);
586                 kernel_neon_end();
587                 err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
588         }
589
590         if (err || likely(!tail))
591                 return err;
592
593         dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
594         if (req->dst != req->src)
595                 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
596
597         skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
598                                    req->iv);
599
600         err = skcipher_walk_virt(&walk, &subreq, false);
601         if (err)
602                 return err;
603
604         kernel_neon_begin();
605         aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
606                         ctx->key1.key_enc, rounds, walk.nbytes,
607                         ctx->key2.key_enc, walk.iv, first);
608         kernel_neon_end();
609
610         return skcipher_walk_done(&walk, 0);
611 }
612
613 static int __maybe_unused xts_decrypt(struct skcipher_request *req)
614 {
615         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
616         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
617         int err, first, rounds = 6 + ctx->key1.key_length / 4;
618         int tail = req->cryptlen % AES_BLOCK_SIZE;
619         struct scatterlist sg_src[2], sg_dst[2];
620         struct skcipher_request subreq;
621         struct scatterlist *src, *dst;
622         struct skcipher_walk walk;
623
624         if (req->cryptlen < AES_BLOCK_SIZE)
625                 return -EINVAL;
626
627         err = skcipher_walk_virt(&walk, req, false);
628
629         if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
630                 int xts_blocks = DIV_ROUND_UP(req->cryptlen,
631                                               AES_BLOCK_SIZE) - 2;
632
633                 skcipher_walk_abort(&walk);
634
635                 skcipher_request_set_tfm(&subreq, tfm);
636                 skcipher_request_set_callback(&subreq,
637                                               skcipher_request_flags(req),
638                                               NULL, NULL);
639                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
640                                            xts_blocks * AES_BLOCK_SIZE,
641                                            req->iv);
642                 req = &subreq;
643                 err = skcipher_walk_virt(&walk, req, false);
644         } else {
645                 tail = 0;
646         }
647
648         for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
649                 int nbytes = walk.nbytes;
650
651                 if (walk.nbytes < walk.total)
652                         nbytes &= ~(AES_BLOCK_SIZE - 1);
653
654                 kernel_neon_begin();
655                 aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
656                                 ctx->key1.key_dec, rounds, nbytes,
657                                 ctx->key2.key_enc, walk.iv, first);
658                 kernel_neon_end();
659                 err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
660         }
661
662         if (err || likely(!tail))
663                 return err;
664
665         dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
666         if (req->dst != req->src)
667                 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
668
669         skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
670                                    req->iv);
671
672         err = skcipher_walk_virt(&walk, &subreq, false);
673         if (err)
674                 return err;
675
676
677         kernel_neon_begin();
678         aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
679                         ctx->key1.key_dec, rounds, walk.nbytes,
680                         ctx->key2.key_enc, walk.iv, first);
681         kernel_neon_end();
682
683         return skcipher_walk_done(&walk, 0);
684 }
685
686 static struct skcipher_alg aes_algs[] = { {
687 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
688         .base = {
689                 .cra_name               = "ecb(aes)",
690                 .cra_driver_name        = "ecb-aes-" MODE,
691                 .cra_priority           = PRIO,
692                 .cra_blocksize          = AES_BLOCK_SIZE,
693                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
694                 .cra_module             = THIS_MODULE,
695         },
696         .min_keysize    = AES_MIN_KEY_SIZE,
697         .max_keysize    = AES_MAX_KEY_SIZE,
698         .setkey         = skcipher_aes_setkey,
699         .encrypt        = ecb_encrypt,
700         .decrypt        = ecb_decrypt,
701 }, {
702         .base = {
703                 .cra_name               = "cbc(aes)",
704                 .cra_driver_name        = "cbc-aes-" MODE,
705                 .cra_priority           = PRIO,
706                 .cra_blocksize          = AES_BLOCK_SIZE,
707                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
708                 .cra_module             = THIS_MODULE,
709         },
710         .min_keysize    = AES_MIN_KEY_SIZE,
711         .max_keysize    = AES_MAX_KEY_SIZE,
712         .ivsize         = AES_BLOCK_SIZE,
713         .setkey         = skcipher_aes_setkey,
714         .encrypt        = cbc_encrypt,
715         .decrypt        = cbc_decrypt,
716 }, {
717         .base = {
718                 .cra_name               = "ctr(aes)",
719                 .cra_driver_name        = "ctr-aes-" MODE,
720                 .cra_priority           = PRIO,
721                 .cra_blocksize          = 1,
722                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
723                 .cra_module             = THIS_MODULE,
724         },
725         .min_keysize    = AES_MIN_KEY_SIZE,
726         .max_keysize    = AES_MAX_KEY_SIZE,
727         .ivsize         = AES_BLOCK_SIZE,
728         .chunksize      = AES_BLOCK_SIZE,
729         .setkey         = skcipher_aes_setkey,
730         .encrypt        = ctr_encrypt,
731         .decrypt        = ctr_encrypt,
732 }, {
733         .base = {
734                 .cra_name               = "xctr(aes)",
735                 .cra_driver_name        = "xctr-aes-" MODE,
736                 .cra_priority           = PRIO,
737                 .cra_blocksize          = 1,
738                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
739                 .cra_module             = THIS_MODULE,
740         },
741         .min_keysize    = AES_MIN_KEY_SIZE,
742         .max_keysize    = AES_MAX_KEY_SIZE,
743         .ivsize         = AES_BLOCK_SIZE,
744         .chunksize      = AES_BLOCK_SIZE,
745         .setkey         = skcipher_aes_setkey,
746         .encrypt        = xctr_encrypt,
747         .decrypt        = xctr_encrypt,
748 }, {
749         .base = {
750                 .cra_name               = "xts(aes)",
751                 .cra_driver_name        = "xts-aes-" MODE,
752                 .cra_priority           = PRIO,
753                 .cra_blocksize          = AES_BLOCK_SIZE,
754                 .cra_ctxsize            = sizeof(struct crypto_aes_xts_ctx),
755                 .cra_module             = THIS_MODULE,
756         },
757         .min_keysize    = 2 * AES_MIN_KEY_SIZE,
758         .max_keysize    = 2 * AES_MAX_KEY_SIZE,
759         .ivsize         = AES_BLOCK_SIZE,
760         .walksize       = 2 * AES_BLOCK_SIZE,
761         .setkey         = xts_set_key,
762         .encrypt        = xts_encrypt,
763         .decrypt        = xts_decrypt,
764 }, {
765 #endif
766         .base = {
767                 .cra_name               = "cts(cbc(aes))",
768                 .cra_driver_name        = "cts-cbc-aes-" MODE,
769                 .cra_priority           = PRIO,
770                 .cra_blocksize          = AES_BLOCK_SIZE,
771                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
772                 .cra_module             = THIS_MODULE,
773         },
774         .min_keysize    = AES_MIN_KEY_SIZE,
775         .max_keysize    = AES_MAX_KEY_SIZE,
776         .ivsize         = AES_BLOCK_SIZE,
777         .walksize       = 2 * AES_BLOCK_SIZE,
778         .setkey         = skcipher_aes_setkey,
779         .encrypt        = cts_cbc_encrypt,
780         .decrypt        = cts_cbc_decrypt,
781 }, {
782         .base = {
783                 .cra_name               = "essiv(cbc(aes),sha256)",
784                 .cra_driver_name        = "essiv-cbc-aes-sha256-" MODE,
785                 .cra_priority           = PRIO + 1,
786                 .cra_blocksize          = AES_BLOCK_SIZE,
787                 .cra_ctxsize            = sizeof(struct crypto_aes_essiv_cbc_ctx),
788                 .cra_module             = THIS_MODULE,
789         },
790         .min_keysize    = AES_MIN_KEY_SIZE,
791         .max_keysize    = AES_MAX_KEY_SIZE,
792         .ivsize         = AES_BLOCK_SIZE,
793         .setkey         = essiv_cbc_set_key,
794         .encrypt        = essiv_cbc_encrypt,
795         .decrypt        = essiv_cbc_decrypt,
796         .init           = essiv_cbc_init_tfm,
797         .exit           = essiv_cbc_exit_tfm,
798 } };
799
800 static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
801                          unsigned int key_len)
802 {
803         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
804
805         return aes_expandkey(&ctx->key, in_key, key_len);
806 }
807
808 static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
809 {
810         u64 a = be64_to_cpu(x->a);
811         u64 b = be64_to_cpu(x->b);
812
813         y->a = cpu_to_be64((a << 1) | (b >> 63));
814         y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
815 }
816
817 static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
818                        unsigned int key_len)
819 {
820         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
821         be128 *consts = (be128 *)ctx->consts;
822         int rounds = 6 + key_len / 4;
823         int err;
824
825         err = cbcmac_setkey(tfm, in_key, key_len);
826         if (err)
827                 return err;
828
829         /* encrypt the zero vector */
830         kernel_neon_begin();
831         aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
832                         rounds, 1);
833         kernel_neon_end();
834
835         cmac_gf128_mul_by_x(consts, consts);
836         cmac_gf128_mul_by_x(consts + 1, consts);
837
838         return 0;
839 }
840
841 static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
842                        unsigned int key_len)
843 {
844         static u8 const ks[3][AES_BLOCK_SIZE] = {
845                 { [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
846                 { [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
847                 { [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
848         };
849
850         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
851         int rounds = 6 + key_len / 4;
852         u8 key[AES_BLOCK_SIZE];
853         int err;
854
855         err = cbcmac_setkey(tfm, in_key, key_len);
856         if (err)
857                 return err;
858
859         kernel_neon_begin();
860         aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
861         aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
862         kernel_neon_end();
863
864         return cbcmac_setkey(tfm, key, sizeof(key));
865 }
866
867 static int mac_init(struct shash_desc *desc)
868 {
869         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
870
871         memset(ctx->dg, 0, AES_BLOCK_SIZE);
872         ctx->len = 0;
873
874         return 0;
875 }
876
877 static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
878                           u8 dg[], int enc_before, int enc_after)
879 {
880         int rounds = 6 + ctx->key_length / 4;
881
882         if (crypto_simd_usable()) {
883                 int rem;
884
885                 do {
886                         kernel_neon_begin();
887                         rem = aes_mac_update(in, ctx->key_enc, rounds, blocks,
888                                              dg, enc_before, enc_after);
889                         kernel_neon_end();
890                         in += (blocks - rem) * AES_BLOCK_SIZE;
891                         blocks = rem;
892                         enc_before = 0;
893                 } while (blocks);
894         } else {
895                 if (enc_before)
896                         aes_encrypt(ctx, dg, dg);
897
898                 while (blocks--) {
899                         crypto_xor(dg, in, AES_BLOCK_SIZE);
900                         in += AES_BLOCK_SIZE;
901
902                         if (blocks || enc_after)
903                                 aes_encrypt(ctx, dg, dg);
904                 }
905         }
906 }
907
908 static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
909 {
910         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
911         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
912
913         while (len > 0) {
914                 unsigned int l;
915
916                 if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
917                     (ctx->len + len) > AES_BLOCK_SIZE) {
918
919                         int blocks = len / AES_BLOCK_SIZE;
920
921                         len %= AES_BLOCK_SIZE;
922
923                         mac_do_update(&tctx->key, p, blocks, ctx->dg,
924                                       (ctx->len != 0), (len != 0));
925
926                         p += blocks * AES_BLOCK_SIZE;
927
928                         if (!len) {
929                                 ctx->len = AES_BLOCK_SIZE;
930                                 break;
931                         }
932                         ctx->len = 0;
933                 }
934
935                 l = min(len, AES_BLOCK_SIZE - ctx->len);
936
937                 if (l <= AES_BLOCK_SIZE) {
938                         crypto_xor(ctx->dg + ctx->len, p, l);
939                         ctx->len += l;
940                         len -= l;
941                         p += l;
942                 }
943         }
944
945         return 0;
946 }
947
948 static int cbcmac_final(struct shash_desc *desc, u8 *out)
949 {
950         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
951         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
952
953         mac_do_update(&tctx->key, NULL, 0, ctx->dg, (ctx->len != 0), 0);
954
955         memcpy(out, ctx->dg, AES_BLOCK_SIZE);
956
957         return 0;
958 }
959
960 static int cmac_final(struct shash_desc *desc, u8 *out)
961 {
962         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
963         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
964         u8 *consts = tctx->consts;
965
966         if (ctx->len != AES_BLOCK_SIZE) {
967                 ctx->dg[ctx->len] ^= 0x80;
968                 consts += AES_BLOCK_SIZE;
969         }
970
971         mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
972
973         memcpy(out, ctx->dg, AES_BLOCK_SIZE);
974
975         return 0;
976 }
977
978 static struct shash_alg mac_algs[] = { {
979         .base.cra_name          = "cmac(aes)",
980         .base.cra_driver_name   = "cmac-aes-" MODE,
981         .base.cra_priority      = PRIO,
982         .base.cra_blocksize     = AES_BLOCK_SIZE,
983         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
984                                   2 * AES_BLOCK_SIZE,
985         .base.cra_module        = THIS_MODULE,
986
987         .digestsize             = AES_BLOCK_SIZE,
988         .init                   = mac_init,
989         .update                 = mac_update,
990         .final                  = cmac_final,
991         .setkey                 = cmac_setkey,
992         .descsize               = sizeof(struct mac_desc_ctx),
993 }, {
994         .base.cra_name          = "xcbc(aes)",
995         .base.cra_driver_name   = "xcbc-aes-" MODE,
996         .base.cra_priority      = PRIO,
997         .base.cra_blocksize     = AES_BLOCK_SIZE,
998         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
999                                   2 * AES_BLOCK_SIZE,
1000         .base.cra_module        = THIS_MODULE,
1001
1002         .digestsize             = AES_BLOCK_SIZE,
1003         .init                   = mac_init,
1004         .update                 = mac_update,
1005         .final                  = cmac_final,
1006         .setkey                 = xcbc_setkey,
1007         .descsize               = sizeof(struct mac_desc_ctx),
1008 }, {
1009         .base.cra_name          = "cbcmac(aes)",
1010         .base.cra_driver_name   = "cbcmac-aes-" MODE,
1011         .base.cra_priority      = PRIO,
1012         .base.cra_blocksize     = 1,
1013         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx),
1014         .base.cra_module        = THIS_MODULE,
1015
1016         .digestsize             = AES_BLOCK_SIZE,
1017         .init                   = mac_init,
1018         .update                 = mac_update,
1019         .final                  = cbcmac_final,
1020         .setkey                 = cbcmac_setkey,
1021         .descsize               = sizeof(struct mac_desc_ctx),
1022 } };
1023
1024 static void aes_exit(void)
1025 {
1026         crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1027         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1028 }
1029
1030 static int __init aes_init(void)
1031 {
1032         int err;
1033
1034         err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1035         if (err)
1036                 return err;
1037
1038         err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1039         if (err)
1040                 goto unregister_ciphers;
1041
1042         return 0;
1043
1044 unregister_ciphers:
1045         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1046         return err;
1047 }
1048
1049 #ifdef USE_V8_CRYPTO_EXTENSIONS
1050 module_cpu_feature_match(AES, aes_init);
1051 #else
1052 module_init(aes_init);
1053 EXPORT_SYMBOL(neon_aes_ecb_encrypt);
1054 EXPORT_SYMBOL(neon_aes_cbc_encrypt);
1055 EXPORT_SYMBOL(neon_aes_ctr_encrypt);
1056 EXPORT_SYMBOL(neon_aes_xts_encrypt);
1057 EXPORT_SYMBOL(neon_aes_xts_decrypt);
1058 #endif
1059 module_exit(aes_exit);