From 3442f60b23b15d0b0e6706d3cf66f72b340cff2c Mon Sep 17 00:00:00 2001 From: Timothy Wu Date: Fri, 1 Dec 2023 15:30:50 -0500 Subject: [PATCH 1/3] add Marshaler, and Unmarshaler interfaces with accompanying encode/decode impl --- pkg/scale/decode.go | 17 ++++++++ pkg/scale/decode_test.go | 85 ++++++++++++++++++++++++++++++++++++++++ pkg/scale/encode.go | 20 ++++++++++ pkg/scale/encode_test.go | 23 +++++++++++ 4 files changed, 145 insertions(+) diff --git a/pkg/scale/decode.go b/pkg/scale/decode.go index 45a527f0b8..bcd99c8d1c 100644 --- a/pkg/scale/decode.go +++ b/pkg/scale/decode.go @@ -75,6 +75,11 @@ func Unmarshal(data []byte, dst interface{}) (err error) { return } +// Unmarshaler is the interface for custom SCALE unmarshalling for a given type +type Unmarshaler interface { + UnmarshalSCALE(io.Reader) error +} + // Decoder is used to decode from an io.Reader type Decoder struct { decodeState @@ -108,6 +113,18 @@ type decodeState struct { } func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) { + unmarshalerType := reflect.TypeOf((*Unmarshaler)(nil)).Elem() + if dstv.CanAddr() && dstv.Addr().Type().Implements(unmarshalerType) { + methodVal := dstv.Addr().MethodByName("UnmarshalSCALE") + values := methodVal.Call([]reflect.Value{reflect.ValueOf(ds.Reader)}) + if !values[0].IsNil() { + errIn := values[0].Interface() + err := errIn.(error) + return err + } + return + } + in := dstv.Interface() switch in.(type) { case *big.Int: diff --git a/pkg/scale/decode_test.go b/pkg/scale/decode_test.go index 3309c58a9b..e578df1bb8 100644 --- a/pkg/scale/decode_test.go +++ b/pkg/scale/decode_test.go @@ -5,6 +5,9 @@ package scale import ( "bytes" + "encoding/binary" + "fmt" + "io" "math/big" "reflect" "testing" @@ -535,3 +538,85 @@ func Test_decodeState_decodeUint(t *testing.T) { }) } } + +type myStruct struct { + First uint32 + Middle any + Last uint32 +} + +func (ms *myStruct) UnmarshalSCALE(reader io.Reader) (err error) { + buf := make([]byte, 4) + _, err = reader.Read(buf) + if err != nil { + return + } + ms.First = uint32(binary.LittleEndian.Uint32(buf)) + + buf = make([]byte, 4) + _, err = reader.Read(buf) + if err != nil { + return + } + ms.Middle = uint32(binary.LittleEndian.Uint32(buf)) + + buf = make([]byte, 4) + _, err = reader.Read(buf) + if err != nil { + return + } + ms.Last = uint32(binary.LittleEndian.Uint32(buf)) + return nil +} + +type myStructError struct { + First uint32 + Middle any + Last uint32 +} + +func (mse *myStructError) UnmarshalSCALE(reader io.Reader) (err error) { + err = fmt.Errorf("eh?") + return err +} + +var _ Unmarshaler = &myStruct{} + +func Test_decodeState_Unmarshaller(t *testing.T) { + expected := myStruct{ + First: 1, + Middle: uint32(2), + Last: 3, + } + bytes := MustMarshal(expected) + ms := myStruct{} + Unmarshal(bytes, &ms) + assert.Equal(t, expected, ms) + + type myParentStruct struct { + First uint + Middle myStruct + Last uint + } + expectedParent := myParentStruct{ + First: 1, + Middle: expected, + Last: 3, + } + bytes = MustMarshal(expectedParent) + mps := myParentStruct{} + Unmarshal(bytes, &mps) + assert.Equal(t, expectedParent, mps) +} + +func Test_decodeState_Unmarshaller_Error(t *testing.T) { + expected := myStruct{ + First: 1, + Middle: uint32(2), + Last: 3, + } + bytes := MustMarshal(expected) + mse := myStructError{} + err := Unmarshal(bytes, &mse) + assert.Error(t, err, "eh?") +} diff --git a/pkg/scale/encode.go b/pkg/scale/encode.go index d312b85f91..45f50f1f5d 100644 --- a/pkg/scale/encode.go +++ b/pkg/scale/encode.go @@ -47,6 +47,11 @@ func Marshal(v interface{}) (b []byte, err error) { return } +// Marshaler is the interface for custom SCALE marshalling for a given type +type Marshaler interface { + MarshalSCALE() ([]byte, error) +} + // MustMarshal runs Marshal and panics on error. func MustMarshal(v interface{}) (b []byte) { b, err := Marshal(v) @@ -62,6 +67,21 @@ type encodeState struct { } func (es *encodeState) marshal(in interface{}) (err error) { + marshalerType := reflect.TypeOf((*Marshaler)(nil)).Elem() + inv := reflect.ValueOf(in) + if inv.Type().Implements(marshalerType) { + methodVal := inv.MethodByName("MarshalSCALE") + values := methodVal.Call(nil) + if !values[1].IsNil() { + errIn := values[1].Interface() + err := errIn.(error) + return err + } + bytes := values[0].Interface().([]byte) + _, err = es.Write(bytes) + return + } + switch in := in.(type) { case int: err = es.encodeUint(uint(in)) diff --git a/pkg/scale/encode_test.go b/pkg/scale/encode_test.go index 92de411919..8f5d9a60ca 100644 --- a/pkg/scale/encode_test.go +++ b/pkg/scale/encode_test.go @@ -5,6 +5,7 @@ package scale import ( "bytes" + "fmt" "math/big" "reflect" "strings" @@ -1250,3 +1251,25 @@ var byteArray = func(length int) []byte { } return b } + +type myMarshalerType uint64 + +func (mmt myMarshalerType) MarshalSCALE() ([]byte, error) { + return []byte{9, 9, 9}, nil +} + +type myMarshalerTypeError uint64 + +func (mmt myMarshalerTypeError) MarshalSCALE() ([]byte, error) { + return nil, fmt.Errorf("eh?") +} + +func Test_encodeState_Mashaler(t *testing.T) { + bytes := MustMarshal(myMarshalerType(888)) + assert.Equal(t, []byte{9, 9, 9}, bytes) +} + +func Test_encodeState_Mashaler_Error(t *testing.T) { + _, err := Marshal(myMarshalerTypeError(888)) + assert.Error(t, err, "eh?") +} From 9993b83e19ac683533b5112fde7adbafa9b1aa21 Mon Sep 17 00:00:00 2001 From: Timothy Wu Date: Fri, 1 Dec 2023 15:48:28 -0500 Subject: [PATCH 2/3] fix lint --- pkg/scale/decode_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/scale/decode_test.go b/pkg/scale/decode_test.go index e578df1bb8..713f3a7bce 100644 --- a/pkg/scale/decode_test.go +++ b/pkg/scale/decode_test.go @@ -551,21 +551,21 @@ func (ms *myStruct) UnmarshalSCALE(reader io.Reader) (err error) { if err != nil { return } - ms.First = uint32(binary.LittleEndian.Uint32(buf)) + ms.First = binary.LittleEndian.Uint32(buf) buf = make([]byte, 4) _, err = reader.Read(buf) if err != nil { return } - ms.Middle = uint32(binary.LittleEndian.Uint32(buf)) + ms.Middle = binary.LittleEndian.Uint32(buf) buf = make([]byte, 4) _, err = reader.Read(buf) if err != nil { return } - ms.Last = uint32(binary.LittleEndian.Uint32(buf)) + ms.Last = binary.LittleEndian.Uint32(buf) return nil } From 6edad0d1a9af4c5ccbdab270c002286a1c2f2cc5 Mon Sep 17 00:00:00 2001 From: Timothy Wu Date: Sat, 2 Dec 2023 01:27:36 -0500 Subject: [PATCH 3/3] simplify encoding --- pkg/scale/encode.go | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/pkg/scale/encode.go b/pkg/scale/encode.go index 45f50f1f5d..c9830aef9d 100644 --- a/pkg/scale/encode.go +++ b/pkg/scale/encode.go @@ -67,17 +67,13 @@ type encodeState struct { } func (es *encodeState) marshal(in interface{}) (err error) { - marshalerType := reflect.TypeOf((*Marshaler)(nil)).Elem() - inv := reflect.ValueOf(in) - if inv.Type().Implements(marshalerType) { - methodVal := inv.MethodByName("MarshalSCALE") - values := methodVal.Call(nil) - if !values[1].IsNil() { - errIn := values[1].Interface() - err := errIn.(error) - return err + marshaler, ok := in.(Marshaler) + if ok { + var bytes []byte + bytes, err = marshaler.MarshalSCALE() + if err != nil { + return } - bytes := values[0].Interface().([]byte) _, err = es.Write(bytes) return }