diff --git a/go.mod b/go.mod index c0df9cba1b7..9b873e3253c 100644 --- a/go.mod +++ b/go.mod @@ -37,6 +37,7 @@ require ( github.com/spf13/viper v1.17.0 github.com/stretchr/testify v1.8.4 github.com/tetratelabs/wazero v1.1.0 + github.com/tidwall/btree v1.7.0 github.com/whyrusleeping/mdns v0.0.0-20190826153040-b9b60ed33aa9 golang.org/x/crypto v0.14.0 golang.org/x/exp v0.0.0-20230905200255-921286631fa9 diff --git a/go.sum b/go.sum index 2d7ba3ae3b1..22db41a454f 100644 --- a/go.sum +++ b/go.sum @@ -767,6 +767,8 @@ github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8 github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= +github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI= +github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= diff --git a/pkg/scale/btree.go b/pkg/scale/btree.go new file mode 100644 index 00000000000..452f7d0061e --- /dev/null +++ b/pkg/scale/btree.go @@ -0,0 +1,33 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package scale + +import ( + "reflect" + + "github.com/tidwall/btree" +) + +// BTree is a wrapper around tidwall/btree.BTree that also stores the comparator function and the type of the items +// stored in the BTree. This is needed during decoding because the BTree is a generic type, and we need to know the +// type of the items stored in the BTree in order to decode them. +type BTree struct { + *btree.BTree + Comparator func(a, b interface{}) bool + ItemType reflect.Type +} + +// NewBTree creates a new BTree with the given comparator function. +func NewBTree[T any](comparator func(a, b any) bool) BTree { + // There's no instantiation overhead of the actual type T because we're only creating a slice type and + // getting the element type from it. + var dummySlice []T + elementType := reflect.TypeOf(dummySlice).Elem() + + return BTree{ + BTree: btree.New(comparator), + Comparator: comparator, + ItemType: elementType, + } +} diff --git a/pkg/scale/btree_test.go b/pkg/scale/btree_test.go new file mode 100644 index 00000000000..647e5c4d194 --- /dev/null +++ b/pkg/scale/btree_test.go @@ -0,0 +1,54 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package scale + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +type dummy struct { + Field1 uint32 + Field2 [32]byte +} + +func TestBTree(t *testing.T) { + comparator := func(a, b interface{}) bool { + v1 := a.(dummy) + v2 := b.(dummy) + return v1.Field1 < v2.Field1 + } + + // Create a BTree with 3 dummy items + tree := NewBTree[dummy](comparator) + tree.BTree.Set(dummy{Field1: 1}) + tree.BTree.Set(dummy{Field1: 2}) + tree.BTree.Set(dummy{Field1: 3}) + + encoded, err := Marshal(tree) + require.NoError(t, err) + + // taken from the rust codec + expectedEncoded := []byte{12, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + } + require.Equal(t, expectedEncoded, encoded) + + // Output: + expected := NewBTree[dummy](comparator) + err = Unmarshal(encoded, &expected) + require.NoError(t, err) + + // Check that the expected BTree has the same items as the original + require.Equal(t, tree.BTree.Len(), expected.BTree.Len()) + require.Equal(t, tree.ItemType, expected.ItemType) + require.Equal(t, tree.BTree.Min(), expected.BTree.Min()) + require.Equal(t, tree.BTree.Max(), expected.BTree.Max()) + require.Equal(t, tree.BTree.Get(dummy{Field1: 1}), expected.BTree.Get(dummy{Field1: 1})) + require.Equal(t, tree.BTree.Get(dummy{Field1: 2}), expected.BTree.Get(dummy{Field1: 2})) + require.Equal(t, tree.BTree.Get(dummy{Field1: 3}), expected.BTree.Get(dummy{Field1: 3})) +} diff --git a/pkg/scale/decode.go b/pkg/scale/decode.go index 45a527f0b84..0bda5f16ea4 100644 --- a/pkg/scale/decode.go +++ b/pkg/scale/decode.go @@ -11,6 +11,8 @@ import ( "io" "math/big" "reflect" + + "github.com/tidwall/btree" ) // indirect walks down v allocating pointers as needed, @@ -130,6 +132,8 @@ func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) { err = ds.decodeVaryingDataType(dstv) case VaryingDataTypeSlice: err = ds.decodeVaryingDataTypeSlice(dstv) + case BTree: + err = ds.decodeBTree(dstv) default: t := reflect.TypeOf(in) switch t.Kind() { @@ -768,3 +772,42 @@ func (ds *decodeState) decodeUint128(dstv reflect.Value) (err error) { dstv.Set(reflect.ValueOf(ui128)) return } + +// decodeBTree accepts a byte array representing a SCALE encoded +// BTree and performs SCALE decoding of the BTree +func (ds *decodeState) decodeBTree(dstv reflect.Value) (err error) { + // Decode the number of items in the tree + length, err := ds.decodeLength() + if err != nil { + return + } + + btreeValue, ok := dstv.Interface().(BTree) + if !ok { + return fmt.Errorf("expected a BTree type") + } + + if btreeValue.Comparator == nil { + return fmt.Errorf("no Comparator function provided for BTree") + } + + if btreeValue.BTree == nil { + btreeValue.BTree = btree.New(btreeValue.Comparator) + } + + // Decode each item in the tree + for i := uint(0); i < length; i++ { + // Decode the value + value := reflect.New(btreeValue.ItemType).Elem() + err = ds.unmarshal(value) + if err != nil { + return + } + + // convert the value to the correct type for the BTree + btreeValue.BTree.Set(value.Interface()) + } + + dstv.Set(reflect.ValueOf(btreeValue)) + return +} diff --git a/pkg/scale/encode.go b/pkg/scale/encode.go index d312b85f913..05ddfe40619 100644 --- a/pkg/scale/encode.go +++ b/pkg/scale/encode.go @@ -85,6 +85,8 @@ func (es *encodeState) marshal(in interface{}) (err error) { err = es.encodeVaryingDataType(in) case VaryingDataTypeSlice: err = es.encodeVaryingDataTypeSlice(in) + case BTree: + err = es.encodeBTree(in) default: switch reflect.TypeOf(in).Kind() { case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, @@ -206,6 +208,21 @@ func (es *encodeState) encodeVaryingDataTypeSlice(vdts VaryingDataTypeSlice) (er return } +func (es *encodeState) encodeBTree(tree BTree) (err error) { + // write the number of items in the tree + err = es.encodeLength(tree.Len()) + if err != nil { + return + } + + // iterate through the tree and encode each item + tree.Ascend(nil, func(item any) bool { + err = es.marshal(item) + return err == nil + }) + return +} + func (es *encodeState) encodeSlice(in interface{}) (err error) { v := reflect.ValueOf(in) err = es.encodeLength(v.Len())