diff --git a/include/gmssl/sm9_z256.h b/include/gmssl/sm9_z256.h index 4490984b8..1e7a640c4 100644 --- a/include/gmssl/sm9_z256.h +++ b/include/gmssl/sm9_z256.h @@ -39,6 +39,7 @@ void sm9_z256_copy_conditional(sm9_z256_t dst, const sm9_z256_t src, uint64_t mo void sm9_z256_set_zero(sm9_z256_t r); int sm9_z256_cmp(const sm9_z256_t a, const sm9_z256_t b); void sm9_z256_mul(uint64_t r[8], const sm9_z256_t a, const sm9_z256_t b); +void sm9_z256_mul_low(uint64_t r[4], const sm9_z256_t a, const sm9_z256_t b); int sm9_z256_from_hex(sm9_z256_t r, const char *hex); void sm9_z256_to_hex(const sm9_z256_t r, char hex[64]); int sm9_z256_equ_hex(const sm9_z256_t a, const char *hex); diff --git a/src/sm9_z256_alg.c b/src/sm9_z256_alg.c index 5fed72986..1059eff67 100644 --- a/src/sm9_z256_alg.c +++ b/src/sm9_z256_alg.c @@ -268,6 +268,17 @@ uint64_t sm9_z256_sub(sm9_z256_t r, const sm9_z256_t a, const sm9_z256_t b) return c; } +#ifndef ENABLE_SM9_Z256_ARMV8 +void sm9_z256_mul_low(uint64_t r[4], const sm9_z256_t a, const sm9_z256_t b) +{ + uint64_t t[8]; + sm9_z256_mul(t, a, b); + r[0] = t[0]; + r[1] = t[1]; + r[2] = t[2]; + r[3] = t[3]; +} + void sm9_z256_mul(uint64_t r[8], const sm9_z256_t a, const sm9_z256_t b) { uint64_t a_[8]; @@ -297,6 +308,7 @@ void sm9_z256_mul(uint64_t r[8], const sm9_z256_t a, const sm9_z256_t b) r[i] = (s[2 * i + 1] << 32) | s[2 * i]; } } +#endif uint64_t sm9_z512_add(uint64_t r[8], const uint64_t a[8], const uint64_t b[8]) { @@ -380,7 +392,7 @@ void sm9_z256_to_hex(const sm9_z256_t r, char hex[64]) { int i; for (i = 3; i >= 0; i--) { - (void)sprintf(hex + 16*(3-i), "%016lx", r[i]); + (void)sprintf(hex + 16*(3-i), "%016llx", r[i]); } } @@ -521,8 +533,47 @@ void sm9_u64_mul_add(uint64_t *w0, uint64_t *w1, *w1 = r[1]; } -#ifndef ENABLE_SM9_Z256_ARMV8 -void sm9_z256_fp_mont_mul(sm9_z256_t r, const sm9_z256_t a, const sm9_z256_t b) + +// p = b640000002a3a6f1d603ab4ff58ec74521f2934b1a7aeedbe56f9b27e351457d +// p' = -p^(-1) mod 2^256 = afd2bac5558a13b3966a4b291522b137181ae39613c8dbaf892bc42c2f2ee42b +// sage: -(IntegerModRing(2^256)(p))^-1 +const uint64_t SM9_Z256_P_PRIME[4] = { + 0x892bc42c2f2ee42b, 0x181ae39613c8dbaf, 0x966a4b291522b137, 0xafd2bac5558a13b3, +}; + + +// z = a*b +// c = (z + (z * p' mod 2^256) * p)/2^256 +void sm9_z256_fp_mont_mul(uint64_t r[4], const uint64_t a[4], const uint64_t b[4]) +{ + uint64_t z[8]; + uint64_t t[8]; + uint64_t c; + + // z = a * b + sm9_z256_mul(z, a, b); + + // t = low(z) * p' + sm9_z256_mul_low(t, z, SM9_Z256_P_PRIME); + + // t = low(t) * p + sm9_z256_mul(t, t, SM9_Z256_P); + + // z = z + t + c = sm9_z512_add(z, z, t); + + // r = high(r) + sm9_z256_copy(r, z + 4); + + if (c) { + sm9_z256_add(r, r, SM9_Z256_MODP_MONT_ONE); + + } else if (sm9_z256_cmp(r, SM9_Z256_P) >= 0) { + (void)sm9_z256_sub(r, r, SM9_Z256_P); + } +} + +void sm9_z256_fp_mont_mul_2way(sm9_z256_t r, const sm9_z256_t a, const sm9_z256_t b) { sm9_z256_t d = {0}, e = {0}; uint64_t q, t0, t1, p0, p1, tmp; @@ -562,7 +613,6 @@ void sm9_z256_fp_mont_sqr(sm9_z256_t r, const sm9_z256_t a) { sm9_z256_fp_mont_mul(r, a, a); } -#endif void sm9_z256_fp_pow(sm9_z256_t r, const sm9_z256_t a, const sm9_z256_t e) { @@ -633,7 +683,7 @@ void sm9_z256_fp_to_hex(const sm9_z256_t r, char hex[64]) sm9_z256_fp_from_mont(t, r); int i; for (i = 3; i >= 0; i--) { - (void)sprintf(hex + 16*(3-i), "%016lx", t[i]); + (void)sprintf(hex + 16*(3-i), "%016llx", t[i]); } }