crypto: arm64/aes-neon-ctr - improve handling of single tail block
authorArd Biesheuvel <ardb@kernel.org>
Thu, 27 Jan 2022 09:52:11 +0000 (10:52 +0100)
committerHerbert Xu <herbert@gondor.apana.org.au>
Sat, 5 Feb 2022 04:10:51 +0000 (15:10 +1100)
Instead of falling back to C code to do a memcpy of the output of the
last block, handle this in the asm code directly if possible, which is
the case if the entire input is longer than 16 bytes.

Cc: Nathan Huckleberry <nhuck@google.com>
Cc: Eric Biggers <ebiggers@google.com>
Signed-off-by: Ard Biesheuvel <ardb@kernel.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
arch/arm64/crypto/aes-glue.c
arch/arm64/crypto/aes-modes.S

index 30b7cc6..7d66f8b 100644 (file)
@@ -24,7 +24,6 @@
 #ifdef USE_V8_CRYPTO_EXTENSIONS
 #define MODE                   "ce"
 #define PRIO                   300
-#define STRIDE                 5
 #define aes_expandkey          ce_aes_expandkey
 #define aes_ecb_encrypt                ce_aes_ecb_encrypt
 #define aes_ecb_decrypt                ce_aes_ecb_decrypt
@@ -42,7 +41,6 @@ MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
 #else
 #define MODE                   "neon"
 #define PRIO                   200
-#define STRIDE                 4
 #define aes_ecb_encrypt                neon_aes_ecb_encrypt
 #define aes_ecb_decrypt                neon_aes_ecb_decrypt
 #define aes_cbc_encrypt                neon_aes_cbc_encrypt
@@ -89,7 +87,7 @@ asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
                                int rounds, int bytes, u8 const iv[]);
 
 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
-                               int rounds, int bytes, u8 ctr[], u8 finalbuf[]);
+                               int rounds, int bytes, u8 ctr[]);
 
 asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
                                int rounds, int bytes, u32 const rk2[], u8 iv[],
@@ -458,26 +456,21 @@ static int __maybe_unused ctr_encrypt(struct skcipher_request *req)
                unsigned int nbytes = walk.nbytes;
                u8 *dst = walk.dst.virt.addr;
                u8 buf[AES_BLOCK_SIZE];
-               unsigned int tail;
 
                if (unlikely(nbytes < AES_BLOCK_SIZE))
-                       src = memcpy(buf, src, nbytes);
+                       src = dst = memcpy(buf + sizeof(buf) - nbytes,
+                                          src, nbytes);
                else if (nbytes < walk.total)
                        nbytes &= ~(AES_BLOCK_SIZE - 1);
 
                kernel_neon_begin();
                aes_ctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
-                               walk.iv, buf);
+                               walk.iv);
                kernel_neon_end();
 
-               tail = nbytes % (STRIDE * AES_BLOCK_SIZE);
-               if (tail > 0 && tail < AES_BLOCK_SIZE)
-                       /*
-                        * The final partial block could not be returned using
-                        * an overlapping store, so it was passed via buf[]
-                        * instead.
-                        */
-                       memcpy(dst + nbytes - tail, buf, tail);
+               if (unlikely(nbytes < AES_BLOCK_SIZE))
+                       memcpy(walk.dst.virt.addr,
+                              buf + sizeof(buf) - nbytes, nbytes);
 
                err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
        }
index ff01f01..dc35eb0 100644 (file)
@@ -321,7 +321,7 @@ AES_FUNC_END(aes_cbc_cts_decrypt)
 
        /*
         * aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
-        *                 int bytes, u8 ctr[], u8 finalbuf[])
+        *                 int bytes, u8 ctr[])
         */
 
 AES_FUNC_START(aes_ctr_encrypt)
@@ -414,8 +414,8 @@ ST5(        st1             {v4.16b}, [x0], #16             )
 .Lctrtail:
        /* XOR up to MAX_STRIDE * 16 - 1 bytes of in/output with v0 ... v3/v4 */
        mov             x16, #16
-       ands            x13, x4, #0xf
-       csel            x13, x13, x16, ne
+       ands            x6, x4, #0xf
+       csel            x13, x6, x16, ne
 
 ST5(   cmp             w4, #64 - (MAX_STRIDE << 4)     )
 ST5(   csel            x14, x16, xzr, gt               )
@@ -424,10 +424,10 @@ ST5(      csel            x14, x16, xzr, gt               )
        cmp             w4, #32 - (MAX_STRIDE << 4)
        csel            x16, x16, xzr, gt
        cmp             w4, #16 - (MAX_STRIDE << 4)
-       ble             .Lctrtail1x
 
        adr_l           x12, .Lcts_permute_table
        add             x12, x12, x13
+       ble             .Lctrtail1x
 
 ST5(   ld1             {v5.16b}, [x1], x14             )
        ld1             {v6.16b}, [x1], x15
@@ -462,11 +462,19 @@ ST5(      st1             {v5.16b}, [x0], x14             )
        b               .Lctrout
 
 .Lctrtail1x:
-       csel            x0, x0, x6, eq          // use finalbuf if less than a full block
+       sub             x7, x6, #16
+       csel            x6, x6, x7, eq
+       add             x1, x1, x6
+       add             x0, x0, x6
        ld1             {v5.16b}, [x1]
+       ld1             {v6.16b}, [x0]
 ST5(   mov             v3.16b, v4.16b                  )
        encrypt_block   v3, w3, x2, x8, w7
+       ld1             {v10.16b-v11.16b}, [x12]
+       tbl             v3.16b, {v3.16b}, v10.16b
+       sshr            v11.16b, v11.16b, #7
        eor             v5.16b, v5.16b, v3.16b
+       bif             v5.16b, v6.16b, v11.16b
        st1             {v5.16b}, [x0]
        b               .Lctrout
 AES_FUNC_END(aes_ctr_encrypt)