crypto: x86/sm4 - export reusable AESNI/AVX functions
authorTianjia Zhang <tianjia.zhang@linux.alibaba.com>
Wed, 18 Aug 2021 03:31:16 +0000 (11:31 +0800)
committerHerbert Xu <herbert@gondor.apana.org.au>
Fri, 27 Aug 2021 08:30:18 +0000 (16:30 +0800)
Export the reusable functions in the SM4 AESNI/AVX implementation,
mainly public functions, which are used to develop the SM4 AESNI/AVX2
implementation, and eliminate unnecessary duplication of code.

At the same time, in order to make the public function universal,
minor fixes was added.

Signed-off-by: Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
arch/x86/crypto/sm4-avx.h [new file with mode: 0644]
arch/x86/crypto/sm4_aesni_avx_glue.c

diff --git a/arch/x86/crypto/sm4-avx.h b/arch/x86/crypto/sm4-avx.h
new file mode 100644 (file)
index 0000000..1bceab7
--- /dev/null
@@ -0,0 +1,24 @@
+/* SPDX-License-Identifier: GPL-2.0-or-later */
+#ifndef ASM_X86_SM4_AVX_H
+#define ASM_X86_SM4_AVX_H
+
+#include <linux/types.h>
+#include <crypto/sm4.h>
+
+typedef void (*sm4_crypt_func)(const u32 *rk, u8 *dst, const u8 *src, u8 *iv);
+
+int sm4_avx_ecb_encrypt(struct skcipher_request *req);
+int sm4_avx_ecb_decrypt(struct skcipher_request *req);
+
+int sm4_cbc_encrypt(struct skcipher_request *req);
+int sm4_avx_cbc_decrypt(struct skcipher_request *req,
+                       unsigned int bsize, sm4_crypt_func func);
+
+int sm4_cfb_encrypt(struct skcipher_request *req);
+int sm4_avx_cfb_decrypt(struct skcipher_request *req,
+                       unsigned int bsize, sm4_crypt_func func);
+
+int sm4_avx_ctr_crypt(struct skcipher_request *req,
+                       unsigned int bsize, sm4_crypt_func func);
+
+#endif
index c1f5728..7800f77 100644 (file)
@@ -15,6 +15,7 @@
 #include <crypto/internal/simd.h>
 #include <crypto/internal/skcipher.h>
 #include <crypto/sm4.h>
+#include "sm4-avx.h"
 
 #define SM4_CRYPT8_BLOCK_SIZE  (SM4_BLOCK_SIZE * 8)
 
@@ -71,23 +72,25 @@ static int ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
        return err;
 }
 
-static int ecb_encrypt(struct skcipher_request *req)
+int sm4_avx_ecb_encrypt(struct skcipher_request *req)
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
 
        return ecb_do_crypt(req, ctx->rkey_enc);
 }
+EXPORT_SYMBOL_GPL(sm4_avx_ecb_encrypt);
 
-static int ecb_decrypt(struct skcipher_request *req)
+int sm4_avx_ecb_decrypt(struct skcipher_request *req)
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
 
        return ecb_do_crypt(req, ctx->rkey_dec);
 }
+EXPORT_SYMBOL_GPL(sm4_avx_ecb_decrypt);
 
-static int cbc_encrypt(struct skcipher_request *req)
+int sm4_cbc_encrypt(struct skcipher_request *req)
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
@@ -118,8 +121,10 @@ static int cbc_encrypt(struct skcipher_request *req)
 
        return err;
 }
+EXPORT_SYMBOL_GPL(sm4_cbc_encrypt);
 
-static int cbc_decrypt(struct skcipher_request *req)
+int sm4_avx_cbc_decrypt(struct skcipher_request *req,
+                       unsigned int bsize, sm4_crypt_func func)
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
@@ -135,15 +140,14 @@ static int cbc_decrypt(struct skcipher_request *req)
 
                kernel_fpu_begin();
 
-               while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) {
-                       sm4_aesni_avx_cbc_dec_blk8(ctx->rkey_dec, dst,
-                                               src, walk.iv);
-                       dst += SM4_CRYPT8_BLOCK_SIZE;
-                       src += SM4_CRYPT8_BLOCK_SIZE;
-                       nbytes -= SM4_CRYPT8_BLOCK_SIZE;
+               while (nbytes >= bsize) {
+                       func(ctx->rkey_dec, dst, src, walk.iv);
+                       dst += bsize;
+                       src += bsize;
+                       nbytes -= bsize;
                }
 
-               if (nbytes >= SM4_BLOCK_SIZE) {
+               while (nbytes >= SM4_BLOCK_SIZE) {
                        u8 keystream[SM4_BLOCK_SIZE * 8];
                        u8 iv[SM4_BLOCK_SIZE];
                        unsigned int nblocks = min(nbytes >> 4, 8u);
@@ -165,6 +169,8 @@ static int cbc_decrypt(struct skcipher_request *req)
                        }
                        crypto_xor_cpy(dst, walk.iv, keystream, SM4_BLOCK_SIZE);
                        memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
+                       dst += nblocks * SM4_BLOCK_SIZE;
+                       src += (nblocks + 1) * SM4_BLOCK_SIZE;
                        nbytes -= nblocks * SM4_BLOCK_SIZE;
                }
 
@@ -174,8 +180,15 @@ static int cbc_decrypt(struct skcipher_request *req)
 
        return err;
 }
