Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add options to support byte string map keys as struct field names #472

Merged
merged 2 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,15 @@ func getEncodingStructType(t reflect.Type) (*encodingStructType, error) {
copy(flds[i].cborName[n:], flds[i].name)
e.Reset()

// If cborName contains a text string, then cborNameByteString contains a
// string that has the byte string major type but is otherwise identical to
// cborName.
flds[i].cborNameByteString = make([]byte, len(flds[i].cborName))
copy(flds[i].cborNameByteString, flds[i].cborName)
// Reset encoded CBOR type to byte string, preserving the "additional
// information" bits:
flds[i].cborNameByteString[0] = byte(cborTypeByteString) | (flds[i].cborNameByteString[0] & 0x1f)

hasKeyAsStr = true
}

Expand Down
45 changes: 38 additions & 7 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,23 @@ func (bstsm ByteStringToStringMode) valid() bool {
return bstsm >= 0 && bstsm < maxByteStringToStringMode
}

// FieldNameByteStringMode specifies the behavior when decoding a CBOR byte string map key as a Go struct field name.
type FieldNameByteStringMode int

const (
// FieldNameByteStringForbidden generates an error on an attempt to decode a CBOR byte string map key as a Go struct field name.
FieldNameByteStringForbidden FieldNameByteStringMode = iota

// FieldNameByteStringAllowed permits CBOR byte string map keys to be recognized as Go struct field names.
FieldNameByteStringAllowed

maxFieldNameByteStringMode
)

func (fnbsm FieldNameByteStringMode) valid() bool {
return fnbsm >= 0 && fnbsm < maxFieldNameByteStringMode
}

