Skip to content

Commit

Permalink
Merge pull request #485 from fxamacker/fxamacker/check-marshaler-data…
Browse files Browse the repository at this point in the history
…-wellformedness

Check well-formedness of data from MarshalCBOR
  • Loading branch information
fxamacker authored Feb 25, 2024
2 parents 00109c1 + a3a1d71 commit cfbd0ff
Show file tree
Hide file tree
Showing 8 changed files with 479 additions and 30 deletions.
54 changes: 41 additions & 13 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -606,69 +606,96 @@ const (
defaultMaxMapPairs = 131072
minMaxMapPairs = 16
maxMaxMapPairs = 2147483647

defaultMaxNestedLevels = 32
minMaxNestedLevels = 4
maxMaxNestedLevels = 65535
)

func (opts DecOptions) decMode() (*decMode, error) {
if !opts.DupMapKey.valid() {
return nil, errors.New("cbor: invalid DupMapKey " + strconv.Itoa(int(opts.DupMapKey)))
}

if !opts.TimeTag.valid() {
return nil, errors.New("cbor: invalid TimeTag " + strconv.Itoa(int(opts.TimeTag)))
}

if !opts.IndefLength.valid() {
return nil, errors.New("cbor: invalid IndefLength " + strconv.Itoa(int(opts.IndefLength)))
}

if !opts.TagsMd.valid() {
return nil, errors.New("cbor: invalid TagsMd " + strconv.Itoa(int(opts.TagsMd)))
}

if !opts.IntDec.valid() {
return nil, errors.New("cbor: invalid IntDec " + strconv.Itoa(int(opts.IntDec)))
}

if !opts.MapKeyByteString.valid() {
return nil, errors.New("cbor: invalid MapKeyByteString " + strconv.Itoa(int(opts.MapKeyByteString)))
}

if opts.MaxNestedLevels == 0 {
opts.MaxNestedLevels = 32
} else if opts.MaxNestedLevels < 4 || opts.MaxNestedLevels > 65535 {
return nil, errors.New("cbor: invalid MaxNestedLevels " + strconv.Itoa(opts.MaxNestedLevels) + " (range is [4, 65535])")
opts.MaxNestedLevels = defaultMaxNestedLevels
} else if opts.MaxNestedLevels < minMaxNestedLevels || opts.MaxNestedLevels > maxMaxNestedLevels {
return nil, errors.New("cbor: invalid MaxNestedLevels " + strconv.Itoa(opts.MaxNestedLevels) +
" (range is [" + strconv.Itoa(minMaxNestedLevels) + ", " + strconv.Itoa(maxMaxNestedLevels) + "])")
}

if opts.MaxArrayElements == 0 {
opts.MaxArrayElements = defaultMaxArrayElements
} else if opts.MaxArrayElements < minMaxArrayElements || opts.MaxArrayElements > maxMaxArrayElements {
return nil, errors.New("cbor: invalid MaxArrayElements " + strconv.Itoa(opts.MaxArrayElements) + " (range is [" + strconv.Itoa(minMaxArrayElements) + ", " + strconv.Itoa(maxMaxArrayElements) + "])")
return nil, errors.New("cbor: invalid MaxArrayElements " + strconv.Itoa(opts.MaxArrayElements) +
" (range is [" + strconv.Itoa(minMaxArrayElements) + ", " + strconv.Itoa(maxMaxArrayElements) + "])")
}

if opts.MaxMapPairs == 0 {
opts.MaxMapPairs = defaultMaxMapPairs
} else if opts.MaxMapPairs < minMaxMapPairs || opts.MaxMapPairs > maxMaxMapPairs {
return nil, errors.New("cbor: invalid MaxMapPairs " + strconv.Itoa(opts.MaxMapPairs) + " (range is [" + strconv.Itoa(minMaxMapPairs) + ", " + strconv.Itoa(maxMaxMapPairs) + "])")
return nil, errors.New("cbor: invalid MaxMapPairs " + strconv.Itoa(opts.MaxMapPairs) +
" (range is [" + strconv.Itoa(minMaxMapPairs) + ", " + strconv.Itoa(maxMaxMapPairs) + "])")
}

if !opts.ExtraReturnErrors.valid() {
return nil, errors.New("cbor: invalid ExtraReturnErrors " + strconv.Itoa(int(opts.ExtraReturnErrors)))
}

if opts.DefaultMapType != nil && opts.DefaultMapType.Kind() != reflect.Map {
return nil, fmt.Errorf("cbor: invalid DefaultMapType %s", opts.DefaultMapType)
}

if !opts.UTF8.valid() {
return nil, errors.New("cbor: invalid UTF8 " + strconv.Itoa(int(opts.UTF8)))
}

if !opts.FieldNameMatching.valid() {
return nil, errors.New("cbor: invalid FieldNameMatching " + strconv.Itoa(int(opts.FieldNameMatching)))
}

if !opts.BigIntDec.valid() {
return nil, errors.New("cbor: invalid BigIntDec " + strconv.Itoa(int(opts.BigIntDec)))
}
if opts.DefaultByteStringType != nil && opts.DefaultByteStringType.Kind() != reflect.String && (opts.DefaultByteStringType.Kind() != reflect.Slice || opts.DefaultByteStringType.Elem().Kind() != reflect.Uint8) {

if opts.DefaultByteStringType != nil &&
opts.DefaultByteStringType.Kind() != reflect.String &&
(opts.DefaultByteStringType.Kind() != reflect.Slice || opts.DefaultByteStringType.Elem().Kind() != reflect.Uint8) {
return nil, fmt.Errorf("cbor: invalid DefaultByteStringType: %s is not of kind string or []uint8", opts.DefaultByteStringType)
}

if !opts.ByteStringToString.valid() {
return nil, errors.New("cbor: invalid ByteStringToString " + strconv.Itoa(int(opts.ByteStringToString)))
}

if !opts.FieldNameByteString.valid() {
return nil, errors.New("cbor: invalid FieldNameByteString " + strconv.Itoa(int(opts.FieldNameByteString)))
}

if !opts.UnrecognizedTagToAny.valid() {
return nil, errors.New("cbor: invalid UnrecognizedTagToAnyMode " + strconv.Itoa(int(opts.UnrecognizedTagToAny)))
}

dm := decMode{
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
Expand All @@ -689,6 +716,7 @@ func (opts DecOptions) decMode() (*decMode, error) {
fieldNameByteString: opts.FieldNameByteString,
unrecognizedTagToAny: opts.UnrecognizedTagToAny,
}

return &dm, nil
}

Expand Down Expand Up @@ -795,9 +823,9 @@ func (dm *decMode) Unmarshal(data []byte, v interface{}) error {
d := decoder{data: data, dm: dm}

// Check well-formedness.
off := d.off // Save offset before data validation
err := d.wellformed(false) // don't allow any extra data after valid data item.
d.off = off // Restore offset
off := d.off // Save offset before data validation
err := d.wellformed(false, false) // don't allow any extra data after valid data item.
d.off = off // Restore offset
if err != nil {
return err
}
Expand All @@ -815,9 +843,9 @@ func (dm *decMode) UnmarshalFirst(data []byte, v interface{}) (rest []byte, err
d := decoder{data: data, dm: dm}

// check well-formedness.
off := d.off // Save offset before data validation
err = d.wellformed(true) // allow extra data after well-formed data item
d.off = off // Restore offset
off := d.off // Save offset before data validation
err = d.wellformed(true, false) // allow extra data after well-formed data item
d.off = off // Restore offset

// If it is well-formed, parse the value. This is structured like this to allow
// better test coverage
Expand Down Expand Up @@ -858,7 +886,7 @@ func (dm *decMode) Valid(data []byte) error {
// an ExtraneousDataError is returned.
func (dm *decMode) Wellformed(data []byte) error {
d := decoder{data: data, dm: dm}
return d.wellformed(false)
return d.wellformed(false, false)
}

// NewDecoder returns a new decoder that reads from r using dm DecMode.
Expand Down
2 changes: 1 addition & 1 deletion diagnose.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ func (di *diagnose) diagFirst() (string, []byte, error) {

func (di *diagnose) wellformed(allowExtraData bool) error {
off := di.d.off
err := di.d.wellformed(allowExtraData)
err := di.d.wellformed(allowExtraData, false)
di.d.off = off
return err
}
Expand Down
94 changes: 94 additions & 0 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,23 @@ type Marshaler interface {
MarshalCBOR() ([]byte, error)
}

// MarshalerError represents error from checking encoded CBOR data item
// returned from MarshalCBOR for well-formedness and some very limited tag validation.
type MarshalerError struct {
typ reflect.Type
err error
}

func (e *MarshalerError) Error() string {
return "cbor: error calling MarshalCBOR for type " +
e.typ.String() +
": " + e.err.Error()
}

func (e *MarshalerError) Unwrap() error {
return e.err
}

// UnsupportedTypeError is returned by Marshal when attempting to encode value
// of an unsupported type.
type UnsupportedTypeError struct {
Expand Down Expand Up @@ -632,6 +649,75 @@ type encMode struct {

var defaultEncMode, _ = EncOptions{}.encMode()

// These four decoding modes are used by getMarshalerDecMode.
// maxNestedLevels, maxArrayElements, and maxMapPairs are
// set to max allowed limits to avoid rejecting Marshaler
// output that would have been the allowable output of a
// non-Marshaler object that exceeds default limits.
var (
marshalerForbidIndefLengthForbidTagsDecMode = decMode{
maxNestedLevels: maxMaxNestedLevels,
maxArrayElements: maxMaxArrayElements,
maxMapPairs: maxMaxMapPairs,
indefLength: IndefLengthForbidden,
tagsMd: TagsForbidden,
}

marshalerAllowIndefLengthForbidTagsDecMode = decMode{
maxNestedLevels: maxMaxNestedLevels,
maxArrayElements: maxMaxArrayElements,
maxMapPairs: maxMaxMapPairs,
indefLength: IndefLengthAllowed,
tagsMd: TagsForbidden,
}

marshalerForbidIndefLengthAllowTagsDecMode = decMode{
maxNestedLevels: maxMaxNestedLevels,
maxArrayElements: maxMaxArrayElements,
maxMapPairs: maxMaxMapPairs,
indefLength: IndefLengthForbidden,
tagsMd: TagsAllowed,
}

marshalerAllowIndefLengthAllowTagsDecMode = decMode{
maxNestedLevels: maxMaxNestedLevels,
maxArrayElements: maxMaxArrayElements,
maxMapPairs: maxMaxMapPairs,
indefLength: IndefLengthAllowed,
tagsMd: TagsAllowed,
}
)

// getMarshalerDecMode returns one of four existing decoding modes
// which can be reused (safe for parallel use) for the purpose of
// checking if data returned by Marshaler is well-formed.
func getMarshalerDecMode(indefLength IndefLengthMode, tagsMd TagsMode) *decMode {
switch {
case indefLength == IndefLengthAllowed && tagsMd == TagsAllowed:
return &marshalerAllowIndefLengthAllowTagsDecMode

case indefLength == IndefLengthAllowed && tagsMd == TagsForbidden:
return &marshalerAllowIndefLengthForbidTagsDecMode

case indefLength == IndefLengthForbidden && tagsMd == TagsAllowed:
return &marshalerForbidIndefLengthAllowTagsDecMode

case indefLength == IndefLengthForbidden && tagsMd == TagsForbidden:
return &marshalerForbidIndefLengthForbidTagsDecMode

default:
// This should never happen, unless we add new options to
// IndefLengthMode or TagsMode without updating this function.
return &decMode{
maxNestedLevels: maxMaxNestedLevels,
maxArrayElements: maxMaxArrayElements,
maxMapPairs: maxMaxMapPairs,
indefLength: indefLength,
tagsMd: tagsMd,
}
}
}

// EncOptions returns user specified options used to create this EncMode.
func (em *encMode) EncOptions() EncOptions {
return EncOptions{
Expand Down Expand Up @@ -1345,6 +1431,14 @@ func encodeMarshalerType(e *encoderBuffer, em *encMode, v reflect.Value) error {
if err != nil {
return err
}

// Verify returned CBOR data item from MarshalCBOR() is well-formed and passes tag validity for builtin tags 0-3.
d := decoder{data: data, dm: getMarshalerDecMode(em.indefLength, em.tagsMd)}
err = d.wellformed(false, true)
if err != nil {
return &MarshalerError{typ: v.Type(), err: err}
}

e.Write(data)
return nil
}
Expand Down
Loading

0 comments on commit cfbd0ff

Please sign in to comment.