Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement Sqrt() #104

Merged
merged 2 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions benchmarks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,31 @@ func BenchmarkSquare(bench *testing.B) {
bench.Run("single/big", benchmarkBig)
}

func BenchmarkSqrt(bench *testing.B) {
benchmarkUint256 := func(bench *testing.B) {
bench.ReportAllocs()
a := new(Int).SetBytes(hex2Bytes("f123456789abcdeffedcba9876543210f2f3f4f5f6f7f8f9fff3f4f5f6f7f8f9"))

result := new(Int)
bench.ResetTimer()
for i := 0; i < bench.N; i++ {
result.Sqrt(a)
}
}
benchmarkBig := func(bench *testing.B) {
bench.ReportAllocs()
a := new(big.Int).SetBytes(hex2Bytes("f123456789abcdeffedcba9876543210f2f3f4f5f6f7f8f9fff3f4f5f6f7f8f9"))

result := new(big.Int)
bench.ResetTimer()
for i := 0; i < bench.N; i++ {
result.Sqrt(a)
}
}
bench.Run("single/uint256", benchmarkUint256)
bench.Run("single/big", benchmarkBig)
}

func benchmark_And_Big(bench *testing.B) {
b1 := big.NewInt(0).SetBytes(hex2Bytes("0123456789abcdeffedcba9876543210f2f3f4f5f6f7f8f9fff3f4f5f6f7f8f9"))
b2 := big.NewInt(0).SetBytes(hex2Bytes("0123456789abcdefaaaaaa9876543210f2f3f4f5f6f7f8f9fff3f4f5f6f7f8f9"))
Expand Down
54 changes: 54 additions & 0 deletions fuzz.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ const (
opMulmod
)

type opUnaryArgFunc func(*Int, *Int) *Int
type bigUnaryArgFunc func(*big.Int, *big.Int) *big.Int

type opDualArgFunc func(*Int, *Int, *Int) *Int
type bigDualArgFunc func(*big.Int, *big.Int, *big.Int) *big.Int

Expand All @@ -42,6 +45,40 @@ func crash(op interface{}, msg string, args ...Int) {
msg, fnName, fnFile, fnLine, strings.Join(strArgs, "\n")))
}

func checkUnaryOp(op opUnaryArgFunc, bigOp bigUnaryArgFunc, x Int) {
origX := x
var result Int
ret := op(&result, &x)
if ret != &result {
crash(op, "returned not the pointer receiver", x)
}
if x != origX {
crash(op, "argument modified", x)
}
expected, _ := FromBig(bigOp(new(big.Int), x.ToBig()))
if result != *expected {
crash(op, "unexpected result", x)
}
// Test again when the receiver is not zero.
var garbage Int
garbage.Sub(&garbage, NewInt(1))
ret = op(&garbage, &x)
if ret != &garbage {
crash(op, "returned not the pointer receiver", x)
}
if garbage != *expected {
crash(op, "unexpected result", x)
}
// Test again with the receiver aliasing arguments.
ret = op(&x, &x)
if ret != &x {
crash(op, "returned not the pointer receiver", x)
}
if x != *expected {
crash(op, "unexpected result", x)
}
}

func checkDualArgOp(op opDualArgFunc, bigOp bigDualArgFunc, x, y Int) {
origX := x
origY := y
Expand Down Expand Up @@ -166,14 +203,31 @@ func checkThreeArgOp(op opThreeArgFunc, bigOp bigThreeArgFunc, x, y, z Int) {
}

func Fuzz(data []byte) int {
if len(data) < 32 {
return 0
} else {

return fuzzUnaryOp(data)
}

switch len(data) {
case 32:
return fuzzUnaryOp(data)
case 64:
return fuzzBinaryOp(data)
case 96:
return fuzzTernaryOp(data)
}
return -1
}

func fuzzUnaryOp(data []byte) int {
var x Int
x.SetBytes(data[0:32])
checkUnaryOp((*Int).Sqrt, (*big.Int).Sqrt, x)
return 1
}

func fuzzBinaryOp(data []byte) int {
var x, y Int
x.SetBytes(data[0:32])
Expand Down
34 changes: 34 additions & 0 deletions uint256.go
Original file line number Diff line number Diff line change
Expand Up @@ -1209,3 +1209,37 @@ func (z *Int) ExtendSign(x, byteNum *Int) *Int {
}
return z
}

// Sqrt sets z to ⌊√x⌋, the largest integer such that z² ≤ x, and returns z.
func (z *Int) Sqrt(x *Int) *Int {
// This implementation of Sqrt is based on big.Int (see math/big/nat.go).
if x.LtUint64(2) {
return z.Set(x)
}
var (
z1 = &Int{1, 0, 0, 0}
z2 = &Int{}
)
// Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller.
z1 = z1.Lsh(z1, uint(x.BitLen()+1)/2) // must be ≥ √x
for {
z2 = z2.Div(x, z1)
z2 = z2.Add(z2, z1)
{ //z2 = z2.Rsh(z2, 1) -- the code below does a 1-bit rsh faster
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is DeepSource complain about this additional block.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah deepsource is nitpicky. "Unnecessary block detected" - hmpf.
Do you also think it should be removed? I personally like partitioning stuff like that sometimes, for clarity.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could move the a,b definition inside this clause to make deepsource happy

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could move the a,b definition inside this clause to make deepsource happy

This is good idea.

a := z2[3] << 63
z2[3] = z2[3] >> 1
b := z2[2] << 63
z2[2] = (z2[2] >> 1) | a
a = z2[1] << 63
z2[1] = (z2[1] >> 1) | b
z2[0] = (z2[0] >> 1) | a
}
// end of inlined bitshift

if z2.Cmp(z1) >= 0 {
// z1 is answer.
return z.Set(z1)
}
z1, z2 = z2, z1
}
}
15 changes: 15 additions & 0 deletions uint256_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ var (
unTestCases = []string{
"0",
"1",
"0x80000000000000000000000000000000",
"0x80000000000000010000000000000000",
"0x80000000000000000000000000000001",
"0x12cbafcee8f60f9f3fa308c90fde8d298772ffea667aa6bc109d5c661e7929a5",
"0x8000000000000000000000000000000000000000000000000000000000000000",
"0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe",
Expand Down Expand Up @@ -356,6 +359,17 @@ func TestRandomSMod(t *testing.T) {
)
}

func TestRandomSqrt(t *testing.T) {
testRandomOp(t,
func(f1, f2, f3 *Int) {
f1.Sqrt(f2)
},
func(b1, b2, b3 *big.Int) {
b1.Sqrt(b2)
},
)
}

func set3Int(s1, s2, s3, d1, d2, d3 *Int) {
d1.Set(s1)
d2.Set(s2)
Expand Down Expand Up @@ -1141,6 +1155,7 @@ func TestUnOp(t *testing.T) {

t.Run("Not", func(t *testing.T) { proc(t, (*Int).Not, (*big.Int).Not) })
t.Run("Neg", func(t *testing.T) { proc(t, (*Int).Neg, (*big.Int).Neg) })
t.Run("Sqrt", func(t *testing.T) { proc(t, (*Int).Sqrt, (*big.Int).Sqrt) })
}

func TestBinOp(t *testing.T) {
Expand Down