Skip to content

Commit

Permalink
add btree
Browse files Browse the repository at this point in the history
  • Loading branch information
kanishkatn committed Oct 16, 2023
1 parent ff3eeab commit 144cdd2
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 0 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ require (
github.com/spf13/viper v1.17.0
github.com/stretchr/testify v1.8.4
github.com/tetratelabs/wazero v1.1.0
github.com/tidwall/btree v1.7.0
github.com/whyrusleeping/mdns v0.0.0-20190826153040-b9b60ed33aa9
golang.org/x/crypto v0.14.0
golang.org/x/exp v0.0.0-20230905200255-921286631fa9
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,8 @@ github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc=
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI=
github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
Expand Down
33 changes: 33 additions & 0 deletions pkg/scale/btree.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright 2023 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package scale

import (
"reflect"

"github.com/tidwall/btree"
)

// BTree is a wrapper around tidwall/btree.BTree that also stores the comparator function and the type of the items
// stored in the BTree. This is needed during decoding because the BTree is a generic type, and we need to know the
// type of the items stored in the BTree in order to decode them.
type BTree struct {
*btree.BTree
Comparator func(a, b interface{}) bool
ItemType reflect.Type
}

// NewBTree creates a new BTree with the given comparator function.
func NewBTree[T any](comparator func(a, b any) bool) BTree {
// There's no instantiation overhead of the actual type T because we're only creating a slice type and
// getting the element type from it.
var dummySlice []T
elementType := reflect.TypeOf(dummySlice).Elem()

return BTree{
BTree: btree.New(comparator),
Comparator: comparator,
ItemType: elementType,
}
}
59 changes: 59 additions & 0 deletions pkg/scale/btree_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright 2023 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package scale

import (
"testing"

"github.com/stretchr/testify/require"
)

type dummy struct {
Field1 uint32
Field2 [32]byte
}

func TestBTree(t *testing.T) {
comparator := func(a, b interface{}) bool {
v1 := a.(dummy)
v2 := b.(dummy)
return v1.Field1 < v2.Field1
}

// Create a BTree with 3 dummy items
tree := NewBTree[dummy](comparator)
tree.BTree.Set(dummy{Field1: 1})
tree.BTree.Set(dummy{Field1: 2})
tree.BTree.Set(dummy{Field1: 3})

encoded, err := Marshal(tree)
require.NoError(t, err)

//let mut btree = BTreeMap::<u32, Hash>::new();
//btree.insert(1, Hash::zero());
//btree.insert(2, Hash::zero());
//btree.insert(3, Hash::zero());
//let encoded = btree.encode();
//println!("encoded: {:?}", encoded);
expectedEncoded := []byte{12,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}
require.Equal(t, expectedEncoded, encoded)

// Output:
expected := NewBTree[dummy](comparator)
err = Unmarshal(encoded, &expected)
require.NoError(t, err)

// Check that the expected BTree has the same items as the original
require.Equal(t, tree.BTree.Len(), expected.BTree.Len())
require.Equal(t, tree.ItemType, expected.ItemType)
require.Equal(t, tree.BTree.Min(), expected.BTree.Min())
require.Equal(t, tree.BTree.Max(), expected.BTree.Max())
require.Equal(t, tree.BTree.Get(dummy{Field1: 1}), expected.BTree.Get(dummy{Field1: 1}))
require.Equal(t, tree.BTree.Get(dummy{Field1: 2}), expected.BTree.Get(dummy{Field1: 2}))
require.Equal(t, tree.BTree.Get(dummy{Field1: 3}), expected.BTree.Get(dummy{Field1: 3}))
}
43 changes: 43 additions & 0 deletions pkg/scale/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"io"
"math/big"
"reflect"

"github.com/tidwall/btree"
)

// indirect walks down v allocating pointers as needed,
Expand Down Expand Up @@ -130,6 +132,8 @@ func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) {
err = ds.decodeVaryingDataType(dstv)
case VaryingDataTypeSlice:
err = ds.decodeVaryingDataTypeSlice(dstv)
case BTree:
err = ds.decodeBTree(dstv)
default:
t := reflect.TypeOf(in)
switch t.Kind() {
Expand Down Expand Up @@ -768,3 +772,42 @@ func (ds *decodeState) decodeUint128(dstv reflect.Value) (err error) {
dstv.Set(reflect.ValueOf(ui128))
return
}

// decodeBTree accepts a byte array representing a SCALE encoded
// BTree and performs SCALE decoding of the BTree
func (ds *decodeState) decodeBTree(dstv reflect.Value) (err error) {
// Decode the number of items in the tree
length, err := ds.decodeLength()
if err != nil {
return
}

btreeValue, ok := dstv.Interface().(BTree)
if !ok {
return fmt.Errorf("expected a BTree type")
}

if btreeValue.Comparator == nil {
return fmt.Errorf("no Comparator function provided for BTree")
}

if btreeValue.BTree == nil {
btreeValue.BTree = btree.New(btreeValue.Comparator)
}

// Decode each item in the tree
for i := uint(0); i < length; i++ {
// Decode the value
value := reflect.New(btreeValue.ItemType).Elem()
err = ds.unmarshal(value)
if err != nil {
return
}

// convert the value to the correct type for the BTree
btreeValue.BTree.Set(value.Interface())
}

dstv.Set(reflect.ValueOf(btreeValue))
return
}
17 changes: 17 additions & 0 deletions pkg/scale/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ func (es *encodeState) marshal(in interface{}) (err error) {
err = es.encodeVaryingDataType(in)
case VaryingDataTypeSlice:
err = es.encodeVaryingDataTypeSlice(in)
case BTree:
err = es.encodeBTree(in)
default:
switch reflect.TypeOf(in).Kind() {
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16,
Expand Down Expand Up @@ -206,6 +208,21 @@ func (es *encodeState) encodeVaryingDataTypeSlice(vdts VaryingDataTypeSlice) (er
return
}

func (es *encodeState) encodeBTree(tree BTree) (err error) {
// write the number of items in the tree
err = es.encodeLength(tree.Len())
if err != nil {
return
}

// iterate through the tree and encode each item
tree.Ascend(nil, func(item any) bool {
err = es.marshal(item)
return err == nil
})
return
}

func (es *encodeState) encodeSlice(in interface{}) (err error) {
v := reflect.ValueOf(in)
err = es.encodeLength(v.Len())
Expand Down

0 comments on commit 144cdd2

Please sign in to comment.