Skip to content

Commit

Permalink
decode: expand test and mark fields decoded recursively
Browse files Browse the repository at this point in the history
This commit expands the test `TestDecodeCustomStructMarkedDecoded`
so that nested data is also tested (thanks to  Martin Tournoij).

This lead to an update of the code to also mark all nested keys as
decoded when a custom `UnmarshalTOML(data any)` interface is used.

It is the job of the implemenator of the UnmarshalTOML() to ensure
everything is decoded. An alternative would be to provide a new
`UnmarshalTOMLWithMetadata(data any, md *MetaData)` inteface that
would allow the implementaor of the custom unmarshaller to mark
fields as decoded. But it's unclear if that is needed because
a UnmarshalTOML() can already error if it "sees" unexpected or
missing data.
  • Loading branch information
mvo5 committed Sep 26, 2024
1 parent 70d427d commit 8ff4571
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 9 deletions.
23 changes: 17 additions & 6 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,19 @@ func (md *MetaData) PrimitiveDecode(primValue Primitive, v any) error {
return md.unify(primValue.undecoded, rvalue(v))
}

// markDecodedRecursive is a helper to mark any key under the given tmap
// as decoded, recursing as needed
func markDecodedRecursive(md *MetaData, tmap map[string]any) {
for key := range tmap {
md.decoded[md.context.add(key).String()] = struct{}{}
if tmap, ok := tmap[key].(map[string]any); ok {
md.context = append(md.context, key)
markDecodedRecursive(md, tmap)
md.context = md.context[0 : len(md.context)-1]
}
}
}

// unify performs a sort of type unification based on the structure of `rv`,
// which is the client representation.
//
Expand All @@ -222,12 +235,10 @@ func (md *MetaData) unify(data any, rv reflect.Value) error {
if err != nil {
return md.parseErr(err)
}
// assume the Unmarshaler did it's job and decoded all fields
tmap, ok := data.(map[string]any)
if ok {
for key := range tmap {
md.decoded[md.context.add(key).String()] = struct{}{}
}
// assume the Unmarshaler did it's job and mark all
// keys under this map decoded
if tmap, ok := data.(map[string]any); ok {
markDecodedRecursive(md, tmap)
}
return nil
}
Expand Down
28 changes: 25 additions & 3 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1149,19 +1149,31 @@ func BenchmarkKey(b *testing.B) {
}

type CustomStruct struct {
Foo string `json:"foo"`
Foo string
TblB int64
TblInlineC int64
}

func (cs *CustomStruct) UnmarshalTOML(data interface{}) error {
d, _ := data.(map[string]interface{})
cs.Foo = d["foo"].(string)
cs.TblB = d["tbl"].(map[string]interface{})["b"].(int64)
cs.TblInlineC = d["tbl"].(map[string]interface{})["inline"].(map[string]interface{})["c"].(int64)

return nil
}

func TestDecodeCustomStruct(t *testing.T) {
func TestDecodeCustomStructMarkedDecoded(t *testing.T) {
var cs CustomStruct
meta, err := Decode(`
foo = "bar"
foo = "bar"
a = 1
arr = [2]
[tbl]
b = 3
inline = {c = 4}
`, &cs)
if err != nil {
t.Fatalf("Decode failed: %s", err)
Expand All @@ -1170,6 +1182,16 @@ func TestDecodeCustomStruct(t *testing.T) {
if cs.Foo != "bar" {
t.Errorf("\nhave:\n%v\nwant:\n%v\n", cs.Foo, "bar")
}
if cs.TblB != 3 {
t.Errorf("\nhave:\n%v\nwant:\n%v\n", cs.TblB, 3)
}
if cs.TblInlineC != 4 {
t.Errorf("\nhave:\n%v\nwant:\n%v\n", cs.TblB, 4)
}
// Note that even though the custom unmarshaler did not decode
// all fields as far as the metadata is concerned they are handlded.
// It is the job of the unmarshaler to ensure this or we would need
// a more powerful interface like UnmarshalTOML(data any, md *MetaData)
if len(meta.Undecoded()) > 0 {
t.Errorf("\ncustom decode leaves unencoded fields: %v\n", meta.Undecoded())
}
Expand Down

0 comments on commit 8ff4571

Please sign in to comment.