Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

divmod: fix aliasing error, add tests #180

Merged
merged 1 commit into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion ternary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ var ternaryOpFuncs = []struct {
{"AddMod", (*Int).AddMod, bigAddMod},
{"MulMod", (*Int).MulMod, bigMulMod},
{"MulModWithReciprocal", (*Int).mulModWithReciprocalWrapper, bigMulMod},
{"DivModZ", divModZ, bigDivModZ},
{"DivModM", divModM, bigDivModM},
}

func checkTernaryOperation(t *testing.T, opName string, op opThreeArgFunc, bigOp bigThreeArgFunc, x, y, z Int) {
Expand Down Expand Up @@ -49,7 +51,10 @@ func checkTernaryOperation(t *testing.T, opName string, op opThreeArgFunc, bigOp
t.Fatalf("%v\nsecond argument had been modified: %x", operation, f2)
}
if !f3.Eq(f3orig) {
t.Fatalf("%v\nthird argument had been modified: %x", operation, f3)
if opName != "DivModZ" && opName != "DivModM" {
// DivMod takes m as third argument, modifies it, and returns it. That is by design.
t.Fatalf("%v\nthird argument had been modified: %x", operation, f3)
}
}
// Check if reusing args as result works correctly.
if have = op(f1, f1, f2orig, f3orig); have != f1 {
Expand Down Expand Up @@ -117,3 +122,29 @@ func (z *Int) mulModWithReciprocalWrapper(x, y, mod *Int) *Int {
mu := Reciprocal(mod)
return z.MulModWithReciprocal(x, y, mod, &mu)
}

func divModZ(z, x, y, m *Int) *Int {
z2, _ := z.DivMod(x, y, m)
return z2
}

func bigDivModZ(result, x, y, mod *big.Int) *big.Int {
if y.Sign() == 0 {
return result.SetUint64(0)
}
z2, _ := result.DivMod(x, y, mod)
return z2
}

func divModM(z, x, y, m *Int) *Int {
_, m2 := z.DivMod(x, y, m)
return z.Set(m2)
}

func bigDivModM(result, x, y, mod *big.Int) *big.Int {
if y.Sign() == 0 {
return result.SetUint64(0)
}
_, m2 := result.DivMod(x, y, mod)
return result.Set(m2)
}
28 changes: 17 additions & 11 deletions uint256.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,9 @@ func umul(x, y *Int, res *[8]uint64) {
func (z *Int) Mul(x, y *Int) *Int {
var (
carry0, carry1, carry2 uint64
res1, res2 uint64
x0, x1, x2, x3 = x[0], x[1], x[2], x[3]
y0, y1, y2, y3 = y[0], y[1], y[2], y[3]
res1, res2 uint64
x0, x1, x2, x3 = x[0], x[1], x[2], x[3]
y0, y1, y2, y3 = y[0], y[1], y[2], y[3]
)

carry0, z[0] = bits.Mul64(x0, y0)
Expand Down Expand Up @@ -610,14 +610,20 @@ func (z *Int) Mod(x, y *Int) *Int {
// DivMod sets z to the quotient x div y and m to the modulus x mod y and returns the pair (z, m) for y != 0.
// If y == 0, both z and m are set to 0 (OBS: differs from the big.Int)
func (z *Int) DivMod(x, y, m *Int) (*Int, *Int) {
if z == m {
// We return both z and m as results, if they are aliased, we have to
// un-alias them to be able to return separate results.
m = new(Int).Set(m)
}
if y.IsZero() {
return z.Clear(), m.Clear()
}
if x.Eq(y) {
return z.SetOne(), m.Clear()
}
if x.Lt(y) {
return z.Clear(), m.Set(x)
m.Set(x)
return z.Clear(), m
}

// At this point:
Expand Down Expand Up @@ -1279,7 +1285,7 @@ func (z *Int) Sqrt(x *Int) *Int {
return z.SetUint64(x0)
}
for {
z2 = (z1 + x0 / z1) >> 1
z2 = (z1 + x0/z1) >> 1
if z2 >= z1 {
return z.SetUint64(z1)
}
Expand All @@ -1291,18 +1297,18 @@ func (z *Int) Sqrt(x *Int) *Int {
z2 := NewInt(0)

// Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller.
z1.Lsh(z1, uint(x.BitLen() + 1) / 2) // must be ≥ √x
z1.Lsh(z1, uint(x.BitLen()+1)/2) // must be ≥ √x

// We can do the first division outside the loop
z2.Rsh(x, uint(x.BitLen() + 1) / 2) // The first div is equal to a right shift
z2.Rsh(x, uint(x.BitLen()+1)/2) // The first div is equal to a right shift

for {
z2.Add(z2, z1)

// z2 = z2.Rsh(z2, 1) -- the code below does a 1-bit rsh faster
z2[0] = (z2[0] >> 1) | z2[1] << 63
z2[1] = (z2[1] >> 1) | z2[2] << 63
z2[2] = (z2[2] >> 1) | z2[3] << 63
z2[0] = (z2[0] >> 1) | z2[1]<<63
z2[1] = (z2[1] >> 1) | z2[2]<<63
z2[2] = (z2[2] >> 1) | z2[3]<<63
z2[3] >>= 1

if !z2.Lt(z1) {
Expand Down