From bc15f7a0c7f7ff64cdcfd310f1c7172c6c58edc0 Mon Sep 17 00:00:00 2001 From: Zhi Guan Date: Sat, 27 Apr 2024 12:08:35 +0800 Subject: [PATCH] Update SM4 API --- CMakeLists.txt | 11 +++++- include/gmssl/sm4.h | 7 +++- src/sm4_ctr.c | 89 +++++++++++++++++++++++++++++++++++++++++++++ src/sm4_ecb.c | 6 +-- src/sm4_gcm.c | 79 ---------------------------------------- tests/sm4_ctrtest.c | 33 +++++++++++++++++ tests/sm4_ecbtest.c | 8 ++-- tests/sm4test.c | 37 ++++++------------- 8 files changed, 157 insertions(+), 113 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a08a308d6..9e6b8ad98 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,8 @@ include_directories(include) add_compile_options(-O3) +option(ENABLE_TEST_SPEED "Enable test speed" ON) + option(ENABLE_SM4_TBOX "Enable SM4 merged S-Box implementation" ON) option(ENABLE_SM4_AARCH64 "Enable SM4 AARCH64 assembly implementation" OFF) option(ENABLE_SM4_CTR_AESNI_AVX "Enable SM4 CTR AESNI+AVX assembly implementation" OFF) @@ -30,7 +32,7 @@ option(ENABLE_GMUL_AARCH64 "Enable GF(2^128) Multiplication AArch64 assembly" OF set(src src/version.c src/debug.c - src/sm4.c + src/sm4_arm_neon.c src/sm4_cbc.c src/sm4_ctr.c src/sm4_gcm.c @@ -84,6 +86,7 @@ set(src src/tls12.c src/tls13.c src/file.c + src/file.c ) set(tools @@ -245,6 +248,12 @@ include(CheckSymbolExists) # when an option has been enabled, `cmake ..` will not refresh the value # use `cmake .. -DENABLE_XXX=OFF` to disable the option + +if (ENABLE_TEST_SPEED) + message(STATUS "ENABLE_TEST_SPEED is ON") + add_definitions(-DENABLE_TEST_SPEED) +endif() + option(ENABLE_SM2_ALGOR_ID_ENCODE_NULL "Enable AlgorithmIdenifier with algorithm sm2sign_with_sm3 encode a NULL object as parameters" OFF) if (ENABLE_SM2_ALGOR_ID_ENCODE_NULL) message(STATUS "ENABLE_SM2_ALGOR_ID_ENCODE_NULL is ON") diff --git a/include/gmssl/sm4.h b/include/gmssl/sm4.h index 08224bec3..6ed6d9a9a 100644 --- a/include/gmssl/sm4.h +++ b/include/gmssl/sm4.h @@ -65,6 +65,8 @@ _gmssl_export int sm4_cbc_decrypt_finish(SM4_CBC_CTX *ctx, uint8_t *out, size_t void sm4_ctr_encrypt(const SM4_KEY *key, uint8_t ctr[SM4_BLOCK_SIZE], const uint8_t *in, size_t inlen, uint8_t *out); +void sm4_ctr32_encrypt(const SM4_KEY *key, uint8_t ctr[SM4_BLOCK_SIZE], + const uint8_t *in, size_t inlen, uint8_t *out); typedef struct { union { @@ -79,6 +81,9 @@ typedef struct { _gmssl_export int sm4_ctr_encrypt_init(SM4_CTR_CTX *ctx, const uint8_t key[SM4_KEY_SIZE], const uint8_t ctr[SM4_BLOCK_SIZE]); _gmssl_export int sm4_ctr_encrypt_update(SM4_CTR_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); _gmssl_export int sm4_ctr_encrypt_finish(SM4_CTR_CTX *ctx, uint8_t *out, size_t *outlen); +int sm4_ctr32_encrypt_init(SM4_CTR_CTX *ctx, const uint8_t key[SM4_KEY_SIZE], const uint8_t ctr[SM4_BLOCK_SIZE]); +int sm4_ctr32_encrypt_update(SM4_CTR_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); +int sm4_ctr32_encrypt_finish(SM4_CTR_CTX *ctx, uint8_t *out, size_t *outlen); #define NIST_SP800_GCM_MAX_IV_SIZE (((uint64_t)1 << (64-3)) - 1) // 2305843009213693951 @@ -134,7 +139,7 @@ _gmssl_export int sm4_gcm_decrypt_finish(SM4_GCM_CTX *ctx, #ifdef ENABLE_SM4_ECB // call `sm4_set_decrypt_key` before decrypt -void sm4_ecb_encrypt(const SM4_KEY *key, const uint8_t *in, size_t nblocks, uint8_t *out); +void sm4_ecb_encrypt_blocks(const SM4_KEY *key, const uint8_t *in, size_t nblocks, uint8_t *out); typedef struct { SM4_KEY sm4_key; diff --git a/src/sm4_ctr.c b/src/sm4_ctr.c index b2a914869..52aa9561b 100644 --- a/src/sm4_ctr.c +++ b/src/sm4_ctr.c @@ -100,3 +100,92 @@ int sm4_ctr_encrypt_finish(SM4_CTR_CTX *ctx, uint8_t *out, size_t *outlen) *outlen = ctx->block_nbytes; return 1; } + +// inc32() in nist-sp800-38d +static void ctr32_incr(uint8_t a[16]) +{ + int i; + for (i = 15; i >= 12; i--) { + a[i]++; + if (a[i]) break; + } +} + +void sm4_ctr32_encrypt(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t inlen, uint8_t *out) +{ + uint8_t block[16]; + size_t len; + + while (inlen) { + len = inlen < 16 ? inlen : 16; + sm4_encrypt(key, ctr, block); + gmssl_memxor(out, in, block, len); + ctr32_incr(ctr); + in += len; + out += len; + inlen -= len; + } +} + +int sm4_ctr32_encrypt_init(SM4_CTR_CTX *ctx, + const uint8_t key[SM4_BLOCK_SIZE], const uint8_t ctr[SM4_BLOCK_SIZE]) +{ + sm4_set_encrypt_key(&ctx->sm4_key, key); + memcpy(ctx->ctr, ctr, SM4_BLOCK_SIZE); + memset(ctx->block, 0, SM4_BLOCK_SIZE); + ctx->block_nbytes = 0; + return 1; +} + +int sm4_ctr32_encrypt_update(SM4_CTR_CTX *ctx, + const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) +{ + size_t left; + size_t nblocks; + size_t len; + + if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { + error_print(); + return -1; + } + *outlen = 0; + if (ctx->block_nbytes) { + left = SM4_BLOCK_SIZE - ctx->block_nbytes; + if (inlen < left) { + memcpy(ctx->block + ctx->block_nbytes, in, inlen); + ctx->block_nbytes += inlen; + return 1; + } + memcpy(ctx->block + ctx->block_nbytes, in, left); + sm4_ctr32_encrypt(&ctx->sm4_key, ctx->ctr, ctx->block, SM4_BLOCK_SIZE, out); + in += left; + inlen -= left; + out += SM4_BLOCK_SIZE; + *outlen += SM4_BLOCK_SIZE; + } + if (inlen >= SM4_BLOCK_SIZE) { + nblocks = inlen / SM4_BLOCK_SIZE; + len = nblocks * SM4_BLOCK_SIZE; + sm4_ctr32_encrypt(&ctx->sm4_key, ctx->ctr, in, len, out); + in += len; + inlen -= len; + out += len; + *outlen += len; + } + if (inlen) { + memcpy(ctx->block, in, inlen); + } + ctx->block_nbytes = inlen; + return 1; +} + +int sm4_ctr32_encrypt_finish(SM4_CTR_CTX *ctx, uint8_t *out, size_t *outlen) +{ + if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { + error_print(); + return -1; + } + sm4_ctr32_encrypt(&ctx->sm4_key, ctx->ctr, ctx->block, ctx->block_nbytes, out); + *outlen = ctx->block_nbytes; + return 1; +} diff --git a/src/sm4_ecb.c b/src/sm4_ecb.c index 7a80f5cd6..394d5daf0 100644 --- a/src/sm4_ecb.c +++ b/src/sm4_ecb.c @@ -13,7 +13,7 @@ #include -void sm4_ecb_encrypt(const SM4_KEY *key, const uint8_t *in, size_t nblocks, uint8_t *out) +void sm4_ecb_encrypt_blocks(const SM4_KEY *key, const uint8_t *in, size_t nblocks, uint8_t *out) { while (nblocks--) { sm4_encrypt(key, in, out); @@ -50,7 +50,7 @@ int sm4_ecb_encrypt_update(SM4_ECB_CTX *ctx, return 1; } memcpy(ctx->block + ctx->block_nbytes, in, left); - sm4_ecb_encrypt(&ctx->sm4_key, ctx->block, 1, out); + sm4_ecb_encrypt_blocks(&ctx->sm4_key, ctx->block, 1, out); in += left; inlen -= left; out += SM4_BLOCK_SIZE; @@ -59,7 +59,7 @@ int sm4_ecb_encrypt_update(SM4_ECB_CTX *ctx, if (inlen >= SM4_BLOCK_SIZE) { nblocks = inlen / SM4_BLOCK_SIZE; len = nblocks * SM4_BLOCK_SIZE; - sm4_ecb_encrypt(&ctx->sm4_key, in, nblocks, out); + sm4_ecb_encrypt_blocks(&ctx->sm4_key, in, nblocks, out); in += len; inlen -= len; out += len; diff --git a/src/sm4_gcm.c b/src/sm4_gcm.c index 92c22ee37..45cca1d84 100644 --- a/src/sm4_gcm.c +++ b/src/sm4_gcm.c @@ -25,22 +25,6 @@ static void ctr32_incr(uint8_t a[16]) } } -static void sm4_ctr32_encrypt(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t inlen, uint8_t *out) -{ - uint8_t block[16]; - size_t len; - - while (inlen) { - len = inlen < 16 ? inlen : 16; - sm4_encrypt(key, ctr, block); - gmssl_memxor(out, in, block, len); - ctr32_incr(ctr); - in += len; - out += len; - inlen -= len; - } -} - int sm4_gcm_encrypt(const SM4_KEY *key, const uint8_t *iv, size_t ivlen, const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen, uint8_t *out, size_t taglen, uint8_t *tag) @@ -108,69 +92,6 @@ int sm4_gcm_decrypt(const SM4_KEY *key, const uint8_t *iv, size_t ivlen, return 1; } -static int sm4_ctr32_encrypt_init(SM4_CTR_CTX *ctx, - const uint8_t key[SM4_BLOCK_SIZE], const uint8_t ctr[SM4_BLOCK_SIZE]) -{ - sm4_set_encrypt_key(&ctx->sm4_key, key); - memcpy(ctx->ctr, ctr, SM4_BLOCK_SIZE); - memset(ctx->block, 0, SM4_BLOCK_SIZE); - ctx->block_nbytes = 0; - return 1; -} - -static int sm4_ctr32_encrypt_update(SM4_CTR_CTX *ctx, - const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) -{ - size_t left; - size_t nblocks; - size_t len; - - if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { - error_print(); - return -1; - } - *outlen = 0; - if (ctx->block_nbytes) { - left = SM4_BLOCK_SIZE - ctx->block_nbytes; - if (inlen < left) { - memcpy(ctx->block + ctx->block_nbytes, in, inlen); - ctx->block_nbytes += inlen; - return 1; - } - memcpy(ctx->block + ctx->block_nbytes, in, left); - sm4_ctr32_encrypt(&ctx->sm4_key, ctx->ctr, ctx->block, SM4_BLOCK_SIZE, out); - in += left; - inlen -= left; - out += SM4_BLOCK_SIZE; - *outlen += SM4_BLOCK_SIZE; - } - if (inlen >= SM4_BLOCK_SIZE) { - nblocks = inlen / SM4_BLOCK_SIZE; - len = nblocks * SM4_BLOCK_SIZE; - sm4_ctr32_encrypt(&ctx->sm4_key, ctx->ctr, in, len, out); - in += len; - inlen -= len; - out += len; - *outlen += len; - } - if (inlen) { - memcpy(ctx->block, in, inlen); - } - ctx->block_nbytes = inlen; - return 1; -} - -static int sm4_ctr32_encrypt_finish(SM4_CTR_CTX *ctx, uint8_t *out, size_t *outlen) -{ - if (ctx->block_nbytes >= SM4_BLOCK_SIZE) { - error_print(); - return -1; - } - sm4_ctr32_encrypt(&ctx->sm4_key, ctx->ctr, ctx->block, ctx->block_nbytes, out); - *outlen = ctx->block_nbytes; - return 1; -} - int sm4_gcm_encrypt_init(SM4_GCM_CTX *ctx, const uint8_t *key, size_t keylen, const uint8_t *iv, size_t ivlen, const uint8_t *aad, size_t aadlen, size_t taglen) diff --git a/tests/sm4_ctrtest.c b/tests/sm4_ctrtest.c index 73b4bfc1e..41514e07b 100644 --- a/tests/sm4_ctrtest.c +++ b/tests/sm4_ctrtest.c @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -357,6 +358,37 @@ static int test_sm4_ctr_ctx_multi_updates(void) return 1; } +extern void sm4_ctr32_encrypt_4blocks(const SM4_KEY *key, uint8_t iv[16], const uint8_t *in, size_t num_4blocks, uint8_t *out); + + +static int test_sm4_ctr32_encrypt_blocks_speed(void) +{ + SM4_KEY sm4_key; + uint8_t key[16]; + uint8_t ctr[16]; + uint8_t buf[4096]; + clock_t begin, end; + double seconds; + int i; + + memset(key, 0, sizeof(key)); + memset(ctr, 0, sizeof(ctr)); + memset(buf, 0, sizeof(buf)); + + sm4_set_encrypt_key(&sm4_key, key); + + begin = clock(); + for (i = 0; i < 4096; i++) { + sm4_ctr32_encrypt_4blocks(&sm4_key, ctr, buf, sizeof(buf)/64, buf); + } + end = clock(); + + seconds = (double)(end - begin)/CLOCKS_PER_SEC; + printf("%s: %f MiB per second\n", __FUNCTION__, 16/seconds); + + return 1; +} + int main(void) { if (test_sm4_ctr() != 1) goto err; @@ -365,6 +397,7 @@ int main(void) if (test_sm4_ctr_ctx() != 1) goto err; if (test_sm4_ctr_ctx_multi_updates() != 1) goto err; + if (test_sm4_ctr32_encrypt_blocks_speed() != 1) goto err; printf("%s all tests passed\n", __FILE__); return 0; err: diff --git a/tests/sm4_ecbtest.c b/tests/sm4_ecbtest.c index 6b05308f8..920950fb3 100644 --- a/tests/sm4_ecbtest.c +++ b/tests/sm4_ecbtest.c @@ -32,10 +32,10 @@ static int test_sm4_ecb(void) rand_bytes(plaintext, sizeof(plaintext)); sm4_set_encrypt_key(&sm4_key, key); - sm4_ecb_encrypt(&sm4_key, plaintext, sizeof(plaintext)/16, encrypted); + sm4_ecb_encrypt_blocks(&sm4_key, plaintext, sizeof(plaintext)/16, encrypted); sm4_set_decrypt_key(&sm4_key, key); - sm4_ecb_encrypt(&sm4_key, encrypted, sizeof(encrypted)/16, decrypted); + sm4_ecb_encrypt_blocks(&sm4_key, encrypted, sizeof(encrypted)/16, decrypted); if (memcmp(decrypted, plaintext, sizeof(plaintext)) != 0) { error_print(); @@ -69,7 +69,7 @@ static int test_sm4_ecb_test_vectors(void) uint8_t decrypted[sizeof(plaintext)] = {0}; sm4_set_encrypt_key(&sm4_key, key); - sm4_ecb_encrypt(&sm4_key, plaintext, sizeof(plaintext)/16, encrypted); + sm4_ecb_encrypt_blocks(&sm4_key, plaintext, sizeof(plaintext)/16, encrypted); format_bytes(stderr, 0, 0, "", encrypted, sizeof(encrypted)); @@ -79,7 +79,7 @@ static int test_sm4_ecb_test_vectors(void) } sm4_set_decrypt_key(&sm4_key, key); - sm4_ecb_encrypt(&sm4_key, encrypted, sizeof(encrypted)/16, decrypted); + sm4_ecb_encrypt_blocks(&sm4_key, encrypted, sizeof(encrypted)/16, decrypted); if (memcmp(decrypted, plaintext, sizeof(plaintext)) != 0) { error_print(); diff --git a/tests/sm4test.c b/tests/sm4test.c index f1668813c..48b67a747 100644 --- a/tests/sm4test.c +++ b/tests/sm4test.c @@ -91,49 +91,36 @@ static int test_sm4(void) return 1; } - -static int test_sm4_speed(void) +static int test_sm4_encrypt_speed(void) { SM4_KEY sm4_key; uint8_t key[16] = {0}; - uint8_t block[16] = {0}; - clock_t start, end; + uint8_t buf[16]; + size_t nbytes = 16 * 1024 * 1024; + clock_t begin, end; double seconds; int i; sm4_set_encrypt_key(&sm4_key, key); - start = clock(); - for (i = 0; i < 1024*1024; i++) { - sm4_encrypt(&sm4_key, block, block); + begin = clock(); + for (i = 0; i < nbytes/sizeof(buf); i++) { + sm4_encrypt(&sm4_key, buf, buf); } end = clock(); - seconds = (double)(end - start)/ CLOCKS_PER_SEC; - - fprintf(stderr, "sm4_encrypt: %f-MiB per seconds\n", 16/seconds); + seconds = (double)(end - begin)/ CLOCKS_PER_SEC; + fprintf(stderr, "sm4_encrypt: %f MiB per second\n", nbytes/(1024 * 1024 *seconds)); return 1; } - int main(void) { if (test_sm4() != 1) goto err; - if (test_sm4_speed() != 1) goto err; - - - /* - if (test_sm4_cbc() != 1) goto err; - if (test_sm4_cbc_padding() != 1) goto err; - - if (test_sm4_gcm() != 1) goto err; - if (test_sm4_gcm_gbt36624_1() != 1) goto err; - if (test_sm4_gcm_gbt36624_2() != 1) goto err; - if (test_sm4_cbc_update() != 1) goto err; - - - */ +#if ENABLE_TEST_SPEED + if (test_sm4_encrypt_speed() != 1) goto err; +#endif printf("%s all tests passed\n", __FILE__); return 0; err: