Merge tag 'pwm/for-5.16-rc1' of git://git.kernel.org/pub/scm/linux/kernel/git/thierry...
[linux-2.6-microblaze.git] / arch / arm64 / crypto / aes-neonbs-glue.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Bit sliced AES using NEON instructions
4  *
5  * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7
8 #include <asm/neon.h>
9 #include <asm/simd.h>
10 #include <crypto/aes.h>
11 #include <crypto/ctr.h>
12 #include <crypto/internal/simd.h>
13 #include <crypto/internal/skcipher.h>
14 #include <crypto/scatterwalk.h>
15 #include <crypto/xts.h>
16 #include <linux/module.h>
17
18 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
19 MODULE_LICENSE("GPL v2");
20
21 MODULE_ALIAS_CRYPTO("ecb(aes)");
22 MODULE_ALIAS_CRYPTO("cbc(aes)");
23 MODULE_ALIAS_CRYPTO("ctr(aes)");
24 MODULE_ALIAS_CRYPTO("xts(aes)");
25
26 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
27
28 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
29                                   int rounds, int blocks);
30 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
31                                   int rounds, int blocks);
32
33 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
34                                   int rounds, int blocks, u8 iv[]);
35
36 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
37                                   int rounds, int blocks, u8 iv[], u8 final[]);
38
39 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
40                                   int rounds, int blocks, u8 iv[]);
41 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
42                                   int rounds, int blocks, u8 iv[]);
43
44 /* borrowed from aes-neon-blk.ko */
45 asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
46                                      int rounds, int blocks);
47 asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
48                                      int rounds, int blocks, u8 iv[]);
49 asmlinkage void neon_aes_xts_encrypt(u8 out[], u8 const in[],
50                                      u32 const rk1[], int rounds, int bytes,
51                                      u32 const rk2[], u8 iv[], int first);
52 asmlinkage void neon_aes_xts_decrypt(u8 out[], u8 const in[],
53                                      u32 const rk1[], int rounds, int bytes,
54                                      u32 const rk2[], u8 iv[], int first);
55
56 struct aesbs_ctx {
57         u8      rk[13 * (8 * AES_BLOCK_SIZE) + 32];
58         int     rounds;
59 } __aligned(AES_BLOCK_SIZE);
60
61 struct aesbs_cbc_ctx {
62         struct aesbs_ctx        key;
63         u32                     enc[AES_MAX_KEYLENGTH_U32];
64 };
65
66 struct aesbs_xts_ctx {
67         struct aesbs_ctx        key;
68         u32                     twkey[AES_MAX_KEYLENGTH_U32];
69         struct crypto_aes_ctx   cts;
70 };
71
72 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
73                         unsigned int key_len)
74 {
75         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
76         struct crypto_aes_ctx rk;
77         int err;
78
79         err = aes_expandkey(&rk, in_key, key_len);
80         if (err)
81                 return err;
82
83         ctx->rounds = 6 + key_len / 4;
84
85         kernel_neon_begin();
86         aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
87         kernel_neon_end();
88
89         return 0;
90 }
91
92 static int __ecb_crypt(struct skcipher_request *req,
93                        void (*fn)(u8 out[], u8 const in[], u8 const rk[],
94                                   int rounds, int blocks))
95 {
96         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
97         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
98         struct skcipher_walk walk;
99         int err;
100
101         err = skcipher_walk_virt(&walk, req, false);
102
103         while (walk.nbytes >= AES_BLOCK_SIZE) {
104                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
105
106                 if (walk.nbytes < walk.total)
107                         blocks = round_down(blocks,
108                                             walk.stride / AES_BLOCK_SIZE);
109
110                 kernel_neon_begin();
111                 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
112                    ctx->rounds, blocks);
113                 kernel_neon_end();
114                 err = skcipher_walk_done(&walk,
115                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
116         }
117
118         return err;
119 }
120
121 static int ecb_encrypt(struct skcipher_request *req)
122 {
123         return __ecb_crypt(req, aesbs_ecb_encrypt);
124 }
125
126 static int ecb_decrypt(struct skcipher_request *req)
127 {
128         return __ecb_crypt(req, aesbs_ecb_decrypt);
129 }
130
131 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
132                             unsigned int key_len)
133 {
134         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
135         struct crypto_aes_ctx rk;
136         int err;
137
138         err = aes_expandkey(&rk, in_key, key_len);
139         if (err)
140                 return err;
141
142         ctx->key.rounds = 6 + key_len / 4;
143
144         memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc));
145
146         kernel_neon_begin();
147         aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
148         kernel_neon_end();
149         memzero_explicit(&rk, sizeof(rk));
150
151         return 0;
152 }
153
154 static int cbc_encrypt(struct skcipher_request *req)
155 {
156         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
157         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
158         struct skcipher_walk walk;
159         int err;
160
161         err = skcipher_walk_virt(&walk, req, false);
162
163         while (walk.nbytes >= AES_BLOCK_SIZE) {
164                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
165
166                 /* fall back to the non-bitsliced NEON implementation */
167                 kernel_neon_begin();
168                 neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
169                                      ctx->enc, ctx->key.rounds, blocks,
170                                      walk.iv);
171                 kernel_neon_end();
172                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
173         }
174         return err;
175 }
176
177 static int cbc_decrypt(struct skcipher_request *req)
178 {
179         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
180         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
181         struct skcipher_walk walk;
182         int err;
183
184         err = skcipher_walk_virt(&walk, req, false);
185
186         while (walk.nbytes >= AES_BLOCK_SIZE) {
187                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
188
189                 if (walk.nbytes < walk.total)
190                         blocks = round_down(blocks,
191                                             walk.stride / AES_BLOCK_SIZE);
192
193                 kernel_neon_begin();
194                 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
195                                   ctx->key.rk, ctx->key.rounds, blocks,
196                                   walk.iv);
197                 kernel_neon_end();
198                 err = skcipher_walk_done(&walk,
199                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
200         }
201
202         return err;
203 }
204
205 static int ctr_encrypt(struct skcipher_request *req)
206 {
207         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
208         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
209         struct skcipher_walk walk;
210         u8 buf[AES_BLOCK_SIZE];
211         int err;
212
213         err = skcipher_walk_virt(&walk, req, false);
214
215         while (walk.nbytes > 0) {
216                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
217                 u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
218
219                 if (walk.nbytes < walk.total) {
220                         blocks = round_down(blocks,
221                                             walk.stride / AES_BLOCK_SIZE);
222                         final = NULL;
223                 }
224
225                 kernel_neon_begin();
226                 aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
227                                   ctx->rk, ctx->rounds, blocks, walk.iv, final);
228                 kernel_neon_end();
229
230                 if (final) {
231                         u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
232                         u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
233
234                         crypto_xor_cpy(dst, src, final,
235                                        walk.total % AES_BLOCK_SIZE);
236
237                         err = skcipher_walk_done(&walk, 0);
238                         break;
239                 }
240                 err = skcipher_walk_done(&walk,
241                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
242         }
243         return err;
244 }
245
246 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
247                             unsigned int key_len)
248 {
249         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
250         struct crypto_aes_ctx rk;
251         int err;
252
253         err = xts_verify_key(tfm, in_key, key_len);
254         if (err)
255                 return err;
256
257         key_len /= 2;
258         err = aes_expandkey(&ctx->cts, in_key, key_len);
259         if (err)
260                 return err;
261
262         err = aes_expandkey(&rk, in_key + key_len, key_len);
263         if (err)
264                 return err;
265
266         memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey));
267
268         return aesbs_setkey(tfm, in_key, key_len);
269 }
270
271 static int __xts_crypt(struct skcipher_request *req, bool encrypt,
272                        void (*fn)(u8 out[], u8 const in[], u8 const rk[],
273                                   int rounds, int blocks, u8 iv[]))
274 {
275         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
276         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
277         int tail = req->cryptlen % (8 * AES_BLOCK_SIZE);
278         struct scatterlist sg_src[2], sg_dst[2];
279         struct skcipher_request subreq;
280         struct scatterlist *src, *dst;
281         struct skcipher_walk walk;
282         int nbytes, err;
283         int first = 1;
284         u8 *out, *in;
285
286         if (req->cryptlen < AES_BLOCK_SIZE)
287                 return -EINVAL;
288
289         /* ensure that the cts tail is covered by a single step */
290         if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) {
291                 int xts_blocks = DIV_ROUND_UP(req->cryptlen,
292                                               AES_BLOCK_SIZE) - 2;
293
294                 skcipher_request_set_tfm(&subreq, tfm);
295                 skcipher_request_set_callback(&subreq,
296                                               skcipher_request_flags(req),
297                                               NULL, NULL);
298                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
299                                            xts_blocks * AES_BLOCK_SIZE,
300                                            req->iv);
301                 req = &subreq;
302         } else {
303                 tail = 0;
304         }
305
306         err = skcipher_walk_virt(&walk, req, false);
307         if (err)
308                 return err;
309
310         while (walk.nbytes >= AES_BLOCK_SIZE) {
311                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
312
313                 if (walk.nbytes < walk.total || walk.nbytes % AES_BLOCK_SIZE)
314                         blocks = round_down(blocks,
315                                             walk.stride / AES_BLOCK_SIZE);
316
317                 out = walk.dst.virt.addr;
318                 in = walk.src.virt.addr;
319                 nbytes = walk.nbytes;
320
321                 kernel_neon_begin();
322                 if (likely(blocks > 6)) { /* plain NEON is faster otherwise */
323                         if (first)
324                                 neon_aes_ecb_encrypt(walk.iv, walk.iv,
325                                                      ctx->twkey,
326                                                      ctx->key.rounds, 1);
327                         first = 0;
328
329                         fn(out, in, ctx->key.rk, ctx->key.rounds, blocks,
330                            walk.iv);
331
332                         out += blocks * AES_BLOCK_SIZE;
333                         in += blocks * AES_BLOCK_SIZE;
334                         nbytes -= blocks * AES_BLOCK_SIZE;
335                 }
336
337                 if (walk.nbytes == walk.total && nbytes > 0)
338                         goto xts_tail;
339
340                 kernel_neon_end();
341                 err = skcipher_walk_done(&walk, nbytes);
342         }
343
344         if (err || likely(!tail))
345                 return err;
346
347         /* handle ciphertext stealing */
348         dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
349         if (req->dst != req->src)
350                 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
351
352         skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
353                                    req->iv);
354
355         err = skcipher_walk_virt(&walk, req, false);
356         if (err)
357                 return err;
358
359         out = walk.dst.virt.addr;
360         in = walk.src.virt.addr;
361         nbytes = walk.nbytes;
362
363         kernel_neon_begin();
364 xts_tail:
365         if (encrypt)
366                 neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, ctx->key.rounds,
367                                      nbytes, ctx->twkey, walk.iv, first ?: 2);
368         else
369                 neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, ctx->key.rounds,
370                                      nbytes, ctx->twkey, walk.iv, first ?: 2);
371         kernel_neon_end();
372
373         return skcipher_walk_done(&walk, 0);
374 }
375
376 static int xts_encrypt(struct skcipher_request *req)
377 {
378         return __xts_crypt(req, true, aesbs_xts_encrypt);
379 }
380
381 static int xts_decrypt(struct skcipher_request *req)
382 {
383         return __xts_crypt(req, false, aesbs_xts_decrypt);
384 }
385
386 static struct skcipher_alg aes_algs[] = { {
387         .base.cra_name          = "ecb(aes)",
388         .base.cra_driver_name   = "ecb-aes-neonbs",
389         .base.cra_priority      = 250,
390         .base.cra_blocksize     = AES_BLOCK_SIZE,
391         .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
392         .base.cra_module        = THIS_MODULE,
393
394         .min_keysize            = AES_MIN_KEY_SIZE,
395         .max_keysize            = AES_MAX_KEY_SIZE,
396         .walksize               = 8 * AES_BLOCK_SIZE,
397         .setkey                 = aesbs_setkey,
398         .encrypt                = ecb_encrypt,
399         .decrypt                = ecb_decrypt,
400 }, {
401         .base.cra_name          = "cbc(aes)",
402         .base.cra_driver_name   = "cbc-aes-neonbs",
403         .base.cra_priority      = 250,
404         .base.cra_blocksize     = AES_BLOCK_SIZE,
405         .base.cra_ctxsize       = sizeof(struct aesbs_cbc_ctx),
406         .base.cra_module        = THIS_MODULE,
407
408         .min_keysize            = AES_MIN_KEY_SIZE,
409         .max_keysize            = AES_MAX_KEY_SIZE,
410         .walksize               = 8 * AES_BLOCK_SIZE,
411         .ivsize                 = AES_BLOCK_SIZE,
412         .setkey                 = aesbs_cbc_setkey,
413         .encrypt                = cbc_encrypt,
414         .decrypt                = cbc_decrypt,
415 }, {
416         .base.cra_name          = "ctr(aes)",
417         .base.cra_driver_name   = "ctr-aes-neonbs",
418         .base.cra_priority      = 250,
419         .base.cra_blocksize     = 1,
420         .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
421         .base.cra_module        = THIS_MODULE,
422
423         .min_keysize            = AES_MIN_KEY_SIZE,
424         .max_keysize            = AES_MAX_KEY_SIZE,
425         .chunksize              = AES_BLOCK_SIZE,
426         .walksize               = 8 * AES_BLOCK_SIZE,
427         .ivsize                 = AES_BLOCK_SIZE,
428         .setkey                 = aesbs_setkey,
429         .encrypt                = ctr_encrypt,
430         .decrypt                = ctr_encrypt,
431 }, {
432         .base.cra_name          = "xts(aes)",
433         .base.cra_driver_name   = "xts-aes-neonbs",
434         .base.cra_priority      = 250,
435         .base.cra_blocksize     = AES_BLOCK_SIZE,
436         .base.cra_ctxsize       = sizeof(struct aesbs_xts_ctx),
437         .base.cra_module        = THIS_MODULE,
438
439         .min_keysize            = 2 * AES_MIN_KEY_SIZE,
440         .max_keysize            = 2 * AES_MAX_KEY_SIZE,
441         .walksize               = 8 * AES_BLOCK_SIZE,
442         .ivsize                 = AES_BLOCK_SIZE,
443         .setkey                 = aesbs_xts_setkey,
444         .encrypt                = xts_encrypt,
445         .decrypt                = xts_decrypt,
446 } };
447
448 static void aes_exit(void)
449 {
450         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
451 }
452
453 static int __init aes_init(void)
454 {
455         if (!cpu_have_named_feature(ASIMD))
456                 return -ENODEV;
457
458         return crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
459 }
460
461 module_init(aes_init);
462 module_exit(aes_exit);