diff --git a/include/gmssl/sm2.h b/include/gmssl/sm2.h index e2057ab02..cc183be8f 100644 --- a/include/gmssl/sm2.h +++ b/include/gmssl/sm2.h @@ -341,6 +341,20 @@ int sm2_do_ecdh(const SM2_KEY *key, const SM2_POINT *peer_public, SM2_POINT *out _gmssl_export int sm2_ecdh(const SM2_KEY *key, const uint8_t *peer_public, size_t peer_public_len, SM2_POINT *out); +typedef struct { + SM2_KEY sm2_key; + uint8_t buf[SM2_MAX_CIPHERTEXT_SIZE]; + size_t buf_size; +} SM2_ENC_CTX; + +_gmssl_export int sm2_encrypt_init(SM2_ENC_CTX *ctx, const SM2_KEY *sm2_key); +_gmssl_export int sm2_encrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); +_gmssl_export int sm2_encrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); +_gmssl_export int sm2_decrypt_init(SM2_ENC_CTX *ctx, const SM2_KEY *sm2_key); +_gmssl_export int sm2_decrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); +_gmssl_export int sm2_decrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen); + + #ifdef __cplusplus } #endif diff --git a/src/sm2_ctx.c b/src/sm2_ctx.c index 7aef5378e..0e4ced429 100644 --- a/src/sm2_ctx.c +++ b/src/sm2_ctx.c @@ -135,3 +135,182 @@ int sm2_verify_finish(SM2_SIGN_CTX *ctx, const uint8_t *sig, size_t siglen) return 1; } + + + +int sm2_encrypt_init(SM2_ENC_CTX *ctx, const SM2_KEY *sm2_key) +{ + if (!ctx || !sm2_key) { + error_print(); + return -1; + } + + memset(ctx, 0, sizeof(*ctx)); + ctx->sm2_key = *sm2_key; + + return 1; +} + +int sm2_encrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) +{ + if (!ctx || !outlen) { + error_print(); + return -1; + } + + if (ctx->buf_size > SM2_MAX_PLAINTEXT_SIZE) { + error_print(); + return -1; + } + + if (!out) { + *outlen = 0; + return 1; + } + + if (in) { + if (inlen > SM2_MAX_PLAINTEXT_SIZE - ctx->buf_size) { + error_print(); + return -1; + } + + memcpy(ctx->buf + ctx->buf_size, in, inlen); + ctx->buf_size += inlen; + } + + *outlen = 0; + return 1; +} + +int sm2_encrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) +{ + if (!ctx || !outlen) { + error_print(); + return -1; + } + + if (ctx->buf_size > SM2_MAX_PLAINTEXT_SIZE) { + error_print(); + return -1; + } + + if (!out) { + *outlen = SM2_MAX_CIPHERTEXT_SIZE; + return 1; + } + + if (ctx->buf_size) { + if (in) { + if (inlen > SM2_MAX_PLAINTEXT_SIZE - ctx->buf_size) { + error_print(); + return -1; + } + memcpy(ctx->buf + ctx->buf_size, in, inlen); + ctx->buf_size += inlen; + } + if (sm2_encrypt(&ctx->sm2_key, ctx->buf, ctx->buf_size, out, outlen) != 1) { + error_print(); + return -1; + } + } else { + if (!in || !inlen || inlen > SM2_MAX_PLAINTEXT_SIZE) { + error_print(); + return -1; + } + if (sm2_encrypt(&ctx->sm2_key, in, inlen, out, outlen) != 1) { + error_print(); + return -1; + } + } + + return 1; +} + + +int sm2_decrypt_init(SM2_ENC_CTX *ctx, const SM2_KEY *sm2_key) +{ + if (!ctx || !sm2_key) { + error_print(); + return -1; + } + + memset(ctx, 0, sizeof(*ctx)); + ctx->sm2_key = *sm2_key; + + return 1; +} + +int sm2_decrypt_update(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) +{ + if (!ctx || !outlen) { + error_print(); + return -1; + } + + if (ctx->buf_size > SM2_MAX_CIPHERTEXT_SIZE) { + error_print(); + return -1; + } + + if (!out) { + *outlen = 0; + return 1; + } + + if (in) { + if (inlen > SM2_MAX_CIPHERTEXT_SIZE - ctx->buf_size) { + error_print(); + return -1; + } + + memcpy(ctx->buf + ctx->buf_size, in, inlen); + ctx->buf_size += inlen; + } + + *outlen = 0; + return 1; +} + +int sm2_decrypt_finish(SM2_ENC_CTX *ctx, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) +{ + if (!ctx || !outlen) { + error_print(); + return -1; + } + + if (ctx->buf_size > SM2_MAX_CIPHERTEXT_SIZE) { + error_print(); + return -1; + } + + if (!out) { + *outlen = SM2_MAX_PLAINTEXT_SIZE; + return 1; + } + + if (ctx->buf_size) { + if (in) { + if (inlen > SM2_MAX_CIPHERTEXT_SIZE - ctx->buf_size) { + error_print(); + return -1; + } + memcpy(ctx->buf + ctx->buf_size, in, inlen); + ctx->buf_size += inlen; + } + if (sm2_decrypt(&ctx->sm2_key, ctx->buf, ctx->buf_size, out, outlen) != 1) { + error_print(); + return -1; + } + } else { + if (!in || !inlen || inlen > SM2_MAX_CIPHERTEXT_SIZE) { + error_print(); + return -1; + } + if (sm2_decrypt(&ctx->sm2_key, in, inlen, out, outlen) != 1) { + error_print(); + return -1; + } + } + + return 1; +} diff --git a/tools/sm2decrypt.c b/tools/sm2decrypt.c index 2e25707cf..b4aac58dd 100644 --- a/tools/sm2decrypt.c +++ b/tools/sm2decrypt.c @@ -30,6 +30,7 @@ int sm2decrypt_main(int argc, char **argv) FILE *infp = stdin; FILE *outfp = stdout; SM2_KEY key; + SM2_ENC_CTX ctx; uint8_t inbuf[SM2_MAX_CIPHERTEXT_SIZE]; uint8_t outbuf[SM2_MAX_CIPHERTEXT_SIZE]; size_t inlen, outlen; @@ -101,7 +102,12 @@ int sm2decrypt_main(int argc, char **argv) fprintf(stderr, "%s: read input failed : %s\n", prog, strerror(errno)); goto end; } - if (sm2_decrypt(&key, inbuf, inlen, outbuf, &outlen) != 1) { + + if (sm2_decrypt_init(&ctx, &key) != 1) { + fprintf(stderr, "%s: sm2_decrypt_init failed\n", prog); + goto end; + } + if (sm2_decrypt_finish(&ctx, inbuf, inlen, outbuf, &outlen) != 1) { fprintf(stderr, "%s: decryption failure\n", prog); goto end; } @@ -112,6 +118,8 @@ int sm2decrypt_main(int argc, char **argv) ret = 0; end: gmssl_secure_clear(&key, sizeof(key)); + gmssl_secure_clear(&ctx, sizeof(ctx)); + gmssl_secure_clear(outbuf, sizeof(outbuf)); if (keyfp) fclose(keyfp); if (infile && infp) fclose(infp); if (outfile && outfp) fclose(outfp); diff --git a/tools/sm2encrypt.c b/tools/sm2encrypt.c index 3e7580288..2f00842d4 100644 --- a/tools/sm2encrypt.c +++ b/tools/sm2encrypt.c @@ -12,10 +12,10 @@ #include #include #include +#include #include #include - static const char *options = "(-pubkey pem | -cert pem) [-in file] [-out file]"; int sm2encrypt_main(int argc, char **argv) @@ -33,6 +33,7 @@ int sm2encrypt_main(int argc, char **argv) uint8_t cert[1024]; size_t certlen; SM2_KEY key; + SM2_ENC_CTX ctx; uint8_t inbuf[SM2_MAX_PLAINTEXT_SIZE + 1]; uint8_t outbuf[SM2_MAX_CIPHERTEXT_SIZE]; size_t inlen, outlen = sizeof(outbuf); @@ -124,8 +125,12 @@ int sm2encrypt_main(int argc, char **argv) goto end; } - if (sm2_encrypt(&key, inbuf, inlen, outbuf, &outlen) != 1) { - fprintf(stderr, "%s: inner error\n", prog); + if (sm2_encrypt_init(&ctx, &key) != 1) { + fprintf(stderr, "%s: sm2_encrypt_init failed\n", prog); + goto end; + } + if (sm2_encrypt_finish(&ctx, inbuf, inlen, outbuf, &outlen) != 1) { + fprintf(stderr, "%s: sm2_encrypt_finish error\n", prog); goto end; } @@ -136,6 +141,8 @@ int sm2encrypt_main(int argc, char **argv) ret = 0; end: + gmssl_secure_clear(&ctx, sizeof(ctx)); + gmssl_secure_clear(inbuf, sizeof(inbuf)); if (infile && infp) fclose(infp); if (outfile && outfp) fclose(outfp); if (pubkeyfp) fclose(pubkeyfp);