Skip to content

Commit

Permalink
divmod: fix aliasing error, add tests (#180)
Browse files Browse the repository at this point in the history
This change fixes a flaw in `DivMod` related to aliasing of input arguments.
  • Loading branch information
holiman authored Jul 25, 2024
1 parent 9fb9e97 commit ce90883
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 12 deletions.
33 changes: 32 additions & 1 deletion ternary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ var ternaryOpFuncs = []struct {
{"AddMod", (*Int).AddMod, bigAddMod},
{"MulMod", (*Int).MulMod, bigMulMod},
{"MulModWithReciprocal", (*Int).mulModWithReciprocalWrapper, bigMulMod},
{"DivModZ", divModZ, bigDivModZ},
{"DivModM", divModM, bigDivModM},
}

func checkTernaryOperation(t *testing.T, opName string, op opThreeArgFunc, bigOp bigThreeArgFunc, x, y, z Int) {
Expand Down Expand Up @@ -49,7 +51,10 @@ func checkTernaryOperation(t *testing.T, opName string, op opThreeArgFunc, bigOp
t.Fatalf("%v\nsecond argument had been modified: %x", operation, f2)
}
if !f3.Eq(f3orig) {
t.Fatalf("%v\nthird argument had been modified: %x", operation, f3)
if opName != "DivModZ" && opName != "DivModM" {
// DivMod takes m as third argument, modifies it, and returns it. That is by design.
t.Fatalf("%v\nthird argument had been modified: %x", operation, f3)
}
}
// Check if reusing args as result works correctly.
if have = op(f1, f1, f2orig, f3orig); have != f1 {
Expand Down Expand Up @@ -117,3 +122,29 @@ func (z *Int) mulModWithReciprocalWrapper(x, y, mod *Int) *Int {
mu := Reciprocal(mod)
return z.MulModWithReciprocal(x, y, mod, &mu)
}

func divModZ(z, x, y, m *Int) *Int {
z2, _ := z.DivMod(x, y, m)
return z2
}

func bigDivModZ(result, x, y, mod *big.Int) *big.Int {
if y.Sign() == 0 {
return result.SetUint64(0)
}
z2, _ := result.DivMod(x, y, mod)
return z2
}

func divModM(z, x, y, m *Int) *Int {
_, m2 := z.DivMod(x, y, m)
return z.Set(m2)
}

func bigDivModM(result, x, y, mod *big.Int) *big.Int {
if y.Sign() == 0 {
return result.SetUint64(0)
}
_, m2 := result.DivMod(x, y, mod)
return result.Set(m2)
}
28 changes: 17 additions & 11 deletions uint256.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,9 @@ func umul(x, y *Int, res *[8]uint64) {
func (z *Int) Mul(x, y *Int) *Int {
var (
carry0, carry1, carry2 uint64
res1, res2 uint64
x0, x1, x2, x3 = x[0], x[1], x[2], x[3]
y0, y1, y2, y3 = y[0], y[1], y[2], y[3]
res1, res2 uint64
x0, x1, x2, x3 = x[0], x[1], x[2], x[3]
y0, y1, y2, y3 = y[0], y[1], y[2], y[3]
)

carry0, z[0] = bits.Mul64(x0, y0)
Expand Down Expand Up @@ -610,14 +610,20 @@ func (z *Int) Mod(x, y *Int) *Int {
// DivMod sets z to the quotient x div y and m to the modulus x mod y and returns the pair (z, m) for y != 0.
// If y == 0, both z and m are set to 0 (OBS: differs from the big.Int)
func (z *Int) DivMod(x, y, m *Int) (*Int, *Int) {
if z == m {
// We return both z and m as results, if they are aliased, we have to
// un-alias them to be able to return separate results.
m = new(Int).Set(m)
}
if y.IsZero() {
return z.Clear(), m.Clear()
}
if x.Eq(y) {
return z.SetOne(), m.Clear()
}
if x.Lt(y) {
return z.Clear(), m.Set(x)
m.Set(x)
return z.Clear(), m
}

// At this point:
Expand Down Expand Up @@ -1279,7 +1285,7 @@ func (z *Int) Sqrt(x *Int) *Int {
return z.SetUint64(x0)
}
for {
z2 = (z1 + x0 / z1) >> 1
z2 = (z1 + x0/z1) >> 1
if z2 >= z1 {
return z.SetUint64(z1)
}
Expand All @@ -1291,18 +1297,18 @@ func (z *Int) Sqrt(x *Int) *Int {
z2 := NewInt(0)

// Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller.
z1.Lsh(z1, uint(x.BitLen() + 1) / 2) // must be ≥ √x
z1.Lsh(z1, uint(x.BitLen()+1)/2) // must be ≥ √x

// We can do the first division outside the loop
z2.Rsh(x, uint(x.BitLen() + 1) / 2) // The first div is equal to a right shift
z2.Rsh(x, uint(x.BitLen()+1)/2) // The first div is equal to a right shift

for {
z2.Add(z2, z1)

// z2 = z2.Rsh(z2, 1) -- the code below does a 1-bit rsh faster
z2[0] = (z2[0] >> 1) | z2[1] << 63
z2[1] = (z2[1] >> 1) | z2[2] << 63
z2[2] = (z2[2] >> 1) | z2[3] << 63
z2[0] = (z2[0] >> 1) | z2[1]<<63
z2[1] = (z2[1] >> 1) | z2[2]<<63
z2[2] = (z2[2] >> 1) | z2[3]<<63
z2[3] >>= 1

if !z2.Lt(z1) {
Expand Down

0 comments on commit ce90883

Please sign in to comment.