diff --git a/config/autonat.go b/config/autonat.go index a1a3f699cdac..64856faa6808 100644 --- a/config/autonat.go +++ b/config/autonat.go @@ -77,5 +77,5 @@ type AutoNATThrottleConfig struct { // global/peer dialback limits. // // When unset, this defaults to 1 minute. - Interval Duration `json:",omitempty"` + Interval OptionalDuration `json:",omitempty"` } diff --git a/config/types.go b/config/types.go index 14b9f7cc0209..ac90fa9b82ab 100644 --- a/config/types.go +++ b/config/types.go @@ -1,9 +1,9 @@ package config import ( - "encoding" "encoding/json" "fmt" + "strings" "time" ) @@ -211,27 +211,56 @@ func (p Priority) String() string { var _ json.Unmarshaler = (*Priority)(nil) var _ json.Marshaler = (*Priority)(nil) -// Duration wraps time.Duration to provide json serialization and deserialization. +// OptionalDuration wraps time.Duration to provide json serialization and deserialization. // -// NOTE: the zero value encodes to an empty string. -type Duration time.Duration +// NOTE: the zero value encodes to JSON nill +type OptionalDuration struct { + value *time.Duration +} -func (d *Duration) UnmarshalText(text []byte) error { - dur, err := time.ParseDuration(string(text)) - *d = Duration(dur) - return err +func (d *OptionalDuration) UnmarshalJSON(input []byte) error { + switch string(input) { + case "null", "undefined", "\"null\"", "", "default", "\"\"", "\"default\"": + *d = OptionalDuration{} + return nil + default: + text := strings.Trim(string(input), "\"") + value, err := time.ParseDuration(text) + if err != nil { + return err + } + *d = OptionalDuration{value: &value} + return nil + } } -func (d Duration) MarshalText() ([]byte, error) { - return []byte(time.Duration(d).String()), nil +func (d *OptionalDuration) IsDefault() bool { + return d == nil || d.value == nil } -func (d Duration) String() string { - return time.Duration(d).String() +func (d *OptionalDuration) WithDefault(defaultValue time.Duration) time.Duration { + if d == nil || d.value == nil { + return defaultValue + } + return *d.value +} + +func (d OptionalDuration) MarshalJSON() ([]byte, error) { + if d.value == nil { + return json.Marshal(nil) + } + return json.Marshal(d.value.String()) +} + +func (d OptionalDuration) String() string { + if d.value == nil { + return "default" + } + return d.value.String() } -var _ encoding.TextUnmarshaler = (*Duration)(nil) -var _ encoding.TextMarshaler = (*Duration)(nil) +var _ json.Unmarshaler = (*OptionalDuration)(nil) +var _ json.Marshaler = (*OptionalDuration)(nil) // OptionalInteger represents an integer that has a default value // diff --git a/config/types_test.go b/config/types_test.go index 94a7a633d075..06ea73a260c2 100644 --- a/config/types_test.go +++ b/config/types_test.go @@ -1,40 +1,128 @@ package config import ( + "bytes" "encoding/json" "testing" "time" ) -func TestDuration(t *testing.T) { - out, err := json.Marshal(Duration(time.Second)) - if err != nil { - t.Fatal(err) +func TestOptionalDuration(t *testing.T) { + makeDurationPointer := func(d time.Duration) *time.Duration { return &d } - } - expected := "\"1s\"" - if string(out) != expected { - t.Fatalf("expected %s, got %s", expected, string(out)) - } - var d Duration - err = json.Unmarshal(out, &d) - if err != nil { - t.Fatal(err) - } - if time.Duration(d) != time.Second { - t.Fatal("expected a second") - } - type Foo struct { - D Duration `json:",omitempty"` - } - out, err = json.Marshal(new(Foo)) - if err != nil { - t.Fatal(err) - } - expected = "{}" - if string(out) != expected { - t.Fatal("expected omitempty to omit the duration") - } + t.Run("marshalling and unmarshalling", func(t *testing.T) { + out, err := json.Marshal(OptionalDuration{value: makeDurationPointer(time.Second)}) + if err != nil { + t.Fatal(err) + } + expected := "\"1s\"" + if string(out) != expected { + t.Fatalf("expected %s, got %s", expected, string(out)) + } + var d OptionalDuration + + if err := json.Unmarshal(out, &d); err != nil { + t.Fatal(err) + } + if *d.value != time.Second { + t.Fatal("expected a second") + } + }) + + t.Run("default value", func(t *testing.T) { + for _, jsonStr := range []string{"null", "\"null\"", "\"\"", "\"default\""} { + var d OptionalDuration + if !d.IsDefault() { + t.Fatal("expected value to be the default initially") + } + if err := json.Unmarshal([]byte(jsonStr), &d); err != nil { + t.Fatalf("%s failed to unmarshall with %s", jsonStr, err) + } + if dur := d.WithDefault(time.Hour); dur != time.Hour { + t.Fatalf("expected default value to be used, got %s", dur) + } + if !d.IsDefault() { + t.Fatal("expected value to be the default") + } + } + }) + + t.Run("omitempty with default value", func(t *testing.T) { + type Foo struct { + D *OptionalDuration `json:",omitempty"` + } + // marshall to JSON without empty field + out, err := json.Marshal(new(Foo)) + if err != nil { + t.Fatal(err) + } + if string(out) != "{}" { + t.Fatalf("expected omitempty to omit the duration, got %s", out) + } + // unmarshall missing value and get the default + var foo2 Foo + if err := json.Unmarshal(out, &foo2); err != nil { + t.Fatalf("%s failed to unmarshall with %s", string(out), err) + } + if dur := foo2.D.WithDefault(time.Hour); dur != time.Hour { + t.Fatalf("expected default value to be used, got %s", dur) + } + if !foo2.D.IsDefault() { + t.Fatal("expected value to be the default") + } + }) + + t.Run("roundtrip including the default values", func(t *testing.T) { + for jsonStr, goValue := range map[string]OptionalDuration{ + // there are various footguns user can hit, normalize them to the canonical default + "null": {}, // JSON null → default value + "\"null\"": {}, // JSON string "null" sent/set by "ipfs config" cli → default value + "\"default\"": {}, // explicit "default" as string + "\"\"": {}, // user removed custom value, empty string should also parse as default + "\"1s\"": {value: makeDurationPointer(time.Second)}, + "\"42h1m3s\"": {value: makeDurationPointer(42*time.Hour + 1*time.Minute + 3*time.Second)}, + } { + var d OptionalDuration + err := json.Unmarshal([]byte(jsonStr), &d) + if err != nil { + t.Fatal(err) + } + + if goValue.value == nil && d.value == nil { + } else if goValue.value == nil && d.value != nil { + t.Errorf("expected nil for %s, got %s", jsonStr, d) + } else if *d.value != *goValue.value { + t.Fatalf("expected %s for %s, got %s", goValue, jsonStr, d) + } + + // Test Reverse + out, err := json.Marshal(goValue) + if err != nil { + t.Fatal(err) + } + if goValue.value == nil { + if !bytes.Equal(out, []byte("null")) { + t.Fatalf("expected JSON null for %s, got %s", jsonStr, string(out)) + } + continue + } + if string(out) != jsonStr { + t.Fatalf("expected %s, got %s", jsonStr, string(out)) + } + } + }) + + t.Run("invalid duration values", func(t *testing.T) { + for _, invalid := range []string{ + "\"s\"", "\"1ę\"", "\"-1\"", "\"1H\"", "\"day\"", + } { + var d OptionalDuration + err := json.Unmarshal([]byte(invalid), &d) + if err == nil { + t.Errorf("expected to fail to decode %s as an OptionalDuration, got %s instead", invalid, d) + } + } + }) } func TestOneStrings(t *testing.T) {