// DecOptions specifies decoding options.
type DecOptions struct {
// DupMapKey specifies whether to enforce duplicate map key.
Expand Down Expand Up @@ -493,6 +510,10 @@ type DecOptions struct {

// ByteStringToString specifies the behavior when decoding a CBOR byte string into a Go string.
ByteStringToString ByteStringToStringMode

// FieldNameByteString specifies the behavior when decoding a CBOR byte string map key as a
// Go struct field name.
FieldNameByteString FieldNameByteStringMode
}

// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
Expand Down Expand Up @@ -613,6 +634,9 @@ func (opts DecOptions) decMode() (*decMode, error) {
if !opts.ByteStringToString.valid() {
return nil, errors.New("cbor: invalid ByteStringToString " + strconv.Itoa(int(opts.ByteStringToString)))
}
if !opts.FieldNameByteString.valid() {
return nil, errors.New("cbor: invalid FieldNameByteString " + strconv.Itoa(int(opts.FieldNameByteString)))
}
dm := decMode{
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
Expand All @@ -630,6 +654,7 @@ func (opts DecOptions) decMode() (*decMode, error) {
bigIntDec: opts.BigIntDec,
defaultByteStringType: opts.DefaultByteStringType,
byteStringToString: opts.ByteStringToString,
fieldNameByteString: opts.FieldNameByteString,
}
return &dm, nil
}
Expand Down Expand Up @@ -698,6 +723,7 @@ type decMode struct {
bigIntDec BigIntDecMode
defaultByteStringType reflect.Type
byteStringToString ByteStringToStringMode
fieldNameByteString FieldNameByteStringMode
}

var defaultDecMode, _ = DecOptions{}.decMode()
Expand All @@ -720,6 +746,7 @@ func (dm *decMode) DecOptions() DecOptions {
BigIntDec: dm.bigIntDec,
DefaultByteStringType: dm.defaultByteStringType,
ByteStringToString: dm.byteStringToString,
FieldNameByteString: dm.fieldNameByteString,
}
}

Expand Down Expand Up @@ -1848,15 +1875,19 @@ func (d *decoder) parseMapToStruct(v reflect.Value, tInfo *typeInfo) error { //n
var k interface{} // Used by duplicate map key detection

t := d.nextCBORType()
if t == cborTypeTextString {
if t == cborTypeTextString || (t == cborTypeByteString && d.dm.fieldNameByteString == FieldNameByteStringAllowed) {
var keyBytes []byte
keyBytes, lastErr = d.parseTextString()
if lastErr != nil {
if err == nil {
err = lastErr
if t == cborTypeTextString {
keyBytes, lastErr = d.parseTextString()
if lastErr != nil {
if err == nil {
err = lastErr
}
d.skip() // skip value
continue
}
d.skip() // skip value
continue
} else { // cborTypeByteString
keyBytes, _ = d.parseByteString()
}

keyLen := len(keyBytes)
Expand Down
64 changes: 64 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8302,6 +8302,70 @@ func TestUnmarshalByteStringToString(t *testing.T) {
}
}

func TestDecModeInvalidFieldNameByteStringMode(t *testing.T) {
for _, tc := range []struct {
name string
opts DecOptions
wantErrorMsg string
}{
{
name: "below range of valid modes",
opts: DecOptions{FieldNameByteString: -1},
wantErrorMsg: "cbor: invalid FieldNameByteString -1",
},
{
name: "above range of valid modes",
opts: DecOptions{FieldNameByteString: 101},
wantErrorMsg: "cbor: invalid FieldNameByteString 101",
},
} {
t.Run(tc.name, func(t *testing.T) {
_, err := tc.opts.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg)
}
})
}
}

func TestUnmarshalFieldNameByteString(t *testing.T) {
allowed, err := DecOptions{
FieldNameByteString: FieldNameByteStringAllowed,
}.DecMode()
if err != nil {
t.Fatal(err)
}

var s struct {
F int64 `json:"f"`
}

err = allowed.Unmarshal(hexDecode("a1414601"), &s) // {h'46': 1}
if err != nil {
t.Fatal(err)
}

if s.F != 1 {
t.Errorf("expected field F to be set to 1, got %d", s.F)
}

forbidden, err := DecOptions{
FieldNameByteString: FieldNameByteStringForbidden,
}.DecMode()
if err != nil {
t.Fatal(err)
}

const wantMsg = "cbor: cannot unmarshal byte string into Go value of type string (map key is of type byte string and cannot be used to match struct field name)"
if err := forbidden.Unmarshal(hexDecode("a1414601"), &s); err == nil {
t.Errorf("expected non-nil error")
} else if gotMsg := err.Error(); gotMsg != wantMsg {
t.Errorf("expected error %q, got %q", wantMsg, gotMsg)
}
}

func isCBORNil(data []byte) bool {
return len(data) > 0 && (data[0] == 0xf6 || data[0] == 0xf7)
}
38 changes: 36 additions & 2 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,23 @@ func (om OmitEmptyMode) valid() bool {
return om >= 0 && om < maxOmitEmptyMode
}

// FieldNameMode specifies the CBOR type to use when encoding struct field names.
type FieldNameMode int

const (
// FieldNameToTextString encodes struct fields to CBOR text string (major type 3).
FieldNameToTextString FieldNameMode = iota

// FieldNameToTextString encodes struct fields to CBOR byte string (major type 2).
FieldNameToByteString

maxFieldNameMode
)

func (fnm FieldNameMode) valid() bool {
return fnm >= 0 && fnm < maxFieldNameMode
}

// EncOptions specifies encoding options.
type EncOptions struct {
// Sort specifies sorting order.
Expand Down Expand Up @@ -381,6 +398,9 @@ type EncOptions struct {
// - CBOR text string (major type 3) is default
// - CBOR byte string (major type 2)
String StringMode

// FieldName specifies the CBOR type to use when encoding struct field names.
FieldName FieldNameMode
}

// CanonicalEncOptions returns EncOptions for "Canonical CBOR" encoding,
Expand Down Expand Up @@ -563,6 +583,9 @@ func (opts EncOptions) encMode() (*encMode, error) {
if err != nil {
return nil, err
}
if !opts.FieldName.valid() {
return nil, errors.New("cbor: invalid FieldName " + strconv.Itoa(int(opts.FieldName)))
}
em := encMode{
sort: opts.Sort,
shortestFloat: opts.ShortestFloat,
Expand All @@ -577,6 +600,7 @@ func (opts EncOptions) encMode() (*encMode, error) {
omitEmpty: opts.OmitEmpty,
stringType: opts.String,
stringMajorType: stringMajorType,
fieldName: opts.FieldName,
}
return &em, nil
}
Expand All @@ -603,6 +627,7 @@ type encMode struct {
omitEmpty OmitEmptyMode
stringType StringMode
stringMajorType cborType
fieldName FieldNameMode
}

var defaultEncMode, _ = EncOptions{}.encMode()
Expand All @@ -621,6 +646,7 @@ func (em *encMode) EncOptions() EncOptions {
TagsMd: em.tagsMd,
OmitEmpty: em.omitEmpty,
String: em.stringType,
FieldName: em.fieldName,
}
}

Expand Down Expand Up @@ -1137,7 +1163,11 @@ func encodeFixedLengthStruct(e *encoderBuffer, em *encMode, v reflect.Value, fld

for i := 0; i < len(flds); i++ {
f := flds[i]
e.Write(f.cborName)
if !f.keyAsInt && em.fieldName == FieldNameToByteString {
e.Write(f.cborNameByteString)
} else { // int or text string
e.Write(f.cborName)
}

fv := v.Field(f.idx[0])
if err := f.ef(e, em, fv); err != nil {
Expand Down Expand Up @@ -1189,7 +1219,11 @@ func encodeStruct(e *encoderBuffer, em *encMode, v reflect.Value) (err error) {
}
}

kve.Write(f.cborName)
if !f.keyAsInt && em.fieldName == FieldNameToByteString {
kve.Write(f.cborNameByteString)
} else { // int or text string
kve.Write(f.cborName)
}

if err := f.ef(kve, em, fv); err != nil {
putEncoderBuffer(kve)
Expand Down
94 changes: 94 additions & 0 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3698,6 +3698,34 @@ func TestEncModeStringType(t *testing.T) {
}
}

func TestEncModeInvalidFieldNameMode(t *testing.T) {
for _, tc := range []struct {
name string
opts EncOptions
wantErrorMsg string
}{
{
name: "",
opts: EncOptions{FieldName: -1},
wantErrorMsg: "cbor: invalid FieldName -1",
},
{
name: "",
opts: EncOptions{FieldName: 101},
wantErrorMsg: "cbor: invalid FieldName 101",
},
} {
t.Run(tc.name, func(t *testing.T) {
_, err := tc.opts.EncMode()
if err == nil {
t.Errorf("EncMode() didn't return an error")
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("EncMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg)
}
})
}
}

func TestEncIndefiniteLengthOption(t *testing.T) {
// Default option allows indefinite length items
var buf bytes.Buffer
Expand Down Expand Up @@ -4055,3 +4083,69 @@ func TestMarshalStringType(t *testing.T) {
})
}
}

func TestMarshalFieldNameType(t *testing.T) {
for _, tc := range []struct {
name string
opts EncOptions
in interface{}
want []byte
}{
{
name: "fixed-length to text string",
opts: EncOptions{FieldName: FieldNameToTextString},
in: struct {
F1 int `cbor:"1,keyasint"`
F2 int `cbor:"a"`
F3 int `cbor:"-3,keyasint"`
}{},
want: hexDecode("a301006161002200"),
},
{
name: "fixed-length to byte string",
opts: EncOptions{FieldName: FieldNameToByteString},
in: struct {
F1 int `cbor:"1,keyasint"`
F2 int `cbor:"a"`
F3 int `cbor:"-3,keyasint"`
}{},
want: hexDecode("a301004161002200"),
},
{
name: "variable-length to text string",
opts: EncOptions{FieldName: FieldNameToTextString},
in: struct {
F1 int `cbor:"1,omitempty,keyasint"`
F2 int `cbor:"a,omitempty"`
F3 int `cbor:"-3,omitempty,keyasint"`
}{F1: 7, F2: 7, F3: 7},
want: hexDecode("a301076161072207"),
},
{
name: "variable-length to byte string",
opts: EncOptions{FieldName: FieldNameToByteString},
in: struct {
F1 int `cbor:"1,omitempty,keyasint"`
F2 int `cbor:"a,omitempty"`
F3 int `cbor:"-3,omitempty,keyasint"`
}{F1: 7, F2: 7, F3: 7},
want: hexDecode("a301074161072207"),
},
} {
t.Run(tc.name, func(t *testing.T) {
em, err := tc.opts.EncMode()
if err != nil {
t.Fatal(err)
}

got, err := em.Marshal(tc.in)
if err != nil {
t.Errorf("unexpected error from Marshal(%q): %v", tc.in, err)
}

if !bytes.Equal(got, tc.want) {
t.Errorf("Marshal(%q): wanted %x, got %x", tc.in, tc.want, got)
}
})
}
}
Loading
Loading