From b6b4b7c09c8767375cf45f681ada255f693401fb Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Fri, 27 Jul 2018 12:44:39 -0700 Subject: [PATCH] proto: make invalid UTF-8 errors non-fatal The current logic currently treats RequredNotSetError as a non-fatal error such that it continues to proceed with its normal execution, but returns that non-fatal error at the end. We now make it such that invalid UTF-8 is also a distinguishable error that is treated as non-fatal by the execution logic. In the rare event that both an RequiredNotSet error and InvalidUTF8 error is encountered, the first one encountered is returned. In the process of making this change, we also fix a number of cases where RequiredNotSet was treated as fatal, when it should not have been non-fatal (notably in the logic for extensions, message sets, and maps). This change deliberately does not provide an API make it easy to distinguish invalid UTF-8 errors as this is not a normal behavior we want users to depend on. We can always expose additional API for this in the future. Users can test for invalid UTF-8 with: re, ok := err.(interface{ InvalidUTF8() bool }) isInvalidUTF8 := ok && re.InvalidUTF8() --- proto/all_test.go | 33 ++++++++++--- proto/encode.go | 15 ------ proto/lib.go | 62 ++++++++++++++++++++++- proto/table_marshal.go | 104 +++++++++++++++++++++------------------ proto/table_unmarshal.go | 49 +++++++++--------- 5 files changed, 169 insertions(+), 94 deletions(-) diff --git a/proto/all_test.go b/proto/all_test.go index d7b671aeeb..a68d91d2aa 100644 --- a/proto/all_test.go +++ b/proto/all_test.go @@ -2256,26 +2256,32 @@ func TestInvalidUTF8(t *testing.T) { label string proto2 Message proto3 Message + want []byte }{{ label: "Scalar", proto2: &TestUTF8{Scalar: String(invalidUTF8)}, proto3: &pb3.TestUTF8{Scalar: invalidUTF8}, + want: []byte{0x0a, 0x07, 0xde, 0xad, 0xbe, 0xef, 0x80, 0x00, 0xff}, }, { label: "Vector", proto2: &TestUTF8{Vector: []string{invalidUTF8}}, proto3: &pb3.TestUTF8{Vector: []string{invalidUTF8}}, + want: []byte{0x12, 0x07, 0xde, 0xad, 0xbe, 0xef, 0x80, 0x00, 0xff}, }, { label: "Oneof", proto2: &TestUTF8{Oneof: &TestUTF8_Field{invalidUTF8}}, proto3: &pb3.TestUTF8{Oneof: &pb3.TestUTF8_Field{invalidUTF8}}, + want: []byte{0x1a, 0x07, 0xde, 0xad, 0xbe, 0xef, 0x80, 0x00, 0xff}, }, { label: "MapKey", proto2: &TestUTF8{MapKey: map[string]int64{invalidUTF8: 0}}, proto3: &pb3.TestUTF8{MapKey: map[string]int64{invalidUTF8: 0}}, + want: []byte{0x22, 0x0b, 0x0a, 0x07, 0xde, 0xad, 0xbe, 0xef, 0x80, 0x00, 0xff, 0x10, 0x00}, }, { label: "MapValue", proto2: &TestUTF8{MapValue: map[int64]string{0: invalidUTF8}}, proto3: &pb3.TestUTF8{MapValue: map[int64]string{0: invalidUTF8}}, + want: []byte{0x2a, 0x0b, 0x08, 0x00, 0x12, 0x07, 0xde, 0xad, 0xbe, 0xef, 0x80, 0x00, 0xff}, }} for _, tt := range tests { @@ -2284,22 +2290,37 @@ func TestInvalidUTF8(t *testing.T) { if err != nil { t.Errorf("Marshal(proto2.%s) = %v, want nil", tt.label, err) } - tt.proto2.Reset() - err = Unmarshal(b, tt.proto2) - if err != nil { + if !bytes.Equal(b, tt.want) { + t.Errorf("Marshal(proto2.%s) = %x, want %x", tt.label, b, tt.want) + } + + m := Clone(tt.proto2) + m.Reset() + if err = Unmarshal(tt.want, m); err != nil { t.Errorf("Unmarshal(proto2.%s) = %v, want nil", tt.label, err) } + if !Equal(m, tt.proto2) { + t.Errorf("proto2.%s: output mismatch:\ngot %v\nwant %v", tt.label, m, tt.proto2) + } // Proto3 should validate UTF-8. - _, err = Marshal(tt.proto3) + b, err = Marshal(tt.proto3) if err == nil { t.Errorf("Marshal(proto3.%s) = %v, want non-nil", tt.label, err) } - tt.proto3.Reset() - err = Unmarshal(b, tt.proto3) + if !bytes.Equal(b, tt.want) { + t.Errorf("Marshal(proto3.%s) = %x, want %x", tt.label, b, tt.want) + } + + m = Clone(tt.proto3) + m.Reset() + err = Unmarshal(tt.want, m) if err == nil { t.Errorf("Unmarshal(proto3.%s) = %v, want non-nil", tt.label, err) } + if !Equal(m, tt.proto3) { + t.Errorf("proto3.%s: output mismatch:\ngot %v\nwant %v", tt.label, m, tt.proto2) + } } } diff --git a/proto/encode.go b/proto/encode.go index 4c35d33738..3abfed2cff 100644 --- a/proto/encode.go +++ b/proto/encode.go @@ -37,24 +37,9 @@ package proto import ( "errors" - "fmt" "reflect" ) -// RequiredNotSetError is an error type returned by either Marshal or Unmarshal. -// Marshal reports this when a required field is not initialized. -// Unmarshal reports this when a required field is missing from the wire data. -type RequiredNotSetError struct { - field string -} - -func (e *RequiredNotSetError) Error() string { - if e.field == "" { - return fmt.Sprintf("proto: required field not set") - } - return fmt.Sprintf("proto: required field %q not set", e.field) -} - var ( // errRepeatedHasNil is the error returned if Marshal is called with // a struct with a repeated field containing a nil element. diff --git a/proto/lib.go b/proto/lib.go index 0e2191b8ad..75565cc6dc 100644 --- a/proto/lib.go +++ b/proto/lib.go @@ -265,7 +265,6 @@ package proto import ( "encoding/json" - "errors" "fmt" "log" "reflect" @@ -274,7 +273,66 @@ import ( "sync" ) -var errInvalidUTF8 = errors.New("proto: invalid UTF-8 string") +// RequiredNotSetError is an error type returned by either Marshal or Unmarshal. +// Marshal reports this when a required field is not initialized. +// Unmarshal reports this when a required field is missing from the wire data. +type RequiredNotSetError struct{ field string } + +func (e *RequiredNotSetError) Error() string { + if e.field == "" { + return fmt.Sprintf("proto: required field not set") + } + return fmt.Sprintf("proto: required field %q not set", e.field) +} +func (e *RequiredNotSetError) RequiredNotSet() bool { + return true +} + +type invalidUTF8Error struct{ field string } + +func (e *invalidUTF8Error) Error() string { + if e.field == "" { + return "proto: invalid UTF-8 detected" + } + return fmt.Sprintf("proto: field %q contains invalid UTF-8", e.field) +} +func (e *invalidUTF8Error) InvalidUTF8() bool { + return true +} + +// errInvalidUTF8 is a sentinel error to identify fields with invalid UTF-8. +// This error should not be exposed to the external API as such errors should +// be recreated with the field information. +var errInvalidUTF8 = &invalidUTF8Error{} + +// isNonFatal reports whether the error is either a RequiredNotSet error +// or a InvalidUTF8 error. +func isNonFatal(err error) bool { + if re, ok := err.(interface{ RequiredNotSet() bool }); ok && re.RequiredNotSet() { + return true + } + if re, ok := err.(interface{ InvalidUTF8() bool }); ok && re.InvalidUTF8() { + return true + } + return false +} + +type nonFatal struct{ E error } + +// Merge merges err into nf and reports whether it was successful. +// Otherwise it returns false for any fatal non-nil errors. +func (nf *nonFatal) Merge(err error) (ok bool) { + if err == nil { + return true // not an error + } + if !isNonFatal(err) { + return false // fatal error + } + if nf.E == nil { + nf.E = err // store first instance of non-fatal error + } + return true +} // Message is implemented by generated protocol buffer messages. type Message interface { diff --git a/proto/table_marshal.go b/proto/table_marshal.go index 5a98be78f0..eafe04d14b 100644 --- a/proto/table_marshal.go +++ b/proto/table_marshal.go @@ -231,7 +231,7 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte return b, err } - var err, errreq error + var err, errLater error // The old marshaler encodes extensions at beginning. if u.extensions.IsValid() { e := ptr.offset(u.extensions).toExtensions() @@ -252,11 +252,11 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte } } for _, f := range u.fields { - if f.required && errreq == nil { + if f.required && errLater == nil { if ptr.offset(f.field).getPointer().isNil() { // Required field is not set. // We record the error but keep going, to give a complete marshaling. - errreq = &RequiredNotSetError{f.name} + errLater = &RequiredNotSetError{f.name} continue } } @@ -269,8 +269,8 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte if err1, ok := err.(*RequiredNotSetError); ok { // Required field in submessage is not set. // We record the error but keep going, to give a complete marshaling. - if errreq == nil { - errreq = &RequiredNotSetError{f.name + "." + err1.field} + if errLater == nil { + errLater = &RequiredNotSetError{f.name + "." + err1.field} } continue } @@ -278,8 +278,11 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte err = errors.New("proto: repeated field " + f.name + " has nil element") } if err == errInvalidUTF8 { - fullName := revProtoTypes[reflect.PtrTo(u.typ)] + "." + f.name - err = fmt.Errorf("proto: string field %q contains invalid UTF-8", fullName) + if errLater == nil { + fullName := revProtoTypes[reflect.PtrTo(u.typ)] + "." + f.name + errLater = &invalidUTF8Error{fullName} + } + continue } return b, err } @@ -288,7 +291,7 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte s := *ptr.offset(u.unrecognized).toBytes() b = append(b, s...) } - return b, errreq + return b, errLater } // computeMarshalInfo initializes the marshal info. @@ -2038,52 +2041,68 @@ func appendStringSlice(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, e return b, nil } func appendUTF8StringValue(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) { + var invalidUTF8 bool v := *ptr.toString() if !utf8.ValidString(v) { - return nil, errInvalidUTF8 + invalidUTF8 = true } b = appendVarint(b, wiretag) b = appendVarint(b, uint64(len(v))) b = append(b, v...) + if invalidUTF8 { + return b, errInvalidUTF8 + } return b, nil } func appendUTF8StringValueNoZero(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) { + var invalidUTF8 bool v := *ptr.toString() if v == "" { return b, nil } if !utf8.ValidString(v) { - return nil, errInvalidUTF8 + invalidUTF8 = true } b = appendVarint(b, wiretag) b = appendVarint(b, uint64(len(v))) b = append(b, v...) + if invalidUTF8 { + return b, errInvalidUTF8 + } return b, nil } func appendUTF8StringPtr(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) { + var invalidUTF8 bool p := *ptr.toStringPtr() if p == nil { return b, nil } v := *p if !utf8.ValidString(v) { - return nil, errInvalidUTF8 + invalidUTF8 = true } b = appendVarint(b, wiretag) b = appendVarint(b, uint64(len(v))) b = append(b, v...) + if invalidUTF8 { + return b, errInvalidUTF8 + } return b, nil } func appendUTF8StringSlice(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) { + var invalidUTF8 bool s := *ptr.toStringSlice() for _, v := range s { if !utf8.ValidString(v) { - return nil, errInvalidUTF8 + invalidUTF8 = true } b = appendVarint(b, wiretag) b = appendVarint(b, uint64(len(v))) b = append(b, v...) } + if invalidUTF8 { + return b, errInvalidUTF8 + } return b, nil } func appendBytes(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) { @@ -2162,7 +2181,8 @@ func makeGroupSliceMarshaler(u *marshalInfo) (sizer, marshaler) { }, func(b []byte, ptr pointer, wiretag uint64, deterministic bool) ([]byte, error) { s := ptr.getPointerSlice() - var err, errreq error + var err error + var nerr nonFatal for _, v := range s { if v.isNil() { return b, errRepeatedHasNil @@ -2170,22 +2190,14 @@ func makeGroupSliceMarshaler(u *marshalInfo) (sizer, marshaler) { b = appendVarint(b, wiretag) // start group b, err = u.marshal(b, v, deterministic) b = appendVarint(b, wiretag+(WireEndGroup-WireStartGroup)) // end group - if err != nil { - if _, ok := err.(*RequiredNotSetError); ok { - // Required field in submessage is not set. - // We record the error but keep going, to give a complete marshaling. - if errreq == nil { - errreq = err - } - continue - } + if !nerr.Merge(err) { if err == ErrNil { err = errRepeatedHasNil } return b, err } } - return b, errreq + return b, nerr.E } } @@ -2229,7 +2241,8 @@ func makeMessageSliceMarshaler(u *marshalInfo) (sizer, marshaler) { }, func(b []byte, ptr pointer, wiretag uint64, deterministic bool) ([]byte, error) { s := ptr.getPointerSlice() - var err, errreq error + var err error + var nerr nonFatal for _, v := range s { if v.isNil() { return b, errRepeatedHasNil @@ -2239,22 +2252,14 @@ func makeMessageSliceMarshaler(u *marshalInfo) (sizer, marshaler) { b = appendVarint(b, uint64(siz)) b, err = u.marshal(b, v, deterministic) - if err != nil { - if _, ok := err.(*RequiredNotSetError); ok { - // Required field in submessage is not set. - // We record the error but keep going, to give a complete marshaling. - if errreq == nil { - errreq = err - } - continue - } + if !nerr.Merge(err) { if err == ErrNil { err = errRepeatedHasNil } return b, err } } - return b, errreq + return b, nerr.E } } @@ -2317,6 +2322,8 @@ func makeMapMarshaler(f *reflect.StructField) (sizer, marshaler) { if len(keys) > 1 && deterministic { sort.Sort(mapKeys(keys)) } + + var nerr nonFatal for _, k := range keys { ki := k.Interface() vi := m.MapIndex(k).Interface() @@ -2326,15 +2333,15 @@ func makeMapMarshaler(f *reflect.StructField) (sizer, marshaler) { siz := keySizer(kaddr, 1) + valCachedSizer(vaddr, 1) // tag of key = 1 (size=1), tag of val = 2 (size=1) b = appendVarint(b, uint64(siz)) b, err = keyMarshaler(b, kaddr, keyWireTag, deterministic) - if err != nil { + if !nerr.Merge(err) { return b, err } b, err = valMarshaler(b, vaddr, valWireTag, deterministic) - if err != nil && err != ErrNil { // allow nil value in map + if err != ErrNil && !nerr.Merge(err) { // allow nil value in map return b, err } } - return b, nil + return b, nerr.E } } @@ -2407,6 +2414,7 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de defer mu.Unlock() var err error + var nerr nonFatal // Fast-path for common cases: zero or one extensions. // Don't bother sorting the keys. @@ -2426,11 +2434,11 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de v := e.value p := toAddrPointer(&v, ei.isptr) b, err = ei.marshaler(b, p, ei.wiretag, deterministic) - if err != nil { + if !nerr.Merge(err) { return b, err } } - return b, nil + return b, nerr.E } // Sort the keys to provide a deterministic encoding. @@ -2457,11 +2465,11 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de v := e.value p := toAddrPointer(&v, ei.isptr) b, err = ei.marshaler(b, p, ei.wiretag, deterministic) - if err != nil { + if !nerr.Merge(err) { return b, err } } - return b, nil + return b, nerr.E } // message set format is: @@ -2518,6 +2526,7 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de defer mu.Unlock() var err error + var nerr nonFatal // Fast-path for common cases: zero or one extensions. // Don't bother sorting the keys. @@ -2544,12 +2553,12 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de v := e.value p := toAddrPointer(&v, ei.isptr) b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic) - if err != nil { + if !nerr.Merge(err) { return b, err } b = append(b, 1<<3|WireEndGroup) } - return b, nil + return b, nerr.E } // Sort the keys to provide a deterministic encoding. @@ -2583,11 +2592,11 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de p := toAddrPointer(&v, ei.isptr) b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic) b = append(b, 1<<3|WireEndGroup) - if err != nil { + if nerr.Merge(err) { return b, err } } - return b, nil + return b, nerr.E } // sizeV1Extensions computes the size of encoded data for a V1-API extension field. @@ -2630,6 +2639,7 @@ func (u *marshalInfo) appendV1Extensions(b []byte, m map[int32]Extension, determ sort.Ints(keys) var err error + var nerr nonFatal for _, k := range keys { e := m[int32(k)] if e.value == nil || e.desc == nil { @@ -2646,11 +2656,11 @@ func (u *marshalInfo) appendV1Extensions(b []byte, m map[int32]Extension, determ v := e.value p := toAddrPointer(&v, ei.isptr) b, err = ei.marshaler(b, p, ei.wiretag, deterministic) - if err != nil { + if !nerr.Merge(err) { return b, err } } - return b, nil + return b, nerr.E } // newMarshaler is the interface representing objects that can marshal themselves. diff --git a/proto/table_unmarshal.go b/proto/table_unmarshal.go index 90ec6c2155..de868ae927 100644 --- a/proto/table_unmarshal.go +++ b/proto/table_unmarshal.go @@ -138,8 +138,8 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error { if u.isMessageSet { return UnmarshalMessageSet(b, m.offset(u.extensions).toExtensions()) } - var reqMask uint64 // bitmask of required fields we've seen. - var rnse *RequiredNotSetError // an instance of a RequiredNotSetError returned by a submessage. + var reqMask uint64 // bitmask of required fields we've seen. + var errLater error for len(b) > 0 { // Read tag and wire type. // Special case 1 and 2 byte varints. @@ -175,17 +175,20 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error { reqMask |= f.reqMask continue } - if r, ok := err.(*RequiredNotSetError); ok { + if r, ok := err.(*RequiredNotSetError); ok && errLater == nil { // Remember this error, but keep parsing. We need to produce // a full parse even if a required field is missing. - rnse = r + errLater = r reqMask |= f.reqMask continue } if err != errInternalBadWireType { if err == errInvalidUTF8 { - fullName := revProtoTypes[reflect.PtrTo(u.typ)] + "." + f.name - err = fmt.Errorf("proto: string field %q contains invalid UTF-8", fullName) + if errLater == nil { + fullName := revProtoTypes[reflect.PtrTo(u.typ)] + "." + f.name + errLater = &invalidUTF8Error{fullName} + } + continue } return err } @@ -245,20 +248,16 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error { emap[int32(tag)] = e } } - if rnse != nil { - // A required field of a submessage/group is missing. Return that error. - return rnse - } - if reqMask != u.reqMask { + if reqMask != u.reqMask && errLater == nil { // A required field of this message is missing. for _, n := range u.reqFields { if reqMask&1 == 0 { - return &RequiredNotSetError{n} + errLater = &RequiredNotSetError{n} } reqMask >>= 1 } } - return nil + return errLater } // computeUnmarshalInfo fills in u with information for use @@ -1529,10 +1528,10 @@ func unmarshalUTF8StringValue(b []byte, f pointer, w int) ([]byte, error) { return nil, io.ErrUnexpectedEOF } v := string(b[:x]) + *f.toString() = v if !utf8.ValidString(v) { - return nil, errInvalidUTF8 + return b[x:], errInvalidUTF8 } - *f.toString() = v return b[x:], nil } @@ -1549,10 +1548,10 @@ func unmarshalUTF8StringPtr(b []byte, f pointer, w int) ([]byte, error) { return nil, io.ErrUnexpectedEOF } v := string(b[:x]) + *f.toStringPtr() = &v if !utf8.ValidString(v) { - return nil, errInvalidUTF8 + return b[x:], errInvalidUTF8 } - *f.toStringPtr() = &v return b[x:], nil } @@ -1569,11 +1568,11 @@ func unmarshalUTF8StringSlice(b []byte, f pointer, w int) ([]byte, error) { return nil, io.ErrUnexpectedEOF } v := string(b[:x]) - if !utf8.ValidString(v) { - return nil, errInvalidUTF8 - } s := f.toStringSlice() *s = append(*s, v) + if !utf8.ValidString(v) { + return b[x:], errInvalidUTF8 + } return b[x:], nil } @@ -1755,6 +1754,7 @@ func makeUnmarshalMap(f *reflect.StructField) unmarshaler { // Maps will be somewhat slow. Oh well. // Read key and value from data. + var nerr nonFatal k := reflect.New(kt) v := reflect.New(vt) for len(b) > 0 { @@ -1775,7 +1775,7 @@ func makeUnmarshalMap(f *reflect.StructField) unmarshaler { err = errInternalBadWireType // skip unknown tag } - if err == nil { + if nerr.Merge(err) { continue } if err != errInternalBadWireType { @@ -1798,7 +1798,7 @@ func makeUnmarshalMap(f *reflect.StructField) unmarshaler { // Insert into map. m.SetMapIndex(k.Elem(), v.Elem()) - return r, nil + return r, nerr.E } } @@ -1824,15 +1824,16 @@ func makeUnmarshalOneof(typ, ityp reflect.Type, unmarshal unmarshaler) unmarshal // Unmarshal data into holder. // We unmarshal into the first field of the holder object. var err error + var nerr nonFatal b, err = unmarshal(b, valToPointer(v).offset(field0), w) - if err != nil { + if !nerr.Merge(err) { return nil, err } // Write pointer to holder into target field. f.asPointerTo(ityp).Elem().Set(v) - return b, nil + return b, nerr.E } }