Merge tag 'drm-msm-fixes-2021-04-02' of https://gitlab.freedesktop.org/drm/msm into...
[linux-2.6-microblaze.git] / arch / s390 / crypto / aes_s390.c
1 // SPDX-License-Identifier: GPL-2.0+
2 /*
3  * Cryptographic API.
4  *
5  * s390 implementation of the AES Cipher Algorithm.
6  *
7  * s390 Version:
8  *   Copyright IBM Corp. 2005, 2017
9  *   Author(s): Jan Glauber (jang@de.ibm.com)
10  *              Sebastian Siewior (sebastian@breakpoint.cc> SW-Fallback
11  *              Patrick Steuer <patrick.steuer@de.ibm.com>
12  *              Harald Freudenberger <freude@de.ibm.com>
13  *
14  * Derived from "crypto/aes_generic.c"
15  */
16
17 #define KMSG_COMPONENT "aes_s390"
18 #define pr_fmt(fmt) KMSG_COMPONENT ": " fmt
19
20 #include <crypto/aes.h>
21 #include <crypto/algapi.h>
22 #include <crypto/ghash.h>
23 #include <crypto/internal/aead.h>
24 #include <crypto/internal/cipher.h>
25 #include <crypto/internal/skcipher.h>
26 #include <crypto/scatterwalk.h>
27 #include <linux/err.h>
28 #include <linux/module.h>
29 #include <linux/cpufeature.h>
30 #include <linux/init.h>
31 #include <linux/mutex.h>
32 #include <linux/fips.h>
33 #include <linux/string.h>
34 #include <crypto/xts.h>
35 #include <asm/cpacf.h>
36
37 static u8 *ctrblk;
38 static DEFINE_MUTEX(ctrblk_lock);
39
40 static cpacf_mask_t km_functions, kmc_functions, kmctr_functions,
41                     kma_functions;
42
43 struct s390_aes_ctx {
44         u8 key[AES_MAX_KEY_SIZE];
45         int key_len;
46         unsigned long fc;
47         union {
48                 struct crypto_skcipher *skcipher;
49                 struct crypto_cipher *cip;
50         } fallback;
51 };
52
53 struct s390_xts_ctx {
54         u8 key[32];
55         u8 pcc_key[32];
56         int key_len;
57         unsigned long fc;
58         struct crypto_skcipher *fallback;
59 };
60
61 struct gcm_sg_walk {
62         struct scatter_walk walk;
63         unsigned int walk_bytes;
64         u8 *walk_ptr;
65         unsigned int walk_bytes_remain;
66         u8 buf[AES_BLOCK_SIZE];
67         unsigned int buf_bytes;
68         u8 *ptr;
69         unsigned int nbytes;
70 };
71
72 static int setkey_fallback_cip(struct crypto_tfm *tfm, const u8 *in_key,
73                 unsigned int key_len)
74 {
75         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
76
77         sctx->fallback.cip->base.crt_flags &= ~CRYPTO_TFM_REQ_MASK;
78         sctx->fallback.cip->base.crt_flags |= (tfm->crt_flags &
79                         CRYPTO_TFM_REQ_MASK);
80
81         return crypto_cipher_setkey(sctx->fallback.cip, in_key, key_len);
82 }
83
84 static int aes_set_key(struct crypto_tfm *tfm, const u8 *in_key,
85                        unsigned int key_len)
86 {
87         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
88         unsigned long fc;
89
90         /* Pick the correct function code based on the key length */
91         fc = (key_len == 16) ? CPACF_KM_AES_128 :
92              (key_len == 24) ? CPACF_KM_AES_192 :
93              (key_len == 32) ? CPACF_KM_AES_256 : 0;
94
95         /* Check if the function code is available */
96         sctx->fc = (fc && cpacf_test_func(&km_functions, fc)) ? fc : 0;
97         if (!sctx->fc)
98                 return setkey_fallback_cip(tfm, in_key, key_len);
99
100         sctx->key_len = key_len;
101         memcpy(sctx->key, in_key, key_len);
102         return 0;
103 }
104
105 static void crypto_aes_encrypt(struct crypto_tfm *tfm, u8 *out, const u8 *in)
106 {
107         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
108
109         if (unlikely(!sctx->fc)) {
110                 crypto_cipher_encrypt_one(sctx->fallback.cip, out, in);
111                 return;
112         }
113         cpacf_km(sctx->fc, &sctx->key, out, in, AES_BLOCK_SIZE);
114 }
115
116 static void crypto_aes_decrypt(struct crypto_tfm *tfm, u8 *out, const u8 *in)
117 {
118         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
119
120         if (unlikely(!sctx->fc)) {
121                 crypto_cipher_decrypt_one(sctx->fallback.cip, out, in);
122                 return;
123         }
124         cpacf_km(sctx->fc | CPACF_DECRYPT,
125                  &sctx->key, out, in, AES_BLOCK_SIZE);
126 }
127
128 static int fallback_init_cip(struct crypto_tfm *tfm)
129 {
130         const char *name = tfm->__crt_alg->cra_name;
131         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
132
133         sctx->fallback.cip = crypto_alloc_cipher(name, 0,
134                                                  CRYPTO_ALG_NEED_FALLBACK);
135
136         if (IS_ERR(sctx->fallback.cip)) {
137                 pr_err("Allocating AES fallback algorithm %s failed\n",
138                        name);
139                 return PTR_ERR(sctx->fallback.cip);
140         }
141
142         return 0;
143 }
144
145 static void fallback_exit_cip(struct crypto_tfm *tfm)
146 {
147         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
148
149         crypto_free_cipher(sctx->fallback.cip);
150         sctx->fallback.cip = NULL;
151 }
152
153 static struct crypto_alg aes_alg = {
154         .cra_name               =       "aes",
155         .cra_driver_name        =       "aes-s390",
156         .cra_priority           =       300,
157         .cra_flags              =       CRYPTO_ALG_TYPE_CIPHER |
158                                         CRYPTO_ALG_NEED_FALLBACK,
159         .cra_blocksize          =       AES_BLOCK_SIZE,
160         .cra_ctxsize            =       sizeof(struct s390_aes_ctx),
161         .cra_module             =       THIS_MODULE,
162         .cra_init               =       fallback_init_cip,
163         .cra_exit               =       fallback_exit_cip,
164         .cra_u                  =       {
165                 .cipher = {
166                         .cia_min_keysize        =       AES_MIN_KEY_SIZE,
167                         .cia_max_keysize        =       AES_MAX_KEY_SIZE,
168                         .cia_setkey             =       aes_set_key,
169                         .cia_encrypt            =       crypto_aes_encrypt,
170                         .cia_decrypt            =       crypto_aes_decrypt,
171                 }
172         }
173 };
174
175 static int setkey_fallback_skcipher(struct crypto_skcipher *tfm, const u8 *key,
176                                     unsigned int len)
177 {
178         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
179
180         crypto_skcipher_clear_flags(sctx->fallback.skcipher,
181                                     CRYPTO_TFM_REQ_MASK);
182         crypto_skcipher_set_flags(sctx->fallback.skcipher,
183                                   crypto_skcipher_get_flags(tfm) &
184                                   CRYPTO_TFM_REQ_MASK);
185         return crypto_skcipher_setkey(sctx->fallback.skcipher, key, len);
186 }
187
188 static int fallback_skcipher_crypt(struct s390_aes_ctx *sctx,
189                                    struct skcipher_request *req,
190                                    unsigned long modifier)
191 {
192         struct skcipher_request *subreq = skcipher_request_ctx(req);
193
194         *subreq = *req;
195         skcipher_request_set_tfm(subreq, sctx->fallback.skcipher);
196         return (modifier & CPACF_DECRYPT) ?
197                 crypto_skcipher_decrypt(subreq) :
198                 crypto_skcipher_encrypt(subreq);
199 }
200
201 static int ecb_aes_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
202                            unsigned int key_len)
203 {
204         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
205         unsigned long fc;
206
207         /* Pick the correct function code based on the key length */
208         fc = (key_len == 16) ? CPACF_KM_AES_128 :
209              (key_len == 24) ? CPACF_KM_AES_192 :
210              (key_len == 32) ? CPACF_KM_AES_256 : 0;
211
212         /* Check if the function code is available */
213         sctx->fc = (fc && cpacf_test_func(&km_functions, fc)) ? fc : 0;
214         if (!sctx->fc)
215                 return setkey_fallback_skcipher(tfm, in_key, key_len);
216
217         sctx->key_len = key_len;
218         memcpy(sctx->key, in_key, key_len);
219         return 0;
220 }
221
222 static int ecb_aes_crypt(struct skcipher_request *req, unsigned long modifier)
223 {
224         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
225         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
226         struct skcipher_walk walk;
227         unsigned int nbytes, n;
228         int ret;
229
230         if (unlikely(!sctx->fc))
231                 return fallback_skcipher_crypt(sctx, req, modifier);
232
233         ret = skcipher_walk_virt(&walk, req, false);
234         while ((nbytes = walk.nbytes) != 0) {
235                 /* only use complete blocks */
236                 n = nbytes & ~(AES_BLOCK_SIZE - 1);
237                 cpacf_km(sctx->fc | modifier, sctx->key,
238                          walk.dst.virt.addr, walk.src.virt.addr, n);
239                 ret = skcipher_walk_done(&walk, nbytes - n);
240         }
241         return ret;
242 }
243
244 static int ecb_aes_encrypt(struct skcipher_request *req)
245 {
246         return ecb_aes_crypt(req, 0);
247 }
248
249 static int ecb_aes_decrypt(struct skcipher_request *req)
250 {
251         return ecb_aes_crypt(req, CPACF_DECRYPT);
252 }
253
254 static int fallback_init_skcipher(struct crypto_skcipher *tfm)
255 {
256         const char *name = crypto_tfm_alg_name(&tfm->base);
257         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
258
259         sctx->fallback.skcipher = crypto_alloc_skcipher(name, 0,
260                                 CRYPTO_ALG_NEED_FALLBACK | CRYPTO_ALG_ASYNC);
261
262         if (IS_ERR(sctx->fallback.skcipher)) {
263                 pr_err("Allocating AES fallback algorithm %s failed\n",
264                        name);
265                 return PTR_ERR(sctx->fallback.skcipher);
266         }
267
268         crypto_skcipher_set_reqsize(tfm, sizeof(struct skcipher_request) +
269                                     crypto_skcipher_reqsize(sctx->fallback.skcipher));
270         return 0;
271 }
272
273 static void fallback_exit_skcipher(struct crypto_skcipher *tfm)
274 {
275         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
276
277         crypto_free_skcipher(sctx->fallback.skcipher);
278 }
279
280 static struct skcipher_alg ecb_aes_alg = {
281         .base.cra_name          =       "ecb(aes)",
282         .base.cra_driver_name   =       "ecb-aes-s390",
283         .base.cra_priority      =       401,    /* combo: aes + ecb + 1 */
284         .base.cra_flags         =       CRYPTO_ALG_NEED_FALLBACK,
285         .base.cra_blocksize     =       AES_BLOCK_SIZE,
286         .base.cra_ctxsize       =       sizeof(struct s390_aes_ctx),
287         .base.cra_module        =       THIS_MODULE,
288         .init                   =       fallback_init_skcipher,
289         .exit                   =       fallback_exit_skcipher,
290         .min_keysize            =       AES_MIN_KEY_SIZE,
291         .max_keysize            =       AES_MAX_KEY_SIZE,
292         .setkey                 =       ecb_aes_set_key,
293         .encrypt                =       ecb_aes_encrypt,
294         .decrypt                =       ecb_aes_decrypt,
295 };
296
297 static int cbc_aes_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
298                            unsigned int key_len)
299 {
300         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
301         unsigned long fc;
302
303         /* Pick the correct function code based on the key length */
304         fc = (key_len == 16) ? CPACF_KMC_AES_128 :
305              (key_len == 24) ? CPACF_KMC_AES_192 :
306              (key_len == 32) ? CPACF_KMC_AES_256 : 0;
307
308         /* Check if the function code is available */
309         sctx->fc = (fc && cpacf_test_func(&kmc_functions, fc)) ? fc : 0;
310         if (!sctx->fc)
311                 return setkey_fallback_skcipher(tfm, in_key, key_len);
312
313         sctx->key_len = key_len;
314         memcpy(sctx->key, in_key, key_len);
315         return 0;
316 }
317
318 static int cbc_aes_crypt(struct skcipher_request *req, unsigned long modifier)
319 {
320         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
321         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
322         struct skcipher_walk walk;
323         unsigned int nbytes, n;
324         int ret;
325         struct {
326                 u8 iv[AES_BLOCK_SIZE];
327                 u8 key[AES_MAX_KEY_SIZE];
328         } param;
329
330         if (unlikely(!sctx->fc))
331                 return fallback_skcipher_crypt(sctx, req, modifier);
332
333         ret = skcipher_walk_virt(&walk, req, false);
334         if (ret)
335                 return ret;
336         memcpy(param.iv, walk.iv, AES_BLOCK_SIZE);
337         memcpy(param.key, sctx->key, sctx->key_len);
338         while ((nbytes = walk.nbytes) != 0) {
339                 /* only use complete blocks */
340                 n = nbytes & ~(AES_BLOCK_SIZE - 1);
341                 cpacf_kmc(sctx->fc | modifier, &param,
342                           walk.dst.virt.addr, walk.src.virt.addr, n);
343                 memcpy(walk.iv, param.iv, AES_BLOCK_SIZE);
344                 ret = skcipher_walk_done(&walk, nbytes - n);
345         }
346         memzero_explicit(&param, sizeof(param));
347         return ret;
348 }
349
350 static int cbc_aes_encrypt(struct skcipher_request *req)
351 {
352         return cbc_aes_crypt(req, 0);
353 }
354
355 static int cbc_aes_decrypt(struct skcipher_request *req)
356 {
357         return cbc_aes_crypt(req, CPACF_DECRYPT);
358 }
359
360 static struct skcipher_alg cbc_aes_alg = {
361         .base.cra_name          =       "cbc(aes)",
362         .base.cra_driver_name   =       "cbc-aes-s390",
363         .base.cra_priority      =       402,    /* ecb-aes-s390 + 1 */
364         .base.cra_flags         =       CRYPTO_ALG_NEED_FALLBACK,
365         .base.cra_blocksize     =       AES_BLOCK_SIZE,
366         .base.cra_ctxsize       =       sizeof(struct s390_aes_ctx),
367         .base.cra_module        =       THIS_MODULE,
368         .init                   =       fallback_init_skcipher,
369         .exit                   =       fallback_exit_skcipher,
370         .min_keysize            =       AES_MIN_KEY_SIZE,
371         .max_keysize            =       AES_MAX_KEY_SIZE,
372         .ivsize                 =       AES_BLOCK_SIZE,
373         .setkey                 =       cbc_aes_set_key,
374         .encrypt                =       cbc_aes_encrypt,
375         .decrypt                =       cbc_aes_decrypt,
376 };
377
378 static int xts_fallback_setkey(struct crypto_skcipher *tfm, const u8 *key,
379                                unsigned int len)
380 {
381         struct s390_xts_ctx *xts_ctx = crypto_skcipher_ctx(tfm);
382
383         crypto_skcipher_clear_flags(xts_ctx->fallback, CRYPTO_TFM_REQ_MASK);
384         crypto_skcipher_set_flags(xts_ctx->fallback,
385                                   crypto_skcipher_get_flags(tfm) &
386                                   CRYPTO_TFM_REQ_MASK);
387         return crypto_skcipher_setkey(xts_ctx->fallback, key, len);
388 }
389
390 static int xts_aes_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
391                            unsigned int key_len)
392 {
393         struct s390_xts_ctx *xts_ctx = crypto_skcipher_ctx(tfm);
394         unsigned long fc;
395         int err;
396
397         err = xts_fallback_setkey(tfm, in_key, key_len);
398         if (err)
399                 return err;
400
401         /* In fips mode only 128 bit or 256 bit keys are valid */
402         if (fips_enabled && key_len != 32 && key_len != 64)
403                 return -EINVAL;
404
405         /* Pick the correct function code based on the key length */
406         fc = (key_len == 32) ? CPACF_KM_XTS_128 :
407              (key_len == 64) ? CPACF_KM_XTS_256 : 0;
408
409         /* Check if the function code is available */
410         xts_ctx->fc = (fc && cpacf_test_func(&km_functions, fc)) ? fc : 0;
411         if (!xts_ctx->fc)
412                 return 0;
413
414         /* Split the XTS key into the two subkeys */
415         key_len = key_len / 2;
416         xts_ctx->key_len = key_len;
417         memcpy(xts_ctx->key, in_key, key_len);
418         memcpy(xts_ctx->pcc_key, in_key + key_len, key_len);
419         return 0;
420 }
421
422 static int xts_aes_crypt(struct skcipher_request *req, unsigned long modifier)
423 {
424         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
425         struct s390_xts_ctx *xts_ctx = crypto_skcipher_ctx(tfm);
426         struct skcipher_walk walk;
427         unsigned int offset, nbytes, n;
428         int ret;
429         struct {
430                 u8 key[32];
431                 u8 tweak[16];
432                 u8 block[16];
433                 u8 bit[16];
434                 u8 xts[16];
435         } pcc_param;
436         struct {
437                 u8 key[32];
438                 u8 init[16];
439         } xts_param;
440
441         if (req->cryptlen < AES_BLOCK_SIZE)
442                 return -EINVAL;
443
444         if (unlikely(!xts_ctx->fc || (req->cryptlen % AES_BLOCK_SIZE) != 0)) {
445                 struct skcipher_request *subreq = skcipher_request_ctx(req);
446
447                 *subreq = *req;
448                 skcipher_request_set_tfm(subreq, xts_ctx->fallback);
449                 return (modifier & CPACF_DECRYPT) ?
450                         crypto_skcipher_decrypt(subreq) :
451                         crypto_skcipher_encrypt(subreq);
452         }
453
454         ret = skcipher_walk_virt(&walk, req, false);
455         if (ret)
456                 return ret;
457         offset = xts_ctx->key_len & 0x10;
458         memset(pcc_param.block, 0, sizeof(pcc_param.block));
459         memset(pcc_param.bit, 0, sizeof(pcc_param.bit));
460         memset(pcc_param.xts, 0, sizeof(pcc_param.xts));
461         memcpy(pcc_param.tweak, walk.iv, sizeof(pcc_param.tweak));
462         memcpy(pcc_param.key + offset, xts_ctx->pcc_key, xts_ctx->key_len);
463         cpacf_pcc(xts_ctx->fc, pcc_param.key + offset);
464
465         memcpy(xts_param.key + offset, xts_ctx->key, xts_ctx->key_len);
466         memcpy(xts_param.init, pcc_param.xts, 16);
467
468         while ((nbytes = walk.nbytes) != 0) {
469                 /* only use complete blocks */
470                 n = nbytes & ~(AES_BLOCK_SIZE - 1);
471                 cpacf_km(xts_ctx->fc | modifier, xts_param.key + offset,
472                          walk.dst.virt.addr, walk.src.virt.addr, n);
473                 ret = skcipher_walk_done(&walk, nbytes - n);
474         }
475         memzero_explicit(&pcc_param, sizeof(pcc_param));
476         memzero_explicit(&xts_param, sizeof(xts_param));
477         return ret;
478 }
479
480 static int xts_aes_encrypt(struct skcipher_request *req)
481 {
482         return xts_aes_crypt(req, 0);
483 }
484
485 static int xts_aes_decrypt(struct skcipher_request *req)
486 {
487         return xts_aes_crypt(req, CPACF_DECRYPT);
488 }
489
490 static int xts_fallback_init(struct crypto_skcipher *tfm)
491 {
492         const char *name = crypto_tfm_alg_name(&tfm->base);
493         struct s390_xts_ctx *xts_ctx = crypto_skcipher_ctx(tfm);
494
495         xts_ctx->fallback = crypto_alloc_skcipher(name, 0,
496                                 CRYPTO_ALG_NEED_FALLBACK | CRYPTO_ALG_ASYNC);
497
498         if (IS_ERR(xts_ctx->fallback)) {
499                 pr_err("Allocating XTS fallback algorithm %s failed\n",
500                        name);
501                 return PTR_ERR(xts_ctx->fallback);
502         }
503         crypto_skcipher_set_reqsize(tfm, sizeof(struct skcipher_request) +
504                                     crypto_skcipher_reqsize(xts_ctx->fallback));
505         return 0;
506 }
507
508 static void xts_fallback_exit(struct crypto_skcipher *tfm)
509 {
510         struct s390_xts_ctx *xts_ctx = crypto_skcipher_ctx(tfm);
511
512         crypto_free_skcipher(xts_ctx->fallback);
513 }
514
515 static struct skcipher_alg xts_aes_alg = {
516         .base.cra_name          =       "xts(aes)",
517         .base.cra_driver_name   =       "xts-aes-s390",
518         .base.cra_priority      =       402,    /* ecb-aes-s390 + 1 */
519         .base.cra_flags         =       CRYPTO_ALG_NEED_FALLBACK,
520         .base.cra_blocksize     =       AES_BLOCK_SIZE,
521         .base.cra_ctxsize       =       sizeof(struct s390_xts_ctx),
522         .base.cra_module        =       THIS_MODULE,
523         .init                   =       xts_fallback_init,
524         .exit                   =       xts_fallback_exit,
525         .min_keysize            =       2 * AES_MIN_KEY_SIZE,
526         .max_keysize            =       2 * AES_MAX_KEY_SIZE,
527         .ivsize                 =       AES_BLOCK_SIZE,
528         .setkey                 =       xts_aes_set_key,
529         .encrypt                =       xts_aes_encrypt,
530         .decrypt                =       xts_aes_decrypt,
531 };
532
533 static int ctr_aes_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
534                            unsigned int key_len)
535 {
536         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
537         unsigned long fc;
538
539         /* Pick the correct function code based on the key length */
540         fc = (key_len == 16) ? CPACF_KMCTR_AES_128 :
541              (key_len == 24) ? CPACF_KMCTR_AES_192 :
542              (key_len == 32) ? CPACF_KMCTR_AES_256 : 0;
543
544         /* Check if the function code is available */
545         sctx->fc = (fc && cpacf_test_func(&kmctr_functions, fc)) ? fc : 0;
546         if (!sctx->fc)
547                 return setkey_fallback_skcipher(tfm, in_key, key_len);
548
549         sctx->key_len = key_len;
550         memcpy(sctx->key, in_key, key_len);
551         return 0;
552 }
553
554 static unsigned int __ctrblk_init(u8 *ctrptr, u8 *iv, unsigned int nbytes)
555 {
556         unsigned int i, n;
557
558         /* only use complete blocks, max. PAGE_SIZE */
559         memcpy(ctrptr, iv, AES_BLOCK_SIZE);
560         n = (nbytes > PAGE_SIZE) ? PAGE_SIZE : nbytes & ~(AES_BLOCK_SIZE - 1);
561         for (i = (n / AES_BLOCK_SIZE) - 1; i > 0; i--) {
562                 memcpy(ctrptr + AES_BLOCK_SIZE, ctrptr, AES_BLOCK_SIZE);
563                 crypto_inc(ctrptr + AES_BLOCK_SIZE, AES_BLOCK_SIZE);
564                 ctrptr += AES_BLOCK_SIZE;
565         }
566         return n;
567 }
568
569 static int ctr_aes_crypt(struct skcipher_request *req)
570 {
571         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
572         struct s390_aes_ctx *sctx = crypto_skcipher_ctx(tfm);
573         u8 buf[AES_BLOCK_SIZE], *ctrptr;
574         struct skcipher_walk walk;
575         unsigned int n, nbytes;
576         int ret, locked;
577
578         if (unlikely(!sctx->fc))
579                 return fallback_skcipher_crypt(sctx, req, 0);
580
581         locked = mutex_trylock(&ctrblk_lock);
582
583         ret = skcipher_walk_virt(&walk, req, false);
584         while ((nbytes = walk.nbytes) >= AES_BLOCK_SIZE) {
585                 n = AES_BLOCK_SIZE;
586
587                 if (nbytes >= 2*AES_BLOCK_SIZE && locked)
588                         n = __ctrblk_init(ctrblk, walk.iv, nbytes);
589                 ctrptr = (n > AES_BLOCK_SIZE) ? ctrblk : walk.iv;
590                 cpacf_kmctr(sctx->fc, sctx->key, walk.dst.virt.addr,
591                             walk.src.virt.addr, n, ctrptr);
592                 if (ctrptr == ctrblk)
593                         memcpy(walk.iv, ctrptr + n - AES_BLOCK_SIZE,
594                                AES_BLOCK_SIZE);
595                 crypto_inc(walk.iv, AES_BLOCK_SIZE);
596                 ret = skcipher_walk_done(&walk, nbytes - n);
597         }
598         if (locked)
599                 mutex_unlock(&ctrblk_lock);
600         /*
601          * final block may be < AES_BLOCK_SIZE, copy only nbytes
602          */
603         if (nbytes) {
604                 cpacf_kmctr(sctx->fc, sctx->key, buf, walk.src.virt.addr,
605                             AES_BLOCK_SIZE, walk.iv);
606                 memcpy(walk.dst.virt.addr, buf, nbytes);
607                 crypto_inc(walk.iv, AES_BLOCK_SIZE);
608                 ret = skcipher_walk_done(&walk, 0);
609         }
610
611         return ret;
612 }
613
614 static struct skcipher_alg ctr_aes_alg = {
615         .base.cra_name          =       "ctr(aes)",
616         .base.cra_driver_name   =       "ctr-aes-s390",
617         .base.cra_priority      =       402,    /* ecb-aes-s390 + 1 */
618         .base.cra_flags         =       CRYPTO_ALG_NEED_FALLBACK,
619         .base.cra_blocksize     =       1,
620         .base.cra_ctxsize       =       sizeof(struct s390_aes_ctx),
621         .base.cra_module        =       THIS_MODULE,
622         .init                   =       fallback_init_skcipher,
623         .exit                   =       fallback_exit_skcipher,
624         .min_keysize            =       AES_MIN_KEY_SIZE,
625         .max_keysize            =       AES_MAX_KEY_SIZE,
626         .ivsize                 =       AES_BLOCK_SIZE,
627         .setkey                 =       ctr_aes_set_key,
628         .encrypt                =       ctr_aes_crypt,
629         .decrypt                =       ctr_aes_crypt,
630         .chunksize              =       AES_BLOCK_SIZE,
631 };
632
633 static int gcm_aes_setkey(struct crypto_aead *tfm, const u8 *key,
634                           unsigned int keylen)
635 {
636         struct s390_aes_ctx *ctx = crypto_aead_ctx(tfm);
637
638         switch (keylen) {
639         case AES_KEYSIZE_128:
640                 ctx->fc = CPACF_KMA_GCM_AES_128;
641                 break;
642         case AES_KEYSIZE_192:
643                 ctx->fc = CPACF_KMA_GCM_AES_192;
644                 break;
645         case AES_KEYSIZE_256:
646                 ctx->fc = CPACF_KMA_GCM_AES_256;
647                 break;
648         default:
649                 return -EINVAL;
650         }
651
652         memcpy(ctx->key, key, keylen);
653         ctx->key_len = keylen;
654         return 0;
655 }
656
657 static int gcm_aes_setauthsize(struct crypto_aead *tfm, unsigned int authsize)
658 {
659         switch (authsize) {
660         case 4:
661         case 8:
662         case 12:
663         case 13:
664         case 14:
665         case 15:
666         case 16:
667                 break;
668         default:
669                 return -EINVAL;
670         }
671
672         return 0;
673 }
674
675 static void gcm_walk_start(struct gcm_sg_walk *gw, struct scatterlist *sg,
676                            unsigned int len)
677 {
678         memset(gw, 0, sizeof(*gw));
679         gw->walk_bytes_remain = len;
680         scatterwalk_start(&gw->walk, sg);
681 }
682
683 static inline unsigned int _gcm_sg_clamp_and_map(struct gcm_sg_walk *gw)
684 {
685         struct scatterlist *nextsg;
686
687         gw->walk_bytes = scatterwalk_clamp(&gw->walk, gw->walk_bytes_remain);
688         while (!gw->walk_bytes) {
689                 nextsg = sg_next(gw->walk.sg);
690                 if (!nextsg)
691                         return 0;
692                 scatterwalk_start(&gw->walk, nextsg);
693                 gw->walk_bytes = scatterwalk_clamp(&gw->walk,
694                                                    gw->walk_bytes_remain);
695         }
696         gw->walk_ptr = scatterwalk_map(&gw->walk);
697         return gw->walk_bytes;
698 }
699
700 static inline void _gcm_sg_unmap_and_advance(struct gcm_sg_walk *gw,
701                                              unsigned int nbytes)
702 {
703         gw->walk_bytes_remain -= nbytes;
704         scatterwalk_unmap(&gw->walk);
705         scatterwalk_advance(&gw->walk, nbytes);
706         scatterwalk_done(&gw->walk, 0, gw->walk_bytes_remain);
707         gw->walk_ptr = NULL;
708 }
709
710 static int gcm_in_walk_go(struct gcm_sg_walk *gw, unsigned int minbytesneeded)
711 {
712         int n;
713
714         if (gw->buf_bytes && gw->buf_bytes >= minbytesneeded) {
715                 gw->ptr = gw->buf;
716                 gw->nbytes = gw->buf_bytes;
717                 goto out;
718         }
719
720         if (gw->walk_bytes_remain == 0) {
721                 gw->ptr = NULL;
722                 gw->nbytes = 0;
723                 goto out;
724         }
725
726         if (!_gcm_sg_clamp_and_map(gw)) {
727                 gw->ptr = NULL;
728                 gw->nbytes = 0;
729                 goto out;
730         }
731
732         if (!gw->buf_bytes && gw->walk_bytes >= minbytesneeded) {
733                 gw->ptr = gw->walk_ptr;
734                 gw->nbytes = gw->walk_bytes;
735                 goto out;
736         }
737
738         while (1) {
739                 n = min(gw->walk_bytes, AES_BLOCK_SIZE - gw->buf_bytes);
740                 memcpy(gw->buf + gw->buf_bytes, gw->walk_ptr, n);
741                 gw->buf_bytes += n;
742                 _gcm_sg_unmap_and_advance(gw, n);
743                 if (gw->buf_bytes >= minbytesneeded) {
744                         gw->ptr = gw->buf;
745                         gw->nbytes = gw->buf_bytes;
746                         goto out;
747                 }
748                 if (!_gcm_sg_clamp_and_map(gw)) {
749                         gw->ptr = NULL;
750                         gw->nbytes = 0;
751                         goto out;
752                 }
753         }
754
755 out:
756         return gw->nbytes;
757 }
758
759 static int gcm_out_walk_go(struct gcm_sg_walk *gw, unsigned int minbytesneeded)
760 {
761         if (gw->walk_bytes_remain == 0) {
762                 gw->ptr = NULL;
763                 gw->nbytes = 0;
764                 goto out;
765         }
766
767         if (!_gcm_sg_clamp_and_map(gw)) {
768                 gw->ptr = NULL;
769                 gw->nbytes = 0;
770                 goto out;
771         }
772
773         if (gw->walk_bytes >= minbytesneeded) {
774                 gw->ptr = gw->walk_ptr;
775                 gw->nbytes = gw->walk_bytes;
776                 goto out;
777         }
778
779         scatterwalk_unmap(&gw->walk);
780         gw->walk_ptr = NULL;
781
782         gw->ptr = gw->buf;
783         gw->nbytes = sizeof(gw->buf);
784
785 out:
786         return gw->nbytes;
787 }
788
789 static int gcm_in_walk_done(struct gcm_sg_walk *gw, unsigned int bytesdone)
790 {
791         if (gw->ptr == NULL)
792                 return 0;
793
794         if (gw->ptr == gw->buf) {
795                 int n = gw->buf_bytes - bytesdone;
796                 if (n > 0) {
797                         memmove(gw->buf, gw->buf + bytesdone, n);
798                         gw->buf_bytes = n;
799                 } else
800                         gw->buf_bytes = 0;
801         } else
802                 _gcm_sg_unmap_and_advance(gw, bytesdone);
803
804         return bytesdone;
805 }
806
807 static int gcm_out_walk_done(struct gcm_sg_walk *gw, unsigned int bytesdone)
808 {
809         int i, n;
810
811         if (gw->ptr == NULL)
812                 return 0;
813
814         if (gw->ptr == gw->buf) {
815                 for (i = 0; i < bytesdone; i += n) {
816                         if (!_gcm_sg_clamp_and_map(gw))
817                                 return i;
818                         n = min(gw->walk_bytes, bytesdone - i);
819                         memcpy(gw->walk_ptr, gw->buf + i, n);
820                         _gcm_sg_unmap_and_advance(gw, n);
821                 }
822         } else
823                 _gcm_sg_unmap_and_advance(gw, bytesdone);
824
825         return bytesdone;
826 }
827
828 static int gcm_aes_crypt(struct aead_request *req, unsigned int flags)
829 {
830         struct crypto_aead *tfm = crypto_aead_reqtfm(req);
831         struct s390_aes_ctx *ctx = crypto_aead_ctx(tfm);
832         unsigned int ivsize = crypto_aead_ivsize(tfm);
833         unsigned int taglen = crypto_aead_authsize(tfm);
834         unsigned int aadlen = req->assoclen;
835         unsigned int pclen = req->cryptlen;
836         int ret = 0;
837
838         unsigned int n, len, in_bytes, out_bytes,
839                      min_bytes, bytes, aad_bytes, pc_bytes;
840         struct gcm_sg_walk gw_in, gw_out;
841         u8 tag[GHASH_DIGEST_SIZE];
842
843         struct {
844                 u32 _[3];               /* reserved */
845                 u32 cv;                 /* Counter Value */
846                 u8 t[GHASH_DIGEST_SIZE];/* Tag */
847                 u8 h[AES_BLOCK_SIZE];   /* Hash-subkey */
848                 u64 taadl;              /* Total AAD Length */
849                 u64 tpcl;               /* Total Plain-/Cipher-text Length */
850                 u8 j0[GHASH_BLOCK_SIZE];/* initial counter value */
851                 u8 k[AES_MAX_KEY_SIZE]; /* Key */
852         } param;
853
854         /*
855          * encrypt
856          *   req->src: aad||plaintext
857          *   req->dst: aad||ciphertext||tag
858          * decrypt
859          *   req->src: aad||ciphertext||tag
860          *   req->dst: aad||plaintext, return 0 or -EBADMSG
861          * aad, plaintext and ciphertext may be empty.
862          */
863         if (flags & CPACF_DECRYPT)
864                 pclen -= taglen;
865         len = aadlen + pclen;
866
867         memset(&param, 0, sizeof(param));
868         param.cv = 1;
869         param.taadl = aadlen * 8;
870         param.tpcl = pclen * 8;
871         memcpy(param.j0, req->iv, ivsize);
872         *(u32 *)(param.j0 + ivsize) = 1;
873         memcpy(param.k, ctx->key, ctx->key_len);
874
875         gcm_walk_start(&gw_in, req->src, len);
876         gcm_walk_start(&gw_out, req->dst, len);
877
878         do {
879                 min_bytes = min_t(unsigned int,
880                                   aadlen > 0 ? aadlen : pclen, AES_BLOCK_SIZE);
881                 in_bytes = gcm_in_walk_go(&gw_in, min_bytes);
882                 out_bytes = gcm_out_walk_go(&gw_out, min_bytes);
883                 bytes = min(in_bytes, out_bytes);
884
885                 if (aadlen + pclen <= bytes) {
886                         aad_bytes = aadlen;
887                         pc_bytes = pclen;
888                         flags |= CPACF_KMA_LAAD | CPACF_KMA_LPC;
889                 } else {
890                         if (aadlen <= bytes) {
891                                 aad_bytes = aadlen;
892                                 pc_bytes = (bytes - aadlen) &
893                                            ~(AES_BLOCK_SIZE - 1);
894                                 flags |= CPACF_KMA_LAAD;
895                         } else {
896                                 aad_bytes = bytes & ~(AES_BLOCK_SIZE - 1);
897                                 pc_bytes = 0;
898                         }
899                 }
900
901                 if (aad_bytes > 0)
902                         memcpy(gw_out.ptr, gw_in.ptr, aad_bytes);
903
904                 cpacf_kma(ctx->fc | flags, &param,
905                           gw_out.ptr + aad_bytes,
906                           gw_in.ptr + aad_bytes, pc_bytes,
907                           gw_in.ptr, aad_bytes);
908
909                 n = aad_bytes + pc_bytes;
910                 if (gcm_in_walk_done(&gw_in, n) != n)
911                         return -ENOMEM;
912                 if (gcm_out_walk_done(&gw_out, n) != n)
913                         return -ENOMEM;
914                 aadlen -= aad_bytes;
915                 pclen -= pc_bytes;
916         } while (aadlen + pclen > 0);
917
918         if (flags & CPACF_DECRYPT) {
919                 scatterwalk_map_and_copy(tag, req->src, len, taglen, 0);
920                 if (crypto_memneq(tag, param.t, taglen))
921                         ret = -EBADMSG;
922         } else
923                 scatterwalk_map_and_copy(param.t, req->dst, len, taglen, 1);
924
925         memzero_explicit(&param, sizeof(param));
926         return ret;
927 }
928
929 static int gcm_aes_encrypt(struct aead_request *req)
930 {
931         return gcm_aes_crypt(req, CPACF_ENCRYPT);
932 }
933
934 static int gcm_aes_decrypt(struct aead_request *req)
935 {
936         return gcm_aes_crypt(req, CPACF_DECRYPT);
937 }
938
939 static struct aead_alg gcm_aes_aead = {
940         .setkey                 = gcm_aes_setkey,
941         .setauthsize            = gcm_aes_setauthsize,
942         .encrypt                = gcm_aes_encrypt,
943         .decrypt                = gcm_aes_decrypt,
944
945         .ivsize                 = GHASH_BLOCK_SIZE - sizeof(u32),
946         .maxauthsize            = GHASH_DIGEST_SIZE,
947         .chunksize              = AES_BLOCK_SIZE,
948
949         .base                   = {
950                 .cra_blocksize          = 1,
951                 .cra_ctxsize            = sizeof(struct s390_aes_ctx),
952                 .cra_priority           = 900,
953                 .cra_name               = "gcm(aes)",
954                 .cra_driver_name        = "gcm-aes-s390",
955                 .cra_module             = THIS_MODULE,
956         },
957 };
958
959 static struct crypto_alg *aes_s390_alg;
960 static struct skcipher_alg *aes_s390_skcipher_algs[4];
961 static int aes_s390_skciphers_num;
962 static struct aead_alg *aes_s390_aead_alg;
963
964 static int aes_s390_register_skcipher(struct skcipher_alg *alg)
965 {
966         int ret;
967
968         ret = crypto_register_skcipher(alg);
969         if (!ret)
970                 aes_s390_skcipher_algs[aes_s390_skciphers_num++] = alg;
971         return ret;
972 }
973
974 static void aes_s390_fini(void)
975 {
976         if (aes_s390_alg)
977                 crypto_unregister_alg(aes_s390_alg);
978         while (aes_s390_skciphers_num--)
979                 crypto_unregister_skcipher(aes_s390_skcipher_algs[aes_s390_skciphers_num]);
980         if (ctrblk)
981                 free_page((unsigned long) ctrblk);
982
983         if (aes_s390_aead_alg)
984                 crypto_unregister_aead(aes_s390_aead_alg);
985 }
986
987 static int __init aes_s390_init(void)
988 {
989         int ret;
990
991         /* Query available functions for KM, KMC, KMCTR and KMA */
992         cpacf_query(CPACF_KM, &km_functions);
993         cpacf_query(CPACF_KMC, &kmc_functions);
994         cpacf_query(CPACF_KMCTR, &kmctr_functions);
995         cpacf_query(CPACF_KMA, &kma_functions);
996
997         if (cpacf_test_func(&km_functions, CPACF_KM_AES_128) ||
998             cpacf_test_func(&km_functions, CPACF_KM_AES_192) ||
999             cpacf_test_func(&km_functions, CPACF_KM_AES_256)) {
1000                 ret = crypto_register_alg(&aes_alg);
1001                 if (ret)
1002                         goto out_err;
1003                 aes_s390_alg = &aes_alg;
1004                 ret = aes_s390_register_skcipher(&ecb_aes_alg);
1005                 if (ret)
1006                         goto out_err;
1007         }
1008
1009         if (cpacf_test_func(&kmc_functions, CPACF_KMC_AES_128) ||
1010             cpacf_test_func(&kmc_functions, CPACF_KMC_AES_192) ||
1011             cpacf_test_func(&kmc_functions, CPACF_KMC_AES_256)) {
1012                 ret = aes_s390_register_skcipher(&cbc_aes_alg);
1013                 if (ret)
1014                         goto out_err;
1015         }
1016
1017         if (cpacf_test_func(&km_functions, CPACF_KM_XTS_128) ||
1018             cpacf_test_func(&km_functions, CPACF_KM_XTS_256)) {
1019                 ret = aes_s390_register_skcipher(&xts_aes_alg);
1020                 if (ret)
1021                         goto out_err;
1022         }
1023
1024         if (cpacf_test_func(&kmctr_functions, CPACF_KMCTR_AES_128) ||
1025             cpacf_test_func(&kmctr_functions, CPACF_KMCTR_AES_192) ||
1026             cpacf_test_func(&kmctr_functions, CPACF_KMCTR_AES_256)) {
1027                 ctrblk = (u8 *) __get_free_page(GFP_KERNEL);
1028                 if (!ctrblk) {
1029                         ret = -ENOMEM;
1030                         goto out_err;
1031                 }
1032                 ret = aes_s390_register_skcipher(&ctr_aes_alg);
1033                 if (ret)
1034                         goto out_err;
1035         }
1036
1037         if (cpacf_test_func(&kma_functions, CPACF_KMA_GCM_AES_128) ||
1038             cpacf_test_func(&kma_functions, CPACF_KMA_GCM_AES_192) ||
1039             cpacf_test_func(&kma_functions, CPACF_KMA_GCM_AES_256)) {
1040                 ret = crypto_register_aead(&gcm_aes_aead);
1041                 if (ret)
1042                         goto out_err;
1043                 aes_s390_aead_alg = &gcm_aes_aead;
1044         }
1045
1046         return 0;
1047 out_err:
1048         aes_s390_fini();
1049         return ret;
1050 }
1051
1052 module_cpu_feature_match(MSA, aes_s390_init);
1053 module_exit(aes_s390_fini);
1054
1055 MODULE_ALIAS_CRYPTO("aes-all");
1056
1057 MODULE_DESCRIPTION("Rijndael (AES) Cipher Algorithm");
1058 MODULE_LICENSE("GPL");
1059 MODULE_IMPORT_NS(CRYPTO_INTERNAL);