Skip to content

Commit

Permalink
Handle ptr smart ptrs (#44)
Browse files Browse the repository at this point in the history
- Added new default & set types for `chan`, `map`, `slice`, `time.Time`
and `Pointer` types.
- Updated to handle pointer to a `Slice` or `Map` for situations where
code is not under your control. Fixes #43
  • Loading branch information
deankarn authored May 28, 2023
1 parent 05050dc commit 4214131
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 1 deletion.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package mold
============
![Project status](https://img.shields.io/badge/version-4.4.0-green.svg)
![Project status](https://img.shields.io/badge/version-4.5.0-green.svg)
[![Build Status](https://travis-ci.org/go-playground/mold.svg?branch=v2)](https://travis-ci.org/go-playground/mold)
[![Coverage Status](https://coveralls.io/repos/github/go-playground/mold/badge.svg?branch=v2)](https://coveralls.io/github/go-playground/mold?branch=v2)
[![Go Report Card](https://goreportcard.com/badge/github.com/go-playground/mold)](https://goreportcard.com/report/github.com/go-playground/mold)
Expand Down Expand Up @@ -58,7 +58,14 @@ These functions modify the data in-place.
| ucase | Uppercases the data. |
| ucfirst | Upper cases the first character of the data. |

**Special Notes:**
`default` and `set` modifiers are special in that they can be used to set the value of a field or underlying type information or attributes and both use the same underlying function to set the data.

Setting a Param will have the following special effects on data types where it's not just the value being set:
- Chan - param used to set the buffer size, default = 0.
- Slice - param used to set the capacity, default = 0.
- Map - param used to set the size, default = 0.
- time.Time - param used to set the time format OR value, default = time.Now(), `utc` = time.Now().UTC(), other tries to parse using RFC3339Nano and set a time value.

Scrubbers
----------
Expand Down
53 changes: 53 additions & 0 deletions modifiers/multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import (
"context"
"reflect"
"strconv"
"strings"
"time"

"github.com/go-playground/mold/v4"
)

var (
durationType = reflect.TypeOf(time.Duration(0))
timeType = reflect.TypeOf(time.Time{})
)

// defaultValue allows setting of a default value IF no value is already present.
Expand Down Expand Up @@ -73,6 +75,57 @@ func setValue(ctx context.Context, fl mold.FieldLevel) error {
}
fl.Field().SetBool(value)

case reflect.Map:
var n int
var err error
if fl.Param() != "" {
n, err = strconv.Atoi(fl.Param())
if err != nil {
return err
}
}
fl.Field().Set(reflect.MakeMapWithSize(fl.Field().Type(), n))

case reflect.Slice:
var cap int
var err error
if fl.Param() != "" {
cap, err = strconv.Atoi(fl.Param())
if err != nil {
return err
}
}
fl.Field().Set(reflect.MakeSlice(fl.Field().Type(), 0, cap))

case reflect.Struct:
if fl.Field().Type() == timeType {
if fl.Param() != "" {
if strings.ToLower(fl.Param()) == "utc" {
fl.Field().Set(reflect.ValueOf(time.Now().UTC()))
} else {
t, err := time.Parse(time.RFC3339Nano, fl.Param())
if err != nil {
return err
}
fl.Field().Set(reflect.ValueOf(t))
}
} else {
fl.Field().Set(reflect.ValueOf(time.Now()))
}
}
case reflect.Chan:
var buffer int
var err error
if fl.Param() != "" {
buffer, err = strconv.Atoi(fl.Param())
if err != nil {
return err
}
}
fl.Field().Set(reflect.MakeChan(fl.Field().Type(), buffer))

case reflect.Ptr:
fl.Field().Set(reflect.New(fl.Field().Type().Elem()))
}
return nil
}
Expand Down
196 changes: 196 additions & 0 deletions modifiers/multi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,202 @@ import (
. "github.com/go-playground/assert/v2"
)

func TestDefaultSetSpecialTypes(t *testing.T) {
conform := New()

tests := []struct {
name string
field interface{}
tags string
vf func(field interface{})
expectError bool
}{
{
name: "default map",
field: (map[string]struct{})(nil),
tags: "default",
vf: func(field interface{}) {
m := field.(map[string]struct{})
Equal(t, len(m), 0)
},
},
{
name: "default map with size",
field: (map[string]struct{})(nil),
tags: "default=5",
vf: func(field interface{}) {
m := field.(map[string]struct{})
Equal(t, len(m), 0)
},
},
{
name: "set map with size",
field: (map[string]struct{})(nil),
tags: "set=5",
vf: func(field interface{}) {
m := field.(map[string]struct{})
Equal(t, len(m), 0)
},
},
{
name: "default slice",
field: ([]string)(nil),
tags: "default",
vf: func(field interface{}) {
m := field.([]string)
Equal(t, len(m), 0)
Equal(t, cap(m), 0)
},
},
{
name: "default slice with capacity",
field: ([]string)(nil),
tags: "default=5",
vf: func(field interface{}) {
m := field.([]string)
Equal(t, len(m), 0)
Equal(t, cap(m), 5)
},
},
{
name: "set slice",
field: ([]string)(nil),
tags: "set",
vf: func(field interface{}) {
m := field.([]string)
Equal(t, len(m), 0)
Equal(t, cap(m), 0)
},
},
{
name: "set slice with capacity",
field: ([]string)(nil),
tags: "set=5",
vf: func(field interface{}) {
m := field.([]string)
Equal(t, len(m), 0)
Equal(t, cap(m), 5)
},
},
{
name: "default chan",
field: (chan struct{})(nil),
tags: "default",
vf: func(field interface{}) {
m := field.(chan struct{})
Equal(t, len(m), 0)
Equal(t, cap(m), 0)
},
},
{
name: "default chan with buffer",
field: (chan struct{})(nil),
tags: "default=5",
vf: func(field interface{}) {
m := field.(chan struct{})
Equal(t, len(m), 0)
Equal(t, cap(m), 5)
},
},
{
name: "default time.Time",
field: time.Time{},
tags: "default",
vf: func(field interface{}) {
m := field.(time.Time)
Equal(t, m.Location(), time.Local)
},
},
{
name: "default time.Time utc",
field: time.Time{},
tags: "default=utc",
vf: func(field interface{}) {
m := field.(time.Time)
Equal(t, m.Location(), time.UTC)
},
},
{
name: "default time.Time to value",
field: time.Time{},
tags: "default=2023-05-28T15:50:31Z",
vf: func(field interface{}) {
m := field.(time.Time)
Equal(t, m.Location(), time.UTC)

tm, err := time.Parse(time.RFC3339Nano, "2023-05-28T15:50:31Z")
Equal(t, err, nil)
Equal(t, tm.Equal(m), true)

},
},
{
name: "set time.Time",
field: time.Time{},
tags: "set",
vf: func(field interface{}) {
m := field.(time.Time)
Equal(t, m.Location(), time.Local)
},
},
{
name: "set time.Time utc",
field: time.Time{},
tags: "set=utc",
vf: func(field interface{}) {
m := field.(time.Time)
Equal(t, m.Location(), time.UTC)
},
},
{
name: "set time.Time to value",
field: time.Time{},
tags: "set=2023-05-28T15:50:31Z",
vf: func(field interface{}) {
m := field.(time.Time)
Equal(t, m.Location(), time.UTC)

tm, err := time.Parse(time.RFC3339Nano, "2023-05-28T15:50:31Z")
Equal(t, err, nil)
Equal(t, tm.Equal(m), true)

},
},
{
name: "default pointer to slice",
field: (*[]string)(nil),
tags: "default",
vf: func(field interface{}) {
m := field.([]string)
Equal(t, len(m), 0)
},
},
{
name: "set pointer to slice",
field: (*[]string)(nil),
tags: "set",
vf: func(field interface{}) {
m := field.([]string)
Equal(t, len(m), 0)
},
},
}

for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
err := conform.Field(context.Background(), &tc.field, tc.tags)
if tc.expectError {
NotEqual(t, err, nil)
return
}
Equal(t, err, nil)
tc.vf(tc.field)
})
}
}

func TestSet(t *testing.T) {

type State int
Expand Down
10 changes: 10 additions & 0 deletions mold.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,14 @@ func (t *Transformer) setByField(ctx context.Context, orig reflect.Value, ct *cT
err = t.setByIterable(ctx, current, ct)
case reflect.Map:
err = t.setByMap(ctx, current, ct)
case reflect.Ptr:
innerKind := current.Type().Elem().Kind()
if innerKind == reflect.Slice || innerKind == reflect.Map {
// is a nil pointer to a slice or map, nothing to do.
return nil
}
// not a valid use of the dive tag
fallthrough
default:
err = ErrInvalidDive
}
Expand Down Expand Up @@ -267,6 +275,8 @@ func (t *Transformer) setByField(ctx context.Context, orig reflect.Value, ct *cT
}); err != nil {
return
}
// value could have been changed or reassigned
current, kind = t.extractType(current)
}
ct = ct.next
}
Expand Down

0 comments on commit 4214131

Please sign in to comment.