crypto: aead - remove useless setting of type flags
[linux-2.6-microblaze.git] / arch / s390 / crypto / aes_s390.c
1 // SPDX-License-Identifier: GPL-2.0+
2 /*
3  * Cryptographic API.
4  *
5  * s390 implementation of the AES Cipher Algorithm.
6  *
7  * s390 Version:
8  *   Copyright IBM Corp. 2005, 2017
9  *   Author(s): Jan Glauber (jang@de.ibm.com)
10  *              Sebastian Siewior (sebastian@breakpoint.cc> SW-Fallback
11  *              Patrick Steuer <patrick.steuer@de.ibm.com>
12  *              Harald Freudenberger <freude@de.ibm.com>
13  *
14  * Derived from "crypto/aes_generic.c"
15  */
16
17 #define KMSG_COMPONENT "aes_s390"
18 #define pr_fmt(fmt) KMSG_COMPONENT ": " fmt
19
20 #include <crypto/aes.h>
21 #include <crypto/algapi.h>
22 #include <crypto/ghash.h>
23 #include <crypto/internal/aead.h>
24 #include <crypto/internal/skcipher.h>
25 #include <crypto/scatterwalk.h>
26 #include <linux/err.h>
27 #include <linux/module.h>
28 #include <linux/cpufeature.h>
29 #include <linux/init.h>
30 #include <linux/spinlock.h>
31 #include <linux/fips.h>
32 #include <linux/string.h>
33 #include <crypto/xts.h>
34 #include <asm/cpacf.h>
35
36 static u8 *ctrblk;
37 static DEFINE_SPINLOCK(ctrblk_lock);
38
39 static cpacf_mask_t km_functions, kmc_functions, kmctr_functions,
40                     kma_functions;
41
42 struct s390_aes_ctx {
43         u8 key[AES_MAX_KEY_SIZE];
44         int key_len;
45         unsigned long fc;
46         union {
47                 struct crypto_skcipher *blk;
48                 struct crypto_cipher *cip;
49         } fallback;
50 };
51
52 struct s390_xts_ctx {
53         u8 key[32];
54         u8 pcc_key[32];
55         int key_len;
56         unsigned long fc;
57         struct crypto_skcipher *fallback;
58 };
59
60 struct gcm_sg_walk {
61         struct scatter_walk walk;
62         unsigned int walk_bytes;
63         u8 *walk_ptr;
64         unsigned int walk_bytes_remain;
65         u8 buf[AES_BLOCK_SIZE];
66         unsigned int buf_bytes;
67         u8 *ptr;
68         unsigned int nbytes;
69 };
70
71 static int setkey_fallback_cip(struct crypto_tfm *tfm, const u8 *in_key,
72                 unsigned int key_len)
73 {
74         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
75         int ret;
76
77         sctx->fallback.cip->base.crt_flags &= ~CRYPTO_TFM_REQ_MASK;
78         sctx->fallback.cip->base.crt_flags |= (tfm->crt_flags &
79                         CRYPTO_TFM_REQ_MASK);
80
81         ret = crypto_cipher_setkey(sctx->fallback.cip, in_key, key_len);
82         if (ret) {
83                 tfm->crt_flags &= ~CRYPTO_TFM_RES_MASK;
84                 tfm->crt_flags |= (sctx->fallback.cip->base.crt_flags &
85                                 CRYPTO_TFM_RES_MASK);
86         }
87         return ret;
88 }
89
90 static int aes_set_key(struct crypto_tfm *tfm, const u8 *in_key,
91                        unsigned int key_len)
92 {
93         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
94         unsigned long fc;
95
96         /* Pick the correct function code based on the key length */
97         fc = (key_len == 16) ? CPACF_KM_AES_128 :
98              (key_len == 24) ? CPACF_KM_AES_192 :
99              (key_len == 32) ? CPACF_KM_AES_256 : 0;
100
101         /* Check if the function code is available */
102         sctx->fc = (fc && cpacf_test_func(&km_functions, fc)) ? fc : 0;
103         if (!sctx->fc)
104                 return setkey_fallback_cip(tfm, in_key, key_len);
105
106         sctx->key_len = key_len;
107         memcpy(sctx->key, in_key, key_len);
108         return 0;
109 }
110
111 static void aes_encrypt(struct crypto_tfm *tfm, u8 *out, const u8 *in)
112 {
113         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
114
115         if (unlikely(!sctx->fc)) {
116                 crypto_cipher_encrypt_one(sctx->fallback.cip, out, in);
117                 return;
118         }
119         cpacf_km(sctx->fc, &sctx->key, out, in, AES_BLOCK_SIZE);
120 }
121
122 static void aes_decrypt(struct crypto_tfm *tfm, u8 *out, const u8 *in)
123 {
124         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
125
126         if (unlikely(!sctx->fc)) {
127                 crypto_cipher_decrypt_one(sctx->fallback.cip, out, in);
128                 return;
129         }
130         cpacf_km(sctx->fc | CPACF_DECRYPT,
131                  &sctx->key, out, in, AES_BLOCK_SIZE);
132 }
133
134 static int fallback_init_cip(struct crypto_tfm *tfm)
135 {
136         const char *name = tfm->__crt_alg->cra_name;
137         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
138
139         sctx->fallback.cip = crypto_alloc_cipher(name, 0,
140                         CRYPTO_ALG_ASYNC | CRYPTO_ALG_NEED_FALLBACK);
141
142         if (IS_ERR(sctx->fallback.cip)) {
143                 pr_err("Allocating AES fallback algorithm %s failed\n",
144                        name);
145                 return PTR_ERR(sctx->fallback.cip);
146         }
147
148         return 0;
149 }
150
151 static void fallback_exit_cip(struct crypto_tfm *tfm)
152 {
153         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
154
155         crypto_free_cipher(sctx->fallback.cip);
156         sctx->fallback.cip = NULL;
157 }
158
159 static struct crypto_alg aes_alg = {
160         .cra_name               =       "aes",
161         .cra_driver_name        =       "aes-s390",
162         .cra_priority           =       300,
163         .cra_flags              =       CRYPTO_ALG_TYPE_CIPHER |
164                                         CRYPTO_ALG_NEED_FALLBACK,
165         .cra_blocksize          =       AES_BLOCK_SIZE,
166         .cra_ctxsize            =       sizeof(struct s390_aes_ctx),
167         .cra_module             =       THIS_MODULE,
168         .cra_init               =       fallback_init_cip,
169         .cra_exit               =       fallback_exit_cip,
170         .cra_u                  =       {
171                 .cipher = {
172                         .cia_min_keysize        =       AES_MIN_KEY_SIZE,
173                         .cia_max_keysize        =       AES_MAX_KEY_SIZE,
174                         .cia_setkey             =       aes_set_key,
175                         .cia_encrypt            =       aes_encrypt,
176                         .cia_decrypt            =       aes_decrypt,
177                 }
178         }
179 };
180
181 static int setkey_fallback_blk(struct crypto_tfm *tfm, const u8 *key,
182                 unsigned int len)
183 {
184         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
185         unsigned int ret;
186
187         crypto_skcipher_clear_flags(sctx->fallback.blk, CRYPTO_TFM_REQ_MASK);
188         crypto_skcipher_set_flags(sctx->fallback.blk, tfm->crt_flags &
189                                                       CRYPTO_TFM_REQ_MASK);
190
191         ret = crypto_skcipher_setkey(sctx->fallback.blk, key, len);
192
193         tfm->crt_flags &= ~CRYPTO_TFM_RES_MASK;
194         tfm->crt_flags |= crypto_skcipher_get_flags(sctx->fallback.blk) &
195                           CRYPTO_TFM_RES_MASK;
196
197         return ret;
198 }
199
200 static int fallback_blk_dec(struct blkcipher_desc *desc,
201                 struct scatterlist *dst, struct scatterlist *src,
202                 unsigned int nbytes)
203 {
204         unsigned int ret;
205         struct crypto_blkcipher *tfm = desc->tfm;
206         struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(tfm);
207         SKCIPHER_REQUEST_ON_STACK(req, sctx->fallback.blk);
208
209         skcipher_request_set_tfm(req, sctx->fallback.blk);
210         skcipher_request_set_callback(req, desc->flags, NULL, NULL);
211         skcipher_request_set_crypt(req, src, dst, nbytes, desc->info);
212
213         ret = crypto_skcipher_decrypt(req);
214
215         skcipher_request_zero(req);
216         return ret;
217 }
218
219 static int fallback_blk_enc(struct blkcipher_desc *desc,
220                 struct scatterlist *dst, struct scatterlist *src,
221                 unsigned int nbytes)
222 {
223         unsigned int ret;
224         struct crypto_blkcipher *tfm = desc->tfm;
225         struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(tfm);
226         SKCIPHER_REQUEST_ON_STACK(req, sctx->fallback.blk);
227
228         skcipher_request_set_tfm(req, sctx->fallback.blk);
229         skcipher_request_set_callback(req, desc->flags, NULL, NULL);
230         skcipher_request_set_crypt(req, src, dst, nbytes, desc->info);
231
232         ret = crypto_skcipher_encrypt(req);
233         return ret;
234 }
235
236 static int ecb_aes_set_key(struct crypto_tfm *tfm, const u8 *in_key,
237                            unsigned int key_len)
238 {
239         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
240         unsigned long fc;
241
242         /* Pick the correct function code based on the key length */
243         fc = (key_len == 16) ? CPACF_KM_AES_128 :
244              (key_len == 24) ? CPACF_KM_AES_192 :
245              (key_len == 32) ? CPACF_KM_AES_256 : 0;
246
247         /* Check if the function code is available */
248         sctx->fc = (fc && cpacf_test_func(&km_functions, fc)) ? fc : 0;
249         if (!sctx->fc)
250                 return setkey_fallback_blk(tfm, in_key, key_len);
251
252         sctx->key_len = key_len;
253         memcpy(sctx->key, in_key, key_len);
254         return 0;
255 }
256
257 static int ecb_aes_crypt(struct blkcipher_desc *desc, unsigned long modifier,
258                          struct blkcipher_walk *walk)
259 {
260         struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(desc->tfm);
261         unsigned int nbytes, n;
262         int ret;
263
264         ret = blkcipher_walk_virt(desc, walk);
265         while ((nbytes = walk->nbytes) >= AES_BLOCK_SIZE) {
266                 /* only use complete blocks */
267                 n = nbytes & ~(AES_BLOCK_SIZE - 1);
268                 cpacf_km(sctx->fc | modifier, sctx->key,
269                          walk->dst.virt.addr, walk->src.virt.addr, n);
270                 ret = blkcipher_walk_done(desc, walk, nbytes - n);
271         }
272
273         return ret;
274 }
275
276 static int ecb_aes_encrypt(struct blkcipher_desc *desc,
277                            struct scatterlist *dst, struct scatterlist *src,
278                            unsigned int nbytes)
279 {
280         struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(desc->tfm);
281         struct blkcipher_walk walk;
282
283         if (unlikely(!sctx->fc))
284                 return fallback_blk_enc(desc, dst, src, nbytes);
285
286         blkcipher_walk_init(&walk, dst, src, nbytes);
287         return ecb_aes_crypt(desc, 0, &walk);
288 }
289
290 static int ecb_aes_decrypt(struct blkcipher_desc *desc,
291                            struct scatterlist *dst, struct scatterlist *src,
292                            unsigned int nbytes)
293 {
294         struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(desc->tfm);
295         struct blkcipher_walk walk;
296
297         if (unlikely(!sctx->fc))
298                 return fallback_blk_dec(desc, dst, src, nbytes);
299
300         blkcipher_walk_init(&walk, dst, src, nbytes);
301         return ecb_aes_crypt(desc, CPACF_DECRYPT, &walk);
302 }
303
304 static int fallback_init_blk(struct crypto_tfm *tfm)
305 {
306         const char *name = tfm->__crt_alg->cra_name;
307         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
308
309         sctx->fallback.blk = crypto_alloc_skcipher(name, 0,
310                                                    CRYPTO_ALG_ASYNC |
311                                                    CRYPTO_ALG_NEED_FALLBACK);
312
313         if (IS_ERR(sctx->fallback.blk)) {
314                 pr_err("Allocating AES fallback algorithm %s failed\n",
315                        name);
316                 return PTR_ERR(sctx->fallback.blk);
317         }
318
319         return 0;
320 }
321
322 static void fallback_exit_blk(struct crypto_tfm *tfm)
323 {
324         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
325
326         crypto_free_skcipher(sctx->fallback.blk);
327 }
328
329 static struct crypto_alg ecb_aes_alg = {
330         .cra_name               =       "ecb(aes)",
331         .cra_driver_name        =       "ecb-aes-s390",
332         .cra_priority           =       401,    /* combo: aes + ecb + 1 */
333         .cra_flags              =       CRYPTO_ALG_TYPE_BLKCIPHER |
334                                         CRYPTO_ALG_NEED_FALLBACK,
335         .cra_blocksize          =       AES_BLOCK_SIZE,
336         .cra_ctxsize            =       sizeof(struct s390_aes_ctx),
337         .cra_type               =       &crypto_blkcipher_type,
338         .cra_module             =       THIS_MODULE,
339         .cra_init               =       fallback_init_blk,
340         .cra_exit               =       fallback_exit_blk,
341         .cra_u                  =       {
342                 .blkcipher = {
343                         .min_keysize            =       AES_MIN_KEY_SIZE,
344                         .max_keysize            =       AES_MAX_KEY_SIZE,
345                         .setkey                 =       ecb_aes_set_key,
346                         .encrypt                =       ecb_aes_encrypt,
347                         .decrypt                =       ecb_aes_decrypt,
348                 }
349         }
350 };
351
352 static int cbc_aes_set_key(struct crypto_tfm *tfm, const u8 *in_key,
353                            unsigned int key_len)
354 {
355         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
356         unsigned long fc;
357
358         /* Pick the correct function code based on the key length */
359         fc = (key_len == 16) ? CPACF_KMC_AES_128 :
360              (key_len == 24) ? CPACF_KMC_AES_192 :
361              (key_len == 32) ? CPACF_KMC_AES_256 : 0;
362
363         /* Check if the function code is available */
364         sctx->fc = (fc && cpacf_test_func(&kmc_functions, fc)) ? fc : 0;
365         if (!sctx->fc)
366                 return setkey_fallback_blk(tfm, in_key, key_len);
367
368         sctx->key_len = key_len;
369         memcpy(sctx->key, in_key, key_len);
370         return 0;
371 }
372
373 static int cbc_aes_crypt(struct blkcipher_desc *desc, unsigned long modifier,
374                          struct blkcipher_walk *walk)
375 {
376         struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(desc->tfm);
377         unsigned int nbytes, n;
378         int ret;
379         struct {
380                 u8 iv[AES_BLOCK_SIZE];
381                 u8 key[AES_MAX_KEY_SIZE];
382         } param;
383
384         ret = blkcipher_walk_virt(desc, walk);
385         memcpy(param.iv, walk->iv, AES_BLOCK_SIZE);
386         memcpy(param.key, sctx->key, sctx->key_len);
387         while ((nbytes = walk->nbytes) >= AES_BLOCK_SIZE) {
388                 /* only use complete blocks */
389                 n = nbytes & ~(AES_BLOCK_SIZE - 1);
390                 cpacf_kmc(sctx->fc | modifier, &param,
391                           walk->dst.virt.addr, walk->src.virt.addr, n);
392                 ret = blkcipher_walk_done(desc, walk, nbytes - n);
393         }
394         memcpy(walk->iv, param.iv, AES_BLOCK_SIZE);
395         return ret;
396 }
397
398 static int cbc_aes_encrypt(struct blkcipher_desc *desc,
399                            struct scatterlist *dst, struct scatterlist *src,
400                            unsigned int nbytes)
401 {
402         struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(desc->tfm);
403         struct blkcipher_walk walk;
404
405         if (unlikely(!sctx->fc))
406                 return fallback_blk_enc(desc, dst, src, nbytes);
407
408         blkcipher_walk_init(&walk, dst, src, nbytes);
409         return cbc_aes_crypt(desc, 0, &walk);
410 }
411
412 static int cbc_aes_decrypt(struct blkcipher_desc *desc,
413                            struct scatterlist *dst, struct scatterlist *src,
414                            unsigned int nbytes)
415 {
416         struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(desc->tfm);
417         struct blkcipher_walk walk;
418
419         if (unlikely(!sctx->fc))
420                 return fallback_blk_dec(desc, dst, src, nbytes);
421
422         blkcipher_walk_init(&walk, dst, src, nbytes);
423         return cbc_aes_crypt(desc, CPACF_DECRYPT, &walk);
424 }
425
426 static struct crypto_alg cbc_aes_alg = {
427         .cra_name               =       "cbc(aes)",
428         .cra_driver_name        =       "cbc-aes-s390",
429         .cra_priority           =       402,    /* ecb-aes-s390 + 1 */
430         .cra_flags              =       CRYPTO_ALG_TYPE_BLKCIPHER |
431                                         CRYPTO_ALG_NEED_FALLBACK,
432         .cra_blocksize          =       AES_BLOCK_SIZE,
433         .cra_ctxsize            =       sizeof(struct s390_aes_ctx),
434         .cra_type               =       &crypto_blkcipher_type,
435         .cra_module             =       THIS_MODULE,
436         .cra_init               =       fallback_init_blk,
437         .cra_exit               =       fallback_exit_blk,
438         .cra_u                  =       {
439                 .blkcipher = {
440                         .min_keysize            =       AES_MIN_KEY_SIZE,
441                         .max_keysize            =       AES_MAX_KEY_SIZE,
442                         .ivsize                 =       AES_BLOCK_SIZE,
443                         .setkey                 =       cbc_aes_set_key,
444                         .encrypt                =       cbc_aes_encrypt,
445                         .decrypt                =       cbc_aes_decrypt,
446                 }
447         }
448 };
449
450 static int xts_fallback_setkey(struct crypto_tfm *tfm, const u8 *key,
451                                    unsigned int len)
452 {
453         struct s390_xts_ctx *xts_ctx = crypto_tfm_ctx(tfm);
454         unsigned int ret;
455
456         crypto_skcipher_clear_flags(xts_ctx->fallback, CRYPTO_TFM_REQ_MASK);
457         crypto_skcipher_set_flags(xts_ctx->fallback, tfm->crt_flags &
458                                                      CRYPTO_TFM_REQ_MASK);
459
460         ret = crypto_skcipher_setkey(xts_ctx->fallback, key, len);
461
462         tfm->crt_flags &= ~CRYPTO_TFM_RES_MASK;
463         tfm->crt_flags |= crypto_skcipher_get_flags(xts_ctx->fallback) &
464                           CRYPTO_TFM_RES_MASK;
465
466         return ret;
467 }
468
469 static int xts_fallback_decrypt(struct blkcipher_desc *desc,
470                 struct scatterlist *dst, struct scatterlist *src,
471                 unsigned int nbytes)
472 {
473         struct crypto_blkcipher *tfm = desc->tfm;
474         struct s390_xts_ctx *xts_ctx = crypto_blkcipher_ctx(tfm);
475         SKCIPHER_REQUEST_ON_STACK(req, xts_ctx->fallback);
476         unsigned int ret;
477
478         skcipher_request_set_tfm(req, xts_ctx->fallback);
479         skcipher_request_set_callback(req, desc->flags, NULL, NULL);
480         skcipher_request_set_crypt(req, src, dst, nbytes, desc->info);
481
482         ret = crypto_skcipher_decrypt(req);
483
484         skcipher_request_zero(req);
485         return ret;
486 }
487
488 static int xts_fallback_encrypt(struct blkcipher_desc *desc,
489                 struct scatterlist *dst, struct scatterlist *src,
490                 unsigned int nbytes)
491 {
492         struct crypto_blkcipher *tfm = desc->tfm;
493         struct s390_xts_ctx *xts_ctx = crypto_blkcipher_ctx(tfm);
494         SKCIPHER_REQUEST_ON_STACK(req, xts_ctx->fallback);
495         unsigned int ret;
496
497         skcipher_request_set_tfm(req, xts_ctx->fallback);
498         skcipher_request_set_callback(req, desc->flags, NULL, NULL);
499         skcipher_request_set_crypt(req, src, dst, nbytes, desc->info);
500
501         ret = crypto_skcipher_encrypt(req);
502
503         skcipher_request_zero(req);
504         return ret;
505 }
506
507 static int xts_aes_set_key(struct crypto_tfm *tfm, const u8 *in_key,
508                            unsigned int key_len)
509 {
510         struct s390_xts_ctx *xts_ctx = crypto_tfm_ctx(tfm);
511         unsigned long fc;
512         int err;
513
514         err = xts_check_key(tfm, in_key, key_len);
515         if (err)
516                 return err;
517
518         /* In fips mode only 128 bit or 256 bit keys are valid */
519         if (fips_enabled && key_len != 32 && key_len != 64) {
520                 tfm->crt_flags |= CRYPTO_TFM_RES_BAD_KEY_LEN;
521                 return -EINVAL;
522         }
523
524         /* Pick the correct function code based on the key length */
525         fc = (key_len == 32) ? CPACF_KM_XTS_128 :
526              (key_len == 64) ? CPACF_KM_XTS_256 : 0;
527
528         /* Check if the function code is available */
529         xts_ctx->fc = (fc && cpacf_test_func(&km_functions, fc)) ? fc : 0;
530         if (!xts_ctx->fc)
531                 return xts_fallback_setkey(tfm, in_key, key_len);
532
533         /* Split the XTS key into the two subkeys */
534         key_len = key_len / 2;
535         xts_ctx->key_len = key_len;
536         memcpy(xts_ctx->key, in_key, key_len);
537         memcpy(xts_ctx->pcc_key, in_key + key_len, key_len);
538         return 0;
539 }
540
541 static int xts_aes_crypt(struct blkcipher_desc *desc, unsigned long modifier,
542                          struct blkcipher_walk *walk)
543 {
544         struct s390_xts_ctx *xts_ctx = crypto_blkcipher_ctx(desc->tfm);
545         unsigned int offset, nbytes, n;
546         int ret;
547         struct {
548                 u8 key[32];
549                 u8 tweak[16];
550                 u8 block[16];
551                 u8 bit[16];
552                 u8 xts[16];
553         } pcc_param;
554         struct {
555                 u8 key[32];
556                 u8 init[16];
557         } xts_param;
558
559         ret = blkcipher_walk_virt(desc, walk);
560         offset = xts_ctx->key_len & 0x10;
561         memset(pcc_param.block, 0, sizeof(pcc_param.block));
562         memset(pcc_param.bit, 0, sizeof(pcc_param.bit));
563         memset(pcc_param.xts, 0, sizeof(pcc_param.xts));
564         memcpy(pcc_param.tweak, walk->iv, sizeof(pcc_param.tweak));
565         memcpy(pcc_param.key + offset, xts_ctx->pcc_key, xts_ctx->key_len);
566         cpacf_pcc(xts_ctx->fc, pcc_param.key + offset);
567
568         memcpy(xts_param.key + offset, xts_ctx->key, xts_ctx->key_len);
569         memcpy(xts_param.init, pcc_param.xts, 16);
570
571         while ((nbytes = walk->nbytes) >= AES_BLOCK_SIZE) {
572                 /* only use complete blocks */
573                 n = nbytes & ~(AES_BLOCK_SIZE - 1);
574                 cpacf_km(xts_ctx->fc | modifier, xts_param.key + offset,
575                          walk->dst.virt.addr, walk->src.virt.addr, n);
576                 ret = blkcipher_walk_done(desc, walk, nbytes - n);
577         }
578         return ret;
579 }
580
581 static int xts_aes_encrypt(struct blkcipher_desc *desc,
582                            struct scatterlist *dst, struct scatterlist *src,
583                            unsigned int nbytes)
584 {
585         struct s390_xts_ctx *xts_ctx = crypto_blkcipher_ctx(desc->tfm);
586         struct blkcipher_walk walk;
587
588         if (unlikely(!xts_ctx->fc))
589                 return xts_fallback_encrypt(desc, dst, src, nbytes);
590
591         blkcipher_walk_init(&walk, dst, src, nbytes);
592         return xts_aes_crypt(desc, 0, &walk);
593 }
594
595 static int xts_aes_decrypt(struct blkcipher_desc *desc,
596                            struct scatterlist *dst, struct scatterlist *src,
597                            unsigned int nbytes)
598 {
599         struct s390_xts_ctx *xts_ctx = crypto_blkcipher_ctx(desc->tfm);
600         struct blkcipher_walk walk;
601
602         if (unlikely(!xts_ctx->fc))
603                 return xts_fallback_decrypt(desc, dst, src, nbytes);
604
605         blkcipher_walk_init(&walk, dst, src, nbytes);
606         return xts_aes_crypt(desc, CPACF_DECRYPT, &walk);
607 }
608
609 static int xts_fallback_init(struct crypto_tfm *tfm)
610 {
611         const char *name = tfm->__crt_alg->cra_name;
612         struct s390_xts_ctx *xts_ctx = crypto_tfm_ctx(tfm);
613
614         xts_ctx->fallback = crypto_alloc_skcipher(name, 0,
615                                                   CRYPTO_ALG_ASYNC |
616                                                   CRYPTO_ALG_NEED_FALLBACK);
617
618         if (IS_ERR(xts_ctx->fallback)) {
619                 pr_err("Allocating XTS fallback algorithm %s failed\n",
620                        name);
621                 return PTR_ERR(xts_ctx->fallback);
622         }
623         return 0;
624 }
625
626 static void xts_fallback_exit(struct crypto_tfm *tfm)
627 {
628         struct s390_xts_ctx *xts_ctx = crypto_tfm_ctx(tfm);
629
630         crypto_free_skcipher(xts_ctx->fallback);
631 }
632
633 static struct crypto_alg xts_aes_alg = {
634         .cra_name               =       "xts(aes)",
635         .cra_driver_name        =       "xts-aes-s390",
636         .cra_priority           =       402,    /* ecb-aes-s390 + 1 */
637         .cra_flags              =       CRYPTO_ALG_TYPE_BLKCIPHER |
638                                         CRYPTO_ALG_NEED_FALLBACK,
639         .cra_blocksize          =       AES_BLOCK_SIZE,
640         .cra_ctxsize            =       sizeof(struct s390_xts_ctx),
641         .cra_type               =       &crypto_blkcipher_type,
642         .cra_module             =       THIS_MODULE,
643         .cra_init               =       xts_fallback_init,
644         .cra_exit               =       xts_fallback_exit,
645         .cra_u                  =       {
646                 .blkcipher = {
647                         .min_keysize            =       2 * AES_MIN_KEY_SIZE,
648                         .max_keysize            =       2 * AES_MAX_KEY_SIZE,
649                         .ivsize                 =       AES_BLOCK_SIZE,
650                         .setkey                 =       xts_aes_set_key,
651                         .encrypt                =       xts_aes_encrypt,
652                         .decrypt                =       xts_aes_decrypt,
653                 }
654         }
655 };
656
657 static int ctr_aes_set_key(struct crypto_tfm *tfm, const u8 *in_key,
658                            unsigned int key_len)
659 {
660         struct s390_aes_ctx *sctx = crypto_tfm_ctx(tfm);
661         unsigned long fc;
662
663         /* Pick the correct function code based on the key length */
664         fc = (key_len == 16) ? CPACF_KMCTR_AES_128 :
665              (key_len == 24) ? CPACF_KMCTR_AES_192 :
666              (key_len == 32) ? CPACF_KMCTR_AES_256 : 0;
667
668         /* Check if the function code is available */
669         sctx->fc = (fc && cpacf_test_func(&kmctr_functions, fc)) ? fc : 0;
670         if (!sctx->fc)
671                 return setkey_fallback_blk(tfm, in_key, key_len);
672
673         sctx->key_len = key_len;
674         memcpy(sctx->key, in_key, key_len);
675         return 0;
676 }
677
678 static unsigned int __ctrblk_init(u8 *ctrptr, u8 *iv, unsigned int nbytes)
679 {
680         unsigned int i, n;
681
682         /* only use complete blocks, max. PAGE_SIZE */
683         memcpy(ctrptr, iv, AES_BLOCK_SIZE);
684         n = (nbytes > PAGE_SIZE) ? PAGE_SIZE : nbytes & ~(AES_BLOCK_SIZE - 1);
685         for (i = (n / AES_BLOCK_SIZE) - 1; i > 0; i--) {
686                 memcpy(ctrptr + AES_BLOCK_SIZE, ctrptr, AES_BLOCK_SIZE);
687                 crypto_inc(ctrptr + AES_BLOCK_SIZE, AES_BLOCK_SIZE);
688                 ctrptr += AES_BLOCK_SIZE;
689         }
690         return n;
691 }
692
693 static int ctr_aes_crypt(struct blkcipher_desc *desc, unsigned long modifier,
694                          struct blkcipher_walk *walk)
695 {
696         struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(desc->tfm);
697         u8 buf[AES_BLOCK_SIZE], *ctrptr;
698         unsigned int n, nbytes;
699         int ret, locked;
700
701         locked = spin_trylock(&ctrblk_lock);
702
703         ret = blkcipher_walk_virt_block(desc, walk, AES_BLOCK_SIZE);
704         while ((nbytes = walk->nbytes) >= AES_BLOCK_SIZE) {
705                 n = AES_BLOCK_SIZE;
706                 if (nbytes >= 2*AES_BLOCK_SIZE && locked)
707                         n = __ctrblk_init(ctrblk, walk->iv, nbytes);
708                 ctrptr = (n > AES_BLOCK_SIZE) ? ctrblk : walk->iv;
709                 cpacf_kmctr(sctx->fc | modifier, sctx->key,
710                             walk->dst.virt.addr, walk->src.virt.addr,
711                             n, ctrptr);
712                 if (ctrptr == ctrblk)
713                         memcpy(walk->iv, ctrptr + n - AES_BLOCK_SIZE,
714                                AES_BLOCK_SIZE);
715                 crypto_inc(walk->iv, AES_BLOCK_SIZE);
716                 ret = blkcipher_walk_done(desc, walk, nbytes - n);
717         }
718         if (locked)
719                 spin_unlock(&ctrblk_lock);
720         /*
721          * final block may be < AES_BLOCK_SIZE, copy only nbytes
722          */
723         if (nbytes) {
724                 cpacf_kmctr(sctx->fc | modifier, sctx->key,
725                             buf, walk->src.virt.addr,
726                             AES_BLOCK_SIZE, walk->iv);
727                 memcpy(walk->dst.virt.addr, buf, nbytes);
728                 crypto_inc(walk->iv, AES_BLOCK_SIZE);
729                 ret = blkcipher_walk_done(desc, walk, 0);
730         }
731
732         return ret;
733 }
734
735 static int ctr_aes_encrypt(struct blkcipher_desc *desc,
736                            struct scatterlist *dst, struct scatterlist *src,
737                            unsigned int nbytes)
738 {
739         struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(desc->tfm);
740         struct blkcipher_walk walk;
741
742         if (unlikely(!sctx->fc))
743                 return fallback_blk_enc(desc, dst, src, nbytes);
744
745         blkcipher_walk_init(&walk, dst, src, nbytes);
746         return ctr_aes_crypt(desc, 0, &walk);
747 }
748
749 static int ctr_aes_decrypt(struct blkcipher_desc *desc,
750                            struct scatterlist *dst, struct scatterlist *src,
751                            unsigned int nbytes)
752 {
753         struct s390_aes_ctx *sctx = crypto_blkcipher_ctx(desc->tfm);
754         struct blkcipher_walk walk;
755
756         if (unlikely(!sctx->fc))
757                 return fallback_blk_dec(desc, dst, src, nbytes);
758
759         blkcipher_walk_init(&walk, dst, src, nbytes);
760         return ctr_aes_crypt(desc, CPACF_DECRYPT, &walk);
761 }
762
763 static struct crypto_alg ctr_aes_alg = {
764         .cra_name               =       "ctr(aes)",
765         .cra_driver_name        =       "ctr-aes-s390",
766         .cra_priority           =       402,    /* ecb-aes-s390 + 1 */
767         .cra_flags              =       CRYPTO_ALG_TYPE_BLKCIPHER |
768                                         CRYPTO_ALG_NEED_FALLBACK,
769         .cra_blocksize          =       1,
770         .cra_ctxsize            =       sizeof(struct s390_aes_ctx),
771         .cra_type               =       &crypto_blkcipher_type,
772         .cra_module             =       THIS_MODULE,
773         .cra_init               =       fallback_init_blk,
774         .cra_exit               =       fallback_exit_blk,
775         .cra_u                  =       {
776                 .blkcipher = {
777                         .min_keysize            =       AES_MIN_KEY_SIZE,
778                         .max_keysize            =       AES_MAX_KEY_SIZE,
779                         .ivsize                 =       AES_BLOCK_SIZE,
780                         .setkey                 =       ctr_aes_set_key,
781                         .encrypt                =       ctr_aes_encrypt,
782                         .decrypt                =       ctr_aes_decrypt,
783                 }
784         }
785 };
786
787 static int gcm_aes_setkey(struct crypto_aead *tfm, const u8 *key,
788                           unsigned int keylen)
789 {
790         struct s390_aes_ctx *ctx = crypto_aead_ctx(tfm);
791
792         switch (keylen) {
793         case AES_KEYSIZE_128:
794                 ctx->fc = CPACF_KMA_GCM_AES_128;
795                 break;
796         case AES_KEYSIZE_192:
797                 ctx->fc = CPACF_KMA_GCM_AES_192;
798                 break;
799         case AES_KEYSIZE_256:
800                 ctx->fc = CPACF_KMA_GCM_AES_256;
801                 break;
802         default:
803                 return -EINVAL;
804         }
805
806         memcpy(ctx->key, key, keylen);
807         ctx->key_len = keylen;
808         return 0;
809 }
810
811 static int gcm_aes_setauthsize(struct crypto_aead *tfm, unsigned int authsize)
812 {
813         switch (authsize) {
814         case 4:
815         case 8:
816         case 12:
817         case 13:
818         case 14:
819         case 15:
820         case 16:
821                 break;
822         default:
823                 return -EINVAL;
824         }
825
826         return 0;
827 }
828
829 static void gcm_sg_walk_start(struct gcm_sg_walk *gw, struct scatterlist *sg,
830                               unsigned int len)
831 {
832         memset(gw, 0, sizeof(*gw));
833         gw->walk_bytes_remain = len;
834         scatterwalk_start(&gw->walk, sg);
835 }
836
837 static int gcm_sg_walk_go(struct gcm_sg_walk *gw, unsigned int minbytesneeded)
838 {
839         int n;
840
841         /* minbytesneeded <= AES_BLOCK_SIZE */
842         if (gw->buf_bytes && gw->buf_bytes >= minbytesneeded) {
843                 gw->ptr = gw->buf;
844                 gw->nbytes = gw->buf_bytes;
845                 goto out;
846         }
847
848         if (gw->walk_bytes_remain == 0) {
849                 gw->ptr = NULL;
850                 gw->nbytes = 0;
851                 goto out;
852         }
853
854         gw->walk_bytes = scatterwalk_clamp(&gw->walk, gw->walk_bytes_remain);
855         if (!gw->walk_bytes) {
856                 scatterwalk_start(&gw->walk, sg_next(gw->walk.sg));
857                 gw->walk_bytes = scatterwalk_clamp(&gw->walk,
858                                                    gw->walk_bytes_remain);
859         }
860         gw->walk_ptr = scatterwalk_map(&gw->walk);
861
862         if (!gw->buf_bytes && gw->walk_bytes >= minbytesneeded) {
863                 gw->ptr = gw->walk_ptr;
864                 gw->nbytes = gw->walk_bytes;
865                 goto out;
866         }
867
868         while (1) {
869                 n = min(gw->walk_bytes, AES_BLOCK_SIZE - gw->buf_bytes);
870                 memcpy(gw->buf + gw->buf_bytes, gw->walk_ptr, n);
871                 gw->buf_bytes += n;
872                 gw->walk_bytes_remain -= n;
873                 scatterwalk_unmap(&gw->walk);
874                 scatterwalk_advance(&gw->walk, n);
875                 scatterwalk_done(&gw->walk, 0, gw->walk_bytes_remain);
876
877                 if (gw->buf_bytes >= minbytesneeded) {
878                         gw->ptr = gw->buf;
879                         gw->nbytes = gw->buf_bytes;
880                         goto out;
881                 }
882
883                 gw->walk_bytes = scatterwalk_clamp(&gw->walk,
884                                                    gw->walk_bytes_remain);
885                 if (!gw->walk_bytes) {
886                         scatterwalk_start(&gw->walk, sg_next(gw->walk.sg));
887                         gw->walk_bytes = scatterwalk_clamp(&gw->walk,
888                                                         gw->walk_bytes_remain);
889                 }
890                 gw->walk_ptr = scatterwalk_map(&gw->walk);
891         }
892
893 out:
894         return gw->nbytes;
895 }
896
897 static void gcm_sg_walk_done(struct gcm_sg_walk *gw, unsigned int bytesdone)
898 {
899         int n;
900
901         if (gw->ptr == NULL)
902                 return;
903
904         if (gw->ptr == gw->buf) {
905                 n = gw->buf_bytes - bytesdone;
906                 if (n > 0) {
907                         memmove(gw->buf, gw->buf + bytesdone, n);
908                         gw->buf_bytes -= n;
909                 } else
910                         gw->buf_bytes = 0;
911         } else {
912                 gw->walk_bytes_remain -= bytesdone;
913                 scatterwalk_unmap(&gw->walk);
914                 scatterwalk_advance(&gw->walk, bytesdone);
915                 scatterwalk_done(&gw->walk, 0, gw->walk_bytes_remain);
916         }
917 }
918
919 static int gcm_aes_crypt(struct aead_request *req, unsigned int flags)
920 {
921         struct crypto_aead *tfm = crypto_aead_reqtfm(req);
922         struct s390_aes_ctx *ctx = crypto_aead_ctx(tfm);
923         unsigned int ivsize = crypto_aead_ivsize(tfm);
924         unsigned int taglen = crypto_aead_authsize(tfm);
925         unsigned int aadlen = req->assoclen;
926         unsigned int pclen = req->cryptlen;
927         int ret = 0;
928
929         unsigned int len, in_bytes, out_bytes,
930                      min_bytes, bytes, aad_bytes, pc_bytes;
931         struct gcm_sg_walk gw_in, gw_out;
932         u8 tag[GHASH_DIGEST_SIZE];
933
934         struct {
935                 u32 _[3];               /* reserved */
936                 u32 cv;                 /* Counter Value */
937                 u8 t[GHASH_DIGEST_SIZE];/* Tag */
938                 u8 h[AES_BLOCK_SIZE];   /* Hash-subkey */
939                 u64 taadl;              /* Total AAD Length */
940                 u64 tpcl;               /* Total Plain-/Cipher-text Length */
941                 u8 j0[GHASH_BLOCK_SIZE];/* initial counter value */
942                 u8 k[AES_MAX_KEY_SIZE]; /* Key */
943         } param;
944
945         /*
946          * encrypt
947          *   req->src: aad||plaintext
948          *   req->dst: aad||ciphertext||tag
949          * decrypt
950          *   req->src: aad||ciphertext||tag
951          *   req->dst: aad||plaintext, return 0 or -EBADMSG
952          * aad, plaintext and ciphertext may be empty.
953          */
954         if (flags & CPACF_DECRYPT)
955                 pclen -= taglen;
956         len = aadlen + pclen;
957
958         memset(&param, 0, sizeof(param));
959         param.cv = 1;
960         param.taadl = aadlen * 8;
961         param.tpcl = pclen * 8;
962         memcpy(param.j0, req->iv, ivsize);
963         *(u32 *)(param.j0 + ivsize) = 1;
964         memcpy(param.k, ctx->key, ctx->key_len);
965
966         gcm_sg_walk_start(&gw_in, req->src, len);
967         gcm_sg_walk_start(&gw_out, req->dst, len);
968
969         do {
970                 min_bytes = min_t(unsigned int,
971                                   aadlen > 0 ? aadlen : pclen, AES_BLOCK_SIZE);
972                 in_bytes = gcm_sg_walk_go(&gw_in, min_bytes);
973                 out_bytes = gcm_sg_walk_go(&gw_out, min_bytes);
974                 bytes = min(in_bytes, out_bytes);
975
976                 if (aadlen + pclen <= bytes) {
977                         aad_bytes = aadlen;
978                         pc_bytes = pclen;
979                         flags |= CPACF_KMA_LAAD | CPACF_KMA_LPC;
980                 } else {
981                         if (aadlen <= bytes) {
982                                 aad_bytes = aadlen;
983                                 pc_bytes = (bytes - aadlen) &
984                                            ~(AES_BLOCK_SIZE - 1);
985                                 flags |= CPACF_KMA_LAAD;
986                         } else {
987                                 aad_bytes = bytes & ~(AES_BLOCK_SIZE - 1);
988                                 pc_bytes = 0;
989                         }
990                 }
991
992                 if (aad_bytes > 0)
993                         memcpy(gw_out.ptr, gw_in.ptr, aad_bytes);
994
995                 cpacf_kma(ctx->fc | flags, &param,
996                           gw_out.ptr + aad_bytes,
997                           gw_in.ptr + aad_bytes, pc_bytes,
998                           gw_in.ptr, aad_bytes);
999
1000                 gcm_sg_walk_done(&gw_in, aad_bytes + pc_bytes);
1001                 gcm_sg_walk_done(&gw_out, aad_bytes + pc_bytes);
1002                 aadlen -= aad_bytes;
1003                 pclen -= pc_bytes;
1004         } while (aadlen + pclen > 0);
1005
1006         if (flags & CPACF_DECRYPT) {
1007                 scatterwalk_map_and_copy(tag, req->src, len, taglen, 0);
1008                 if (crypto_memneq(tag, param.t, taglen))
1009                         ret = -EBADMSG;
1010         } else
1011                 scatterwalk_map_and_copy(param.t, req->dst, len, taglen, 1);
1012
1013         memzero_explicit(&param, sizeof(param));
1014         return ret;
1015 }
1016
1017 static int gcm_aes_encrypt(struct aead_request *req)
1018 {
1019         return gcm_aes_crypt(req, CPACF_ENCRYPT);
1020 }
1021
1022 static int gcm_aes_decrypt(struct aead_request *req)
1023 {
1024         return gcm_aes_crypt(req, CPACF_DECRYPT);
1025 }
1026
1027 static struct aead_alg gcm_aes_aead = {
1028         .setkey                 = gcm_aes_setkey,
1029         .setauthsize            = gcm_aes_setauthsize,
1030         .encrypt                = gcm_aes_encrypt,
1031         .decrypt                = gcm_aes_decrypt,
1032
1033         .ivsize                 = GHASH_BLOCK_SIZE - sizeof(u32),
1034         .maxauthsize            = GHASH_DIGEST_SIZE,
1035         .chunksize              = AES_BLOCK_SIZE,
1036
1037         .base                   = {
1038                 .cra_blocksize          = 1,
1039                 .cra_ctxsize            = sizeof(struct s390_aes_ctx),
1040                 .cra_priority           = 900,
1041                 .cra_name               = "gcm(aes)",
1042                 .cra_driver_name        = "gcm-aes-s390",
1043                 .cra_module             = THIS_MODULE,
1044         },
1045 };
1046
1047 static struct crypto_alg *aes_s390_algs_ptr[5];
1048 static int aes_s390_algs_num;
1049 static struct aead_alg *aes_s390_aead_alg;
1050
1051 static int aes_s390_register_alg(struct crypto_alg *alg)
1052 {
1053         int ret;
1054
1055         ret = crypto_register_alg(alg);
1056         if (!ret)
1057                 aes_s390_algs_ptr[aes_s390_algs_num++] = alg;
1058         return ret;
1059 }
1060
1061 static void aes_s390_fini(void)
1062 {
1063         while (aes_s390_algs_num--)
1064                 crypto_unregister_alg(aes_s390_algs_ptr[aes_s390_algs_num]);
1065         if (ctrblk)
1066                 free_page((unsigned long) ctrblk);
1067
1068         if (aes_s390_aead_alg)
1069                 crypto_unregister_aead(aes_s390_aead_alg);
1070 }
1071
1072 static int __init aes_s390_init(void)
1073 {
1074         int ret;
1075
1076         /* Query available functions for KM, KMC, KMCTR and KMA */
1077         cpacf_query(CPACF_KM, &km_functions);
1078         cpacf_query(CPACF_KMC, &kmc_functions);
1079         cpacf_query(CPACF_KMCTR, &kmctr_functions);
1080         cpacf_query(CPACF_KMA, &kma_functions);
1081
1082         if (cpacf_test_func(&km_functions, CPACF_KM_AES_128) ||
1083             cpacf_test_func(&km_functions, CPACF_KM_AES_192) ||
1084             cpacf_test_func(&km_functions, CPACF_KM_AES_256)) {
1085                 ret = aes_s390_register_alg(&aes_alg);
1086                 if (ret)
1087                         goto out_err;
1088                 ret = aes_s390_register_alg(&ecb_aes_alg);
1089                 if (ret)
1090                         goto out_err;
1091         }
1092
1093         if (cpacf_test_func(&kmc_functions, CPACF_KMC_AES_128) ||
1094             cpacf_test_func(&kmc_functions, CPACF_KMC_AES_192) ||
1095             cpacf_test_func(&kmc_functions, CPACF_KMC_AES_256)) {
1096                 ret = aes_s390_register_alg(&cbc_aes_alg);
1097                 if (ret)
1098                         goto out_err;
1099         }
1100
1101         if (cpacf_test_func(&km_functions, CPACF_KM_XTS_128) ||
1102             cpacf_test_func(&km_functions, CPACF_KM_XTS_256)) {
1103                 ret = aes_s390_register_alg(&xts_aes_alg);
1104                 if (ret)
1105                         goto out_err;
1106         }
1107
1108         if (cpacf_test_func(&kmctr_functions, CPACF_KMCTR_AES_128) ||
1109             cpacf_test_func(&kmctr_functions, CPACF_KMCTR_AES_192) ||
1110             cpacf_test_func(&kmctr_functions, CPACF_KMCTR_AES_256)) {
1111                 ctrblk = (u8 *) __get_free_page(GFP_KERNEL);
1112                 if (!ctrblk) {
1113                         ret = -ENOMEM;
1114                         goto out_err;
1115                 }
1116                 ret = aes_s390_register_alg(&ctr_aes_alg);
1117                 if (ret)
1118                         goto out_err;
1119         }
1120
1121         if (cpacf_test_func(&kma_functions, CPACF_KMA_GCM_AES_128) ||
1122             cpacf_test_func(&kma_functions, CPACF_KMA_GCM_AES_192) ||
1123             cpacf_test_func(&kma_functions, CPACF_KMA_GCM_AES_256)) {
1124                 ret = crypto_register_aead(&gcm_aes_aead);
1125                 if (ret)
1126                         goto out_err;
1127                 aes_s390_aead_alg = &gcm_aes_aead;
1128         }
1129
1130         return 0;
1131 out_err:
1132         aes_s390_fini();
1133         return ret;
1134 }
1135
1136 module_cpu_feature_match(MSA, aes_s390_init);
1137 module_exit(aes_s390_fini);
1138
1139 MODULE_ALIAS_CRYPTO("aes-all");
1140
1141 MODULE_DESCRIPTION("Rijndael (AES) Cipher Algorithm");
1142 MODULE_LICENSE("GPL");