From bc12130f6c4da2cd5bccfd091600f509d4834b4a Mon Sep 17 00:00:00 2001 From: Suriyan S Date: Wed, 24 Jan 2024 21:43:36 -0500 Subject: [PATCH] Add decoding option for return type of data with an unknown tag Adds a decoding option to specify the preferred return type when unmarshalling a data with an unknown tag. If this option is not set, Unmarshal returns a value of type Tag{}. This ensures backward compatibility. If the option is set, Unmarshal returns the unmarshalled content. Signed-off-by: Suriyan Subbarayan suriyansub710@gmail.com --- decode.go | 131 ++++++++++++++++++++++++++++++------------------- decode_test.go | 66 +++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 51 deletions(-) diff --git a/decode.go b/decode.go index 015ac9d1..189628e4 100644 --- a/decode.go +++ b/decode.go @@ -446,6 +446,23 @@ func (fnbsm FieldNameByteStringMode) valid() bool { return fnbsm >= 0 && fnbsm < maxFieldNameByteStringMode } +// ReturnTypeForEmptyInterface specifies the type of the value to be returned when decoding into an empty interface +type ReturnTypeForEmptyInterface int + +const ( + // ReturnTagForEmptyInterface returns a value of type Tag + ReturnTagForEmptyInterface ReturnTypeForEmptyInterface = iota + + // ReturnTagContentForEmptyInterface returns only the content of the Tag + ReturnTagContentForEmptyInterface + + maxReturnTypesForEmptyInterface +) + +func (rtfei ReturnTypeForEmptyInterface) valid() bool { + return rtfei >= 0 && rtfei < maxReturnTypesForEmptyInterface +} + // DecOptions specifies decoding options. type DecOptions struct { // DupMapKey specifies whether to enforce duplicate map key. @@ -514,6 +531,9 @@ type DecOptions struct { // FieldNameByteString specifies the behavior when decoding a CBOR byte string map key as a // Go struct field name. FieldNameByteString FieldNameByteStringMode + + // ReturnTypeForEmptyIface specifies the type of the value to be returned when decoding into an empty interface + ReturnTypeForEmptyIface ReturnTypeForEmptyInterface } // DecMode returns DecMode with immutable options and no tags (safe for concurrency). @@ -637,24 +657,28 @@ func (opts DecOptions) decMode() (*decMode, error) { if !opts.FieldNameByteString.valid() { return nil, errors.New("cbor: invalid FieldNameByteString " + strconv.Itoa(int(opts.FieldNameByteString))) } + if !opts.ReturnTypeForEmptyIface.valid() { + return nil, errors.New("cbor: invalid ReturnTypeForEmptyIface " + strconv.Itoa(int(opts.ReturnTypeForEmptyIface))) + } dm := decMode{ - dupMapKey: opts.DupMapKey, - timeTag: opts.TimeTag, - maxNestedLevels: opts.MaxNestedLevels, - maxArrayElements: opts.MaxArrayElements, - maxMapPairs: opts.MaxMapPairs, - indefLength: opts.IndefLength, - tagsMd: opts.TagsMd, - intDec: opts.IntDec, - mapKeyByteString: opts.MapKeyByteString, - extraReturnErrors: opts.ExtraReturnErrors, - defaultMapType: opts.DefaultMapType, - utf8: opts.UTF8, - fieldNameMatching: opts.FieldNameMatching, - bigIntDec: opts.BigIntDec, - defaultByteStringType: opts.DefaultByteStringType, - byteStringToString: opts.ByteStringToString, - fieldNameByteString: opts.FieldNameByteString, + dupMapKey: opts.DupMapKey, + timeTag: opts.TimeTag, + maxNestedLevels: opts.MaxNestedLevels, + maxArrayElements: opts.MaxArrayElements, + maxMapPairs: opts.MaxMapPairs, + indefLength: opts.IndefLength, + tagsMd: opts.TagsMd, + intDec: opts.IntDec, + mapKeyByteString: opts.MapKeyByteString, + extraReturnErrors: opts.ExtraReturnErrors, + defaultMapType: opts.DefaultMapType, + utf8: opts.UTF8, + fieldNameMatching: opts.FieldNameMatching, + bigIntDec: opts.BigIntDec, + defaultByteStringType: opts.DefaultByteStringType, + byteStringToString: opts.ByteStringToString, + fieldNameByteString: opts.FieldNameByteString, + returnTypeForEmptyIface: opts.ReturnTypeForEmptyIface, } return &dm, nil } @@ -706,24 +730,25 @@ type DecMode interface { } type decMode struct { - tags tagProvider - dupMapKey DupMapKeyMode - timeTag DecTagMode - maxNestedLevels int - maxArrayElements int - maxMapPairs int - indefLength IndefLengthMode - tagsMd TagsMode - intDec IntDecMode - mapKeyByteString MapKeyByteStringMode - extraReturnErrors ExtraDecErrorCond - defaultMapType reflect.Type - utf8 UTF8Mode - fieldNameMatching FieldNameMatchingMode - bigIntDec BigIntDecMode - defaultByteStringType reflect.Type - byteStringToString ByteStringToStringMode - fieldNameByteString FieldNameByteStringMode + tags tagProvider + dupMapKey DupMapKeyMode + timeTag DecTagMode + maxNestedLevels int + maxArrayElements int + maxMapPairs int + indefLength IndefLengthMode + tagsMd TagsMode + intDec IntDecMode + mapKeyByteString MapKeyByteStringMode + extraReturnErrors ExtraDecErrorCond + defaultMapType reflect.Type + utf8 UTF8Mode + fieldNameMatching FieldNameMatchingMode + bigIntDec BigIntDecMode + defaultByteStringType reflect.Type + byteStringToString ByteStringToStringMode + fieldNameByteString FieldNameByteStringMode + returnTypeForEmptyIface ReturnTypeForEmptyInterface } var defaultDecMode, _ = DecOptions{}.decMode() @@ -731,22 +756,23 @@ var defaultDecMode, _ = DecOptions{}.decMode() // DecOptions returns user specified options used to create this DecMode. func (dm *decMode) DecOptions() DecOptions { return DecOptions{ - DupMapKey: dm.dupMapKey, - TimeTag: dm.timeTag, - MaxNestedLevels: dm.maxNestedLevels, - MaxArrayElements: dm.maxArrayElements, - MaxMapPairs: dm.maxMapPairs, - IndefLength: dm.indefLength, - TagsMd: dm.tagsMd, - IntDec: dm.intDec, - MapKeyByteString: dm.mapKeyByteString, - ExtraReturnErrors: dm.extraReturnErrors, - UTF8: dm.utf8, - FieldNameMatching: dm.fieldNameMatching, - BigIntDec: dm.bigIntDec, - DefaultByteStringType: dm.defaultByteStringType, - ByteStringToString: dm.byteStringToString, - FieldNameByteString: dm.fieldNameByteString, + DupMapKey: dm.dupMapKey, + TimeTag: dm.timeTag, + MaxNestedLevels: dm.maxNestedLevels, + MaxArrayElements: dm.maxArrayElements, + MaxMapPairs: dm.maxMapPairs, + IndefLength: dm.indefLength, + TagsMd: dm.tagsMd, + IntDec: dm.intDec, + MapKeyByteString: dm.mapKeyByteString, + ExtraReturnErrors: dm.extraReturnErrors, + UTF8: dm.utf8, + FieldNameMatching: dm.fieldNameMatching, + BigIntDec: dm.bigIntDec, + DefaultByteStringType: dm.defaultByteStringType, + ByteStringToString: dm.byteStringToString, + FieldNameByteString: dm.fieldNameByteString, + ReturnTypeForEmptyIface: dm.returnTypeForEmptyIface, } } @@ -1426,6 +1452,9 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli if err != nil { return nil, err } + if d.dm.returnTypeForEmptyIface == ReturnTagContentForEmptyInterface { + return content, nil + } return Tag{tagNum, content}, nil case cborTypePrimitives: _, ai, val := d.getHead() diff --git a/decode_test.go b/decode_test.go index 99265b63..efe62d9f 100644 --- a/decode_test.go +++ b/decode_test.go @@ -8366,6 +8366,72 @@ func TestUnmarshalFieldNameByteString(t *testing.T) { } } +func TestDecModeInvalidReturnTypeForEmptyInterface(t *testing.T) { + for _, tc := range []struct { + name string + opts DecOptions + wantErrorMsg string + }{ + { + name: "below range of valid modes", + opts: DecOptions{ReturnTypeForEmptyIface: -1}, + wantErrorMsg: "cbor: invalid ReturnTypeForEmptyIface -1", + }, + { + name: "above range of valid modes", + opts: DecOptions{ReturnTypeForEmptyIface: 101}, + wantErrorMsg: "cbor: invalid ReturnTypeForEmptyIface 101", + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.opts.DecMode() + if err == nil { + t.Errorf("DecMode() didn't return an error") + } else if err.Error() != tc.wantErrorMsg { + t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg) + } + }) + } +} + +func TestUnmarshalTaggedDataToEmptyInterfaceWithReturnType(t *testing.T) { + for _, tc := range []struct { + name string + opts DecOptions + in []byte + want interface{} + }{ + { + name: "default to value of type Tag", + opts: DecOptions{}, + in: hexDecode("d8ff00"), + want: Tag{}, + }, + { + name: "Tag's content", + opts: DecOptions{ReturnTypeForEmptyIface: ReturnTagContentForEmptyInterface}, + in: hexDecode("d8ff00"), + want: uint64(0), + }, + } { + t.Run(tc.name, func(t *testing.T) { + dm, err := tc.opts.DecMode() + if err != nil { + t.Fatal(err) + } + + var got interface{} + if err := dm.Unmarshal(tc.in, &got); err != nil { + t.Errorf("unexpected error: %v", err) + } + + if reflect.TypeOf(tc.want) != reflect.TypeOf(got) { + t.Errorf("got %s, want %s", reflect.TypeOf(got), reflect.TypeOf(tc.want)) + } + }) + } +} + func isCBORNil(data []byte) bool { return len(data) > 0 && (data[0] == 0xf6 || data[0] == 0xf7) }