+EXPORT_SYMBOL_GPL(sm4_avx_cbc_decrypt);
+
+static int cbc_decrypt(struct skcipher_request *req)
+{
+       return sm4_avx_cbc_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
+                               sm4_aesni_avx_cbc_dec_blk8);
+}
 
-static int cfb_encrypt(struct skcipher_request *req)
+int sm4_cfb_encrypt(struct skcipher_request *req)
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
@@ -214,8 +227,10 @@ static int cfb_encrypt(struct skcipher_request *req)
 
        return err;
 }
+EXPORT_SYMBOL_GPL(sm4_cfb_encrypt);
 
-static int cfb_decrypt(struct skcipher_request *req)
+int sm4_avx_cfb_decrypt(struct skcipher_request *req,
+                       unsigned int bsize, sm4_crypt_func func)
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
@@ -231,15 +246,14 @@ static int cfb_decrypt(struct skcipher_request *req)
 
                kernel_fpu_begin();
 
-               while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) {
-                       sm4_aesni_avx_cfb_dec_blk8(ctx->rkey_enc, dst,
-                                               src, walk.iv);
-                       dst += SM4_CRYPT8_BLOCK_SIZE;
-                       src += SM4_CRYPT8_BLOCK_SIZE;
-                       nbytes -= SM4_CRYPT8_BLOCK_SIZE;
+               while (nbytes >= bsize) {
+                       func(ctx->rkey_enc, dst, src, walk.iv);
+                       dst += bsize;
+                       src += bsize;
+                       nbytes -= bsize;
                }
 
-               if (nbytes >= SM4_BLOCK_SIZE) {
+               while (nbytes >= SM4_BLOCK_SIZE) {
                        u8 keystream[SM4_BLOCK_SIZE * 8];
                        unsigned int nblocks = min(nbytes >> 4, 8u);
 
@@ -276,8 +290,16 @@ static int cfb_decrypt(struct skcipher_request *req)
 
        return err;
 }
+EXPORT_SYMBOL_GPL(sm4_avx_cfb_decrypt);
 
-static int ctr_crypt(struct skcipher_request *req)
+static int cfb_decrypt(struct skcipher_request *req)
+{
+       return sm4_avx_cfb_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
+                               sm4_aesni_avx_cfb_dec_blk8);
+}
+
+int sm4_avx_ctr_crypt(struct skcipher_request *req,
+                       unsigned int bsize, sm4_crypt_func func)
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
@@ -293,15 +315,14 @@ static int ctr_crypt(struct skcipher_request *req)
 
                kernel_fpu_begin();
 
-               while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) {
-                       sm4_aesni_avx_ctr_enc_blk8(ctx->rkey_enc, dst,
-                                               src, walk.iv);
-                       dst += SM4_CRYPT8_BLOCK_SIZE;
-                       src += SM4_CRYPT8_BLOCK_SIZE;
-                       nbytes -= SM4_CRYPT8_BLOCK_SIZE;
+               while (nbytes >= bsize) {
+                       func(ctx->rkey_enc, dst, src, walk.iv);
+                       dst += bsize;
+                       src += bsize;
+                       nbytes -= bsize;
                }
 
-               if (nbytes >= SM4_BLOCK_SIZE) {
+               while (nbytes >= SM4_BLOCK_SIZE) {
                        u8 keystream[SM4_BLOCK_SIZE * 8];
                        unsigned int nblocks = min(nbytes >> 4, 8u);
                        int i;
@@ -343,6 +364,13 @@ static int ctr_crypt(struct skcipher_request *req)
 
        return err;
 }
+EXPORT_SYMBOL_GPL(sm4_avx_ctr_crypt);
+
+static int ctr_crypt(struct skcipher_request *req)
+{
+       return sm4_avx_ctr_crypt(req, SM4_CRYPT8_BLOCK_SIZE,
+                               sm4_aesni_avx_ctr_enc_blk8);
+}
 
 static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
        {
@@ -359,8 +387,8 @@ static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
                .max_keysize    = SM4_KEY_SIZE,
                .walksize       = 8 * SM4_BLOCK_SIZE,
                .setkey         = sm4_skcipher_setkey,
-               .encrypt        = ecb_encrypt,
-               .decrypt        = ecb_decrypt,
+               .encrypt        = sm4_avx_ecb_encrypt,
+               .decrypt        = sm4_avx_ecb_decrypt,
        }, {
                .base = {
                        .cra_name               = "__cbc(sm4)",
@@ -376,7 +404,7 @@ static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
                .ivsize         = SM4_BLOCK_SIZE,
                .walksize       = 8 * SM4_BLOCK_SIZE,
                .setkey         = sm4_skcipher_setkey,
-               .encrypt        = cbc_encrypt,
+               .encrypt        = sm4_cbc_encrypt,
                .decrypt        = cbc_decrypt,
        }, {
                .base = {
@@ -394,7 +422,7 @@ static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
                .chunksize      = SM4_BLOCK_SIZE,
                .walksize       = 8 * SM4_BLOCK_SIZE,
                .setkey         = sm4_skcipher_setkey,
-               .encrypt        = cfb_encrypt,
+               .encrypt        = sm4_cfb_encrypt,
                .decrypt        = cfb_decrypt,
        }, {
                .base = {