From 144cdd2cfacb95b8b11730cfd142a52afad8092a Mon Sep 17 00:00:00 2001 From: Kanishka Date: Fri, 13 Oct 2023 17:23:25 +0530 Subject: [PATCH 1/6] add btree --- go.mod | 1 + go.sum | 2 ++ pkg/scale/btree.go | 33 +++++++++++++++++++++++ pkg/scale/btree_test.go | 59 +++++++++++++++++++++++++++++++++++++++++ pkg/scale/decode.go | 43 ++++++++++++++++++++++++++++++ pkg/scale/encode.go | 17 ++++++++++++ 6 files changed, 155 insertions(+) create mode 100644 pkg/scale/btree.go create mode 100644 pkg/scale/btree_test.go diff --git a/go.mod b/go.mod index c0df9cba1b..9b873e3253 100644 --- a/go.mod +++ b/go.mod @@ -37,6 +37,7 @@ require ( github.com/spf13/viper v1.17.0 github.com/stretchr/testify v1.8.4 github.com/tetratelabs/wazero v1.1.0 + github.com/tidwall/btree v1.7.0 github.com/whyrusleeping/mdns v0.0.0-20190826153040-b9b60ed33aa9 golang.org/x/crypto v0.14.0 golang.org/x/exp v0.0.0-20230905200255-921286631fa9 diff --git a/go.sum b/go.sum index 2d7ba3ae3b..22db41a454 100644 --- a/go.sum +++ b/go.sum @@ -767,6 +767,8 @@ github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8 github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= +github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI= +github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= diff --git a/pkg/scale/btree.go b/pkg/scale/btree.go new file mode 100644 index 0000000000..452f7d0061 --- /dev/null +++ b/pkg/scale/btree.go @@ -0,0 +1,33 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package scale + +import ( + "reflect" + + "github.com/tidwall/btree" +) + +// BTree is a wrapper around tidwall/btree.BTree that also stores the comparator function and the type of the items +// stored in the BTree. This is needed during decoding because the BTree is a generic type, and we need to know the +// type of the items stored in the BTree in order to decode them. +type BTree struct { + *btree.BTree + Comparator func(a, b interface{}) bool + ItemType reflect.Type +} + +// NewBTree creates a new BTree with the given comparator function. +func NewBTree[T any](comparator func(a, b any) bool) BTree { + // There's no instantiation overhead of the actual type T because we're only creating a slice type and + // getting the element type from it. + var dummySlice []T + elementType := reflect.TypeOf(dummySlice).Elem() + + return BTree{ + BTree: btree.New(comparator), + Comparator: comparator, + ItemType: elementType, + } +} diff --git a/pkg/scale/btree_test.go b/pkg/scale/btree_test.go new file mode 100644 index 0000000000..4618f23385 --- /dev/null +++ b/pkg/scale/btree_test.go @@ -0,0 +1,59 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package scale + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +type dummy struct { + Field1 uint32 + Field2 [32]byte +} + +func TestBTree(t *testing.T) { + comparator := func(a, b interface{}) bool { + v1 := a.(dummy) + v2 := b.(dummy) + return v1.Field1 < v2.Field1 + } + + // Create a BTree with 3 dummy items + tree := NewBTree[dummy](comparator) + tree.BTree.Set(dummy{Field1: 1}) + tree.BTree.Set(dummy{Field1: 2}) + tree.BTree.Set(dummy{Field1: 3}) + + encoded, err := Marshal(tree) + require.NoError(t, err) + + //let mut btree = BTreeMap::::new(); + //btree.insert(1, Hash::zero()); + //btree.insert(2, Hash::zero()); + //btree.insert(3, Hash::zero()); + //let encoded = btree.encode(); + //println!("encoded: {:?}", encoded); + expectedEncoded := []byte{12, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + } + require.Equal(t, expectedEncoded, encoded) + + // Output: + expected := NewBTree[dummy](comparator) + err = Unmarshal(encoded, &expected) + require.NoError(t, err) + + // Check that the expected BTree has the same items as the original + require.Equal(t, tree.BTree.Len(), expected.BTree.Len()) + require.Equal(t, tree.ItemType, expected.ItemType) + require.Equal(t, tree.BTree.Min(), expected.BTree.Min()) + require.Equal(t, tree.BTree.Max(), expected.BTree.Max()) + require.Equal(t, tree.BTree.Get(dummy{Field1: 1}), expected.BTree.Get(dummy{Field1: 1})) + require.Equal(t, tree.BTree.Get(dummy{Field1: 2}), expected.BTree.Get(dummy{Field1: 2})) + require.Equal(t, tree.BTree.Get(dummy{Field1: 3}), expected.BTree.Get(dummy{Field1: 3})) +} diff --git a/pkg/scale/decode.go b/pkg/scale/decode.go index 45a527f0b8..0bda5f16ea 100644 --- a/pkg/scale/decode.go +++ b/pkg/scale/decode.go @@ -11,6 +11,8 @@ import ( "io" "math/big" "reflect" + + "github.com/tidwall/btree" ) // indirect walks down v allocating pointers as needed, @@ -130,6 +132,8 @@ func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) { err = ds.decodeVaryingDataType(dstv) case VaryingDataTypeSlice: err = ds.decodeVaryingDataTypeSlice(dstv) + case BTree: + err = ds.decodeBTree(dstv) default: t := reflect.TypeOf(in) switch t.Kind() { @@ -768,3 +772,42 @@ func (ds *decodeState) decodeUint128(dstv reflect.Value) (err error) { dstv.Set(reflect.ValueOf(ui128)) return } + +// decodeBTree accepts a byte array representing a SCALE encoded +// BTree and performs SCALE decoding of the BTree +func (ds *decodeState) decodeBTree(dstv reflect.Value) (err error) { + // Decode the number of items in the tree + length, err := ds.decodeLength() + if err != nil { + return + } + + btreeValue, ok := dstv.Interface().(BTree) + if !ok { + return fmt.Errorf("expected a BTree type") + } + + if btreeValue.Comparator == nil { + return fmt.Errorf("no Comparator function provided for BTree") + } + + if btreeValue.BTree == nil { + btreeValue.BTree = btree.New(btreeValue.Comparator) + } + + // Decode each item in the tree + for i := uint(0); i < length; i++ { + // Decode the value + value := reflect.New(btreeValue.ItemType).Elem() + err = ds.unmarshal(value) + if err != nil { + return + } + + // convert the value to the correct type for the BTree + btreeValue.BTree.Set(value.Interface()) + } + + dstv.Set(reflect.ValueOf(btreeValue)) + return +} diff --git a/pkg/scale/encode.go b/pkg/scale/encode.go index d312b85f91..05ddfe4061 100644 --- a/pkg/scale/encode.go +++ b/pkg/scale/encode.go @@ -85,6 +85,8 @@ func (es *encodeState) marshal(in interface{}) (err error) { err = es.encodeVaryingDataType(in) case VaryingDataTypeSlice: err = es.encodeVaryingDataTypeSlice(in) + case BTree: + err = es.encodeBTree(in) default: switch reflect.TypeOf(in).Kind() { case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, @@ -206,6 +208,21 @@ func (es *encodeState) encodeVaryingDataTypeSlice(vdts VaryingDataTypeSlice) (er return } +func (es *encodeState) encodeBTree(tree BTree) (err error) { + // write the number of items in the tree + err = es.encodeLength(tree.Len()) + if err != nil { + return + } + + // iterate through the tree and encode each item + tree.Ascend(nil, func(item any) bool { + err = es.marshal(item) + return err == nil + }) + return +} + func (es *encodeState) encodeSlice(in interface{}) (err error) { v := reflect.ValueOf(in) err = es.encodeLength(v.Len()) From 5a60b6f8a0c11af87a019f3f999e0eed42ee480b Mon Sep 17 00:00:00 2001 From: Kanishka Date: Wed, 25 Oct 2023 10:22:43 +0530 Subject: [PATCH 2/6] add btreemap --- pkg/scale/btree.go | 158 +++++++++++++++++++++++++++++++++++++++- pkg/scale/btree_test.go | 52 ++++++++++--- pkg/scale/decode.go | 51 +++++++++++++ pkg/scale/encode.go | 19 +---- 4 files changed, 249 insertions(+), 31 deletions(-) diff --git a/pkg/scale/btree.go b/pkg/scale/btree.go index 452f7d0061..0da9872331 100644 --- a/pkg/scale/btree.go +++ b/pkg/scale/btree.go @@ -4,11 +4,23 @@ package scale import ( + "fmt" "reflect" "github.com/tidwall/btree" ) +type Ordered interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | + ~float32 | ~float64 | ~string +} + +type BTreeCodec interface { + Encode(es *encodeState) error + Decode(ds *decodeState, dstv reflect.Value) error +} + // BTree is a wrapper around tidwall/btree.BTree that also stores the comparator function and the type of the items // stored in the BTree. This is needed during decoding because the BTree is a generic type, and we need to know the // type of the items stored in the BTree in order to decode them. @@ -18,16 +30,158 @@ type BTree struct { ItemType reflect.Type } +// Encode encodes the BTree using the given encodeState. +func (bt *BTree) Encode(es *encodeState) error { + // write the number of items in the tree + err := es.encodeLength(bt.Len()) + if err != nil { + return err + } + + bt.Ascend(nil, func(item interface{}) bool { + err = es.marshal(item) + return err == nil + }) + + return err +} + +// Decode decodes the BTree using the given decodeState. +func (bt *BTree) Decode(ds *decodeState, dstv reflect.Value) error { + // Decode the number of items in the tree + length, err := ds.decodeLength() + if err != nil { + return fmt.Errorf("decoding BTree length: %w", err) + } + + if bt.Comparator == nil { + return fmt.Errorf("no Comparator function provided for BTree") + } + + if bt.BTree == nil { + bt.BTree = btree.New(bt.Comparator) + } + + // Decode each item in the tree + for i := uint(0); i < length; i++ { + // Decode the value + value := reflect.New(bt.ItemType).Elem() + err = ds.unmarshal(value) + if err != nil { + return fmt.Errorf("decoding BTree item: %w", err) + } + + // convert the value to the correct type for the BTree + bt.Set(value.Interface()) + } + + dstv.Set(reflect.ValueOf(*bt)) + return nil +} + +// Copy returns a copy of the BTree. +func (bt *BTree) Copy() *BTree { + return &BTree{ + BTree: bt.BTree.Copy(), + Comparator: bt.Comparator, + ItemType: bt.ItemType, + } +} + // NewBTree creates a new BTree with the given comparator function. -func NewBTree[T any](comparator func(a, b any) bool) BTree { +func NewBTree[T any](comparator func(a, b any) bool) *BTree { // There's no instantiation overhead of the actual type T because we're only creating a slice type and // getting the element type from it. var dummySlice []T elementType := reflect.TypeOf(dummySlice).Elem() - return BTree{ + return &BTree{ BTree: btree.New(comparator), Comparator: comparator, ItemType: elementType, } } + +// BTreeMap is a wrapper around tidwall/btree.Map +type BTreeMap[K Ordered, V any] struct { + *btree.Map[K, V] + Degree int +} + +// Encode encodes the BTreeMap using the given encodeState. +func (btm *BTreeMap[K, V]) Encode(es *encodeState) error { + // write the number of items in the tree + err := es.encodeLength(btm.Len()) + if err != nil { + return err + } + + // write each item in the tree + var pivot K + btm.Ascend(pivot, func(key K, value V) bool { + if err = es.marshal(key); err != nil { + return false + } + + if err = es.marshal(value); err != nil { + return false + } + + return true + }) + + return err +} + +// Decode decodes the BTreeMap using the given decodeState. +func (btm *BTreeMap[K, V]) Decode(ds *decodeState, dstv reflect.Value) error { + // Decode the number of items in the tree + length, err := ds.decodeLength() + if err != nil { + return fmt.Errorf("decoding BTreeMap length: %w", err) + } + + if btm.Map == nil { + btm.Map = btree.NewMap[K, V](btm.Degree) + } + + // Decode each item in the tree + for i := uint(0); i < length; i++ { + // Decode the key + keyType := reflect.TypeOf((*K)(nil)).Elem() + keyInstance := reflect.New(keyType).Elem() + err = ds.unmarshal(keyInstance) + if err != nil { + return fmt.Errorf("decoding BTreeMap key: %w", err) + } + key := keyInstance.Interface().(K) + + // Decode the value + valueType := reflect.TypeOf((*V)(nil)).Elem() + valueInstance := reflect.New(valueType).Elem() + err = ds.unmarshal(valueInstance) + if err != nil { + return fmt.Errorf("decoding BTreeMap value: %w", err) + } + value := valueInstance.Interface().(V) + + btm.Map.Set(key, value) + } + + dstv.Set(reflect.ValueOf(*btm)) + return nil +} + +// Copy returns a copy of the BTreeMap. +func (btm *BTreeMap[K, V]) Copy() BTreeMap[K, V] { + return BTreeMap[K, V]{ + Map: btm.Map.Copy(), + } +} + +// NewBTreeMap creates a new BTreeMap with the given degree. +func NewBTreeMap[K Ordered, V any](degree int) *BTreeMap[K, V] { + return &BTreeMap[K, V]{ + Map: btree.NewMap[K, V](degree), + } +} diff --git a/pkg/scale/btree_test.go b/pkg/scale/btree_test.go index 4618f23385..61b0153ab4 100644 --- a/pkg/scale/btree_test.go +++ b/pkg/scale/btree_test.go @@ -14,7 +14,7 @@ type dummy struct { Field2 [32]byte } -func TestBTree(t *testing.T) { +func TestBTree_Codec(t *testing.T) { comparator := func(a, b interface{}) bool { v1 := a.(dummy) v2 := b.(dummy) @@ -23,9 +23,9 @@ func TestBTree(t *testing.T) { // Create a BTree with 3 dummy items tree := NewBTree[dummy](comparator) - tree.BTree.Set(dummy{Field1: 1}) - tree.BTree.Set(dummy{Field1: 2}) - tree.BTree.Set(dummy{Field1: 3}) + tree.Set(dummy{Field1: 1}) + tree.Set(dummy{Field1: 2}) + tree.Set(dummy{Field1: 3}) encoded, err := Marshal(tree) require.NoError(t, err) @@ -43,17 +43,45 @@ func TestBTree(t *testing.T) { } require.Equal(t, expectedEncoded, encoded) - // Output: expected := NewBTree[dummy](comparator) - err = Unmarshal(encoded, &expected) + err = Unmarshal(encoded, expected) require.NoError(t, err) // Check that the expected BTree has the same items as the original - require.Equal(t, tree.BTree.Len(), expected.BTree.Len()) + require.Equal(t, tree.Len(), expected.Len()) require.Equal(t, tree.ItemType, expected.ItemType) - require.Equal(t, tree.BTree.Min(), expected.BTree.Min()) - require.Equal(t, tree.BTree.Max(), expected.BTree.Max()) - require.Equal(t, tree.BTree.Get(dummy{Field1: 1}), expected.BTree.Get(dummy{Field1: 1})) - require.Equal(t, tree.BTree.Get(dummy{Field1: 2}), expected.BTree.Get(dummy{Field1: 2})) - require.Equal(t, tree.BTree.Get(dummy{Field1: 3}), expected.BTree.Get(dummy{Field1: 3})) + require.Equal(t, tree.Min(), expected.Min()) + require.Equal(t, tree.Max(), expected.Max()) + require.Equal(t, tree.Get(dummy{Field1: 1}), expected.Get(dummy{Field1: 1})) + require.Equal(t, tree.Get(dummy{Field1: 2}), expected.Get(dummy{Field1: 2})) + require.Equal(t, tree.Get(dummy{Field1: 3}), expected.Get(dummy{Field1: 3})) +} + +func TestBTreeMap_Codec(t *testing.T) { + btreeMap := NewBTreeMap[uint32, dummy](32) + btreeMap.Set(uint32(1), dummy{Field1: 1}) + btreeMap.Set(uint32(2), dummy{Field1: 2}) + btreeMap.Set(uint32(3), dummy{Field1: 3}) + + encoded, err := Marshal(btreeMap) + require.NoError(t, err) + + //let mut btree = BTreeMap::::new(); + //btree.insert(1, (1, Hash::zero())); + //btree.insert(2, (2, Hash::zero())); + //btree.insert(3, (3, Hash::zero())); + //let encoded = btree.encode(); + //println!("encoded: {:?}", encoded); + expectedEncoded := []byte{12, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + } + require.Equal(t, expectedEncoded, encoded) + + expected := NewBTreeMap[uint32, dummy](32) + err = Unmarshal(encoded, expected) + require.NoError(t, err) + + require.Equal(t, btreeMap.Len(), expected.Len()) } diff --git a/pkg/scale/decode.go b/pkg/scale/decode.go index 0bda5f16ea..d8f354b989 100644 --- a/pkg/scale/decode.go +++ b/pkg/scale/decode.go @@ -11,6 +11,7 @@ import ( "io" "math/big" "reflect" + "strings" "github.com/tidwall/btree" ) @@ -110,6 +111,20 @@ type decodeState struct { } func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) { + // Handle BTreeMap type separately for the following reasons: + // 1. BTreeMap is a generic type, so we can't use the normal type switch + // 2. We cannot use BTreeCodec because we are comparing the type of the dstv.Interface() in the type switch + if isBTree(dstv.Type()) { + if btm, ok := dstv.Addr().Interface().(BTreeCodec); ok { + if err := btm.Decode(ds, dstv); err != nil { + return err + } + } else { + return fmt.Errorf("could not type assert to BTreeCodec") + } + return nil + } + in := dstv.Interface() switch in.(type) { case *big.Int: @@ -811,3 +826,39 @@ func (ds *decodeState) decodeBTree(dstv reflect.Value) (err error) { dstv.Set(reflect.ValueOf(btreeValue)) return } + +func isBTree(t reflect.Type) bool { + if t.Kind() != reflect.Struct { + return false + } + + // For BTreeMap + mapField, hasMap := t.FieldByName("Map") + _, hasDegree := t.FieldByName("Degree") + + // For BTree + btreeField, hasBTree := t.FieldByName("BTree") + comparatorField, hasComparator := t.FieldByName("Comparator") + itemTypeField, hasItemType := t.FieldByName("ItemType") + + if hasMap && hasDegree && + mapField.Type.Kind() == reflect.Ptr && + strings.HasPrefix(mapField.Type.String(), "*btree.Map[") { + return true + } + + if hasBTree && hasComparator && hasItemType { + if btreeField.Type.Kind() != reflect.Ptr || btreeField.Type.String() != "*btree.BTree" { + return false + } + if comparatorField.Type.Kind() != reflect.Func { + return false + } + if itemTypeField.Type != reflect.TypeOf((*reflect.Type)(nil)).Elem() { + return false + } + return true + } + + return false +} diff --git a/pkg/scale/encode.go b/pkg/scale/encode.go index 05ddfe4061..0f3da14891 100644 --- a/pkg/scale/encode.go +++ b/pkg/scale/encode.go @@ -85,8 +85,8 @@ func (es *encodeState) marshal(in interface{}) (err error) { err = es.encodeVaryingDataType(in) case VaryingDataTypeSlice: err = es.encodeVaryingDataTypeSlice(in) - case BTree: - err = es.encodeBTree(in) + case BTreeCodec: + err = in.Encode(es) default: switch reflect.TypeOf(in).Kind() { case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, @@ -208,21 +208,6 @@ func (es *encodeState) encodeVaryingDataTypeSlice(vdts VaryingDataTypeSlice) (er return } -func (es *encodeState) encodeBTree(tree BTree) (err error) { - // write the number of items in the tree - err = es.encodeLength(tree.Len()) - if err != nil { - return - } - - // iterate through the tree and encode each item - tree.Ascend(nil, func(item any) bool { - err = es.marshal(item) - return err == nil - }) - return -} - func (es *encodeState) encodeSlice(in interface{}) (err error) { v := reflect.ValueOf(in) err = es.encodeLength(v.Len()) From 3be4e0511c89098417efbb847ffd7d8a0e7fdb9c Mon Sep 17 00:00:00 2001 From: Kanishka Date: Wed, 25 Oct 2023 19:59:39 +0530 Subject: [PATCH 3/6] review suggestions --- pkg/scale/btree.go | 17 ++++------------- pkg/scale/decode.go | 2 -- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/pkg/scale/btree.go b/pkg/scale/btree.go index 0da9872331..5e23a33403 100644 --- a/pkg/scale/btree.go +++ b/pkg/scale/btree.go @@ -5,17 +5,12 @@ package scale import ( "fmt" + "golang.org/x/exp/constraints" "reflect" "github.com/tidwall/btree" ) -type Ordered interface { - ~int | ~int8 | ~int16 | ~int32 | ~int64 | - ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | - ~float32 | ~float64 | ~string -} - type BTreeCodec interface { Encode(es *encodeState) error Decode(ds *decodeState, dstv reflect.Value) error @@ -90,11 +85,7 @@ func (bt *BTree) Copy() *BTree { // NewBTree creates a new BTree with the given comparator function. func NewBTree[T any](comparator func(a, b any) bool) *BTree { - // There's no instantiation overhead of the actual type T because we're only creating a slice type and - // getting the element type from it. - var dummySlice []T - elementType := reflect.TypeOf(dummySlice).Elem() - + elementType := reflect.TypeOf((*T)(nil)).Elem() return &BTree{ BTree: btree.New(comparator), Comparator: comparator, @@ -103,7 +94,7 @@ func NewBTree[T any](comparator func(a, b any) bool) *BTree { } // BTreeMap is a wrapper around tidwall/btree.Map -type BTreeMap[K Ordered, V any] struct { +type BTreeMap[K constraints.Ordered, V any] struct { *btree.Map[K, V] Degree int } @@ -180,7 +171,7 @@ func (btm *BTreeMap[K, V]) Copy() BTreeMap[K, V] { } // NewBTreeMap creates a new BTreeMap with the given degree. -func NewBTreeMap[K Ordered, V any](degree int) *BTreeMap[K, V] { +func NewBTreeMap[K constraints.Ordered, V any](degree int) *BTreeMap[K, V] { return &BTreeMap[K, V]{ Map: btree.NewMap[K, V](degree), } diff --git a/pkg/scale/decode.go b/pkg/scale/decode.go index d8f354b989..bea98fdf50 100644 --- a/pkg/scale/decode.go +++ b/pkg/scale/decode.go @@ -147,8 +147,6 @@ func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) { err = ds.decodeVaryingDataType(dstv) case VaryingDataTypeSlice: err = ds.decodeVaryingDataTypeSlice(dstv) - case BTree: - err = ds.decodeBTree(dstv) default: t := reflect.TypeOf(in) switch t.Kind() { From 73f923f0b871bfdf23f291ce5aff4f4936d6e5db Mon Sep 17 00:00:00 2001 From: Kanishka Date: Thu, 26 Oct 2023 01:47:55 +0530 Subject: [PATCH 4/6] cleanup --- pkg/scale/btree.go | 15 ++++++++----- pkg/scale/btree_test.go | 8 +++---- pkg/scale/decode.go | 49 +++++------------------------------------ 3 files changed, 20 insertions(+), 52 deletions(-) diff --git a/pkg/scale/btree.go b/pkg/scale/btree.go index 5e23a33403..ccbdeab3c9 100644 --- a/pkg/scale/btree.go +++ b/pkg/scale/btree.go @@ -5,9 +5,10 @@ package scale import ( "fmt" - "golang.org/x/exp/constraints" "reflect" + "golang.org/x/exp/constraints" + "github.com/tidwall/btree" ) @@ -84,15 +85,17 @@ func (bt *BTree) Copy() *BTree { } // NewBTree creates a new BTree with the given comparator function. -func NewBTree[T any](comparator func(a, b any) bool) *BTree { +func NewBTree[T any](comparator func(a, b any) bool) BTree { elementType := reflect.TypeOf((*T)(nil)).Elem() - return &BTree{ + return BTree{ BTree: btree.New(comparator), Comparator: comparator, ItemType: elementType, } } +var _ BTreeCodec = (*BTree)(nil) + // BTreeMap is a wrapper around tidwall/btree.Map type BTreeMap[K constraints.Ordered, V any] struct { *btree.Map[K, V] @@ -171,8 +174,10 @@ func (btm *BTreeMap[K, V]) Copy() BTreeMap[K, V] { } // NewBTreeMap creates a new BTreeMap with the given degree. -func NewBTreeMap[K constraints.Ordered, V any](degree int) *BTreeMap[K, V] { - return &BTreeMap[K, V]{ +func NewBTreeMap[K constraints.Ordered, V any](degree int) BTreeMap[K, V] { + return BTreeMap[K, V]{ Map: btree.NewMap[K, V](degree), } } + +var _ BTreeCodec = (*BTreeMap[int, string])(nil) diff --git a/pkg/scale/btree_test.go b/pkg/scale/btree_test.go index 61b0153ab4..919acd2703 100644 --- a/pkg/scale/btree_test.go +++ b/pkg/scale/btree_test.go @@ -27,7 +27,7 @@ func TestBTree_Codec(t *testing.T) { tree.Set(dummy{Field1: 2}) tree.Set(dummy{Field1: 3}) - encoded, err := Marshal(tree) + encoded, err := Marshal(&tree) require.NoError(t, err) //let mut btree = BTreeMap::::new(); @@ -44,7 +44,7 @@ func TestBTree_Codec(t *testing.T) { require.Equal(t, expectedEncoded, encoded) expected := NewBTree[dummy](comparator) - err = Unmarshal(encoded, expected) + err = Unmarshal(encoded, &expected) require.NoError(t, err) // Check that the expected BTree has the same items as the original @@ -63,7 +63,7 @@ func TestBTreeMap_Codec(t *testing.T) { btreeMap.Set(uint32(2), dummy{Field1: 2}) btreeMap.Set(uint32(3), dummy{Field1: 3}) - encoded, err := Marshal(btreeMap) + encoded, err := Marshal(&btreeMap) require.NoError(t, err) //let mut btree = BTreeMap::::new(); @@ -80,7 +80,7 @@ func TestBTreeMap_Codec(t *testing.T) { require.Equal(t, expectedEncoded, encoded) expected := NewBTreeMap[uint32, dummy](32) - err = Unmarshal(encoded, expected) + err = Unmarshal(encoded, &expected) require.NoError(t, err) require.Equal(t, btreeMap.Len(), expected.Len()) diff --git a/pkg/scale/decode.go b/pkg/scale/decode.go index bea98fdf50..24303de2b8 100644 --- a/pkg/scale/decode.go +++ b/pkg/scale/decode.go @@ -12,8 +12,6 @@ import ( "math/big" "reflect" "strings" - - "github.com/tidwall/btree" ) // indirect walks down v allocating pointers as needed, @@ -786,45 +784,7 @@ func (ds *decodeState) decodeUint128(dstv reflect.Value) (err error) { return } -// decodeBTree accepts a byte array representing a SCALE encoded -// BTree and performs SCALE decoding of the BTree -func (ds *decodeState) decodeBTree(dstv reflect.Value) (err error) { - // Decode the number of items in the tree - length, err := ds.decodeLength() - if err != nil { - return - } - - btreeValue, ok := dstv.Interface().(BTree) - if !ok { - return fmt.Errorf("expected a BTree type") - } - - if btreeValue.Comparator == nil { - return fmt.Errorf("no Comparator function provided for BTree") - } - - if btreeValue.BTree == nil { - btreeValue.BTree = btree.New(btreeValue.Comparator) - } - - // Decode each item in the tree - for i := uint(0); i < length; i++ { - // Decode the value - value := reflect.New(btreeValue.ItemType).Elem() - err = ds.unmarshal(value) - if err != nil { - return - } - - // convert the value to the correct type for the BTree - btreeValue.BTree.Set(value.Interface()) - } - - dstv.Set(reflect.ValueOf(btreeValue)) - return -} - +// isBTree returns true if the type is a BTree or BTreeMap func isBTree(t reflect.Type) bool { if t.Kind() != reflect.Struct { return false @@ -839,13 +799,16 @@ func isBTree(t reflect.Type) bool { comparatorField, hasComparator := t.FieldByName("Comparator") itemTypeField, hasItemType := t.FieldByName("ItemType") - if hasMap && hasDegree && + if hasMap && + hasDegree && mapField.Type.Kind() == reflect.Ptr && strings.HasPrefix(mapField.Type.String(), "*btree.Map[") { return true } - if hasBTree && hasComparator && hasItemType { + if hasBTree && + hasComparator && + hasItemType { if btreeField.Type.Kind() != reflect.Ptr || btreeField.Type.String() != "*btree.BTree" { return false } From 9ff55c5c47afb02794945ad93a11c57a65e830c6 Mon Sep 17 00:00:00 2001 From: Kanishka Date: Sat, 16 Dec 2023 00:03:59 +0530 Subject: [PATCH 5/6] update --- pkg/btree/btree.go | 187 +++++++++++++++++++++++++++++ pkg/{scale => btree}/btree_test.go | 19 ++- pkg/scale/btree.go | 183 ---------------------------- pkg/scale/decode.go | 66 +--------- pkg/scale/decode_test.go | 16 +-- pkg/scale/encode.go | 2 - 6 files changed, 206 insertions(+), 267 deletions(-) create mode 100644 pkg/btree/btree.go rename pkg/{scale => btree}/btree_test.go (88%) delete mode 100644 pkg/scale/btree.go diff --git a/pkg/btree/btree.go b/pkg/btree/btree.go new file mode 100644 index 0000000000..4e7fce8393 --- /dev/null +++ b/pkg/btree/btree.go @@ -0,0 +1,187 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package btree + +import ( + "fmt" + "github.com/ChainSafe/gossamer/pkg/scale" + "io" + "reflect" + + "golang.org/x/exp/constraints" + + "github.com/tidwall/btree" +) + +type Codec interface { + MarshalSCALE() ([]byte, error) + UnmarshalSCALE(reader io.Reader) error +} + +// BTree is a wrapper around tidwall/btree.BTree that also stores the comparator function and the type of the items +// stored in the BTree. This is needed during decoding because the BTree is a generic type, and we need to know the +// type of the items stored in the BTree in order to decode them. +type BTree struct { + *btree.BTree + Comparator func(a, b interface{}) bool + ItemType reflect.Type +} + +// MarshalSCALE encodes the BTree using SCALE. +func (bt BTree) MarshalSCALE() ([]byte, error) { + encodedLen, err := scale.Marshal(uint(bt.Len())) + if err != nil { + return nil, fmt.Errorf("failed to encode BTree length: %w", err) + } + + var encodedItems []byte + bt.Ascend(nil, func(item interface{}) bool { + var encodedItem []byte + encodedItem, err = scale.Marshal(item) + if err != nil { + return false + } + + encodedItems = append(encodedItems, encodedItem...) + return true + }) + + return append(encodedLen, encodedItems...), err +} + +// UnmarshalSCALE decodes the BTree using SCALE. +func (bt BTree) UnmarshalSCALE(reader io.Reader) error { + if bt.Comparator == nil { + return fmt.Errorf("comparator not found") + } + + sliceType := reflect.SliceOf(bt.ItemType) + slicePtr := reflect.New(sliceType) + encodedItems, err := io.ReadAll(reader) + if err != nil { + return fmt.Errorf("read BTree items: %w", err) + } + err = scale.Unmarshal(encodedItems, slicePtr.Interface()) + if err != nil { + return fmt.Errorf("decode BTree items: %w", err) + } + + for i := 0; i < slicePtr.Elem().Len(); i++ { + item := slicePtr.Elem().Index(i).Interface() + bt.Set(item) + } + return nil +} + +// Copy returns a copy of the BTree. +func (bt BTree) Copy() *BTree { + return &BTree{ + BTree: bt.BTree.Copy(), + Comparator: bt.Comparator, + ItemType: bt.ItemType, + } +} + +// NewBTree creates a new BTree with the given comparator function. +func NewBTree[T any](comparator func(a, b any) bool) BTree { + elementType := reflect.TypeOf((*T)(nil)).Elem() + return BTree{ + BTree: btree.New(comparator), + Comparator: comparator, + ItemType: elementType, + } +} + +var _ Codec = (*BTree)(nil) + +// Map is a wrapper around tidwall/btree.Map +type Map[K constraints.Ordered, V any] struct { + *btree.Map[K, V] + Degree int +} + +type mapItem[K constraints.Ordered, V any] struct { + Key K + Value V +} + +// MarshalSCALE encodes the Map using SCALE. +func (btm Map[K, V]) MarshalSCALE() ([]byte, error) { + encodedLen, err := scale.Marshal(uint(btm.Len())) + if err != nil { + return nil, fmt.Errorf("failed to encode BTree length: %w", err) + } + + // write each item in the tree + var ( + pivot K + encodedItems []byte + ) + btm.Ascend(pivot, func(key K, value V) bool { + var ( + encodedKey []byte + encodedValue []byte + ) + encodedKey, err = scale.Marshal(key) + if err != nil { + return false + } + + encodedValue, err = scale.Marshal(value) + if err != nil { + return false + } + + encodedItems = append(encodedItems, encodedKey...) + encodedItems = append(encodedItems, encodedValue...) + return true + }) + + return append(encodedLen, encodedItems...), err +} + +// UnmarshalSCALE decodes the Map using SCALE. +func (btm Map[K, V]) UnmarshalSCALE(reader io.Reader) error { + if btm.Degree == 0 { + return fmt.Errorf("nothing to decode into") + } + + if btm.Map == nil { + btm.Map = btree.NewMap[K, V](btm.Degree) + } + + sliceType := reflect.SliceOf(reflect.TypeOf((*mapItem[K, V])(nil)).Elem()) + slicePtr := reflect.New(sliceType) + encodedItems, err := io.ReadAll(reader) + if err != nil { + return fmt.Errorf("read BTree items: %w", err) + } + err = scale.Unmarshal(encodedItems, slicePtr.Interface()) + if err != nil { + return fmt.Errorf("decode BTree items: %w", err) + } + + for i := 0; i < slicePtr.Elem().Len(); i++ { + item := slicePtr.Elem().Index(i).Interface().(mapItem[K, V]) + btm.Map.Set(item.Key, item.Value) + } + return nil +} + +// Copy returns a copy of the Map. +func (btm Map[K, V]) Copy() Map[K, V] { + return Map[K, V]{ + Map: btm.Map.Copy(), + } +} + +// NewBTreeMap creates a new Map with the given degree. +func NewBTreeMap[K constraints.Ordered, V any](degree int) Map[K, V] { + return Map[K, V]{ + Map: btree.NewMap[K, V](degree), + Degree: degree, + } +} + +var _ Codec = (*Map[int, string])(nil) diff --git a/pkg/scale/btree_test.go b/pkg/btree/btree_test.go similarity index 88% rename from pkg/scale/btree_test.go rename to pkg/btree/btree_test.go index 919acd2703..723cb8d96f 100644 --- a/pkg/scale/btree_test.go +++ b/pkg/btree/btree_test.go @@ -1,9 +1,10 @@ // Copyright 2023 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package scale +package btree import ( + "github.com/ChainSafe/gossamer/pkg/scale" "testing" "github.com/stretchr/testify/require" @@ -26,11 +27,10 @@ func TestBTree_Codec(t *testing.T) { tree.Set(dummy{Field1: 1}) tree.Set(dummy{Field1: 2}) tree.Set(dummy{Field1: 3}) - - encoded, err := Marshal(&tree) + encoded, err := scale.Marshal(tree) require.NoError(t, err) - //let mut btree = BTreeMap::::new(); + //let mut btree = Map::::new(); //btree.insert(1, Hash::zero()); //btree.insert(2, Hash::zero()); //btree.insert(3, Hash::zero()); @@ -44,7 +44,7 @@ func TestBTree_Codec(t *testing.T) { require.Equal(t, expectedEncoded, encoded) expected := NewBTree[dummy](comparator) - err = Unmarshal(encoded, &expected) + err = scale.Unmarshal(expectedEncoded, &expected) require.NoError(t, err) // Check that the expected BTree has the same items as the original @@ -62,11 +62,10 @@ func TestBTreeMap_Codec(t *testing.T) { btreeMap.Set(uint32(1), dummy{Field1: 1}) btreeMap.Set(uint32(2), dummy{Field1: 2}) btreeMap.Set(uint32(3), dummy{Field1: 3}) - - encoded, err := Marshal(&btreeMap) + encoded, err := scale.Marshal(btreeMap) require.NoError(t, err) - //let mut btree = BTreeMap::::new(); + //let mut btree = Map::::new(); //btree.insert(1, (1, Hash::zero())); //btree.insert(2, (2, Hash::zero())); //btree.insert(3, (3, Hash::zero())); @@ -78,10 +77,8 @@ func TestBTreeMap_Codec(t *testing.T) { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, } require.Equal(t, expectedEncoded, encoded) - expected := NewBTreeMap[uint32, dummy](32) - err = Unmarshal(encoded, &expected) + err = scale.Unmarshal(expectedEncoded, &expected) require.NoError(t, err) - require.Equal(t, btreeMap.Len(), expected.Len()) } diff --git a/pkg/scale/btree.go b/pkg/scale/btree.go deleted file mode 100644 index ccbdeab3c9..0000000000 --- a/pkg/scale/btree.go +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2023 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package scale - -import ( - "fmt" - "reflect" - - "golang.org/x/exp/constraints" - - "github.com/tidwall/btree" -) - -type BTreeCodec interface { - Encode(es *encodeState) error - Decode(ds *decodeState, dstv reflect.Value) error -} - -// BTree is a wrapper around tidwall/btree.BTree that also stores the comparator function and the type of the items -// stored in the BTree. This is needed during decoding because the BTree is a generic type, and we need to know the -// type of the items stored in the BTree in order to decode them. -type BTree struct { - *btree.BTree - Comparator func(a, b interface{}) bool - ItemType reflect.Type -} - -// Encode encodes the BTree using the given encodeState. -func (bt *BTree) Encode(es *encodeState) error { - // write the number of items in the tree - err := es.encodeLength(bt.Len()) - if err != nil { - return err - } - - bt.Ascend(nil, func(item interface{}) bool { - err = es.marshal(item) - return err == nil - }) - - return err -} - -// Decode decodes the BTree using the given decodeState. -func (bt *BTree) Decode(ds *decodeState, dstv reflect.Value) error { - // Decode the number of items in the tree - length, err := ds.decodeLength() - if err != nil { - return fmt.Errorf("decoding BTree length: %w", err) - } - - if bt.Comparator == nil { - return fmt.Errorf("no Comparator function provided for BTree") - } - - if bt.BTree == nil { - bt.BTree = btree.New(bt.Comparator) - } - - // Decode each item in the tree - for i := uint(0); i < length; i++ { - // Decode the value - value := reflect.New(bt.ItemType).Elem() - err = ds.unmarshal(value) - if err != nil { - return fmt.Errorf("decoding BTree item: %w", err) - } - - // convert the value to the correct type for the BTree - bt.Set(value.Interface()) - } - - dstv.Set(reflect.ValueOf(*bt)) - return nil -} - -// Copy returns a copy of the BTree. -func (bt *BTree) Copy() *BTree { - return &BTree{ - BTree: bt.BTree.Copy(), - Comparator: bt.Comparator, - ItemType: bt.ItemType, - } -} - -// NewBTree creates a new BTree with the given comparator function. -func NewBTree[T any](comparator func(a, b any) bool) BTree { - elementType := reflect.TypeOf((*T)(nil)).Elem() - return BTree{ - BTree: btree.New(comparator), - Comparator: comparator, - ItemType: elementType, - } -} - -var _ BTreeCodec = (*BTree)(nil) - -// BTreeMap is a wrapper around tidwall/btree.Map -type BTreeMap[K constraints.Ordered, V any] struct { - *btree.Map[K, V] - Degree int -} - -// Encode encodes the BTreeMap using the given encodeState. -func (btm *BTreeMap[K, V]) Encode(es *encodeState) error { - // write the number of items in the tree - err := es.encodeLength(btm.Len()) - if err != nil { - return err - } - - // write each item in the tree - var pivot K - btm.Ascend(pivot, func(key K, value V) bool { - if err = es.marshal(key); err != nil { - return false - } - - if err = es.marshal(value); err != nil { - return false - } - - return true - }) - - return err -} - -// Decode decodes the BTreeMap using the given decodeState. -func (btm *BTreeMap[K, V]) Decode(ds *decodeState, dstv reflect.Value) error { - // Decode the number of items in the tree - length, err := ds.decodeLength() - if err != nil { - return fmt.Errorf("decoding BTreeMap length: %w", err) - } - - if btm.Map == nil { - btm.Map = btree.NewMap[K, V](btm.Degree) - } - - // Decode each item in the tree - for i := uint(0); i < length; i++ { - // Decode the key - keyType := reflect.TypeOf((*K)(nil)).Elem() - keyInstance := reflect.New(keyType).Elem() - err = ds.unmarshal(keyInstance) - if err != nil { - return fmt.Errorf("decoding BTreeMap key: %w", err) - } - key := keyInstance.Interface().(K) - - // Decode the value - valueType := reflect.TypeOf((*V)(nil)).Elem() - valueInstance := reflect.New(valueType).Elem() - err = ds.unmarshal(valueInstance) - if err != nil { - return fmt.Errorf("decoding BTreeMap value: %w", err) - } - value := valueInstance.Interface().(V) - - btm.Map.Set(key, value) - } - - dstv.Set(reflect.ValueOf(*btm)) - return nil -} - -// Copy returns a copy of the BTreeMap. -func (btm *BTreeMap[K, V]) Copy() BTreeMap[K, V] { - return BTreeMap[K, V]{ - Map: btm.Map.Copy(), - } -} - -// NewBTreeMap creates a new BTreeMap with the given degree. -func NewBTreeMap[K constraints.Ordered, V any](degree int) BTreeMap[K, V] { - return BTreeMap[K, V]{ - Map: btree.NewMap[K, V](degree), - } -} - -var _ BTreeCodec = (*BTreeMap[int, string])(nil) diff --git a/pkg/scale/decode.go b/pkg/scale/decode.go index 847d27fcf5..1c3d965346 100644 --- a/pkg/scale/decode.go +++ b/pkg/scale/decode.go @@ -11,7 +11,6 @@ import ( "io" "math/big" "reflect" - "strings" ) // indirect walks down v allocating pointers as needed, @@ -114,29 +113,8 @@ type decodeState struct { } func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) { - // Handle BTreeMap type separately for the following reasons: - // 1. BTreeMap is a generic type, so we can't use the normal type switch - // 2. We cannot use BTreeCodec because we are comparing the type of the dstv.Interface() in the type switch - if isBTree(dstv.Type()) { - if btm, ok := dstv.Addr().Interface().(BTreeCodec); ok { - if err := btm.Decode(ds, dstv); err != nil { - return err - } - } else { - return fmt.Errorf("could not type assert to BTreeCodec") - } - return nil - - unmarshalerType := reflect.TypeOf((*Unmarshaler)(nil)).Elem() - if dstv.CanAddr() && dstv.Addr().Type().Implements(unmarshalerType) { - methodVal := dstv.Addr().MethodByName("UnmarshalSCALE") - values := methodVal.Call([]reflect.Value{reflect.ValueOf(ds.Reader)}) - if !values[0].IsNil() { - errIn := values[0].Interface() - err := errIn.(error) - return err - } - return + if unmarshaler, ok := dstv.Addr().Interface().(Unmarshaler); ok { + return unmarshaler.UnmarshalSCALE(ds.Reader) } in := dstv.Interface() @@ -799,43 +777,3 @@ func (ds *decodeState) decodeUint128(dstv reflect.Value) (err error) { dstv.Set(reflect.ValueOf(ui128)) return } - -// isBTree returns true if the type is a BTree or BTreeMap -func isBTree(t reflect.Type) bool { - if t.Kind() != reflect.Struct { - return false - } - - // For BTreeMap - mapField, hasMap := t.FieldByName("Map") - _, hasDegree := t.FieldByName("Degree") - - // For BTree - btreeField, hasBTree := t.FieldByName("BTree") - comparatorField, hasComparator := t.FieldByName("Comparator") - itemTypeField, hasItemType := t.FieldByName("ItemType") - - if hasMap && - hasDegree && - mapField.Type.Kind() == reflect.Ptr && - strings.HasPrefix(mapField.Type.String(), "*btree.Map[") { - return true - } - - if hasBTree && - hasComparator && - hasItemType { - if btreeField.Type.Kind() != reflect.Ptr || btreeField.Type.String() != "*btree.BTree" { - return false - } - if comparatorField.Type.Kind() != reflect.Func { - return false - } - if itemTypeField.Type != reflect.TypeOf((*reflect.Type)(nil)).Elem() { - return false - } - return true - } - - return false -} diff --git a/pkg/scale/decode_test.go b/pkg/scale/decode_test.go index 713f3a7bce..c5c219b521 100644 --- a/pkg/scale/decode_test.go +++ b/pkg/scale/decode_test.go @@ -588,9 +588,10 @@ func Test_decodeState_Unmarshaller(t *testing.T) { Middle: uint32(2), Last: 3, } - bytes := MustMarshal(expected) + encoded := MustMarshal(expected) ms := myStruct{} - Unmarshal(bytes, &ms) + err := Unmarshal(encoded, &ms) + assert.NoError(t, err) assert.Equal(t, expected, ms) type myParentStruct struct { @@ -603,9 +604,10 @@ func Test_decodeState_Unmarshaller(t *testing.T) { Middle: expected, Last: 3, } - bytes = MustMarshal(expectedParent) + encoded = MustMarshal(expectedParent) mps := myParentStruct{} - Unmarshal(bytes, &mps) + err = Unmarshal(encoded, &mps) + assert.NoError(t, err) assert.Equal(t, expectedParent, mps) } @@ -615,8 +617,8 @@ func Test_decodeState_Unmarshaller_Error(t *testing.T) { Middle: uint32(2), Last: 3, } - bytes := MustMarshal(expected) + encoded := MustMarshal(expected) mse := myStructError{} - err := Unmarshal(bytes, &mse) - assert.Error(t, err, "eh?") + err := Unmarshal(encoded, &mse) + assert.Error(t, err) } diff --git a/pkg/scale/encode.go b/pkg/scale/encode.go index ae26ac9a87..c9830aef9d 100644 --- a/pkg/scale/encode.go +++ b/pkg/scale/encode.go @@ -101,8 +101,6 @@ func (es *encodeState) marshal(in interface{}) (err error) { err = es.encodeVaryingDataType(in) case VaryingDataTypeSlice: err = es.encodeVaryingDataTypeSlice(in) - case BTreeCodec: - err = in.Encode(es) default: switch reflect.TypeOf(in).Kind() { case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, From 5615c8a753d43cf9bfb377bcd0775ae8a08715a8 Mon Sep 17 00:00:00 2001 From: Kanishka Date: Sat, 16 Dec 2023 00:16:29 +0530 Subject: [PATCH 6/6] cleanup --- pkg/btree/btree.go | 44 ++++++++++++++++++++--------------------- pkg/btree/btree_test.go | 15 +++++++------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/pkg/btree/btree.go b/pkg/btree/btree.go index 4e7fce8393..3cf2d85ac5 100644 --- a/pkg/btree/btree.go +++ b/pkg/btree/btree.go @@ -5,10 +5,11 @@ package btree import ( "fmt" - "github.com/ChainSafe/gossamer/pkg/scale" "io" "reflect" + "github.com/ChainSafe/gossamer/pkg/scale" + "golang.org/x/exp/constraints" "github.com/tidwall/btree" @@ -19,17 +20,17 @@ type Codec interface { UnmarshalSCALE(reader io.Reader) error } -// BTree is a wrapper around tidwall/btree.BTree that also stores the comparator function and the type of the items -// stored in the BTree. This is needed during decoding because the BTree is a generic type, and we need to know the -// type of the items stored in the BTree in order to decode them. -type BTree struct { +// Tree is a wrapper around tidwall/btree.BTree that also stores the comparator function and the type of the items +// stored in the BTree. This is needed during decoding because the Tree item is a generic type, and we need to know it +// at the time of decoding. +type Tree struct { *btree.BTree Comparator func(a, b interface{}) bool ItemType reflect.Type } -// MarshalSCALE encodes the BTree using SCALE. -func (bt BTree) MarshalSCALE() ([]byte, error) { +// MarshalSCALE encodes the Tree using SCALE. +func (bt Tree) MarshalSCALE() ([]byte, error) { encodedLen, err := scale.Marshal(uint(bt.Len())) if err != nil { return nil, fmt.Errorf("failed to encode BTree length: %w", err) @@ -50,8 +51,8 @@ func (bt BTree) MarshalSCALE() ([]byte, error) { return append(encodedLen, encodedItems...), err } -// UnmarshalSCALE decodes the BTree using SCALE. -func (bt BTree) UnmarshalSCALE(reader io.Reader) error { +// UnmarshalSCALE decodes the Tree using SCALE. +func (bt Tree) UnmarshalSCALE(reader io.Reader) error { if bt.Comparator == nil { return fmt.Errorf("comparator not found") } @@ -74,26 +75,26 @@ func (bt BTree) UnmarshalSCALE(reader io.Reader) error { return nil } -// Copy returns a copy of the BTree. -func (bt BTree) Copy() *BTree { - return &BTree{ +// Copy returns a copy of the Tree. +func (bt Tree) Copy() *Tree { + return &Tree{ BTree: bt.BTree.Copy(), Comparator: bt.Comparator, ItemType: bt.ItemType, } } -// NewBTree creates a new BTree with the given comparator function. -func NewBTree[T any](comparator func(a, b any) bool) BTree { +// NewTree creates a new Tree with the given comparator function. +func NewTree[T any](comparator func(a, b any) bool) Tree { elementType := reflect.TypeOf((*T)(nil)).Elem() - return BTree{ + return Tree{ BTree: btree.New(comparator), Comparator: comparator, ItemType: elementType, } } -var _ Codec = (*BTree)(nil) +var _ Codec = (*Tree)(nil) // Map is a wrapper around tidwall/btree.Map type Map[K constraints.Ordered, V any] struct { @@ -110,10 +111,9 @@ type mapItem[K constraints.Ordered, V any] struct { func (btm Map[K, V]) MarshalSCALE() ([]byte, error) { encodedLen, err := scale.Marshal(uint(btm.Len())) if err != nil { - return nil, fmt.Errorf("failed to encode BTree length: %w", err) + return nil, fmt.Errorf("failed to encode Map length: %w", err) } - // write each item in the tree var ( pivot K encodedItems []byte @@ -155,11 +155,11 @@ func (btm Map[K, V]) UnmarshalSCALE(reader io.Reader) error { slicePtr := reflect.New(sliceType) encodedItems, err := io.ReadAll(reader) if err != nil { - return fmt.Errorf("read BTree items: %w", err) + return fmt.Errorf("read Map items: %w", err) } err = scale.Unmarshal(encodedItems, slicePtr.Interface()) if err != nil { - return fmt.Errorf("decode BTree items: %w", err) + return fmt.Errorf("decode Map items: %w", err) } for i := 0; i < slicePtr.Elem().Len(); i++ { @@ -176,8 +176,8 @@ func (btm Map[K, V]) Copy() Map[K, V] { } } -// NewBTreeMap creates a new Map with the given degree. -func NewBTreeMap[K constraints.Ordered, V any](degree int) Map[K, V] { +// NewMap creates a new Map with the given degree. +func NewMap[K constraints.Ordered, V any](degree int) Map[K, V] { return Map[K, V]{ Map: btree.NewMap[K, V](degree), Degree: degree, diff --git a/pkg/btree/btree_test.go b/pkg/btree/btree_test.go index 723cb8d96f..3d73e4187b 100644 --- a/pkg/btree/btree_test.go +++ b/pkg/btree/btree_test.go @@ -4,9 +4,10 @@ package btree import ( - "github.com/ChainSafe/gossamer/pkg/scale" "testing" + "github.com/ChainSafe/gossamer/pkg/scale" + "github.com/stretchr/testify/require" ) @@ -22,8 +23,8 @@ func TestBTree_Codec(t *testing.T) { return v1.Field1 < v2.Field1 } - // Create a BTree with 3 dummy items - tree := NewBTree[dummy](comparator) + // Create a Tree with 3 dummy items + tree := NewTree[dummy](comparator) tree.Set(dummy{Field1: 1}) tree.Set(dummy{Field1: 2}) tree.Set(dummy{Field1: 3}) @@ -43,11 +44,11 @@ func TestBTree_Codec(t *testing.T) { } require.Equal(t, expectedEncoded, encoded) - expected := NewBTree[dummy](comparator) + expected := NewTree[dummy](comparator) err = scale.Unmarshal(expectedEncoded, &expected) require.NoError(t, err) - // Check that the expected BTree has the same items as the original + // Check that the expected Tree has the same items as the original require.Equal(t, tree.Len(), expected.Len()) require.Equal(t, tree.ItemType, expected.ItemType) require.Equal(t, tree.Min(), expected.Min()) @@ -58,7 +59,7 @@ func TestBTree_Codec(t *testing.T) { } func TestBTreeMap_Codec(t *testing.T) { - btreeMap := NewBTreeMap[uint32, dummy](32) + btreeMap := NewMap[uint32, dummy](32) btreeMap.Set(uint32(1), dummy{Field1: 1}) btreeMap.Set(uint32(2), dummy{Field1: 2}) btreeMap.Set(uint32(3), dummy{Field1: 3}) @@ -77,7 +78,7 @@ func TestBTreeMap_Codec(t *testing.T) { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, } require.Equal(t, expectedEncoded, encoded) - expected := NewBTreeMap[uint32, dummy](32) + expected := NewMap[uint32, dummy](32) err = scale.Unmarshal(expectedEncoded, &expected) require.NoError(t, err) require.Equal(t, btreeMap.Len(), expected.Len())