diff --git a/include/gmssl/sm4.h b/include/gmssl/sm4.h index 2cd9e0eb5..63476df32 100644 --- a/include/gmssl/sm4.h +++ b/include/gmssl/sm4.h @@ -212,11 +212,27 @@ int sm4_ccm_decrypt(const SM4_KEY *sm4_key, const uint8_t *iv, size_t ivlen, #ifdef ENABLE_SM4_XTS // call `sm4_set_encrypt_key` to set both `key1` and `key2` -int sm4_xts_encrypt(const SM4_KEY *key1, const SM4_KEY *key2, size_t tweak, +int sm4_xts_encrypt(const SM4_KEY *key1, const SM4_KEY *key2, const uint8_t tweak[16], const uint8_t *in, size_t inlen, uint8_t *out); // call `sm4_set_decrypt_key(key1)` and `sm4_set_encrypt_key(key2)` -int sm4_xts_decrypt(const SM4_KEY *key1, const SM4_KEY *key2, size_t tweak, +int sm4_xts_decrypt(const SM4_KEY *key1, const SM4_KEY *key2, const uint8_t tweak[16], const uint8_t *in, size_t inlen, uint8_t *out); + +typedef struct { + SM4_KEY key1; + SM4_KEY key2; + uint8_t tweak[16]; + size_t data_unit_size; + uint8_t *block; + size_t block_nbytes; +} SM4_XTS_CTX; + +int sm4_xts_encrypt_init(SM4_XTS_CTX *ctx, const uint8_t key[32], const uint8_t iv[16], size_t data_unit_size); +int sm4_xts_encrypt_update(SM4_XTS_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); +int sm4_xts_encrypt_finish(SM4_XTS_CTX *ctx, uint8_t *out, size_t *outlen); +int sm4_xts_decrypt_init(SM4_XTS_CTX *ctx, const uint8_t key[32], const uint8_t iv[16], size_t data_unit_size); +int sm4_xts_decrypt_update(SM4_XTS_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); +int sm4_xts_decrypt_finish(SM4_XTS_CTX *ctx, uint8_t *out, size_t *outlen); #endif diff --git a/src/sm4_xts.c b/src/sm4_xts.c index d473df842..3d9a6a978 100644 --- a/src/sm4_xts.c +++ b/src/sm4_xts.c @@ -14,10 +14,10 @@ #include -int sm4_xts_encrypt(const SM4_KEY *key1, const SM4_KEY *key2, size_t tweak, +int sm4_xts_encrypt(const SM4_KEY *key1, const SM4_KEY *key2, const uint8_t tweak[16], const uint8_t *in, size_t inlen, uint8_t *out) { - uint8_t T[16] = {0}; + uint8_t T[16]; uint8_t block[16]; size_t nblocks, i; gf128_t a; @@ -28,10 +28,7 @@ int sm4_xts_encrypt(const SM4_KEY *key1, const SM4_KEY *key2, size_t tweak, } nblocks = inlen / 16 + 1; - for (i = 0; i < 8; i++) { - T[i] = tweak & 0xff; - tweak >>= 8; - } + memcpy(T, tweak, 16); sm4_encrypt(key2, T, T); for (i = 0; i < nblocks - 2; i++) { @@ -76,10 +73,10 @@ int sm4_xts_encrypt(const SM4_KEY *key1, const SM4_KEY *key2, size_t tweak, return 1; } -int sm4_xts_decrypt(const SM4_KEY *key1, const SM4_KEY *key2, size_t tweak, +int sm4_xts_decrypt(const SM4_KEY *key1, const SM4_KEY *key2, const uint8_t tweak[16], const uint8_t *in, size_t inlen, uint8_t *out) { - uint8_t T[16] = {0}; + uint8_t T[16]; uint8_t block[16]; size_t nblocks, i; gf128_t a; @@ -90,10 +87,7 @@ int sm4_xts_decrypt(const SM4_KEY *key1, const SM4_KEY *key2, size_t tweak, } nblocks = inlen / 16 + 1; - for (i = 0; i < 8; i++) { - T[i] = tweak & 0xff; - tweak >>= 8; - } + memcpy(T, tweak, 16); sm4_encrypt(key2, T, T); for (i = 0; i < nblocks - 2; i++) { @@ -140,3 +134,188 @@ int sm4_xts_decrypt(const SM4_KEY *key1, const SM4_KEY *key2, size_t tweak, return 1; } +static void tweak_incr(uint8_t a[16]) +{ + int i; + for (i = 0; i < 16; i++) { + a[i]++; + if (a[i]) break; + } +} + +int sm4_xts_encrypt_init(SM4_XTS_CTX *ctx, const uint8_t key[32], const uint8_t iv[16], size_t data_unit_size) +{ + if (data_unit_size < SM4_BLOCK_SIZE) { + error_print(); + return -1; + } + sm4_set_encrypt_key(&ctx->key1, key); + sm4_set_encrypt_key(&ctx->key2, key + 16); + memcpy(ctx->tweak, iv, 16); + ctx->data_unit_size = data_unit_size; + if (!(ctx->block = (uint8_t *)malloc(data_unit_size))) { + error_print(); + return -1; + } + ctx->block_nbytes = 0; + return 1; +} + +int sm4_xts_encrypt_update(SM4_XTS_CTX *ctx, + const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) +{ + size_t DATA_UNIT_SIZE = ctx->data_unit_size; + + size_t left; + size_t nblocks; + size_t len; + + if (ctx->block_nbytes >= DATA_UNIT_SIZE) { + error_print(); + return -1; + } + *outlen = 0; + if (ctx->block_nbytes) { + left = DATA_UNIT_SIZE - ctx->block_nbytes; + if (inlen < left) { + memcpy(ctx->block + ctx->block_nbytes, in, inlen); + ctx->block_nbytes += inlen; + return 1; + } + memcpy(ctx->block + ctx->block_nbytes, in, left); + if (sm4_xts_encrypt(&ctx->key1, &ctx->key2, ctx->tweak, ctx->block, DATA_UNIT_SIZE, out) != 1) { + error_print(); + return -1; + } + tweak_incr(ctx->tweak); + in += left; + inlen -= left; + out += DATA_UNIT_SIZE; + *outlen += DATA_UNIT_SIZE; + } + while (inlen >= DATA_UNIT_SIZE) { + if (sm4_xts_encrypt(&ctx->key1, &ctx->key2, ctx->tweak, in, DATA_UNIT_SIZE, out) != 1) { + error_print(); + return -1; + } + tweak_incr(ctx->tweak); + in += DATA_UNIT_SIZE; + inlen -= DATA_UNIT_SIZE; + out += DATA_UNIT_SIZE; + *outlen += DATA_UNIT_SIZE; + } + if (inlen) { + memcpy(ctx->block, in, inlen); + } + ctx->block_nbytes = inlen; + return 1; +} + +int sm4_xts_encrypt_finish(SM4_XTS_CTX *ctx, uint8_t *out, size_t *outlen) +{ + size_t DATA_UNIT_SIZE = ctx->data_unit_size; + if (ctx->block_nbytes >= DATA_UNIT_SIZE) { + error_print(); + return -1; + } + if (ctx->block) { + free(ctx->block); + ctx->block = NULL; + } + if (ctx->block_nbytes) { + error_puts("invalid total input length"); + return -1; + } + *outlen = 0; + return 1; +} + + +int sm4_xts_decrypt_init(SM4_XTS_CTX *ctx, const uint8_t key[32], const uint8_t iv[16], size_t data_unit_size) +{ + if (data_unit_size < SM4_BLOCK_SIZE) { + error_print(); + return -1; + } + sm4_set_decrypt_key(&ctx->key1, key); + sm4_set_encrypt_key(&ctx->key2, key + 16); + memcpy(ctx->tweak, iv, 16); + ctx->data_unit_size = data_unit_size; + if (!(ctx->block = (uint8_t *)malloc(data_unit_size))) { + error_print(); + return -1; + } + ctx->block_nbytes = 0; + return 1; +} + +int sm4_xts_decrypt_update(SM4_XTS_CTX *ctx, + const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) +{ + size_t DATA_UNIT_SIZE = ctx->data_unit_size; + + size_t left; + size_t nblocks; + size_t len; + + if (ctx->block_nbytes >= DATA_UNIT_SIZE) { + error_print(); + return -1; + } + *outlen = 0; + if (ctx->block_nbytes) { + left = DATA_UNIT_SIZE - ctx->block_nbytes; + if (inlen < left) { + memcpy(ctx->block + ctx->block_nbytes, in, inlen); + ctx->block_nbytes += inlen; + return 1; + } + memcpy(ctx->block + ctx->block_nbytes, in, left); + if (sm4_xts_decrypt(&ctx->key1, &ctx->key2, ctx->tweak, ctx->block, DATA_UNIT_SIZE, out) != 1) { + error_print(); + return -1; + } + tweak_incr(ctx->tweak); + in += left; + inlen -= left; + out += DATA_UNIT_SIZE; + *outlen += DATA_UNIT_SIZE; + } + while (inlen >= DATA_UNIT_SIZE) { + if (sm4_xts_decrypt(&ctx->key1, &ctx->key2, ctx->tweak, in, DATA_UNIT_SIZE, out) != 1) { + error_print(); + return -1; + } + tweak_incr(ctx->tweak); + in += DATA_UNIT_SIZE; + inlen -= DATA_UNIT_SIZE; + out += DATA_UNIT_SIZE; + *outlen += DATA_UNIT_SIZE; + } + if (inlen) { + memcpy(ctx->block, in, inlen); + } + ctx->block_nbytes = inlen; + return 1; +} + +int sm4_xts_decrypt_finish(SM4_XTS_CTX *ctx, uint8_t *out, size_t *outlen) +{ + size_t DATA_UNIT_SIZE = ctx->data_unit_size; + if (ctx->block_nbytes >= DATA_UNIT_SIZE) { + error_print(); + return -1; + } + if (ctx->block) { + free(ctx->block); + ctx->block = NULL; + } + if (ctx->block_nbytes) { + error_puts("invalid total input length"); + return -1; + } + *outlen = 0; + return 1; +} + + diff --git a/tests/sm4_xtstest.c b/tests/sm4_xtstest.c index 035f586cf..378e19b8d 100644 --- a/tests/sm4_xtstest.c +++ b/tests/sm4_xtstest.c @@ -27,10 +27,11 @@ static int test_sm4_xts(void) uint8_t plaintext[16 * 4]; uint8_t encrypted[sizeof(plaintext)]; uint8_t decrypted[sizeof(plaintext)]; - size_t tweak = 0x12345678; + uint8_t tweak[16]; size_t i; rand_bytes(key, sizeof(key)); + rand_bytes(tweak, sizeof(tweak)); rand_bytes(plaintext, sizeof(plaintext)); for (i = 0; i < sizeof(len)/sizeof(len[0]); i++) { diff --git a/tools/sm4.c b/tools/sm4.c index 96fce3b1f..9114af73a 100755 --- a/tools/sm4.c +++ b/tools/sm4.c @@ -19,25 +19,177 @@ #include -#define SM4_MODE_CBC 1 -#define SM4_MODE_CTR 2 -#define SM4_MODE_GCM 3 -#define SM4_MODE_CBC_SM3_HMAC 4 -#define SM4_MODE_CTR_SM3_HMAC 5 +enum { + SM4_MODE_ECB = 1, + SM4_MODE_CBC, + SM4_MODE_CFB, + SM4_MODE_OFB, + SM4_MODE_CTR, + SM4_MODE_XTS, + SM4_MODE_CCM, + SM4_MODE_GCM, + SM4_MODE_CBC_SM3_HMAC, + SM4_MODE_CTR_SM3_HMAC, +}; + +static uint8_t *read_content(FILE *infp, size_t *outlen, const char *prog) +{ + const size_t INITIAL_BUFFER_SIZE = 4096; + const size_t MAX_BUFFER_SIZE = 512 * 1024 * 1024; + uint8_t *buffer = NULL; + size_t buffer_size = INITIAL_BUFFER_SIZE; + size_t total_read = 0; + + if (!(buffer = (uint8_t *)malloc(INITIAL_BUFFER_SIZE))) { + fprintf(stderr, "%s: malloc failure\n", prog); + return NULL; + } + + while (1) { + size_t bytes_read; + + if (total_read == buffer_size) { + uint8_t *new_buffer; + + if (buffer_size >= MAX_BUFFER_SIZE) { + fprintf(stderr, "%s: input too long, should be less than %zu\n", prog, MAX_BUFFER_SIZE); + free(buffer); + return NULL; + } + buffer_size = buffer_size * 2; + if (buffer_size > MAX_BUFFER_SIZE) { + buffer_size = MAX_BUFFER_SIZE; + } + + if (!(new_buffer = (uint8_t *)realloc(buffer, buffer_size))) { + fprintf(stderr, "%s: realloc failure\n", prog); + free(buffer); + return NULL; + } + buffer = new_buffer; + } + + bytes_read = fread(buffer + total_read, 1, buffer_size - total_read, infp); + total_read += bytes_read; + + if (feof(infp)) { + break; + } + + if (ferror(infp)) { + fprintf(stderr, "%s: fread error\n", prog); + perror("error reading input"); + free(buffer); + return NULL; + } + } + + *outlen = total_read; + + return buffer; +} + +static int sm4_ccm_crypt(const uint8_t *key, size_t keylen, const uint8_t *iv, size_t ivlen, + const uint8_t *aad, size_t aadlen, size_t taglen, FILE *infp, FILE *outfp, int enc, + const char *prog) +{ + int ret = -1; + SM4_KEY sm4_key; + uint8_t *inbuf = NULL; + uint8_t *outbuf = NULL; + size_t inlen, outlen; + uint8_t *tag; + + if (keylen != 16) { + error_print(); + return -1; + } + if (ivlen < SM4_CCM_MIN_IV_SIZE || ivlen > SM4_CCM_MAX_IV_SIZE) { + fprintf(stderr, "%s: invalid SM4-CCM IV length, should be in [%d, %d]\n", + prog, SM4_CCM_MIN_IV_SIZE, SM4_CCM_MAX_IV_SIZE); + return -1; + } + if (taglen < SM4_CCM_MIN_MAC_SIZE || taglen > SM4_CCM_MAX_MAC_SIZE) { + fprintf(stderr, "%s: invalid SM4-CCM MAC tag length, should be in [%d, %d]\n", + prog, SM4_CCM_MIN_MAC_SIZE, SM4_CCM_MAX_MAC_SIZE); + return -1; + } + if (enc < 0) { + error_print(); + return -1; + } + + sm4_set_encrypt_key(&sm4_key, key); + + if (!(inbuf = read_content(infp, &inlen, prog))) { + goto end; + } + + if (enc) { + outlen = inlen + taglen; + if (!(outbuf = (uint8_t *)malloc(outlen))) { + fprintf(stderr, "%s: malloc failure\n", prog); + goto end; + } + tag = outbuf + inlen; + if (sm4_ccm_encrypt(&sm4_key, iv, ivlen, aad, aadlen, inbuf, inlen, outbuf, taglen, tag) != 1) { + error_print(); + goto end; + } + } else { + if (inlen < taglen) { + fprintf(stderr, "%s: input length (%zu bytes) shorter than tag length (%zu bytes)\n", + prog, inlen, taglen); + goto end; + } + outlen = inlen - taglen; + tag = inbuf + inlen - taglen; + if (!(outbuf = (uint8_t *)malloc(outlen))) { + fprintf(stderr, "%s: malloc failure\n", prog); + goto end; + } + if (sm4_ccm_decrypt(&sm4_key, iv, ivlen, aad, aadlen, inbuf, inlen - taglen, + tag, taglen, outbuf) != 1) { + error_print(); + goto end; + } + } + + if (fwrite(outbuf, 1, outlen, outfp) != outlen) { + fprintf(stderr, "%s: fwrite error\n", prog); + goto end; + } + + ret = 1; + +end: + gmssl_secure_clear(&sm4_key, sizeof(sm4_key)); + if (inbuf) free(inbuf); + if (outbuf) free(outbuf); + return ret; +} -static const char *usage = "(-cbc|-ctr|-gcm|-cbc_sm3_hmac|-ctr_sm3_hmac) {-encrypt|-decrypt} -key hex -iv hex [-aad str| -aad_hex hex] [-in file] [-out file]"; +static const char *usage = + "(-cbc|-ctr|-gcm|-cbc_sm3_hmac|-ctr_sm3_hmac)" + " {-encrypt|-decrypt} -key hex -iv hex" + " [-aad str| -aad_hex hex] [-taglen num] [-in file] [-out file]"; static const char *options = "Options\n" "\n" " Modes\n" "\n" +" -ecb ECB mode\n" " -cbc CBC mode with padding, need 16-byte key and 16-byte iv\n" +" -cfb CFB mode with padding, need 16-byte key and 16-byte iv\n" +" -ofb OFB mode with padding, need 16-byte key and 16-byte iv\n" " -ctr CTR mode, need 16-byte key and 16-byte iv\n" +" -ccm CCM mode, need 16-byte key and any iv length\n" " -gcm GCM mode, need 16-byte key and any iv length\n" " -cbc_sm3_hmac CBC mode with padding and HMAC-SM3 (encrypt-then-mac), need 48-byte key and 16-byte iv\n" " -ctr_sm3_hmac CTR mode with HMAC-SM3 (entrypt-then-mac), need 48-byte key and 16-byte iv\n" +" -xts XTS mode\n" "\n" " -encrypt Encrypt\n" " -decrypt Decrypt\n" @@ -65,6 +217,7 @@ int sm4_main(int argc, char **argv) { int ret = 1; char *prog = argv[0]; + char *keyhex = NULL; char *ivhex = NULL; uint8_t *aad = NULL; @@ -75,25 +228,44 @@ int sm4_main(int argc, char **argv) char *outfile = NULL; uint8_t key[48]; uint8_t iv[SM4_GCM_MAX_IV_SIZE]; - size_t keylen = sizeof(key); - size_t ivlen = sizeof(iv); + size_t keylen = 0; + size_t ivlen = 0; FILE *infp = stdin; FILE *outfp = stdout; int mode = 0; int enc = -1; int rv; + union { +#ifdef ENABLE_SM4_ECB + SM4_ECB_CTX ecb; +#endif SM4_CBC_CTX cbc; +#ifdef ENABLE_SM4_CFB + SM4_CFB_CTX cfb; +#endif +#ifdef ENABLE_SM4_OFB + SM4_OFB_CTX ofb; +#endif SM4_CTR_CTX ctr; +#ifdef ENABLE_SM4_XTS + SM4_XTS_CTX xts; +#endif + SM4_GCM_CTX gcm; SM4_CBC_SM3_HMAC_CTX cbc_sm3_hmac; SM4_CTR_SM3_HMAC_CTX ctr_sm3_hmac; - SM4_GCM_CTX gcm; } sm4_ctx; + + uint8_t inbuf[4096]; size_t inlen; - uint8_t outbuf[4196]; + + uint8_t outbuf[41960]; size_t outlen; + int taglen = -1; + int xts_data_unit_size = 0; + argc--; argv++; @@ -130,10 +302,21 @@ int sm4_main(int argc, char **argv) fprintf(stderr, "%s: invalid IV hex digits\n", prog); goto end; } + } else if (!strcmp(*argv, "-encrypt")) { + if (enc == 0) { + fprintf(stderr, "%s: `-encrypt` `-decrypt` should not be used together\n", prog); + goto end; + } enc = 1; } else if (!strcmp(*argv, "-decrypt")) { + if (enc == 1) { + fprintf(stderr, "%s: `-encrypt` `-decrypt` should not be used together\n", prog); + goto end; + } enc = 0; + + } else if (!strcmp(*argv, "-cbc")) { if (mode) goto bad; mode = SM4_MODE_CBC; @@ -149,6 +332,22 @@ int sm4_main(int argc, char **argv) } else if (!strcmp(*argv, "-gcm")) { if (mode) goto bad; mode = SM4_MODE_GCM; + } else if (!strcmp(*argv, "-ecb")) { + if (mode) goto bad; + mode = SM4_MODE_ECB; + } else if (!strcmp(*argv, "-cfb")) { + if (mode) goto bad; + mode = SM4_MODE_CFB; + } else if (!strcmp(*argv, "-ofb")) { + if (mode) goto bad; + mode = SM4_MODE_OFB; + } else if (!strcmp(*argv, "-ccm")) { + if (mode) goto bad; + mode = SM4_MODE_CCM; + } else if (!strcmp(*argv, "-xts")) { + if (mode) goto bad; + mode = SM4_MODE_XTS; + } else if (!strcmp(*argv, "-aad")) { if (--argc < 1) goto bad; if (aad) { @@ -173,6 +372,22 @@ int sm4_main(int argc, char **argv) goto end; } aad = aad_buf; + } else if (!strcmp(*argv, "-taglen")) { + if (--argc < 1) goto bad; + taglen = atoi(*(++argv)); + if (taglen < 0 || taglen > 32) { + fprintf(stderr, "%s: `-taglen` invalid integer argument\n", prog); + goto end; + } + + } else if (!strcmp(*argv, "-xts_data_unit_size")) { + if (--argc < 1) goto bad; + xts_data_unit_size = atoi(*(++argv)); + // FIXME: malloc outbuf for XTS + if (xts_data_unit_size > sizeof(outbuf) - 1024) { + error_print(); + goto end; + } } else if (!strcmp(*argv, "-in")) { if (--argc < 1) goto bad; infile = *(++argv); @@ -207,20 +422,62 @@ int sm4_main(int argc, char **argv) fprintf(stderr, "%s: option `-key` missing\n", prog); goto end; } + /* if (!ivhex) { fprintf(stderr, "%s: option `-iv` missing\n", prog); goto end; } + */ + + // check if mode is supported switch (mode) { +#ifdef ENABLE_SM4_ECB + case SM4_MODE_ECB: +#endif + case SM4_MODE_CBC: +#ifdef ENABLE_SM4_CFB + case SM4_MODE_CFB: +#endif +#ifdef ENABLE_SM4_OFB + case SM4_MODE_OFB: +#endif case SM4_MODE_CTR: +#ifdef ENABLE_SM4_CCM + case SM4_MODE_CCM: +#endif + case SM4_MODE_GCM: +#ifdef ENABLE_SM4_XTS + case SM4_MODE_XTS: +#endif + case SM4_MODE_CBC_SM3_HMAC: + case SM4_MODE_CTR_SM3_HMAC: + break; + default: + fprintf(stderr, "%s: mode is not supported\n", prog); + goto end; + } + + // check key length + switch (mode) { + case SM4_MODE_ECB: case SM4_MODE_CBC: + case SM4_MODE_CFB: + case SM4_MODE_OFB: + case SM4_MODE_CTR: + case SM4_MODE_CCM: case SM4_MODE_GCM: if (keylen != 16) { fprintf(stderr, "%s: invalid key length, should be 32 hex digits\n", prog); goto end; } break; + case SM4_MODE_XTS: + if (keylen != 32) { + fprintf(stderr, "%s: invalid key length, should be 64 hex digits\n", prog); + goto end; + } + break; case SM4_MODE_CBC_SM3_HMAC: case SM4_MODE_CTR_SM3_HMAC: if (keylen != 48) { @@ -230,9 +487,18 @@ int sm4_main(int argc, char **argv) break; } + // check iv length switch (mode) { - case SM4_MODE_CTR: + case SM4_MODE_ECB: + if (ivlen != 0) { + fprintf(stderr, "%s: ECB mode need no IV\n", prog); + goto end; + } + break; case SM4_MODE_CBC: + case SM4_MODE_CFB: + case SM4_MODE_OFB: + case SM4_MODE_CTR: case SM4_MODE_CBC_SM3_HMAC: case SM4_MODE_CTR_SM3_HMAC: if (ivlen != 16) { @@ -242,9 +508,14 @@ int sm4_main(int argc, char **argv) break; } + // check aad switch (mode) { + case SM4_MODE_ECB: case SM4_MODE_CBC: + case SM4_MODE_CFB: + case SM4_MODE_OFB: case SM4_MODE_CTR: + case SM4_MODE_XTS: if (aad) { fprintf(stderr, "%s: specified mode does not support `-aad` nor `-aad_hex`\n", prog); goto end; @@ -252,6 +523,7 @@ int sm4_main(int argc, char **argv) break; } + // encrypt/decrypt if (mode == SM4_MODE_CTR) { if (sm4_ctr_encrypt_init(&sm4_ctx.ctr, key, iv) != 1) { error_print(); @@ -285,9 +557,20 @@ int sm4_main(int argc, char **argv) goto end; } + if (mode == SM4_MODE_XTS) { + if (sm4_ccm_crypt(key, keylen, iv, ivlen, aad, aadlen, taglen, infp, outfp, enc, prog) != 1) { + goto end; + } + } + + if (enc) { switch (mode) { + case SM4_MODE_ECB: rv = sm4_ecb_encrypt_init(&sm4_ctx.ecb, key); break; case SM4_MODE_CBC: rv = sm4_cbc_encrypt_init(&sm4_ctx.cbc, key, iv); break; + case SM4_MODE_CFB: rv = sm4_cfb_encrypt_init(&sm4_ctx.cfb, 16, key, iv); break; + case SM4_MODE_OFB: rv = sm4_ofb_encrypt_init(&sm4_ctx.ofb, key, iv); break; + case SM4_MODE_XTS: rv = sm4_xts_encrypt_init(&sm4_ctx.xts, key, iv, xts_data_unit_size); break; case SM4_MODE_GCM: rv = sm4_gcm_encrypt_init(&sm4_ctx.gcm, key, keylen, iv, ivlen, aad, aadlen, GHASH_SIZE); break; case SM4_MODE_CBC_SM3_HMAC: rv = sm4_cbc_sm3_hmac_encrypt_init(&sm4_ctx.cbc_sm3_hmac, key, keylen, iv, ivlen, aad, aadlen); break; case SM4_MODE_CTR_SM3_HMAC: rv = sm4_ctr_sm3_hmac_encrypt_init(&sm4_ctx.ctr_sm3_hmac, key, keylen, iv, ivlen, aad, aadlen); break; @@ -299,7 +582,11 @@ int sm4_main(int argc, char **argv) while ((inlen = fread(inbuf, 1, sizeof(inbuf), infp)) > 0) { switch (mode) { + case SM4_MODE_ECB: rv = sm4_ecb_encrypt_update(&sm4_ctx.ecb, inbuf, inlen, outbuf, &outlen); break; case SM4_MODE_CBC: rv = sm4_cbc_encrypt_update(&sm4_ctx.cbc, inbuf, inlen, outbuf, &outlen); break; + case SM4_MODE_CFB: rv = sm4_cfb_encrypt_update(&sm4_ctx.cfb, inbuf, inlen, outbuf, &outlen); break; + case SM4_MODE_OFB: rv = sm4_ofb_encrypt_update(&sm4_ctx.ofb, inbuf, inlen, outbuf, &outlen); break; + case SM4_MODE_XTS: rv = sm4_xts_encrypt_update(&sm4_ctx.xts, inbuf, inlen, outbuf, &outlen); break; case SM4_MODE_GCM: rv = sm4_gcm_encrypt_update(&sm4_ctx.gcm, inbuf, inlen, outbuf, &outlen); break; case SM4_MODE_CBC_SM3_HMAC: rv = sm4_cbc_sm3_hmac_encrypt_update(&sm4_ctx.cbc_sm3_hmac, inbuf, inlen, outbuf, &outlen); break; case SM4_MODE_CTR_SM3_HMAC: rv = sm4_ctr_sm3_hmac_encrypt_update(&sm4_ctx.ctr_sm3_hmac, inbuf, inlen, outbuf, &outlen); break; @@ -315,7 +602,11 @@ int sm4_main(int argc, char **argv) } switch (mode) { + case SM4_MODE_ECB: rv = sm4_ecb_encrypt_finish(&sm4_ctx.ecb, outbuf, &outlen); break; case SM4_MODE_CBC: rv = sm4_cbc_encrypt_finish(&sm4_ctx.cbc, outbuf, &outlen); break; + case SM4_MODE_CFB: rv = sm4_cfb_encrypt_finish(&sm4_ctx.cfb, outbuf, &outlen); break; + case SM4_MODE_OFB: rv = sm4_ofb_encrypt_finish(&sm4_ctx.ofb, outbuf, &outlen); break; + case SM4_MODE_XTS: rv = sm4_xts_encrypt_finish(&sm4_ctx.xts, outbuf, &outlen); break; case SM4_MODE_GCM: rv = sm4_gcm_encrypt_finish(&sm4_ctx.gcm, outbuf, &outlen); break; case SM4_MODE_CBC_SM3_HMAC: rv = sm4_cbc_sm3_hmac_encrypt_finish(&sm4_ctx.cbc_sm3_hmac, outbuf, &outlen); break; case SM4_MODE_CTR_SM3_HMAC: rv = sm4_ctr_sm3_hmac_encrypt_finish(&sm4_ctx.ctr_sm3_hmac, outbuf, &outlen); break; @@ -331,7 +622,11 @@ int sm4_main(int argc, char **argv) } else { switch (mode) { + case SM4_MODE_ECB: rv = sm4_ecb_decrypt_init(&sm4_ctx.ecb, key); break; case SM4_MODE_CBC: rv = sm4_cbc_decrypt_init(&sm4_ctx.cbc, key, iv); break; + case SM4_MODE_CFB: rv = sm4_cfb_decrypt_init(&sm4_ctx.cfb, 16, key, iv); break; + case SM4_MODE_OFB: rv = sm4_ofb_encrypt_init(&sm4_ctx.ofb, key, iv); break; + case SM4_MODE_XTS: rv = sm4_xts_decrypt_init(&sm4_ctx.xts, key, iv, xts_data_unit_size); break; case SM4_MODE_GCM: rv = sm4_gcm_decrypt_init(&sm4_ctx.gcm, key, keylen, iv, ivlen, aad, aadlen, GHASH_SIZE); break; case SM4_MODE_CBC_SM3_HMAC: rv = sm4_cbc_sm3_hmac_decrypt_init(&sm4_ctx.cbc_sm3_hmac, key, keylen, iv, ivlen, aad, aadlen); break; case SM4_MODE_CTR_SM3_HMAC: rv = sm4_ctr_sm3_hmac_decrypt_init(&sm4_ctx.ctr_sm3_hmac, key, keylen, iv, ivlen, aad, aadlen); break; @@ -343,7 +638,11 @@ int sm4_main(int argc, char **argv) while ((inlen = fread(inbuf, 1, sizeof(inbuf), infp)) > 0) { switch (mode) { + case SM4_MODE_ECB: rv = sm4_ecb_decrypt_update(&sm4_ctx.ecb, inbuf, inlen, outbuf, &outlen); break; case SM4_MODE_CBC: rv = sm4_cbc_decrypt_update(&sm4_ctx.cbc, inbuf, inlen, outbuf, &outlen); break; + case SM4_MODE_CFB: rv = sm4_cfb_decrypt_update(&sm4_ctx.cfb, inbuf, inlen, outbuf, &outlen); break; + case SM4_MODE_OFB: rv = sm4_ofb_encrypt_update(&sm4_ctx.ofb, inbuf, inlen, outbuf, &outlen); break; + case SM4_MODE_XTS: rv = sm4_xts_encrypt_update(&sm4_ctx.xts, inbuf, inlen, outbuf, &outlen); break; case SM4_MODE_GCM: rv = sm4_gcm_decrypt_update(&sm4_ctx.gcm, inbuf, inlen, outbuf, &outlen); break; case SM4_MODE_CBC_SM3_HMAC: rv = sm4_cbc_sm3_hmac_decrypt_update(&sm4_ctx.cbc_sm3_hmac, inbuf, inlen, outbuf, &outlen); break; case SM4_MODE_CTR_SM3_HMAC: rv = sm4_ctr_sm3_hmac_decrypt_update(&sm4_ctx.ctr_sm3_hmac, inbuf, inlen, outbuf, &outlen); break; @@ -360,7 +659,11 @@ int sm4_main(int argc, char **argv) } switch (mode) { + case SM4_MODE_ECB: rv = sm4_ecb_decrypt_finish(&sm4_ctx.ecb, outbuf, &outlen); break; case SM4_MODE_CBC: rv = sm4_cbc_decrypt_finish(&sm4_ctx.cbc, outbuf, &outlen); break; + case SM4_MODE_CFB: rv = sm4_cfb_decrypt_finish(&sm4_ctx.cfb, outbuf, &outlen); break; + case SM4_MODE_OFB: rv = sm4_ofb_encrypt_finish(&sm4_ctx.ofb, outbuf, &outlen); break; + case SM4_MODE_XTS: rv = sm4_xts_decrypt_finish(&sm4_ctx.xts, outbuf, &outlen); break; case SM4_MODE_GCM: rv = sm4_gcm_decrypt_finish(&sm4_ctx.gcm, outbuf, &outlen); break; case SM4_MODE_CBC_SM3_HMAC: rv = sm4_cbc_sm3_hmac_decrypt_finish(&sm4_ctx.cbc_sm3_hmac, outbuf, &outlen); break; case SM4_MODE_CTR_SM3_HMAC: rv = sm4_ctr_sm3_hmac_decrypt_finish(&sm4_ctx.ctr_sm3_hmac, outbuf, &outlen); break;