Skip to content

Commit

Permalink
add btreemap
Browse files Browse the repository at this point in the history
  • Loading branch information
kanishkatn committed Oct 25, 2023
1 parent 144cdd2 commit 5a60b6f
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 31 deletions.
158 changes: 156 additions & 2 deletions pkg/scale/btree.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,23 @@
package scale

import (
"fmt"
"reflect"

"github.com/tidwall/btree"
)

type Ordered interface {
~int | ~int8 | ~int16 | ~int32 | ~int64 |
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr |
~float32 | ~float64 | ~string
}

type BTreeCodec interface {
Encode(es *encodeState) error
Decode(ds *decodeState, dstv reflect.Value) error
}

// BTree is a wrapper around tidwall/btree.BTree that also stores the comparator function and the type of the items
// stored in the BTree. This is needed during decoding because the BTree is a generic type, and we need to know the
// type of the items stored in the BTree in order to decode them.
Expand All @@ -18,16 +30,158 @@ type BTree struct {
ItemType reflect.Type
}

// Encode encodes the BTree using the given encodeState.
func (bt *BTree) Encode(es *encodeState) error {
// write the number of items in the tree
err := es.encodeLength(bt.Len())
if err != nil {
return err
}

bt.Ascend(nil, func(item interface{}) bool {
err = es.marshal(item)
return err == nil
})

return err
}

// Decode decodes the BTree using the given decodeState.
func (bt *BTree) Decode(ds *decodeState, dstv reflect.Value) error {
// Decode the number of items in the tree
length, err := ds.decodeLength()
if err != nil {
return fmt.Errorf("decoding BTree length: %w", err)
}

if bt.Comparator == nil {
return fmt.Errorf("no Comparator function provided for BTree")
}

if bt.BTree == nil {
bt.BTree = btree.New(bt.Comparator)
}

// Decode each item in the tree
for i := uint(0); i < length; i++ {
// Decode the value
value := reflect.New(bt.ItemType).Elem()
err = ds.unmarshal(value)
if err != nil {
return fmt.Errorf("decoding BTree item: %w", err)
}

// convert the value to the correct type for the BTree
bt.Set(value.Interface())
}

dstv.Set(reflect.ValueOf(*bt))
return nil
}

// Copy returns a copy of the BTree.
func (bt *BTree) Copy() *BTree {
return &BTree{
BTree: bt.BTree.Copy(),
Comparator: bt.Comparator,
ItemType: bt.ItemType,
}
}

