Skip to content

Commit

Permalink
Prevent zero length values in slices and maps in codec (#2819)
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenButtolph authored Mar 8, 2024
1 parent d2d09c2 commit f02d463
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 9 deletions.
2 changes: 2 additions & 0 deletions codec/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ var (
ErrDoesNotImplementInterface = errors.New("does not implement interface")
ErrUnexportedField = errors.New("unexported field")
ErrExtraSpace = errors.New("trailing buffer space")
ErrMarshalZeroLength = errors.New("can't marshal zero length value")
ErrUnmarshalZeroLength = errors.New("can't unmarshal zero length value")
)

// Codec marshals and unmarshals
Expand Down
25 changes: 25 additions & 0 deletions codec/reflectcodec/type_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ func (c *genericCodec) size(
return 0, false, err
}

if size == 0 {
return 0, false, fmt.Errorf("can't marshal slice of zero length values: %w", codec.ErrMarshalZeroLength)
}

// For fixed-size types we manually calculate lengths rather than
// processing each element separately to improve performance.
if constSize {
Expand Down Expand Up @@ -235,6 +239,10 @@ func (c *genericCodec) size(
return 0, false, err
}

if keySize == 0 && valueSize == 0 {
return 0, false, fmt.Errorf("can't marshal map with zero length entries: %w", codec.ErrMarshalZeroLength)
}

switch {
case keyConstSize && valueConstSize:
numElts := value.Len()
Expand Down Expand Up @@ -394,9 +402,13 @@ func (c *genericCodec) marshal(
return p.Err
}
for i := 0; i < numElts; i++ { // Process each element in the slice
startOffset := p.Offset
if err := c.marshal(value.Index(i), p, typeStack); err != nil {
return err
}
if startOffset == p.Offset {
return fmt.Errorf("couldn't marshal slice of zero length values: %w", codec.ErrMarshalZeroLength)
}
}
return nil
case reflect.Array:
Expand Down Expand Up @@ -479,6 +491,8 @@ func (c *genericCodec) marshal(
allKeyBytes := slices.Clone(p.Bytes[startOffset:p.Offset])
p.Offset = startOffset
for _, key := range sortedKeys {
keyStartOffset := p.Offset

// pack key
startIndex := key.startIndex - startOffset
endIndex := key.endIndex - startOffset
Expand All @@ -492,6 +506,9 @@ func (c *genericCodec) marshal(
if err := c.marshal(value.MapIndex(key.key), p, typeStack); err != nil {
return err
}
if keyStartOffset == p.Offset {
return fmt.Errorf("couldn't marshal map with zero length entries: %w", codec.ErrMarshalZeroLength)
}
}

return nil
Expand Down Expand Up @@ -625,9 +642,14 @@ func (c *genericCodec) unmarshal(
zeroValue := reflect.Zero(innerType)
for i := 0; i < numElts; i++ {
value.Set(reflect.Append(value, zeroValue))

startOffset := p.Offset
if err := c.unmarshal(p, value.Index(i), typeStack); err != nil {
return err
}
if startOffset == p.Offset {
return fmt.Errorf("couldn't unmarshal slice of zero length values: %w", codec.ErrUnmarshalZeroLength)
}
}
return nil
case reflect.Array:
Expand Down Expand Up @@ -755,6 +777,9 @@ func (c *genericCodec) unmarshal(
if err := c.unmarshal(p, mapValue, typeStack); err != nil {
return err
}
if keyStartOffset == p.Offset {
return fmt.Errorf("couldn't unmarshal map with zero length entries: %w", codec.ErrUnmarshalZeroLength)
}

// Assign the key-value pair in the map
value.SetMapIndex(mapKey, mapValue)
Expand Down
75 changes: 66 additions & 9 deletions codec/test_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ var (
TestNilSliceSerialization,
TestEmptySliceSerialization,
TestSliceWithEmptySerialization,
TestSliceWithEmptySerializationOutOfMemory,
TestSliceWithEmptySerializationError,
TestMapWithEmptySerialization,
TestMapWithEmptySerializationError,
TestSliceTooLarge,
TestNegativeNumbers,
TestTooLargeUnmarshal,
Expand Down Expand Up @@ -731,7 +733,7 @@ func TestEmptySliceSerialization(codec GeneralCodec, t testing.TB) {
require.Equal(val, valUnmarshaled)
}

// Test marshaling slice that is not nil and not empty
// Test marshaling empty slice of zero length structs
func TestSliceWithEmptySerialization(codec GeneralCodec, t testing.TB) {
require := require.New(t)

Expand All @@ -745,9 +747,9 @@ func TestSliceWithEmptySerialization(codec GeneralCodec, t testing.TB) {
require.NoError(manager.RegisterCodec(0, codec))

val := &nestedSliceStruct{
Arr: make([]emptyStruct, 1000),
Arr: make([]emptyStruct, 0),
}
expected := []byte{0x00, 0x00, 0x00, 0x00, 0x03, 0xE8} // codec version (0x00, 0x00) then 1000 for numElts
expected := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00} // codec version (0x00, 0x00) then (0x00, 0x00, 0x00, 0x00) for numElts
result, err := manager.Marshal(0, val)
require.NoError(err)
require.Equal(expected, result)
Expand All @@ -760,10 +762,10 @@ func TestSliceWithEmptySerialization(codec GeneralCodec, t testing.TB) {
version, err := manager.Unmarshal(expected, &unmarshaled)
require.NoError(err)
require.Zero(version)
require.Len(unmarshaled.Arr, 1000)
require.Empty(unmarshaled.Arr)
}

func TestSliceWithEmptySerializationOutOfMemory(codec GeneralCodec, t testing.TB) {
func TestSliceWithEmptySerializationError(codec GeneralCodec, t testing.TB) {
require := require.New(t)

type emptyStruct struct{}
Expand All @@ -776,14 +778,69 @@ func TestSliceWithEmptySerializationOutOfMemory(codec GeneralCodec, t testing.TB
require.NoError(manager.RegisterCodec(0, codec))

val := &nestedSliceStruct{
Arr: make([]emptyStruct, math.MaxInt32),
Arr: make([]emptyStruct, 1),
}
_, err := manager.Marshal(0, val)
require.ErrorIs(err, ErrMaxSliceLenExceeded)
require.ErrorIs(err, ErrMarshalZeroLength)

_, err = manager.Size(0, val)
require.ErrorIs(err, ErrMarshalZeroLength)

b := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01} // codec version (0x00, 0x00) then (0x00, 0x00, 0x00, 0x01) for numElts

unmarshaled := nestedSliceStruct{}
_, err = manager.Unmarshal(b, &unmarshaled)
require.ErrorIs(err, ErrUnmarshalZeroLength)
}

// Test marshaling empty map of zero length structs
func TestMapWithEmptySerialization(codec GeneralCodec, t testing.TB) {
require := require.New(t)

type emptyStruct struct{}

manager := NewDefaultManager()
require.NoError(manager.RegisterCodec(0, codec))

val := make(map[emptyStruct]emptyStruct)
expected := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00} // codec version (0x00, 0x00) then (0x00, 0x00, 0x00, 0x00) for numElts
result, err := manager.Marshal(0, val)
require.NoError(err)
require.Equal(expected, result)

bytesLen, err := manager.Size(0, val)
require.NoError(err)
require.Equal(6, bytesLen) // 2 byte codec version + 4 byte length prefix
require.Len(result, bytesLen)

var unmarshaled map[emptyStruct]emptyStruct
version, err := manager.Unmarshal(expected, &unmarshaled)
require.NoError(err)
require.Zero(version)
require.Empty(unmarshaled)
}

func TestMapWithEmptySerializationError(codec GeneralCodec, t testing.TB) {
require := require.New(t)

type emptyStruct struct{}

manager := NewDefaultManager()
require.NoError(manager.RegisterCodec(0, codec))

val := map[emptyStruct]emptyStruct{
{}: {},
}
_, err := manager.Marshal(0, val)
require.ErrorIs(err, ErrMarshalZeroLength)

_, err = manager.Size(0, val)
require.ErrorIs(err, ErrMarshalZeroLength)

b := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x01} // codec version (0x00, 0x00) then (0x00, 0x00, 0x00, 0x01) for numElts

var unmarshaled map[emptyStruct]emptyStruct
_, err = manager.Unmarshal(b, &unmarshaled)
require.ErrorIs(err, ErrUnmarshalZeroLength)
}

func TestSliceTooLarge(codec GeneralCodec, t testing.TB) {
Expand Down

0 comments on commit f02d463

Please sign in to comment.