Skip to content

Commit

Permalink
Add Muldiv() (#110)
Browse files Browse the repository at this point in the history
* Implement MulDivOverflow()

* update document

* update benchmark

* update document

* update authors

* optimize for case x or y == 0

* fuzzing: add muldiv fuzzing + fix fuzzer

* fix benchmark opt

* fix typo

* fuzzing: doc fix

* zero-check directly

* remove wrong optimization

* handle case denominator is greater than numerator

* update func description

* remove  rarely case code

* handle this case in udivrem instead

* fuzz: use correct return value

Co-authored-by: Martin Holst Swende <martin@swende.se>
  • Loading branch information
Planxnx and holiman authored Mar 24, 2022
1 parent d97bdee commit 77643b2
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 12 deletions.
2 changes: 2 additions & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ Kurkó Mihály <kurkomisi@users.noreply.github.com>
Paweł Bylica <chfast@gmail.com>
Yao Zengzeng <yaozengzeng@zju.edu.cn>
Dag Arne Osvik <daosvik@users.noreply.github.com>
Thanee Charattrakool <planxthanee@gmail.com>

39 changes: 39 additions & 0 deletions benchmarks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -842,3 +842,42 @@ func Benchmark_DecodeHex(b *testing.B) {
b.Run("large/uint256", func(b *testing.B) { hexDecodeU256(b, &int256Samples) })
b.Run("large/big", func(b *testing.B) { hexDecodeBig(b, &big256Samples) })
}

func BenchmarkMulDivOverflow(b *testing.B) {
benchmarkUint256 := func(b *testing.B, factorsSamples, muldivSamples *[numSamples]Int) {
iter := (b.N + numSamples - 1) / numSamples

for j := 0; j < numSamples; j++ {
x := factorsSamples[j]

for i := 0; i < iter; i++ {
x.MulDivOverflow(&x, &factorsSamples[j], &muldivSamples[j])
}
}
}

benchmarkBig := func(b *testing.B, factorsSamples, muldivSamples *[numSamples]big.Int) {
iter := (b.N + numSamples - 1) / numSamples

for j := 0; j < numSamples; j++ {
x := factorsSamples[j]

for i := 0; i < iter; i++ {
x.Mul(&x, &factorsSamples[j])
x.Div(&x, &muldivSamples[j])
}
}
}

b.Run("small/uint256", func(b *testing.B) { benchmarkUint256(b, &int32SamplesLt, &int32Samples) })
b.Run("div64/uint256", func(b *testing.B) { benchmarkUint256(b, &int64SamplesLt, &int64Samples) })
b.Run("div128/uint256", func(b *testing.B) { benchmarkUint256(b, &int128SamplesLt, &int128Samples) })
b.Run("div192/uint256", func(b *testing.B) { benchmarkUint256(b, &int192SamplesLt, &int192Samples) })
b.Run("div256/uint256", func(b *testing.B) { benchmarkUint256(b, &int256SamplesLt, &int256Samples) })
b.Run("small/big", func(b *testing.B) { benchmarkBig(b, &big32SamplesLt, &big32Samples) })
b.Run("div64/big", func(b *testing.B) { benchmarkBig(b, &big64SamplesLt, &big64Samples) })
b.Run("div128/big", func(b *testing.B) { benchmarkBig(b, &big128SamplesLt, &big128Samples) })
b.Run("div192/big", func(b *testing.B) { benchmarkBig(b, &big192SamplesLt, &big192Samples) })
b.Run("div256/big", func(b *testing.B) { benchmarkBig(b, &big256SamplesLt, &big256Samples) })

}
34 changes: 22 additions & 12 deletions fuzz.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,19 +205,16 @@ func checkThreeArgOp(op opThreeArgFunc, bigOp bigThreeArgFunc, x, y, z Int) {
func Fuzz(data []byte) int {
if len(data) < 32 {
return 0
} else {

return fuzzUnaryOp(data)
}

switch len(data) {
case 32:
return fuzzUnaryOp(data)
case 64:
return fuzzBinaryOp(data)
case 96:
return fuzzTernaryOp(data)
}
switch {
case len(data) < 64:
return fuzzUnaryOp(data) // needs 32 byte
case len(data) < 96:
return fuzzBinaryOp(data) // needs 64 byte
case len(data) < 128:
return fuzzTernaryOp(data) // needs 96 byte
}
// Too large input
return -1
}

Expand Down Expand Up @@ -277,6 +274,16 @@ func intAddMod(f1, f2, f3, f4 *Int) *Int {
return f1.AddMod(f2, f3, f4)
}

func bigMulDiv(b1, b2, b3, b4 *big.Int) *big.Int {
b1.Mul(b2, b3)
return b1.Div(b1, b4)
}

func intMulDiv(f1, f2, f3, f4 *Int) *Int {
f1.MulDivOverflow(f2, f3, f4)
return f1
}

func fuzzTernaryOp(data []byte) int {
var x, y, z Int
x.SetBytes(data[:32])
Expand All @@ -292,5 +299,8 @@ func fuzzTernaryOp(data []byte) int {
{ // addMod
checkThreeArgOp(intAddMod, bigAddMod, x, y, z)
}
{ // mulDiv
checkThreeArgOp(intMulDiv, bigMulDiv, x, y, z)
}
return 1
}
16 changes: 16 additions & 0 deletions uint256.go
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,22 @@ func (z *Int) MulMod(x, y, m *Int) *Int {
return z.Set(&rem)
}

// MulDivOverflow calculates (x*y)/d with full precision, returns z and whether overflow occurred in multiply process (result does not fit to 256-bit).
// computes 512-bit multiplication and 512 by 256 division.
func (z *Int) MulDivOverflow(x, y, d *Int) (*Int, bool) {
if x.IsZero() || y.IsZero() || d.IsZero() {
return z.Clear(), false
}
p := umul(x, y)

var quot [8]uint64
udivrem(quot[:], p[:], d)

copy(z[:], quot[:4])

return z, (quot[4] | quot[5] | quot[6] | quot[7]) != 0
}

// Abs interprets x as a two's complement signed number,
// and sets z to the absolute value
// Abs(0) = 0
Expand Down
32 changes: 32 additions & 0 deletions uint256_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,38 @@ func TestRandomMulMod(t *testing.T) {
}
}

func TestRandomMulDivOverflow(t *testing.T) {
for i := 0; i < 10000; i++ {
b1, f1, err := randNums()
if err != nil {
t.Fatal(err)
}
b2, f2, err := randNums()
if err != nil {
t.Fatal(err)
}
b3, f3, err := randNums()
if err != nil {
t.Fatal(err)
}
f1a, f2a, f3a := f1.Clone(), f2.Clone(), f3.Clone()

_, overflow := f1.MulDivOverflow(f1, f2, f3)
if b3.BitLen() == 0 {
b1.SetInt64(0)
} else {
b1.Div(b1.Mul(b1, b2), b3)
}

if err := checkOverflow(b1, f1, overflow); err != nil {
t.Fatal(err)
}
if eq := checkEq(b1, f1); !eq {
t.Fatalf("Expected equality:\nf1= %x\nf2= %x\nf3= %x\n[ - ]==\nf= %x\nb= %x\n", f1a, f2a, f3a, f1, b1)
}
}
}

func S256(x *big.Int) *big.Int {
if x.Cmp(bigtt255) < 0 {
return x
Expand Down

0 comments on commit 77643b2

Please sign in to comment.