From 64ed7b735bb608d49da85515a2fcf6172da99223 Mon Sep 17 00:00:00 2001 From: Danila Fomin Date: Mon, 6 Jun 2022 16:56:49 +0300 Subject: [PATCH] add option to enforce nil container marshaling as empty containers --- decode.go | 17 +++++++++++++++++ encode.go | 14 +++++++++++--- encode_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 3 deletions(-) diff --git a/decode.go b/decode.go index 79ccaf80..31107663 100644 --- a/decode.go +++ b/decode.go @@ -200,6 +200,23 @@ func (m IndefLengthMode) valid() bool { return m < maxIndefLengthMode } +// NilContainersMode specifies how to encode []Type(nil) and map[Key]Type(nil). +type NilContainersMode int + +const ( + // NullForNil enforces null for []Type(nil)/map[Key]Type(nil). + NullForNil NilContainersMode = iota + + // EmptyForNil enforces empty map/list for []Type(nil)/map[Key]Type(nil). + EmptyForNil + + maxNilContainersMode +) + +func (m NilContainersMode) valid() bool { + return m < maxNilContainersMode +} + // TagsMode specifies whether to allow CBOR tags. type TagsMode int diff --git a/encode.go b/encode.go index 95d2c23f..6691d38d 100644 --- a/encode.go +++ b/encode.go @@ -292,6 +292,9 @@ type EncOptions struct { // IndefLength specifies whether to allow indefinite length CBOR items. IndefLength IndefLengthMode + // NilContainers specifies how to encode map[Key]Type(nil)/[]Type(nil) + NilContainers NilContainersMode + // TagsMd specifies whether to allow CBOR tags (major type 6). TagsMd TagsMode } @@ -464,6 +467,9 @@ func (opts EncOptions) encMode() (*encMode, error) { if !opts.IndefLength.valid() { return nil, errors.New("cbor: invalid IndefLength " + strconv.Itoa(int(opts.IndefLength))) } + if !opts.NilContainers.valid() { + return nil, errors.New("cbor: invalid NilContainers " + strconv.Itoa(int(opts.NilContainers))) + } if !opts.TagsMd.valid() { return nil, errors.New("cbor: invalid TagsMd " + strconv.Itoa(int(opts.TagsMd))) } @@ -479,6 +485,7 @@ func (opts EncOptions) encMode() (*encMode, error) { time: opts.Time, timeTag: opts.TimeTag, indefLength: opts.IndefLength, + nilContainers: opts.NilContainers, tagsMd: opts.TagsMd, } return &em, nil @@ -501,6 +508,7 @@ type encMode struct { time TimeMode timeTag EncTagMode indefLength IndefLengthMode + nilContainers NilContainersMode tagsMd TagsMode } @@ -787,7 +795,7 @@ func encodeFloat64(e *encoderBuffer, f64 float64) error { func encodeByteString(e *encoderBuffer, em *encMode, v reflect.Value) error { vk := v.Kind() - if vk == reflect.Slice && v.IsNil() { + if vk == reflect.Slice && v.IsNil() && em.nilContainers == NullForNil { e.Write(cborNil) return nil } @@ -824,7 +832,7 @@ type arrayEncodeFunc struct { } func (ae arrayEncodeFunc) encode(e *encoderBuffer, em *encMode, v reflect.Value) error { - if v.Kind() == reflect.Slice && v.IsNil() { + if v.Kind() == reflect.Slice && v.IsNil() && em.nilContainers == NullForNil { e.Write(cborNil) return nil } @@ -849,7 +857,7 @@ type mapEncodeFunc struct { } func (me mapEncodeFunc) encode(e *encoderBuffer, em *encMode, v reflect.Value) error { - if v.IsNil() { + if v.IsNil() && em.nilContainers == NullForNil { e.Write(cborNil) return nil } diff --git a/encode_test.go b/encode_test.go index 229f7555..7468c08f 100644 --- a/encode_test.go +++ b/encode_test.go @@ -2852,6 +2852,50 @@ func TestInvalidInfConvert(t *testing.T) { } } +func TestNilContainers(t *testing.T) { + nilContainersNull := EncOptions{NilContainers: NullForNil} + nilContainersEmpty := EncOptions{NilContainers: EmptyForNil} + testCases := []struct { + name string + v interface{} + opts EncOptions + wantCborData []byte + }{ + {"map(nil) as null", map[string]string(nil), nilContainersNull, hexDecode("f6")}, + {"map(nil) as empty map", map[string]string(nil), nilContainersEmpty, hexDecode("a0")}, + + {"slice(nil) as null", []int(nil), nilContainersNull, hexDecode("f6")}, + {"slice(nil) as empty list", []int(nil), nilContainersEmpty, hexDecode("80")}, + + {"[]byte(nil) as null", []byte(nil), nilContainersNull, hexDecode("f6")}, + {"[]byte(nil) as empty bytestring", []byte(nil), nilContainersEmpty, hexDecode("40")}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + em, err := tc.opts.EncMode() + if err != nil { + t.Errorf("EncMode() returned an error %v", err) + } + b, err := em.Marshal(tc.v) + if err != nil { + t.Errorf("Marshal(%v) returned error %v", tc.v, err) + } else if !bytes.Equal(b, tc.wantCborData) { + t.Errorf("Marshal(%v) = 0x%x, want 0x%x", tc.v, b, tc.wantCborData) + } + }) + } +} + +func TestInvalidNilContainers(t *testing.T) { + wantErrorMsg := "cbor: invalid NilContainers 100" + _, err := EncOptions{NilContainers: NilContainersMode(100)}.EncMode() + if err == nil { + t.Errorf("EncMode() didn't return an error") + } else if err.Error() != wantErrorMsg { + t.Errorf("EncMode() returned error %q, want %q", err.Error(), wantErrorMsg) + } +} + // Keith Randall's workaround for constant propagation issue https://github.com/golang/go/issues/36400 const ( // qnan 32 bits constants @@ -3306,6 +3350,7 @@ func TestEncOptions(t *testing.T) { Time: TimeRFC3339Nano, TimeTag: EncTagRequired, IndefLength: IndefLengthForbidden, + NilContainers: NullForNil, TagsMd: TagsAllowed, } em, err := opts1.EncMode()