diff --git a/benchmarks_test.go b/benchmarks_test.go index 7a262109..bb210665 100644 --- a/benchmarks_test.go +++ b/benchmarks_test.go @@ -64,7 +64,7 @@ func initSamples() bool { l := newRandInt(1) g := newRandInt(1) if g.Lt(&l) { - g,l = l,g + g, l = l, g } if g[0] == 0 { g[0]++ @@ -77,7 +77,7 @@ func initSamples() bool { l = newRandInt(2) g = newRandInt(2) if g.Lt(&l) { - g,l = l,g + g, l = l, g } if g[1] == 0 { g[1]++ @@ -90,7 +90,7 @@ func initSamples() bool { l = newRandInt(3) g = newRandInt(3) if g.Lt(&l) { - g,l = l,g + g, l = l, g } if g[2] == 0 { g[2]++ @@ -103,7 +103,7 @@ func initSamples() bool { l = newRandInt(4) g = newRandInt(4) if g.Lt(&l) { - g,l = l,g + g, l = l, g } if g[3] == 0 { g[3]++ @@ -599,14 +599,14 @@ func BenchmarkDiv(b *testing.B) { } b.Run("small/uint256", func(b *testing.B) { benchmarkDivUint256(b, &int32Samples, &int32SamplesLt) }) - b.Run("small/big", func(b *testing.B) { benchmarkDivBig(b, &big32Samples, &big32SamplesLt) }) b.Run("mod64/uint256", func(b *testing.B) { benchmarkDivUint256(b, &int256Samples, &int64Samples) }) - b.Run("mod64/big", func(b *testing.B) { benchmarkDivBig(b, &big256Samples, &big64Samples) }) b.Run("mod128/uint256", func(b *testing.B) { benchmarkDivUint256(b, &int256Samples, &int128Samples) }) - b.Run("mod128/big", func(b *testing.B) { benchmarkDivBig(b, &big256Samples, &big128Samples) }) b.Run("mod192/uint256", func(b *testing.B) { benchmarkDivUint256(b, &int256Samples, &int192Samples) }) - b.Run("mod192/big", func(b *testing.B) { benchmarkDivBig(b, &big256Samples, &big192Samples) }) b.Run("mod256/uint256", func(b *testing.B) { benchmarkDivUint256(b, &int256Samples, &int256SamplesLt) }) + b.Run("small/big", func(b *testing.B) { benchmarkDivBig(b, &big32Samples, &big32SamplesLt) }) + b.Run("mod64/big", func(b *testing.B) { benchmarkDivBig(b, &big256Samples, &big64Samples) }) + b.Run("mod128/big", func(b *testing.B) { benchmarkDivBig(b, &big256Samples, &big128Samples) }) + b.Run("mod192/big", func(b *testing.B) { benchmarkDivBig(b, &big256Samples, &big192Samples) }) b.Run("mod256/big", func(b *testing.B) { benchmarkDivBig(b, &big256Samples, &big256SamplesLt) }) } @@ -629,14 +629,14 @@ func BenchmarkMod(b *testing.B) { } b.Run("small/uint256", func(b *testing.B) { benchmarkModUint256(b, &int32Samples, &int32SamplesLt) }) - b.Run("small/big", func(b *testing.B) { benchmarkModBig(b, &big32Samples, &big32SamplesLt) }) b.Run("mod64/uint256", func(b *testing.B) { benchmarkModUint256(b, &int256Samples, &int64Samples) }) - b.Run("mod64/big", func(b *testing.B) { benchmarkModBig(b, &big256Samples, &big64Samples) }) b.Run("mod128/uint256", func(b *testing.B) { benchmarkModUint256(b, &int256Samples, &int128Samples) }) - b.Run("mod128/big", func(b *testing.B) { benchmarkModBig(b, &big256Samples, &big128Samples) }) b.Run("mod192/uint256", func(b *testing.B) { benchmarkModUint256(b, &int256Samples, &int192Samples) }) - b.Run("mod192/big", func(b *testing.B) { benchmarkModBig(b, &big256Samples, &big192Samples) }) b.Run("mod256/uint256", func(b *testing.B) { benchmarkModUint256(b, &int256Samples, &int256SamplesLt) }) + b.Run("small/big", func(b *testing.B) { benchmarkModBig(b, &big32Samples, &big32SamplesLt) }) + b.Run("mod64/big", func(b *testing.B) { benchmarkModBig(b, &big256Samples, &big64Samples) }) + b.Run("mod128/big", func(b *testing.B) { benchmarkModBig(b, &big256Samples, &big128Samples) }) + b.Run("mod192/big", func(b *testing.B) { benchmarkModBig(b, &big256Samples, &big192Samples) }) b.Run("mod256/big", func(b *testing.B) { benchmarkModBig(b, &big256Samples, &big256SamplesLt) }) } @@ -667,19 +667,38 @@ func BenchmarkAddMod(b *testing.B) { } } - b.Run("small/uint256", func(b *testing.B) { benchmarkAddModUint256 (b, &int32SamplesLt, &int32Samples) }) - b.Run("small/big", func(b *testing.B) { benchmarkAddModBig (b, &big32SamplesLt, &big32Samples) }) - b.Run("mod64/uint256", func(b *testing.B) { benchmarkAddModUint256 (b, &int64SamplesLt, &int64Samples) }) - b.Run("mod64/big", func(b *testing.B) { benchmarkAddModBig (b, &big64SamplesLt, &big64Samples) }) - b.Run("mod128/uint256", func(b *testing.B) { benchmarkAddModUint256 (b, &int128SamplesLt, &int128Samples) }) - b.Run("mod128/big", func(b *testing.B) { benchmarkAddModBig (b, &big128SamplesLt, &big128Samples) }) - b.Run("mod192/uint256", func(b *testing.B) { benchmarkAddModUint256 (b, &int192SamplesLt, &int192Samples) }) - b.Run("mod192/big", func(b *testing.B) { benchmarkAddModBig (b, &big192SamplesLt, &big192Samples) }) - b.Run("mod256/uint256", func(b *testing.B) { benchmarkAddModUint256 (b, &int256SamplesLt, &int256Samples) }) - b.Run("mod256/big", func(b *testing.B) { benchmarkAddModBig (b, &big256SamplesLt, &big256Samples) }) + b.Run("small/uint256", func(b *testing.B) { benchmarkAddModUint256(b, &int32SamplesLt, &int32Samples) }) + b.Run("mod64/uint256", func(b *testing.B) { benchmarkAddModUint256(b, &int64SamplesLt, &int64Samples) }) + b.Run("mod128/uint256", func(b *testing.B) { benchmarkAddModUint256(b, &int128SamplesLt, &int128Samples) }) + b.Run("mod192/uint256", func(b *testing.B) { benchmarkAddModUint256(b, &int192SamplesLt, &int192Samples) }) + b.Run("mod256/uint256", func(b *testing.B) { benchmarkAddModUint256(b, &int256SamplesLt, &int256Samples) }) + b.Run("small/big", func(b *testing.B) { benchmarkAddModBig(b, &big32SamplesLt, &big32Samples) }) + b.Run("mod64/big", func(b *testing.B) { benchmarkAddModBig(b, &big64SamplesLt, &big64Samples) }) + b.Run("mod128/big", func(b *testing.B) { benchmarkAddModBig(b, &big128SamplesLt, &big128Samples) }) + b.Run("mod192/big", func(b *testing.B) { benchmarkAddModBig(b, &big192SamplesLt, &big192Samples) }) + b.Run("mod256/big", func(b *testing.B) { benchmarkAddModBig(b, &big256SamplesLt, &big256Samples) }) } func BenchmarkMulMod(b *testing.B) { + benchmarkMulModUint256R := func(b *testing.B, factorsSamples, modSamples *[numSamples]Int) { + iter := (b.N + numSamples - 1) / numSamples + + var mu [numSamples][5]uint64 + + for i := 0; i < numSamples; i++ { + mu[i] = Reciprocal(&modSamples[i]) + } + + b.ResetTimer() + + for j := 0; j < numSamples; j++ { + x := factorsSamples[j] + + for i := 0; i < iter; i++ { + x.MulModWithReciprocal(&x, &factorsSamples[j], &modSamples[j], &mu[j]) + } + } + } benchmarkMulModUint256 := func(b *testing.B, factorsSamples, modSamples *[numSamples]Int) { iter := (b.N + numSamples - 1) / numSamples @@ -704,16 +723,17 @@ func BenchmarkMulMod(b *testing.B) { } } - b.Run("small/uint256", func(b *testing.B) { benchmarkMulModUint256 (b, &int32SamplesLt, &int32Samples) }) - b.Run("small/big", func(b *testing.B) { benchmarkMulModBig (b, &big32SamplesLt, &big32Samples) }) - b.Run("mod64/uint256", func(b *testing.B) { benchmarkMulModUint256 (b, &int64SamplesLt, &int64Samples) }) - b.Run("mod64/big", func(b *testing.B) { benchmarkMulModBig (b, &big64SamplesLt, &big64Samples) }) - b.Run("mod128/uint256", func(b *testing.B) { benchmarkMulModUint256 (b, &int128SamplesLt, &int128Samples) }) - b.Run("mod128/big", func(b *testing.B) { benchmarkMulModBig (b, &big128SamplesLt, &big128Samples) }) - b.Run("mod192/uint256", func(b *testing.B) { benchmarkMulModUint256 (b, &int192SamplesLt, &int192Samples) }) - b.Run("mod192/big", func(b *testing.B) { benchmarkMulModBig (b, &big192SamplesLt, &big192Samples) }) - b.Run("mod256/uint256", func(b *testing.B) { benchmarkMulModUint256 (b, &int256SamplesLt, &int256Samples) }) - b.Run("mod256/big", func(b *testing.B) { benchmarkMulModBig (b, &big256SamplesLt, &big256Samples) }) + b.Run("small/uint256", func(b *testing.B) { benchmarkMulModUint256(b, &int32SamplesLt, &int32Samples) }) + b.Run("mod64/uint256", func(b *testing.B) { benchmarkMulModUint256(b, &int64SamplesLt, &int64Samples) }) + b.Run("mod128/uint256", func(b *testing.B) { benchmarkMulModUint256(b, &int128SamplesLt, &int128Samples) }) + b.Run("mod192/uint256", func(b *testing.B) { benchmarkMulModUint256(b, &int192SamplesLt, &int192Samples) }) + b.Run("mod256/uint256", func(b *testing.B) { benchmarkMulModUint256(b, &int256SamplesLt, &int256Samples) }) + b.Run("mod256/uint256r", func(b *testing.B) { benchmarkMulModUint256R(b, &int256SamplesLt, &int256Samples) }) + b.Run("small/big", func(b *testing.B) { benchmarkMulModBig(b, &big32SamplesLt, &big32Samples) }) + b.Run("mod64/big", func(b *testing.B) { benchmarkMulModBig(b, &big64SamplesLt, &big64Samples) }) + b.Run("mod128/big", func(b *testing.B) { benchmarkMulModBig(b, &big128SamplesLt, &big128Samples) }) + b.Run("mod192/big", func(b *testing.B) { benchmarkMulModBig(b, &big192SamplesLt, &big192Samples) }) + b.Run("mod256/big", func(b *testing.B) { benchmarkMulModBig(b, &big256SamplesLt, &big256Samples) }) } func benchmark_SdivLarge_Big(bench *testing.B) { diff --git a/mod.go b/mod.go new file mode 100644 index 00000000..e52bfde5 --- /dev/null +++ b/mod.go @@ -0,0 +1,481 @@ +// uint256: Fixed size 256-bit math library +// Copyright 2021 uint256 Authors +// SPDX-License-Identifier: BSD-3-Clause + +package uint256 + +import ( + "math/bits" +) + +// Some utility functions + +func leadingZeros(x *Int) (z int) { + var t int + z = bits.LeadingZeros64(x[3]) + t = bits.LeadingZeros64(x[2]); if z == 64 { z = t + 64 } + t = bits.LeadingZeros64(x[1]); if z == 128 { z = t + 128 } + t = bits.LeadingZeros64(x[0]); if z == 192 { z = t + 192 } + return z +} + +// Reciprocal computes a 320-bit value representing 1/m +// +// Notes: +// - specialized for m[3] != 0, hence limited to 2^192 <= m < 2^256 +// - returns zero if m[3] == 0 +// - starts with a 32-bit division, refines with newton-raphson iterations +func Reciprocal(m *Int) (mu [5]uint64) { + + if m[3] == 0 { + return mu + } + + s := bits.LeadingZeros64(m[3]) // Replace with leadingZeros(m) for general case + p := 255 - s // floor(log_2(m)), m>0 + + // 0 or a power of 2? + + // Check if at least one bit is set in m[2], m[1] or m[0], + // or at least two bits in m[3] + + if m[0] | m[1] | m[2] | (m[3] & (m[3]-1)) == 0 { + + mu[4] = ^uint64(0) >> uint(p & 63) + mu[3] = ^uint64(0) + mu[2] = ^uint64(0) + mu[1] = ^uint64(0) + mu[0] = ^uint64(0) + + return mu + } + + // Maximise division precision by left-aligning divisor + + var ( + y Int // left-aligned copy of m + r0 uint32 // estimate of 2^31/y + ) + + y.Lsh(m, uint(s)) // 1/2 < y < 1 + + // Extract most significant 32 bits + + yh := uint32(y[3] >> 32) + + + if yh == 0x80000000 { // Avoid overflow in division + r0 = 0xffffffff + } else { + r0, _ = bits.Div32(0x80000000, 0, yh) + } + + // First iteration: 32 -> 64 + + t1 := uint64(r0) // 2^31/y + t1 *= t1 // 2^62/y^2 + t1, _ = bits.Mul64(t1, y[3]) // 2^62/y^2 * 2^64/y / 2^64 = 2^62/y + + r1 := uint64(r0) << 32 // 2^63/y + r1 -= t1 // 2^63/y - 2^62/y = 2^62/y + r1 *= 2 // 2^63/y + + if (r1 | (y[3]<<1)) == 0 { + r1 = ^uint64(0) + } + + // Second iteration: 64 -> 128 + + // square: 2^126/y^2 + a2h, a2l := bits.Mul64(r1, r1) + + // multiply by y: e2h:e2l:b2h = 2^126/y^2 * 2^128/y / 2^128 = 2^126/y + b2h, _ := bits.Mul64(a2l, y[2]) + c2h, c2l := bits.Mul64(a2l, y[3]) + d2h, d2l := bits.Mul64(a2h, y[2]) + e2h, e2l := bits.Mul64(a2h, y[3]) + + b2h, c := bits.Add64(b2h, c2l, 0) + e2l, c = bits.Add64(e2l, c2h, c) + e2h, _ = bits.Add64(e2h, 0, c) + + _, c = bits.Add64(b2h, d2l, 0) + e2l, c = bits.Add64(e2l, d2h, c) + e2h, _ = bits.Add64(e2h, 0, c) + + // subtract: t2h:t2l = 2^127/y - 2^126/y = 2^126/y + t2l, b := bits.Sub64( 0, e2l, 0) + t2h, _ := bits.Sub64(r1, e2h, b) + + // double: r2h:r2l = 2^127/y + r2l, c := bits.Add64(t2l, t2l, 0) + r2h, _ := bits.Add64(t2h, t2h, c) + + if (r2h | r2l | (y[3]<<1)) == 0 { + r2h = ^uint64(0) + r2l = ^uint64(0) + } + + // Third iteration: 128 -> 192 + + // square r2 (keep 256 bits): 2^190/y^2 + a3h, a3l := bits.Mul64(r2l, r2l) + b3h, b3l := bits.Mul64(r2l, r2h) + c3h, c3l := bits.Mul64(r2h, r2h) + + a3h, c = bits.Add64(a3h, b3l, 0) + c3l, c = bits.Add64(c3l, b3h, c) + c3h, _ = bits.Add64(c3h, 0, c) + + a3h, c = bits.Add64(a3h, b3l, 0) + c3l, c = bits.Add64(c3l, b3h, c) + c3h, _ = bits.Add64(c3h, 0, c) + + // multiply by y: q = 2^190/y^2 * 2^192/y / 2^192 = 2^190/y + + x0 := a3l + x1 := a3h + x2 := c3l + x3 := c3h + + var q0, q1, q2, q3, q4, t0 uint64 + + q0, _ = bits.Mul64(x2, y[0]) + q1, t0 = bits.Mul64(x3, y[0]); q0, c = bits.Add64(q0, t0, 0); q1, _ = bits.Add64(q1, 0, c) + + + t1, _ = bits.Mul64(x1, y[1]); q0, c = bits.Add64(q0, t1, 0) + q2, t0 = bits.Mul64(x3, y[1]); q1, c = bits.Add64(q1, t0, c); q2, _ = bits.Add64(q2, 0, c) + + t1, t0 = bits.Mul64(x2, y[1]); q0, c = bits.Add64(q0, t0, 0); q1, c = bits.Add64(q1, t1, c); q2, _ = bits.Add64(q2, 0, c) + + + t1, t0 = bits.Mul64(x1, y[2]); q0, c = bits.Add64(q0, t0, 0); q1, c = bits.Add64(q1, t1, c) + q3, t0 = bits.Mul64(x3, y[2]); q2, c = bits.Add64(q2, t0, c); q3, _ = bits.Add64(q3, 0, c) + + t1, _ = bits.Mul64(x0, y[2]); q0, c = bits.Add64(q0, t1, 0) + t1, t0 = bits.Mul64(x2, y[2]); q1, c = bits.Add64(q1, t0, c); q2, c = bits.Add64(q2, t1, c); q3, _ = bits.Add64(q3, 0, c) + + + t1, t0 = bits.Mul64(x1, y[3]); q1, c = bits.Add64(q1, t0, 0); q2, c = bits.Add64(q2, t1, c) + q4, t0 = bits.Mul64(x3, y[3]); q3, c = bits.Add64(q3, t0, c); q4, _ = bits.Add64(q4, 0, c) + + t1, t0 = bits.Mul64(x0, y[3]); q0, c = bits.Add64(q0, t0, 0); q1, c = bits.Add64(q1, t1, c) + t1, t0 = bits.Mul64(x2, y[3]); q2, c = bits.Add64(q2, t0, c); q3, c = bits.Add64(q3, t1, c); q4, _ = bits.Add64(q4, 0, c) + + // subtract: t3 = 2^191/y - 2^190/y = 2^190/y + _, b = bits.Sub64( 0, q0, 0) + _, b = bits.Sub64( 0, q1, b) + t3l, b := bits.Sub64( 0, q2, b) + t3m, b := bits.Sub64(r2l, q3, b) + t3h, _ := bits.Sub64(r2h, q4, b) + + // double: r3 = 2^191/y + r3l, c := bits.Add64(t3l, t3l, 0) + r3m, c := bits.Add64(t3m, t3m, c) + r3h, _ := bits.Add64(t3h, t3h, c) + + // Fourth iteration: 192 -> 320 + + // square r3 + + a4h, a4l := bits.Mul64(r3l, r3l) + b4h, b4l := bits.Mul64(r3l, r3m) + c4h, c4l := bits.Mul64(r3l, r3h) + d4h, d4l := bits.Mul64(r3m, r3m) + e4h, e4l := bits.Mul64(r3m, r3h) + f4h, f4l := bits.Mul64(r3h, r3h) + + b4h, c = bits.Add64(b4h, c4l, 0) + e4l, c = bits.Add64(e4l, c4h, c) + e4h, _ = bits.Add64(e4h, 0, c) + + a4h, c = bits.Add64(a4h, b4l, 0) + d4l, c = bits.Add64(d4l, b4h, c) + d4h, c = bits.Add64(d4h, e4l, c) + f4l, c = bits.Add64(f4l, e4h, c) + f4h, _ = bits.Add64(f4h, 0, c) + + a4h, c = bits.Add64(a4h, b4l, 0) + d4l, c = bits.Add64(d4l, b4h, c) + d4h, c = bits.Add64(d4h, e4l, c) + f4l, c = bits.Add64(f4l, e4h, c) + f4h, _ = bits.Add64(f4h, 0, c) + + // multiply by y + + x1, x0 = bits.Mul64(d4h, y[0]) + x3, x2 = bits.Mul64(f4h, y[0]) + t1, t0 = bits.Mul64(f4l, y[0]); x1, c = bits.Add64(x1, t0, 0); x2, c = bits.Add64(x2, t1, c) + x3, _ = bits.Add64(x3, 0, c) + + t1, t0 = bits.Mul64(d4h, y[1]); x1, c = bits.Add64(x1, t0, 0); x2, c = bits.Add64(x2, t1, c) + x4, t0 := bits.Mul64(f4h, y[1]); x3, c = bits.Add64(x3, t0, c); x4, _ = bits.Add64(x4, 0, c) + t1, t0 = bits.Mul64(d4l, y[1]); x0, c = bits.Add64(x0, t0, 0); x1, c = bits.Add64(x1, t1, c) + t1, t0 = bits.Mul64(f4l, y[1]); x2, c = bits.Add64(x2, t0, c); x3, c = bits.Add64(x3, t1, c) + x4, _ = bits.Add64(x4, 0, c) + + t1, t0 = bits.Mul64(a4h, y[2]); x0, c = bits.Add64(x0, t0, 0); x1, c = bits.Add64(x1, t1, c) + t1, t0 = bits.Mul64(d4h, y[2]); x2, c = bits.Add64(x2, t0, c); x3, c = bits.Add64(x3, t1, c) + x5, t0 := bits.Mul64(f4h, y[2]); x4, c = bits.Add64(x4, t0, c); x5, _ = bits.Add64(x5, 0, c) + t1, t0 = bits.Mul64(d4l, y[2]); x1, c = bits.Add64(x1, t0, 0); x2, c = bits.Add64(x2, t1, c) + t1, t0 = bits.Mul64(f4l, y[2]); x3, c = bits.Add64(x3, t0, c); x4, c = bits.Add64(x4, t1, c) + x5, _ = bits.Add64(x5, 0, c) + + t1, t0 = bits.Mul64(a4h, y[3]); x1, c = bits.Add64(x1, t0, 0); x2, c = bits.Add64(x2, t1, c) + t1, t0 = bits.Mul64(d4h, y[3]); x3, c = bits.Add64(x3, t0, c); x4, c = bits.Add64(x4, t1, c) + x6, t0 := bits.Mul64(f4h, y[3]); x5, c = bits.Add64(x5, t0, c); x6, _ = bits.Add64(x6, 0, c) + t1, t0 = bits.Mul64(a4l, y[3]); x0, c = bits.Add64(x0, t0, 0); x1, c = bits.Add64(x1, t1, c) + t1, t0 = bits.Mul64(d4l, y[3]); x2, c = bits.Add64(x2, t0, c); x3, c = bits.Add64(x3, t1, c) + t1, t0 = bits.Mul64(f4l, y[3]); x4, c = bits.Add64(x4, t0, c); x5, c = bits.Add64(x5, t1, c) + x6, _ = bits.Add64(x6, 0, c) + + // subtract + _, b = bits.Sub64( 0, x0, 0) + _, b = bits.Sub64( 0, x1, b) + r4l, b := bits.Sub64( 0, x2, b) + r4k, b := bits.Sub64( 0, x3, b) + r4j, b := bits.Sub64(r3l, x4, b) + r4i, b := bits.Sub64(r3m, x5, b) + r4h, _ := bits.Sub64(r3h, x6, b) + + // Multiply candidate for 1/4y by y, with full precision + + x0 = r4l + x1 = r4k + x2 = r4j + x3 = r4i + x4 = r4h + + q1, q0 = bits.Mul64(x0, y[0]) + q3, q2 = bits.Mul64(x2, y[0]) + q5, q4 := bits.Mul64(x4, y[0]) + + t1, t0 = bits.Mul64(x1, y[0]); q1, c = bits.Add64(q1, t0, 0); q2, c = bits.Add64(q2, t1, c) + t1, t0 = bits.Mul64(x3, y[0]); q3, c = bits.Add64(q3, t0, c); q4, c = bits.Add64(q4, t1, c); q5, _ = bits.Add64(q5, 0, c) + + t1, t0 = bits.Mul64(x0, y[1]); q1, c = bits.Add64(q1, t0, 0); q2, c = bits.Add64(q2, t1, c) + t1, t0 = bits.Mul64(x2, y[1]); q3, c = bits.Add64(q3, t0, c); q4, c = bits.Add64(q4, t1, c) + q6, t0 := bits.Mul64(x4, y[1]); q5, c = bits.Add64(q5, t0, c); q6, _ = bits.Add64(q6, 0, c) + + t1, t0 = bits.Mul64(x1, y[1]); q2, c = bits.Add64(q2, t0, 0); q3, c = bits.Add64(q3, t1, c) + t1, t0 = bits.Mul64(x3, y[1]); q4, c = bits.Add64(q4, t0, c); q5, c = bits.Add64(q5, t1, c); q6, _ = bits.Add64(q6, 0, c) + + t1, t0 = bits.Mul64(x0, y[2]); q2, c = bits.Add64(q2, t0, 0); q3, c = bits.Add64(q3, t1, c) + t1, t0 = bits.Mul64(x2, y[2]); q4, c = bits.Add64(q4, t0, c); q5, c = bits.Add64(q5, t1, c) + q7, t0 := bits.Mul64(x4, y[2]); q6, c = bits.Add64(q6, t0, c); q7, _ = bits.Add64(q7, 0, c) + + t1, t0 = bits.Mul64(x1, y[2]); q3, c = bits.Add64(q3, t0, 0); q4, c = bits.Add64(q4, t1, c) + t1, t0 = bits.Mul64(x3, y[2]); q5, c = bits.Add64(q5, t0, c); q6, c = bits.Add64(q6, t1, c); q7, _ = bits.Add64(q7, 0, c) + + t1, t0 = bits.Mul64(x0, y[3]); q3, c = bits.Add64(q3, t0, 0); q4, c = bits.Add64(q4, t1, c) + t1, t0 = bits.Mul64(x2, y[3]); q5, c = bits.Add64(q5, t0, c); q6, c = bits.Add64(q6, t1, c) + q8, t0 := bits.Mul64(x4, y[3]); q7, c = bits.Add64(q7, t0, c); q8, _ = bits.Add64(q8, 0, c) + + t1, t0 = bits.Mul64(x1, y[3]); q4, c = bits.Add64(q4, t0, 0); q5, c = bits.Add64(q5, t1, c) + t1, t0 = bits.Mul64(x3, y[3]); q6, c = bits.Add64(q6, t0, c); q7, c = bits.Add64(q7, t1, c); q8, _ = bits.Add64(q8, 0, c) + + // Final adjustment + + // subtract q from 1/4 + _, b = bits.Sub64(0, q0, 0) + _, b = bits.Sub64(0, q1, b) + _, b = bits.Sub64(0, q2, b) + _, b = bits.Sub64(0, q3, b) + _, b = bits.Sub64(0, q4, b) + _, b = bits.Sub64(0, q5, b) + _, b = bits.Sub64(0, q6, b) + _, b = bits.Sub64(0, q7, b) + _, b = bits.Sub64(uint64(1) << 62, q8, b) + + // decrement the result + x0, t := bits.Sub64(r4l, 1, 0) + x1, t = bits.Sub64(r4k, 0, t) + x2, t = bits.Sub64(r4j, 0, t) + x3, t = bits.Sub64(r4i, 0, t) + x4, _ = bits.Sub64(r4h, 0, t) + + // commit the decrement if the subtraction underflowed (reciprocal was too large) + if b != 0 { + r4h, r4i, r4j, r4k, r4l = x4, x3, x2, x1, x0 + } + + // Shift to correct bit alignment, truncating excess bits + + p = (p & 63) - 1 + + x0, c = bits.Add64(r4l, r4l, 0) + x1, c = bits.Add64(r4k, r4k, c) + x2, c = bits.Add64(r4j, r4j, c) + x3, c = bits.Add64(r4i, r4i, c) + x4, _ = bits.Add64(r4h, r4h, c) + + if p < 0 { + r4h, r4i, r4j, r4k, r4l = x4, x3, x2, x1, x0 + p = 0 // avoid negative shift below + } + + { + r := uint(p) // right shift + l := uint(64 - r) // left shift + + x0 = (r4l >> r) | (r4k << l) + x1 = (r4k >> r) | (r4j << l) + x2 = (r4j >> r) | (r4i << l) + x3 = (r4i >> r) | (r4h << l) + x4 = (r4h >> r) + } + + if p > 0 { + r4h, r4i, r4j, r4k, r4l = x4, x3, x2, x1, x0 + } + + mu[0] = r4l + mu[1] = r4k + mu[2] = r4j + mu[3] = r4i + mu[4] = r4h + + return mu +} + +// reduce4 computes the least non-negative residue of x modulo m +// +// requires a four-word modulus (m[3] > 1) and its inverse (mu) +func reduce4(x [8]uint64, m *Int, mu [5]uint64) (z Int) { + + // NB: Most variable names in the comments match the pseudocode for + // Barrett reduction in the Handbook of Applied Cryptography. + + // q1 = x/2^192 + + x0 := x[3] + x1 := x[4] + x2 := x[5] + x3 := x[6] + x4 := x[7] + + // q2 = q1 * mu; q3 = q2 / 2^320 + + var q0, q1, q2, q3, q4, q5, t0, t1, c uint64 + + q0, _ = bits.Mul64(x3, mu[0]) + q1, t0 = bits.Mul64(x4, mu[0]); q0, c = bits.Add64(q0, t0, 0); q1, _ = bits.Add64(q1, 0, c) + + + t1, _ = bits.Mul64(x2, mu[1]); q0, c = bits.Add64(q0, t1, 0) + q2, t0 = bits.Mul64(x4, mu[1]); q1, c = bits.Add64(q1, t0, c); q2, _ = bits.Add64(q2, 0, c) + + t1, t0 = bits.Mul64(x3, mu[1]); q0, c = bits.Add64(q0, t0, 0); q1, c = bits.Add64(q1, t1, c); q2, _ = bits.Add64(q2, 0, c) + + + t1, t0 = bits.Mul64(x2, mu[2]); q0, c = bits.Add64(q0, t0, 0); q1, c = bits.Add64(q1, t1, c) + q3, t0 = bits.Mul64(x4, mu[2]); q2, c = bits.Add64(q2, t0, c); q3, _ = bits.Add64(q3, 0, c) + + t1, _ = bits.Mul64(x1, mu[2]); q0, c = bits.Add64(q0, t1, 0) + t1, t0 = bits.Mul64(x3, mu[2]); q1, c = bits.Add64(q1, t0, c); q2, c = bits.Add64(q2, t1, c); q3, _ = bits.Add64(q3, 0, c) + + + t1, _ = bits.Mul64(x0, mu[3]); q0, c = bits.Add64(q0, t1, 0) + t1, t0 = bits.Mul64(x2, mu[3]); q1, c = bits.Add64(q1, t0, c); q2, c = bits.Add64(q2, t1, c) + q4, t0 = bits.Mul64(x4, mu[3]); q3, c = bits.Add64(q3, t0, c); q4, _ = bits.Add64(q4, 0, c) + + t1, t0 = bits.Mul64(x1, mu[3]); q0, c = bits.Add64(q0, t0, 0); q1, c = bits.Add64(q1, t1, c) + t1, t0 = bits.Mul64(x3, mu[3]); q2, c = bits.Add64(q2, t0, c); q3, c = bits.Add64(q3, t1, c); q4, _ = bits.Add64(q4, 0, c) + + + t1, t0 = bits.Mul64(x0, mu[4]); _, c = bits.Add64(q0, t0, 0); q1, c = bits.Add64(q1, t1, c) + t1, t0 = bits.Mul64(x2, mu[4]); q2, c = bits.Add64(q2, t0, c); q3, c = bits.Add64(q3, t1, c) + q5, t0 = bits.Mul64(x4, mu[4]); q4, c = bits.Add64(q4, t0, c); q5, _ = bits.Add64(q5, 0, c) + + t1, t0 = bits.Mul64(x1, mu[4]); q1, c = bits.Add64(q1, t0, 0); q2, c = bits.Add64(q2, t1, c) + t1, t0 = bits.Mul64(x3, mu[4]); q3, c = bits.Add64(q3, t0, c); q4, c = bits.Add64(q4, t1, c); q5, _ = bits.Add64(q5, 0, c) + + // Drop the fractional part of q3 + + q0 = q1 + q1 = q2 + q2 = q3 + q3 = q4 + q4 = q5 + + // r1 = x mod 2^320 + + x0 = x[0] + x1 = x[1] + x2 = x[2] + x3 = x[3] + x4 = x[4] + + // r2 = q3 * m mod 2^320 + + var r0, r1, r2, r3, r4 uint64 + + r4, r3 = bits.Mul64(q0, m[3]) + _, t0 = bits.Mul64(q1, m[3]); r4, _ = bits.Add64(r4, t0, 0) + + + t1, r2 = bits.Mul64(q0, m[2]); r3, c = bits.Add64(r3, t1, 0) + _, t0 = bits.Mul64(q2, m[2]); r4, _ = bits.Add64(r4, t0, c) + + t1, t0 = bits.Mul64(q1, m[2]); r3, c = bits.Add64(r3, t0, 0); r4, _ = bits.Add64(r4, t1, c) + + + t1, r1 = bits.Mul64(q0, m[1]); r2, c = bits.Add64(r2, t1, 0) + t1, t0 = bits.Mul64(q2, m[1]); r3, c = bits.Add64(r3, t0, c); r4, _ = bits.Add64(r4, t1, c) + + t1, t0 = bits.Mul64(q1, m[1]); r2, c = bits.Add64(r2, t0, 0); r3, c = bits.Add64(r3, t1, c) + _, t0 = bits.Mul64(q3, m[1]); r4, _ = bits.Add64(r4, t0, c) + + + t1, r0 = bits.Mul64(q0, m[0]); r1, c = bits.Add64(r1, t1, 0) + t1, t0 = bits.Mul64(q2, m[0]); r2, c = bits.Add64(r2, t0, c); r3, c = bits.Add64(r3, t1, c) + _, t0 = bits.Mul64(q4, m[0]); r4, _ = bits.Add64(r4, t0, c) + + t1, t0 = bits.Mul64(q1, m[0]); r1, c = bits.Add64(r1, t0, 0); r2, c = bits.Add64(r2, t1, c) + t1, t0 = bits.Mul64(q3, m[0]); r3, c = bits.Add64(r3, t0, c); r4, _ = bits.Add64(r4, t1, c) + + + // r = r1 - r2 + + var b uint64 + + r0, b = bits.Sub64(x0, r0, 0) + r1, b = bits.Sub64(x1, r1, b) + r2, b = bits.Sub64(x2, r2, b) + r3, b = bits.Sub64(x3, r3, b) + r4, b = bits.Sub64(x4, r4, b) + + // if r<0 then r+=m + + if b != 0 { + r0, c = bits.Add64(r0, m[0], 0) + r1, c = bits.Add64(r1, m[1], c) + r2, c = bits.Add64(r2, m[2], c) + r3, c = bits.Add64(r3, m[3], c) + r4, _ = bits.Add64(r4, 0, c) + } + + // while (r>=m) r-=m + + for { + // q = r - m + q0, b = bits.Sub64(r0, m[0], 0) + q1, b = bits.Sub64(r1, m[1], b) + q2, b = bits.Sub64(r2, m[2], b) + q3, b = bits.Sub64(r3, m[3], b) + q4, b = bits.Sub64(r4, 0, b) + + // if borrow break + if b != 0 { + break + } + + // r = q + r4, r3, r2, r1, r0 = q4, q3, q2, q1, q0 + } + + z[3], z[2], z[1], z[0] = r3, r2, r1, r0 + + return z +} diff --git a/mod_test.go b/mod_test.go new file mode 100644 index 00000000..9d24b489 --- /dev/null +++ b/mod_test.go @@ -0,0 +1,47 @@ +// uint256: Fixed size 256-bit math library +// Copyright 2021 uint256 Authors +// SPDX-License-Identifier: BSD-3-Clause + +package uint256 + +import "testing" + +func TestLeadingZeros(t *testing.T) { + one := Int{1, 0, 0, 0} + + testCases := []Int{ + Int{0, 0, 0, 0}, + Int{1, 0, 0, 0}, + Int{0x7fffffffffffffff, 0, 0, 0}, + Int{0x8000000000000000, 0, 0, 0}, + Int{0xffffffffffffffff, 0, 0, 0}, + Int{0, 1, 0, 0}, + Int{0, 0x7fffffffffffffff, 0, 0}, + Int{0, 0x8000000000000000, 0, 0}, + Int{0, 0xffffffffffffffff, 0, 0}, + Int{0, 0, 1, 0}, + Int{0, 0, 0x7fffffffffffffff, 0}, + Int{0, 0, 0x8000000000000000, 0}, + Int{0, 0, 0xffffffffffffffff, 0}, + Int{0, 0, 0, 1}, + Int{0, 0, 0, 0x7fffffffffffffff}, + Int{0, 0, 0, 0x8000000000000000}, + Int{0, 0, 0, 0xffffffffffffffff}, + } + + for _, x := range testCases { + z := leadingZeros(&x) + if z >= 0 && z < 256 { + allZeros := new(Int).Rsh(&x, uint(256-z)) + oneBit := new(Int).Rsh(&x, uint(255-z)) + if allZeros.IsZero() && oneBit.Eq(&one) { + continue + } + } else if z == 256 { + if x.IsZero() { + continue + } + } + t.Errorf("wrong leading zeros %d of %x", z, x) + } +} diff --git a/uint256.go b/uint256.go index 05cd1a6e..4a3f9461 100644 --- a/uint256.go +++ b/uint256.go @@ -195,10 +195,43 @@ func (z *Int) AddOverflow(x, y *Int) (*Int, bool) { // AddMod sets z to the sum ( x+y ) mod m, and returns z. // If m == 0, z is set to 0 (OBS: differs from the big.Int) func (z *Int) AddMod(x, y, m *Int) *Int { + + // Fast path for m >= 2^192, with x and y at most slightly bigger than m. + // This is always the case when x and y are already reduced modulo such m. + + if (m[3] != 0) && (x[3] <= m[3]) && (y[3] <= m[3]) { + var ( + s, t Int + overflow bool + ) + + s = *x + if _, overflow = s.SubOverflow(&s, m); overflow { + s = *x + } + + t = *y + if _, overflow = t.SubOverflow(&t, m); overflow { + t = *y + } + + if _, overflow = s.AddOverflow(&s, &t); overflow { + s.Sub(&s, m) + } + + t = s + if _, overflow = s.SubOverflow(&s, m); overflow { + s = t + } + + *z = s + return z + } + if m.IsZero() { return z.Clear() } - if z == m { // z is an alias for m // TODO: Understand why needed and add tests for all "division" methods. + if z == m { // z is an alias for m and will be overwritten by AddOverflow before m is read m = m.Clone() } if _, overflow := z.AddOverflow(x, y); overflow { @@ -576,6 +609,38 @@ func (z *Int) SMod(x, y *Int) *Int { return z } +// MulModWithReciprocal calculates the modulo-m multiplication of x and y +// and returns z, using the reciprocal of m provided as the mu parameter. +// Use uint256.Reciprocal to calculate mu from m. +// If m == 0, z is set to 0 (OBS: differs from the big.Int) +func (z *Int) MulModWithReciprocal(x, y, m *Int, mu *[5]uint64) *Int { + if x.IsZero() || y.IsZero() || m.IsZero() { + return z.Clear() + } + p := umul(x, y) + + if m[3] != 0 { + r := reduce4(p, m, *mu) + return z.Set(&r) + } + + var ( + pl Int + ph Int + ) + copy(pl[:], p[:4]) + copy(ph[:], p[4:]) + + // If the multiplication is within 256 bits use Mod(). + if ph.IsZero() { + return z.Mod(&pl, m) + } + + var quot [8]uint64 + rem := udivrem(quot[:], p[:], m) + return z.Set(&rem) +} + // MulMod calculates the modulo-m multiplication of x and y and // returns z. // If m == 0, z is set to 0 (OBS: differs from the big.Int) @@ -584,6 +649,13 @@ func (z *Int) MulMod(x, y, m *Int) *Int { return z.Clear() } p := umul(x, y) + + if m[3] != 0 { + mu := Reciprocal(m) + r := reduce4(p, m, mu) + return z.Set(&r) + } + var ( pl Int ph Int diff --git a/uint256_test.go b/uint256_test.go index 8ae6e969..b4f70b5e 100644 --- a/uint256_test.go +++ b/uint256_test.go @@ -89,6 +89,15 @@ var ( {"0xffffffffffffffffffffffffffff000004020041fffffffffc00000060000020", "0xffffffffffffffffffffffffffffffe6000000ffffffe60000febebeffffffff", "0xffffffffffffffffffe6000000ffffffe60000febebeffffffffffffffffffff"}, {"0xffffffffffffffffffffffffffffffff00ffffe6ff0000000000000060000020", "0xffffffffffffffffffffffffffffffffffe6000000ffff00e60000febebeffff", "0xffffffffffffffffffe6000000ffff00e60000fe0000ffff00e60000febebeff"}, {"0xfffffffffffffffffffffffff600000000005af50100bebe000000004a00be0a", "0xffffffffffffffffffffffffffffeaffdfd9fffffffffffff5f60000000000ff", "0xffffffffffffffffffffffeaffdfd9fffffffffffffff60000000000ffffffff"}, + {"0x8000000000000001000000000000000000000000000000000000000000000000", "0x800000000000000100000000000000000000000000000000000000000000000b", "0x8000000000000000000000000000000000000000000000000000000000000000"}, + {"0x8000000000000000000000000000000000000000000000000000000000000000", "0x8000000000000001000000000000000000000000000000000000000000000000", "0x8000000000000000000000000000000000000000000000000000000000000000"}, + {"0x8000000000000000000000000000000000000000000000000000000000000000", "0x8000000000000001000000000000000000000000000000000000000000000000", "0x8000000000000001000000000000000000000000000000000000000000000000"}, + {"0x8000000000000000000000000000000000000000000000000000000000000000", "0x8000000000000000000000000000000100000000000000000000000000000000", "0x8000000000000000000000000000000000000000000000000000000000000001"}, + {"1", "1", "0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff"}, + {"1", "1", "0x1000000003030303030303030303030303030303030303030303030303030"}, + {"1", "1", "0x4000000000000000130303030303030303030303030303030303030303030"}, + {"1", "1", "0x8000000000000000000000000000000043030303000000000"}, + {"1", "1", "0x8000000000000000000000000000000003030303030303030"}, } ) @@ -347,6 +356,18 @@ func TestRandomSMod(t *testing.T) { ) } +func set3Int(s1, s2, s3, d1, d2, d3 *Int) { + d1.Set(s1) + d2.Set(s2) + d3.Set(s3) +} + +func set3Big(s1, s2, s3, d1, d2, d3 *big.Int) { + d1.Set(s1) + d2.Set(s2) + d3.Set(s3) +} + func TestRandomMulMod(t *testing.T) { for i := 0; i < 10000; i++ { b1, f1, err := randNums() @@ -378,6 +399,318 @@ func TestRandomMulMod(t *testing.T) { if !checkEq(b1, f1) { t.Fatalf("Expected equality:\nf2= %x\nf3= %x\nf4= %x\n[ op ]==\nf = %x\nb = %x\n", f2, f3, f4, f1, b1) } + + f1.mulModWithReciprocalWrapper(f2, f3, f4) + + if !checkEq(b1, f1) { + t.Fatalf("Expected equality:\nf2= %x\nf3= %x\nf4= %x\n[ op ]==\nf = %x\nb = %x\n", f2, f3, f4, f1, b1) + } + } + + // Tests related to powers of 2 + + f_minusone := &Int{^uint64(0), ^uint64(0), ^uint64(0), ^uint64(0)} + + b_one := big.NewInt(1) + b_minusone := big.NewInt(0) + b_minusone = b_minusone.Sub(b_minusone, b_one) + + for i := uint(0); i < 256; i++ { + b := big.NewInt(0) + f := NewInt(0) + + t1, t2, t3 := b, b, b + u1, u2, u3 := f, f, f + + b1 := b.Lsh(b, i) + f1 := f.Lsh(f, i) + + b2, f2, err := randNums() + if err != nil { + t.Fatalf("Error getting a random number: %v", err) + } + for b2.Cmp(big.NewInt(0)) == 0 { + b2, f2, err = randNums() + if err != nil { + t.Fatalf("Error getting a random number: %v", err) + } + } + + b3, f3, err := randNums() + if err != nil { + t.Fatalf("Error getting a random number: %v", err) + } + for b3.Cmp(big.NewInt(0)) == 0 { + b3, f3, err = randNums() + if err != nil { + t.Fatalf("Error getting a random number: %v", err) + } + } + + // Tests with one operand a power of 2 + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t1, t2), t3) + f.MulMod(u1, u2, u3) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf1= 0x%x\nf2= 0x%x\nf3= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f1, f2, f3, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t1, t3), t2) + f.MulMod(u1, u3, u2) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf1= 0x%x\nf3= 0x%x\nf2= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f1, f3, f2, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t2, t1), t3) + f.MulMod(u2, u1, u3) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf2= 0x%x\nf1= 0x%x\nf3= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f2, f1, f3, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t2, t3), t1) + f.MulMod(u2, u3, u1) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf2= 0x%x\nf3= 0x%x\nf1= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f2, f3, f1, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t3, t1), t2) + f.MulMod(u3, u1, u2) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf3= 0x%x\nf1= 0x%x\nf2= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f3, f1, f2, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t3, t2), t1) + f.MulMod(u3, u2, u1) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf3= 0x%x\nf2= 0x%x\nf1= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f3, f2, f1, f, b) + } + + // Tests with one operand 2^256 minus a power of 2 + + f1.Xor(f1, f_minusone) + b1.Xor(b1, b_minusone) + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t1, t2), t3) + f.MulMod(u1, u2, u3) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf1= 0x%x\nf2= 0x%x\nf3= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f1, f2, f3, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t1, t3), t2) + f.MulMod(u1, u3, u2) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf1= 0x%x\nf3= 0x%x\nf2= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f1, f3, f2, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t2, t1), t3) + f.MulMod(u2, u1, u3) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf2= 0x%x\nf1= 0x%x\nf3= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f2, f1, f3, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t2, t3), t1) + f.MulMod(u2, u3, u1) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf2= 0x%x\nf3= 0x%x\nf1= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f2, f3, f1, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t3, t1), t2) + f.MulMod(u3, u1, u2) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf3= 0x%x\nf1= 0x%x\nf2= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f3, f1, f2, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t3, t2), t1) + f.MulMod(u3, u2, u1) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf3= 0x%x\nf2= 0x%x\nf1= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f3, f2, f1, f, b) + } + + f1.Xor(f1, f_minusone) + b1.Xor(b1, b_minusone) + + // Tests with one operand a power of 2 plus 1 + + b1.Add(b1, b_one) + f1.AddUint64(f1, 1) + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t1, t2), t3) + f.MulMod(u1, u2, u3) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf1= 0x%x\nf2= 0x%x\nf3= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f1, f2, f3, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t1, t3), t2) + f.MulMod(u1, u3, u2) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf1= 0x%x\nf3= 0x%x\nf2= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f1, f3, f2, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t2, t1), t3) + f.MulMod(u2, u1, u3) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf2= 0x%x\nf1= 0x%x\nf3= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f2, f1, f3, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t2, t3), t1) + f.MulMod(u2, u3, u1) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf2= 0x%x\nf3= 0x%x\nf1= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f2, f3, f1, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t3, t1), t2) + f.MulMod(u3, u1, u2) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf3= 0x%x\nf1= 0x%x\nf2= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f3, f1, f2, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t3, t2), t1) + f.MulMod(u3, u2, u1) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf3= 0x%x\nf2= 0x%x\nf1= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f3, f2, f1, f, b) + } + + // Tests with one operand a power of 2 minus 1 + + if i == 0 { + continue // skip zero operand + } + + b1.Sub(b1, b_one) + b1.Sub(b1, b_one) + f1.SubUint64(f1, 2) + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t1, t2), t3) + f.MulMod(u1, u2, u3) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf1= 0x%x\nf2= 0x%x\nf3= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f1, f2, f3, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t1, t3), t2) + f.MulMod(u1, u3, u2) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf1= 0x%x\nf3= 0x%x\nf2= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f1, f3, f2, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t2, t1), t3) + f.MulMod(u2, u1, u3) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf2= 0x%x\nf1= 0x%x\nf3= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f2, f1, f3, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t2, t3), t1) + f.MulMod(u2, u3, u1) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf2= 0x%x\nf3= 0x%x\nf1= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f2, f3, f1, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t3, t1), t2) + f.MulMod(u3, u1, u2) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf3= 0x%x\nf1= 0x%x\nf2= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f3, f1, f2, f, b) + } + + set3Big(b1, b2, b3, t1, t2, t3) + set3Int(f1, f2, f3, u1, u2, u3) + + b.Mod(b.Mul(t3, t2), t1) + f.MulMod(u3, u2, u1) + + if !checkEq(b, f) { + t.Fatalf("Expected equality:\nf3= 0x%x\nf2= 0x%x\nf1= 0x%x\n[ op ]==\nf = %x\nb = %x\n", f3, f2, f1, f, b) + } } } @@ -732,6 +1065,11 @@ func mulMod(result, x, y, mod *big.Int) *big.Int { return result.Mod(result.Mul(x, y), mod) } +func (z *Int) mulModWithReciprocalWrapper(x, y, mod *Int) *Int { + mu := Reciprocal(mod) + return z.MulModWithReciprocal(x, y, mod, &mu) +} + func referenceExp(base, exponent *big.Int) *big.Int { // TODO: Maybe use the Exp() procedure from above? res := new(big.Int) @@ -993,6 +1331,14 @@ func TestTernOp(t *testing.T) { return mulMod(z, x, y, m) }) }) + t.Run("MulModWithReciprocal", func(t *testing.T) { + proc(t, (*Int).mulModWithReciprocalWrapper, func(z, x, y, m *big.Int) *big.Int { + if m.Sign() == 0 { + return z.SetUint64(0) + } + return mulMod(z, x, y, m) + }) + }) } func TestCmpOp(t *testing.T) {