Skip to content

Commit

Permalink
Add option to permit decoding CBOR byte strings into Go strings.
Browse files Browse the repository at this point in the history
The unchanged default behavior is to produce an UnmarshalTypeError when decoding a CBOR byte string
into a Go string.

Signed-off-by: Ben Luddy <bluddy@redhat.com>
  • Loading branch information
benluddy committed Jan 5, 2024
1 parent a077161 commit 5af0d26
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 4 deletions.
38 changes: 34 additions & 4 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,23 @@ func (bidm BigIntDecMode) valid() bool {
return bidm >= 0 && bidm < maxBigIntDecMode
}

// ByteStringToStringMode specifies the behavior when decoding a CBOR byte string into a Go string.
type ByteStringToStringMode int

const (
// ByteStringToStringError generates an error on an attempt to decode a CBOR byte string into a Go string.
ByteStringToStringError ByteStringToStringMode = iota

// ByteStringToStringAllow permits decoding a CBOR byte string into a Go string.
ByteStringToStringAllow

maxByteStringToStringMode
)

func (bstsm ByteStringToStringMode) valid() bool {
return bstsm >= 0 && bstsm < maxByteStringToStringMode
}

// DecOptions specifies decoding options.
type DecOptions struct {
// DupMapKey specifies whether to enforce duplicate map key.
Expand Down Expand Up @@ -473,6 +490,9 @@ type DecOptions struct {
// for this option, except for array and pointer-to-array types. If nil, the default is
// []byte.
DefaultByteStringType reflect.Type

// ByteStringToString specifies the behavior when decoding a CBOR byte string into a Go string.
ByteStringToString ByteStringToStringMode
}

// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
Expand Down Expand Up @@ -590,6 +610,9 @@ func (opts DecOptions) decMode() (*decMode, error) {
if opts.DefaultByteStringType != nil && opts.DefaultByteStringType.Kind() != reflect.String && (opts.DefaultByteStringType.Kind() != reflect.Slice || opts.DefaultByteStringType.Elem().Kind() != reflect.Uint8) {
return nil, fmt.Errorf("cbor: invalid DefaultByteStringType: %s is not of kind string or []uint8", opts.DefaultByteStringType)
}
if !opts.ByteStringToString.valid() {
return nil, errors.New("cbor: invalid ByteStringToString " + strconv.Itoa(int(opts.ByteStringToString)))
}
dm := decMode{
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
Expand All @@ -606,6 +629,7 @@ func (opts DecOptions) decMode() (*decMode, error) {
fieldNameMatching: opts.FieldNameMatching,
bigIntDec: opts.BigIntDec,
defaultByteStringType: opts.DefaultByteStringType,
byteStringToString: opts.ByteStringToString,
}
return &dm, nil
}
Expand Down Expand Up @@ -673,6 +697,7 @@ type decMode struct {
fieldNameMatching FieldNameMatchingMode
bigIntDec BigIntDecMode
defaultByteStringType reflect.Type
byteStringToString ByteStringToStringMode
}

var defaultDecMode, _ = DecOptions{}.decMode()
Expand All @@ -694,6 +719,7 @@ func (dm *decMode) DecOptions() DecOptions {
FieldNameMatching: dm.fieldNameMatching,
BigIntDec: dm.bigIntDec,
DefaultByteStringType: dm.defaultByteStringType,
ByteStringToString: dm.byteStringToString,
}
}

Expand Down Expand Up @@ -992,7 +1018,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin

case cborTypeByteString:
b, copied := d.parseByteString()
return fillByteString(t, b, !copied, v)
return fillByteString(t, b, !copied, v, d.dm.byteStringToString)

case cborTypeTextString:
b, err := d.parseTextString()
Expand Down Expand Up @@ -1037,7 +1063,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
return nil
}
if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array {
return fillByteString(t, b, !copied, v)
return fillByteString(t, b, !copied, v, ByteStringToStringError)
}
if bi.IsUint64() {
return fillPositiveInt(t, bi.Uint64(), v)
Expand All @@ -1059,7 +1085,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
return nil
}
if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array {
return fillByteString(t, b, !copied, v)
return fillByteString(t, b, !copied, v, ByteStringToStringError)
}
if bi.IsInt64() {
return fillNegativeInt(t, bi.Int64(), v)
Expand Down Expand Up @@ -2219,7 +2245,7 @@ func fillFloat(t cborType, val float64, v reflect.Value) error {
return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()}
}

func fillByteString(t cborType, val []byte, shared bool, v reflect.Value) error {
func fillByteString(t cborType, val []byte, shared bool, v reflect.Value, bsts ByteStringToStringMode) error {
if reflect.PtrTo(v.Type()).Implements(typeBinaryUnmarshaler) {
if v.CanAddr() {
v = v.Addr()
Expand All @@ -2232,6 +2258,10 @@ func fillByteString(t cborType, val []byte, shared bool, v reflect.Value) error
}
return errors.New("cbor: cannot set new value for " + v.Type().String())
}
if bsts == ByteStringToStringAllow && v.Kind() == reflect.String {
v.SetString(string(val))
return nil
}
if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 {
src := val
if shared {
Expand Down
58 changes: 58 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8246,6 +8246,64 @@ func TestUnmarshalDefaultByteStringType(t *testing.T) {
}
}

func TestDecModeInvalidByteStringToStringMode(t *testing.T) {
for _, tc := range []struct {
name string
opts DecOptions
wantErrorMsg string
}{
{
name: "below range of valid modes",
opts: DecOptions{ByteStringToString: -1},
wantErrorMsg: "cbor: invalid ByteStringToString -1",
},
{
name: "above range of valid modes",
opts: DecOptions{ByteStringToString: 101},
wantErrorMsg: "cbor: invalid ByteStringToString 101",
},
} {
t.Run(tc.name, func(t *testing.T) {
_, err := tc.opts.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg)
}
})
}
}

func TestUnmarshalByteStringToString(t *testing.T) {
var s string

derror, err := DecOptions{ByteStringToString: ByteStringToStringError}.DecMode()
if err != nil {
t.Fatal(err)
}

if err = derror.Unmarshal(hexDecode("43414243"), &s); err == nil {
t.Error("expected non-nil error from Unmarshal")
}

if s != "" {
t.Errorf("expected destination string to be empty, got %q", s)
}

dallow, err := DecOptions{ByteStringToString: ByteStringToStringAllow}.DecMode()
if err != nil {
t.Fatal(err)
}

if err = dallow.Unmarshal(hexDecode("43414243"), &s); err != nil {
t.Errorf("expected nil error from Unmarshal, got: %v", err)
}

if s != "ABC" {
t.Errorf("expected destination string to be \"ABC\", got %q", s)
}
}

func isCBORNil(data []byte) bool {
return len(data) > 0 && (data[0] == 0xf6 || data[0] == 0xf7)
}

0 comments on commit 5af0d26

Please sign in to comment.