Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New options for encoding Go strings to and from CBOR byte strings #465

Merged
merged 3 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion bytestring.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func (bs *ByteString) UnmarshalCBOR(data []byte) error {
return &UnmarshalTypeError{CBORType: typ.String(), GoType: typeByteString.String()}
}

*bs = ByteString(d.parseByteString())
b, _ := d.parseByteString()
*bs = ByteString(b)
return nil
}
196 changes: 136 additions & 60 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,23 @@ func (bidm BigIntDecMode) valid() bool {
return bidm >= 0 && bidm < maxBigIntDecMode
}

// ByteStringToStringMode specifies the behavior when decoding a CBOR byte string into a Go string.
type ByteStringToStringMode int

const (
// ByteStringToStringForbidden generates an error on an attempt to decode a CBOR byte string into a Go string.
ByteStringToStringForbidden ByteStringToStringMode = iota

// ByteStringToStringAllowed permits decoding a CBOR byte string into a Go string.
ByteStringToStringAllowed

maxByteStringToStringMode
)

func (bstsm ByteStringToStringMode) valid() bool {
return bstsm >= 0 && bstsm < maxByteStringToStringMode
}

// DecOptions specifies decoding options.
type DecOptions struct {
// DupMapKey specifies whether to enforce duplicate map key.
Expand Down Expand Up @@ -467,6 +484,15 @@ type DecOptions struct {

// BigIntDec specifies how to decode CBOR bignum to Go interface{}.
BigIntDec BigIntDecMode

// DefaultByteStringType is the Go type that should be produced when decoding a CBOR byte
// string into an empty interface value. Types to which a []byte is convertible are valid
// for this option, except for array and pointer-to-array types. If nil, the default is
// []byte.
DefaultByteStringType reflect.Type

// ByteStringToString specifies the behavior when decoding a CBOR byte string into a Go string.
ByteStringToString ByteStringToStringMode
}

// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
Expand Down Expand Up @@ -581,21 +607,29 @@ func (opts DecOptions) decMode() (*decMode, error) {
if !opts.BigIntDec.valid() {
return nil, errors.New("cbor: invalid BigIntDec " + strconv.Itoa(int(opts.BigIntDec)))
}
if opts.DefaultByteStringType != nil && opts.DefaultByteStringType.Kind() != reflect.String && (opts.DefaultByteStringType.Kind() != reflect.Slice || opts.DefaultByteStringType.Elem().Kind() != reflect.Uint8) {
return nil, fmt.Errorf("cbor: invalid DefaultByteStringType: %s is not of kind string or []uint8", opts.DefaultByteStringType)
}
if !opts.ByteStringToString.valid() {
return nil, errors.New("cbor: invalid ByteStringToString " + strconv.Itoa(int(opts.ByteStringToString)))
}
dm := decMode{
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
maxNestedLevels: opts.MaxNestedLevels,
maxArrayElements: opts.MaxArrayElements,
maxMapPairs: opts.MaxMapPairs,
indefLength: opts.IndefLength,
tagsMd: opts.TagsMd,
intDec: opts.IntDec,
mapKeyByteString: opts.MapKeyByteString,
extraReturnErrors: opts.ExtraReturnErrors,
defaultMapType: opts.DefaultMapType,
utf8: opts.UTF8,
fieldNameMatching: opts.FieldNameMatching,
bigIntDec: opts.BigIntDec,
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
maxNestedLevels: opts.MaxNestedLevels,
maxArrayElements: opts.MaxArrayElements,
maxMapPairs: opts.MaxMapPairs,
indefLength: opts.IndefLength,
tagsMd: opts.TagsMd,
intDec: opts.IntDec,
mapKeyByteString: opts.MapKeyByteString,
extraReturnErrors: opts.ExtraReturnErrors,
defaultMapType: opts.DefaultMapType,
utf8: opts.UTF8,
fieldNameMatching: opts.FieldNameMatching,
bigIntDec: opts.BigIntDec,
defaultByteStringType: opts.DefaultByteStringType,
byteStringToString: opts.ByteStringToString,
}
return &dm, nil
}
Expand Down Expand Up @@ -647,41 +681,45 @@ type DecMode interface {
}

type decMode struct {
tags tagProvider
dupMapKey DupMapKeyMode
timeTag DecTagMode
maxNestedLevels int
maxArrayElements int
maxMapPairs int
indefLength IndefLengthMode
tagsMd TagsMode
intDec IntDecMode
mapKeyByteString MapKeyByteStringMode
extraReturnErrors ExtraDecErrorCond
defaultMapType reflect.Type
utf8 UTF8Mode
fieldNameMatching FieldNameMatchingMode
bigIntDec BigIntDecMode
tags tagProvider
dupMapKey DupMapKeyMode
timeTag DecTagMode
maxNestedLevels int
maxArrayElements int
maxMapPairs int
indefLength IndefLengthMode
tagsMd TagsMode
intDec IntDecMode
mapKeyByteString MapKeyByteStringMode
extraReturnErrors ExtraDecErrorCond
defaultMapType reflect.Type
utf8 UTF8Mode
fieldNameMatching FieldNameMatchingMode
bigIntDec BigIntDecMode
defaultByteStringType reflect.Type
byteStringToString ByteStringToStringMode
}

var defaultDecMode, _ = DecOptions{}.decMode()

// DecOptions returns user specified options used to create this DecMode.
func (dm *decMode) DecOptions() DecOptions {
return DecOptions{
DupMapKey: dm.dupMapKey,
TimeTag: dm.timeTag,
MaxNestedLevels: dm.maxNestedLevels,
MaxArrayElements: dm.maxArrayElements,
MaxMapPairs: dm.maxMapPairs,
IndefLength: dm.indefLength,
TagsMd: dm.tagsMd,
IntDec: dm.intDec,
MapKeyByteString: dm.mapKeyByteString,
ExtraReturnErrors: dm.extraReturnErrors,
UTF8: dm.utf8,
FieldNameMatching: dm.fieldNameMatching,
BigIntDec: dm.bigIntDec,
DupMapKey: dm.dupMapKey,
TimeTag: dm.timeTag,
MaxNestedLevels: dm.maxNestedLevels,
MaxArrayElements: dm.maxArrayElements,
MaxMapPairs: dm.maxMapPairs,
IndefLength: dm.indefLength,
TagsMd: dm.tagsMd,
IntDec: dm.intDec,
MapKeyByteString: dm.mapKeyByteString,
ExtraReturnErrors: dm.extraReturnErrors,
UTF8: dm.utf8,
FieldNameMatching: dm.fieldNameMatching,
BigIntDec: dm.bigIntDec,
DefaultByteStringType: dm.defaultByteStringType,
ByteStringToString: dm.byteStringToString,
}
}

Expand Down Expand Up @@ -979,8 +1017,8 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
return fillNegativeInt(t, nValue, v)

case cborTypeByteString:
b := d.parseByteString()
return fillByteString(t, b, v)
b, copied := d.parseByteString()
return fillByteString(t, b, !copied, v, d.dm.byteStringToString)

case cborTypeTextString:
b, err := d.parseTextString()
Expand Down Expand Up @@ -1017,15 +1055,15 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
switch tagNum {
case 2:
// Bignum (tag 2) can be decoded to uint, int, float, slice, array, or big.Int.
b := d.parseByteString()
b, copied := d.parseByteString()
bi := new(big.Int).SetBytes(b)

if tInfo.nonPtrType == typeBigInt {
v.Set(reflect.ValueOf(*bi))
return nil
}
if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array {
return fillByteString(t, b, v)
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden)
}
if bi.IsUint64() {
return fillPositiveInt(t, bi.Uint64(), v)
Expand All @@ -1037,7 +1075,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
}
case 3:
// Bignum (tag 3) can be decoded to int, float, slice, array, or big.Int.
b := d.parseByteString()
b, copied := d.parseByteString()
bi := new(big.Int).SetBytes(b)
bi.Add(bi, big.NewInt(1))
bi.Neg(bi)
Expand All @@ -1047,7 +1085,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
return nil
}
if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array {
return fillByteString(t, b, v)
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden)
}
if bi.IsInt64() {
return fillNegativeInt(t, bi.Int64(), v)
Expand Down Expand Up @@ -1279,7 +1317,29 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
return nValue, nil

case cborTypeByteString:
return d.parseByteString(), nil
switch d.dm.defaultByteStringType {
case nil, typeByteSlice:
b, copied := d.parseByteString()
if copied {
return b, nil
}
clone := make([]byte, len(b))
copy(clone, b)
return clone, nil
case typeString:
b, _ := d.parseByteString()
return string(b), nil
default:
b, copied := d.parseByteString()
if copied || d.dm.defaultByteStringType.Kind() == reflect.String {
// Avoid an unnecessary copy since the conversion to string must
// copy the underlying bytes.
return reflect.ValueOf(b).Convert(d.dm.defaultByteStringType).Interface(), nil
}
clone := make([]byte, len(b))
copy(clone, b)
return reflect.ValueOf(clone).Convert(d.dm.defaultByteStringType).Interface(), nil
}
case cborTypeTextString:
b, err := d.parseTextString()
if err != nil {
Expand All @@ -1296,15 +1356,15 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
d.off = tagOff
return d.parseToTime()
case 2:
b := d.parseByteString()
b, _ := d.parseByteString()
bi := new(big.Int).SetBytes(b)

if d.dm.bigIntDec == BigIntDecodePointer {
return bi, nil
}
return *bi, nil
case 3:
b := d.parseByteString()
b, _ := d.parseByteString()
bi := new(big.Int).SetBytes(b)
bi.Add(bi, big.NewInt(1))
bi.Neg(bi)
Expand Down Expand Up @@ -1376,15 +1436,16 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
return nil, nil
}

// parseByteString parses CBOR encoded byte string. It returns a byte slice
// pointing to a copy of parsed data.
func (d *decoder) parseByteString() []byte {
// parseByteString parses a CBOR encoded byte string. The returned byte slice
// may be backed directly by the input. The second return value will be true if
// and only if the slice is backed by a copy of the input. Callers are
// responsible for making a copy if necessary.
func (d *decoder) parseByteString() ([]byte, bool) {
fxamacker marked this conversation as resolved.
Show resolved Hide resolved
_, ai, val := d.getHead()
if ai != 31 {
b := make([]byte, int(val))
copy(b, d.data[d.off:d.off+int(val)])
b := d.data[d.off : d.off+int(val)]
d.off += int(val)
return b
return b, false
}
// Process indefinite length string chunks.
b := []byte{}
Expand All @@ -1393,7 +1454,7 @@ func (d *decoder) parseByteString() []byte {
b = append(b, d.data[d.off:d.off+int(val)]...)
d.off += int(val)
}
return b
return b, true
}

// parseTextString parses CBOR encoded text string. It returns a byte slice
Expand Down Expand Up @@ -2082,6 +2143,8 @@ var (
typeBigInt = reflect.TypeOf(big.Int{})
typeUnmarshaler = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
typeBinaryUnmarshaler = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem()
typeString = reflect.TypeOf("")
typeByteSlice = reflect.TypeOf([]byte(nil))
)

func fillNil(_ cborType, v reflect.Value) error {
Expand Down Expand Up @@ -2184,18 +2247,31 @@ func fillFloat(t cborType, val float64, v reflect.Value) error {
return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()}
}

func fillByteString(t cborType, val []byte, v reflect.Value) error {
func fillByteString(t cborType, val []byte, shared bool, v reflect.Value, bsts ByteStringToStringMode) error {
if reflect.PtrTo(v.Type()).Implements(typeBinaryUnmarshaler) {
if v.CanAddr() {
v = v.Addr()
if u, ok := v.Interface().(encoding.BinaryUnmarshaler); ok {
// The contract of BinaryUnmarshaler forbids
// retaining the input bytes, so no copying is
// required even if val is shared.
fxamacker marked this conversation as resolved.
Show resolved Hide resolved
return u.UnmarshalBinary(val)
}
}
return errors.New("cbor: cannot set new value for " + v.Type().String())
}
if bsts == ByteStringToStringAllowed && v.Kind() == reflect.String {
v.SetString(string(val))
return nil
}
if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 {
v.SetBytes(val)
src := val
if shared {
// SetBytes shares the underlying bytes of the source slice.
src = make([]byte, len(val))
copy(src, val)
}
v.SetBytes(src)
return nil
}
if v.Kind() == reflect.Array && v.Type().Elem().Kind() == reflect.Uint8 {
Expand Down
Loading
Loading