// NewBTree creates a new BTree with the given comparator function.
func NewBTree[T any](comparator func(a, b any) bool) BTree {
func NewBTree[T any](comparator func(a, b any) bool) *BTree {
// There's no instantiation overhead of the actual type T because we're only creating a slice type and
// getting the element type from it.
var dummySlice []T
elementType := reflect.TypeOf(dummySlice).Elem()

return BTree{
return &BTree{
BTree: btree.New(comparator),
Comparator: comparator,
ItemType: elementType,
}
}

// BTreeMap is a wrapper around tidwall/btree.Map
type BTreeMap[K Ordered, V any] struct {
*btree.Map[K, V]
Degree int
}

// Encode encodes the BTreeMap using the given encodeState.
func (btm *BTreeMap[K, V]) Encode(es *encodeState) error {
// write the number of items in the tree
err := es.encodeLength(btm.Len())
if err != nil {
return err
}

// write each item in the tree
var pivot K
btm.Ascend(pivot, func(key K, value V) bool {
if err = es.marshal(key); err != nil {
return false
}

if err = es.marshal(value); err != nil {
return false
}

return true
})

return err
}

// Decode decodes the BTreeMap using the given decodeState.
func (btm *BTreeMap[K, V]) Decode(ds *decodeState, dstv reflect.Value) error {
// Decode the number of items in the tree
length, err := ds.decodeLength()
if err != nil {
return fmt.Errorf("decoding BTreeMap length: %w", err)
}

if btm.Map == nil {
btm.Map = btree.NewMap[K, V](btm.Degree)
}

// Decode each item in the tree
for i := uint(0); i < length; i++ {
// Decode the key
keyType := reflect.TypeOf((*K)(nil)).Elem()
keyInstance := reflect.New(keyType).Elem()
err = ds.unmarshal(keyInstance)
if err != nil {
return fmt.Errorf("decoding BTreeMap key: %w", err)
}
key := keyInstance.Interface().(K)

// Decode the value
valueType := reflect.TypeOf((*V)(nil)).Elem()
valueInstance := reflect.New(valueType).Elem()
err = ds.unmarshal(valueInstance)
if err != nil {
return fmt.Errorf("decoding BTreeMap value: %w", err)
}
value := valueInstance.Interface().(V)

btm.Map.Set(key, value)
}

dstv.Set(reflect.ValueOf(*btm))
return nil
}

// Copy returns a copy of the BTreeMap.
func (btm *BTreeMap[K, V]) Copy() BTreeMap[K, V] {
return BTreeMap[K, V]{
Map: btm.Map.Copy(),
}
}

// NewBTreeMap creates a new BTreeMap with the given degree.
func NewBTreeMap[K Ordered, V any](degree int) *BTreeMap[K, V] {
return &BTreeMap[K, V]{
Map: btree.NewMap[K, V](degree),
}
}
52 changes: 40 additions & 12 deletions pkg/scale/btree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type dummy struct {
Field2 [32]byte
}

func TestBTree(t *testing.T) {
func TestBTree_Codec(t *testing.T) {
comparator := func(a, b interface{}) bool {
v1 := a.(dummy)
v2 := b.(dummy)
Expand All @@ -23,9 +23,9 @@ func TestBTree(t *testing.T) {

// Create a BTree with 3 dummy items
tree := NewBTree[dummy](comparator)
tree.BTree.Set(dummy{Field1: 1})
tree.BTree.Set(dummy{Field1: 2})
tree.BTree.Set(dummy{Field1: 3})
tree.Set(dummy{Field1: 1})
tree.Set(dummy{Field1: 2})
tree.Set(dummy{Field1: 3})

encoded, err := Marshal(tree)
require.NoError(t, err)
Expand All @@ -43,17 +43,45 @@ func TestBTree(t *testing.T) {
}
require.Equal(t, expectedEncoded, encoded)

// Output:
expected := NewBTree[dummy](comparator)
err = Unmarshal(encoded, &expected)
err = Unmarshal(encoded, expected)
require.NoError(t, err)

// Check that the expected BTree has the same items as the original
require.Equal(t, tree.BTree.Len(), expected.BTree.Len())
require.Equal(t, tree.Len(), expected.Len())
require.Equal(t, tree.ItemType, expected.ItemType)
require.Equal(t, tree.BTree.Min(), expected.BTree.Min())
require.Equal(t, tree.BTree.Max(), expected.BTree.Max())
require.Equal(t, tree.BTree.Get(dummy{Field1: 1}), expected.BTree.Get(dummy{Field1: 1}))
require.Equal(t, tree.BTree.Get(dummy{Field1: 2}), expected.BTree.Get(dummy{Field1: 2}))
require.Equal(t, tree.BTree.Get(dummy{Field1: 3}), expected.BTree.Get(dummy{Field1: 3}))
require.Equal(t, tree.Min(), expected.Min())
require.Equal(t, tree.Max(), expected.Max())
require.Equal(t, tree.Get(dummy{Field1: 1}), expected.Get(dummy{Field1: 1}))
require.Equal(t, tree.Get(dummy{Field1: 2}), expected.Get(dummy{Field1: 2}))
require.Equal(t, tree.Get(dummy{Field1: 3}), expected.Get(dummy{Field1: 3}))
}

func TestBTreeMap_Codec(t *testing.T) {
btreeMap := NewBTreeMap[uint32, dummy](32)
btreeMap.Set(uint32(1), dummy{Field1: 1})
btreeMap.Set(uint32(2), dummy{Field1: 2})
btreeMap.Set(uint32(3), dummy{Field1: 3})

encoded, err := Marshal(btreeMap)
require.NoError(t, err)

//let mut btree = BTreeMap::<u32, (u32, Hash)>::new();
//btree.insert(1, (1, Hash::zero()));
//btree.insert(2, (2, Hash::zero()));
//btree.insert(3, (3, Hash::zero()));
//let encoded = btree.encode();
//println!("encoded: {:?}", encoded);
expectedEncoded := []byte{12, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}
require.Equal(t, expectedEncoded, encoded)

expected := NewBTreeMap[uint32, dummy](32)
err = Unmarshal(encoded, expected)
require.NoError(t, err)

require.Equal(t, btreeMap.Len(), expected.Len())
}
51 changes: 51 additions & 0 deletions pkg/scale/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"io"
"math/big"
"reflect"
"strings"

"github.com/tidwall/btree"
)
Expand Down Expand Up @@ -110,6 +111,20 @@ type decodeState struct {
}

func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) {
// Handle BTreeMap type separately for the following reasons:
// 1. BTreeMap is a generic type, so we can't use the normal type switch
// 2. We cannot use BTreeCodec because we are comparing the type of the dstv.Interface() in the type switch
if isBTree(dstv.Type()) {
if btm, ok := dstv.Addr().Interface().(BTreeCodec); ok {
if err := btm.Decode(ds, dstv); err != nil {
return err
}
} else {
return fmt.Errorf("could not type assert to BTreeCodec")
}
return nil
}

in := dstv.Interface()
switch in.(type) {
case *big.Int:
Expand Down Expand Up @@ -811,3 +826,39 @@ func (ds *decodeState) decodeBTree(dstv reflect.Value) (err error) {
dstv.Set(reflect.ValueOf(btreeValue))
return
}

func isBTree(t reflect.Type) bool {
if t.Kind() != reflect.Struct {
return false
}

// For BTreeMap
mapField, hasMap := t.FieldByName("Map")
_, hasDegree := t.FieldByName("Degree")

// For BTree
btreeField, hasBTree := t.FieldByName("BTree")
comparatorField, hasComparator := t.FieldByName("Comparator")
itemTypeField, hasItemType := t.FieldByName("ItemType")

if hasMap && hasDegree &&
mapField.Type.Kind() == reflect.Ptr &&
strings.HasPrefix(mapField.Type.String(), "*btree.Map[") {
return true
}

if hasBTree && hasComparator && hasItemType {
if btreeField.Type.Kind() != reflect.Ptr || btreeField.Type.String() != "*btree.BTree" {
return false
}
if comparatorField.Type.Kind() != reflect.Func {
return false
}
if itemTypeField.Type != reflect.TypeOf((*reflect.Type)(nil)).Elem() {
return false
}
return true
}

return false
}
19 changes: 2 additions & 17 deletions pkg/scale/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ func (es *encodeState) marshal(in interface{}) (err error) {
err = es.encodeVaryingDataType(in)
case VaryingDataTypeSlice:
err = es.encodeVaryingDataTypeSlice(in)
case BTree:
err = es.encodeBTree(in)
case BTreeCodec:
err = in.Encode(es)
default:
switch reflect.TypeOf(in).Kind() {
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16,
Expand Down Expand Up @@ -208,21 +208,6 @@ func (es *encodeState) encodeVaryingDataTypeSlice(vdts VaryingDataTypeSlice) (er
return
}

func (es *encodeState) encodeBTree(tree BTree) (err error) {
// write the number of items in the tree
err = es.encodeLength(tree.Len())
if err != nil {
return
}

// iterate through the tree and encode each item
tree.Ascend(nil, func(item any) bool {
err = es.marshal(item)
return err == nil
})
return
}

func (es *encodeState) encodeSlice(in interface{}) (err error) {
v := reflect.ValueOf(in)
err = es.encodeLength(v.Len())
Expand Down

0 comments on commit 5a60b6f

Please sign in to comment.