From 855da70def0924dd2d71910263c13d531d9fd51e Mon Sep 17 00:00:00 2001 From: Aurora Gaffney Date: Sun, 17 Mar 2024 13:15:19 -0500 Subject: [PATCH] feat: custom CBOR tag handling This moves the handling for custom CBOR tag types from cbor.Value to the generic encoder/decoder to make them more widely usable. Fixes #548 --- cbor/cbor.go | 16 +-- cbor/decode.go | 4 +- cbor/encode.go | 4 +- cbor/tags.go | 143 +++++++++++++++++++++++++++ cbor/tags_test.go | 94 ++++++++++++++++++ cbor/value.go | 147 +++++++++++++++------------- cbor/value_test.go | 53 ++++++++-- protocol/localtxmonitor/messages.go | 2 +- 8 files changed, 364 insertions(+), 99 deletions(-) create mode 100644 cbor/tags.go create mode 100644 cbor/tags_test.go diff --git a/cbor/cbor.go b/cbor/cbor.go index b3057e05..e57595a5 100644 --- a/cbor/cbor.go +++ b/cbor/cbor.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -30,20 +30,6 @@ const ( // Max value able to be stored in a single byte without type prefix CborMaxUintSimple uint8 = 0x17 - - // Useful tag numbers - CborTagCbor = 24 - CborTagRational = 30 - CborTagSet = 258 - CborTagMap = 259 - - // Tag ranges for "alternatives" - // https://www.ietf.org/archive/id/draft-bormann-cbor-notable-tags-07.html#name-enumerated-alternative-data - CborTagAlternative1Min = 121 - CborTagAlternative1Max = 127 - CborTagAlternative2Min = 1280 - CborTagAlternative2Max = 1400 - CborTagAlternative3 = 101 ) // Create an alias for RawMessage for convenience diff --git a/cbor/decode.go b/cbor/decode.go index 2e1210d5..ec30f8c1 100644 --- a/cbor/decode.go +++ b/cbor/decode.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ func Decode(dataBytes []byte, dest interface{}) (int, error) { // This defaults to 32, but there are blocks in the wild using >64 nested levels MaxNestedLevels: 256, } - decMode, err := decOptions.DecMode() + decMode, err := decOptions.DecModeWithTags(customTagSet) if err != nil { return 0, err } diff --git a/cbor/encode.go b/cbor/encode.go index 88184b4d..c6df0729 100644 --- a/cbor/encode.go +++ b/cbor/encode.go @@ -1,4 +1,4 @@ -// Copyright 2023 Blink Labs Software +// Copyright 2024 Blink Labs Software // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -25,7 +25,7 @@ import ( func Encode(data interface{}) ([]byte, error) { buf := bytes.NewBuffer(nil) - em, err := _cbor.CoreDetEncOptions().EncMode() + em, err := _cbor.CoreDetEncOptions().EncModeWithTags(customTagSet) if err != nil { return nil, err } diff --git a/cbor/tags.go b/cbor/tags.go new file mode 100644 index 00000000..f0c32eb8 --- /dev/null +++ b/cbor/tags.go @@ -0,0 +1,143 @@ +// Copyright 2024 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cbor + +import ( + "math/big" + "reflect" + + _cbor "github.com/fxamacker/cbor/v2" +) + +const ( + // Useful tag numbers + CborTagCbor = 24 + CborTagRational = 30 + CborTagSet = 258 + CborTagMap = 259 + + // Tag ranges for "alternatives" + // https://www.ietf.org/archive/id/draft-bormann-cbor-notable-tags-07.html#name-enumerated-alternative-data + CborTagAlternative1Min = 121 + CborTagAlternative1Max = 127 + CborTagAlternative2Min = 1280 + CborTagAlternative2Max = 1400 + CborTagAlternative3 = 101 +) + +var customTagSet _cbor.TagSet + +func init() { + // Build custom tagset + customTagSet = _cbor.NewTagSet() + tagOpts := _cbor.TagOptions{EncTag: _cbor.EncTagRequired, DecTag: _cbor.DecTagRequired} + // Wrapped CBOR + if err := customTagSet.Add( + tagOpts, + reflect.TypeOf(WrappedCbor{}), + CborTagCbor, + ); err != nil { + panic(err) + } + // Rational numbers + if err := customTagSet.Add( + tagOpts, + reflect.TypeOf(Rat{}), + CborTagRational, + ); err != nil { + panic(err) + } + // Sets + if err := customTagSet.Add( + tagOpts, + reflect.TypeOf(Set{}), + CborTagSet, + ); err != nil { + panic(err) + } + // Maps + if err := customTagSet.Add( + tagOpts, + reflect.TypeOf(Map{}), + CborTagMap, + ); err != nil { + panic(err) + } +} + +type WrappedCbor []byte + +func (w *WrappedCbor) UnmarshalCBOR(cborData []byte) error { + var tmpData []byte + if _, err := Decode(cborData, &tmpData); err != nil { + return err + } + *w = WrappedCbor(tmpData[:]) + return nil +} + +func (w WrappedCbor) Bytes() []byte { + return w[:] +} + +type Rat struct { + *big.Rat +} + +func (r *Rat) UnmarshalCBOR(cborData []byte) error { + tmpRat := []int64{} + if _, err := Decode(cborData, &tmpRat); err != nil { + return err + } + r.Rat = big.NewRat(tmpRat[0], tmpRat[1]) + return nil +} + +func (r *Rat) MarshalCBOR() ([]byte, error) { + tmpData := _cbor.Tag{ + Number: CborTagRational, + Content: []uint64{ + r.Num().Uint64(), + r.Denom().Uint64(), + }, + } + return Encode(&tmpData) +} + +func (r *Rat) ToBigRat() *big.Rat { + return r.Rat +} + +type Set []any + +func (s *Set) UnmarshalCBOR(cborData []byte) error { + var tmpData []any + if _, err := Decode(cborData, &tmpData); err != nil { + return err + } + *s = tmpData[:] + return nil +} + +type Map map[any]any + +func (m *Map) UnmarshalCBOR(cborData []byte) error { + tmpData := make(map[any]any) + if _, err := Decode(cborData, &tmpData); err != nil { + return err + } + *m = tmpData + return nil +} diff --git a/cbor/tags_test.go b/cbor/tags_test.go new file mode 100644 index 00000000..cf9a907f --- /dev/null +++ b/cbor/tags_test.go @@ -0,0 +1,94 @@ +// Copyright 2023 Blink Labs Software +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cbor_test + +import ( + "encoding/hex" + "math/big" + "reflect" + "testing" + + "github.com/blinklabs-io/gouroboros/cbor" +) + +var tagsTestDefs = []struct { + cborHex string + object any +}{ + { + cborHex: "d81843abcdef", + object: cbor.WrappedCbor([]byte{0xab, 0xcd, 0xef}), + }, + { + cborHex: "d81e82031903e8", + object: cbor.Rat{ + Rat: big.NewRat(3, 1000), + }, + }, + { + cborHex: "d9010283010203", + object: cbor.Set( + []any{ + uint64(1), uint64(2), uint64(3), + }, + ), + }, + { + cborHex: "d90103a201020304", + object: cbor.Map( + map[any]any{ + uint64(1): uint64(2), + uint64(3): uint64(4), + }, + ), + }, +} + +func TestTagsDecode(t *testing.T) { + for _, testDef := range tagsTestDefs { + cborData, err := hex.DecodeString(testDef.cborHex) + if err != nil { + t.Fatalf("failed to decode CBOR hex: %s", err) + } + var dest any + if _, err := cbor.Decode(cborData, &dest); err != nil { + t.Fatalf("failed to decode CBOR: %s", err) + } + if !reflect.DeepEqual(dest, testDef.object) { + t.Fatalf( + "CBOR did not decode to expected object\n got: %#v\n wanted: %#v", + dest, + testDef.object, + ) + } + } +} + +func TestTagsEncode(t *testing.T) { + for _, testDef := range tagsTestDefs { + cborData, err := cbor.Encode(testDef.object) + if err != nil { + t.Fatalf("failed to encode object to CBOR: %s", err) + } + cborHex := hex.EncodeToString(cborData) + if cborHex != testDef.cborHex { + t.Fatalf( + "object did not encode to expected CBOR\n got: %s\n wanted: %s", + cborHex, + testDef.cborHex, + ) + } + } +} diff --git a/cbor/value.go b/cbor/value.go index 0a8b10b9..1a815c03 100644 --- a/cbor/value.go +++ b/cbor/value.go @@ -59,41 +59,26 @@ func (v *Value) UnmarshalCBOR(data []byte) error { if _, err := Decode(data, &tmpTag); err != nil { return err } - switch tmpTag.Number { - case CborTagCbor: - v.value = tmpTag.Content - case CborTagRational: - tmpRat := []int64{} - if _, err := Decode(tmpTag.Content, &tmpRat); err != nil { + if (tmpTag.Number >= CborTagAlternative1Min && tmpTag.Number <= CborTagAlternative1Max) || + (tmpTag.Number >= CborTagAlternative2Min && tmpTag.Number <= CborTagAlternative2Max) || + tmpTag.Number == CborTagAlternative3 { + // Constructors/alternatives + var tmpConstr Constructor + if _, err := Decode(data, &tmpConstr); err != nil { return err } - v.value = big.NewRat(tmpRat[0], tmpRat[1]) - case CborTagSet: - return v.processArray(tmpTag.Content) - case CborTagMap: - return v.processMap(tmpTag.Content) - default: - if (tmpTag.Number >= CborTagAlternative1Min && tmpTag.Number <= CborTagAlternative1Max) || - (tmpTag.Number >= CborTagAlternative2Min && tmpTag.Number <= CborTagAlternative2Max) || - tmpTag.Number == CborTagAlternative3 { - // Constructors/alternatives - var tmpConstr Constructor - if _, err := Decode(data, &tmpConstr); err != nil { - return err - } - v.value = tmpConstr - } else { - // Fall back to standard CBOR tag parsing for our supported types - var tmpTagDecode interface{} - if _, err := Decode(data, &tmpTagDecode); err != nil { - return err - } - switch tmpTagDecode.(type) { - case int, uint, int64, uint64, bool, big.Int: - v.value = tmpTagDecode - default: - return fmt.Errorf("unsupported CBOR tag number: %d", tmpTag.Number) - } + v.value = tmpConstr + } else { + // Fall back to standard CBOR tag parsing for our supported types + var tmpTagDecode interface{} + if _, err := Decode(data, &tmpTagDecode); err != nil { + return err + } + switch tmpTagDecode.(type) { + case int, uint, int64, uint64, bool, big.Int, WrappedCbor, Rat, Set, Map: + v.value = tmpTagDecode + default: + return fmt.Errorf("unsupported CBOR tag number: %d", tmpTag.Number) } } default: @@ -179,45 +164,16 @@ func generateAstJson(obj interface{}) ([]byte, error) { switch v := obj.(type) { case ByteString: tmpJsonObj["bytes"] = hex.EncodeToString(v.Bytes()) + case WrappedCbor: + tmpJsonObj["bytes"] = hex.EncodeToString(v.Bytes()) case []interface{}: - tmpJson := `{"list":[` - for idx, val := range v { - tmpVal, err := generateAstJson(val) - if err != nil { - return nil, err - } - tmpJson += string(tmpVal) - if idx != (len(v) - 1) { - tmpJson += `,` - } - } - tmpJson += `]}` - return []byte(tmpJson), nil + return generateAstJsonList[[]any](v) + case Set: + return generateAstJsonList[Set](v) case map[interface{}]interface{}: - tmpItems := []string{} - for key, val := range v { - keyAstJson, err := generateAstJson(key) - if err != nil { - return nil, err - } - valAstJson, err := generateAstJson(val) - if err != nil { - return nil, err - } - tmpJson := fmt.Sprintf( - `{"k":%s,"v":%s}`, - keyAstJson, - valAstJson, - ) - tmpItems = append(tmpItems, string(tmpJson)) - } - // We naively sort the rendered map items to give consistent ordering - sort.Strings(tmpItems) - tmpJson := fmt.Sprintf( - `{"map":[%s]}`, - strings.Join(tmpItems, ","), - ) - return []byte(tmpJson), nil + return generateAstJsonMap[map[any]any](v) + case Map: + return generateAstJsonMap[Map](v) case Constructor: return json.Marshal(obj) case big.Int: @@ -226,6 +182,13 @@ func generateAstJson(obj interface{}) ([]byte, error) { v.String(), ) return []byte(tmpJson), nil + case Rat: + return generateAstJson( + []any{ + v.Num().Uint64(), + v.Denom().Uint64(), + }, + ) case int, uint, uint64, int64: tmpJsonObj["int"] = v case bool: @@ -238,6 +201,50 @@ func generateAstJson(obj interface{}) ([]byte, error) { return json.Marshal(&tmpJsonObj) } +func generateAstJsonList[T []any | Set](v T) ([]byte, error) { + tmpJson := `{"list":[` + for idx, val := range v { + tmpVal, err := generateAstJson(val) + if err != nil { + return nil, err + } + tmpJson += string(tmpVal) + if idx != (len(v) - 1) { + tmpJson += `,` + } + } + tmpJson += `]}` + return []byte(tmpJson), nil +} + +func generateAstJsonMap[T map[any]any | Map](v T) ([]byte, error) { + tmpItems := []string{} + for key, val := range v { + keyAstJson, err := generateAstJson(key) + if err != nil { + return nil, err + } + valAstJson, err := generateAstJson(val) + if err != nil { + return nil, err + } + tmpJson := fmt.Sprintf( + `{"k":%s,"v":%s}`, + keyAstJson, + valAstJson, + ) + tmpItems = append(tmpItems, string(tmpJson)) + } + // We naively sort the rendered map items to give consistent ordering + sort.Strings(tmpItems) + tmpJson := fmt.Sprintf( + `{"map":[%s]}`, + strings.Join(tmpItems, ","), + ) + return []byte(tmpJson), nil + +} + type Constructor struct { DecodeStoreCbor constructor uint diff --git a/cbor/value_test.go b/cbor/value_test.go index eb5f00bb..dcf89744 100644 --- a/cbor/value_test.go +++ b/cbor/value_test.go @@ -86,6 +86,41 @@ var testDefs = []struct { }, expectedAstJson: `{"list":[{"int":22318265904693663008365},{"int":8535038193994223137511702528}]}`, }, + // 24('abcdef') + { + cborHex: "d81843abcdef", + expectedObject: cbor.WrappedCbor([]byte{0xab, 0xcd, 0xef}), + expectedAstJson: `{"bytes":"abcdef"}`, + }, + // 30([3, 1000]) + { + cborHex: "d81e82031903e8", + expectedObject: cbor.Rat{ + Rat: big.NewRat(3, 1000), + }, + expectedAstJson: `{"list":[{"int":3},{"int":1000}]}`, + }, + // 258([1, 2, 3]) + { + cborHex: "d9010283010203", + expectedObject: cbor.Set( + []any{ + uint64(1), uint64(2), uint64(3), + }, + ), + expectedAstJson: `{"list":[{"int":1},{"int":2},{"int":3}]}`, + }, + // 259({1: 2, 3: 4}) + { + cborHex: "d90103a201020304", + expectedObject: cbor.Map( + map[any]any{ + uint64(1): uint64(2), + uint64(3): uint64(4), + }, + ), + expectedAstJson: `{"map":[{"k":{"int":1},"v":{"int":2}},{"k":{"int":3},"v":{"int":4}}]}`, + }, } func TestValueDecode(t *testing.T) { @@ -99,7 +134,7 @@ func TestValueDecode(t *testing.T) { if testDef.expectedDecodeError != nil { if err.Error() != testDef.expectedDecodeError.Error() { t.Fatalf( - "did not receive expected decode error, got: %s, wanted: %s", + "did not receive expected decode error, got: %s, wanted: %s", err, testDef.expectedDecodeError, ) @@ -116,7 +151,7 @@ func TestValueDecode(t *testing.T) { newObj := tmpValue.Value() if !reflect.DeepEqual(newObj, testDef.expectedObject) { t.Fatalf( - "CBOR did not decode to expected object\n got: %#v\n wanted: %#v", + "CBOR did not decode to expected object\n got: %#v\n wanted: %#v", newObj, testDef.expectedObject, ) @@ -156,7 +191,7 @@ func TestValueMarshalJSON(t *testing.T) { } if !test.JsonStringsEqual(jsonData, []byte(fullExpectedJson)) { t.Fatalf( - "CBOR did not marshal to expected JSON\n got: %s\n wanted: %s", + "CBOR did not marshal to expected JSON\n got: %s\n wanted: %s", jsonData, fullExpectedJson, ) @@ -175,7 +210,7 @@ func TestLazyValueDecode(t *testing.T) { if testDef.expectedDecodeError != nil { if err.Error() != testDef.expectedDecodeError.Error() { t.Fatalf( - "did not receive expected decode error, got: %s, wanted: %s", + "did not receive expected decode error, got: %s, wanted: %s", err, testDef.expectedDecodeError, ) @@ -190,7 +225,7 @@ func TestLazyValueDecode(t *testing.T) { if testDef.expectedDecodeError != nil { if err.Error() != testDef.expectedDecodeError.Error() { t.Fatalf( - "did not receive expected decode error, got: %s, wanted: %s", + "did not receive expected decode error, got: %s, wanted: %s", err, testDef.expectedDecodeError, ) @@ -206,7 +241,7 @@ func TestLazyValueDecode(t *testing.T) { } if !reflect.DeepEqual(newObj, testDef.expectedObject) { t.Fatalf( - "CBOR did not decode to expected object\n got: %#v\n wanted: %#v", + "CBOR did not decode to expected object\n got: %#v\n wanted: %#v", newObj, testDef.expectedObject, ) @@ -246,7 +281,7 @@ func TestLazyValueMarshalJSON(t *testing.T) { } if !test.JsonStringsEqual(jsonData, []byte(fullExpectedJson)) { t.Fatalf( - "CBOR did not marshal to expected JSON\n got: %s\n wanted: %s", + "CBOR did not marshal to expected JSON\n got: %s\n wanted: %s", jsonData, fullExpectedJson, ) @@ -303,7 +338,7 @@ func TestConstructorDecode(t *testing.T) { } if !reflect.DeepEqual(tmpConstr.Fields(), testDef.expectedObj.Fields()) { t.Fatalf( - "did not decode to expected fields\n got: %#v\n wanted: %#v", + "did not decode to expected fields\n got: %#v\n wanted: %#v", tmpConstr.Fields(), testDef.expectedObj.Fields(), ) @@ -320,7 +355,7 @@ func TestConstructorEncode(t *testing.T) { cborDataHex := hex.EncodeToString(cborData) if cborDataHex != strings.ToLower(testDef.cborHex) { t.Fatalf( - "did not encode to expected CBOR\n got: %s\n wanted: %s", + "did not encode to expected CBOR\n got: %s\n wanted: %s", cborDataHex, strings.ToLower(testDef.cborHex), ) diff --git a/protocol/localtxmonitor/messages.go b/protocol/localtxmonitor/messages.go index 1bf919ed..f68a3fad 100644 --- a/protocol/localtxmonitor/messages.go +++ b/protocol/localtxmonitor/messages.go @@ -174,7 +174,7 @@ func (m *MsgReplyNextTx) UnmarshalCBOR(data []byte) error { txWrapper := tmp[1].([]interface{}) m.Transaction = MsgReplyNextTxTransaction{ EraId: uint8(txWrapper[0].(uint64)), - Tx: txWrapper[1].(cbor.Tag).Content.([]byte), + Tx: txWrapper[1].(cbor.WrappedCbor).Bytes(), } } return nil