diff --git a/core/capabilities/compute/transformer.go b/core/capabilities/compute/transformer.go new file mode 100644 index 00000000000..4b705d1dab0 --- /dev/null +++ b/core/capabilities/compute/transformer.go @@ -0,0 +1,149 @@ +package compute + +import ( + "errors" + "fmt" + "time" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/values" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host" +) + +type Transformer[T any, U any] interface { + Transform(T, ...func(*U)) (*U, error) +} + +type ConfigTransformer = Transformer[*values.Map, ParsedConfig] + +type ParsedConfig struct { + Binary []byte + Config []byte + ModuleConfig *host.ModuleConfig +} + +type transformer struct{} + +func (t *transformer) Transform(in *values.Map, opts ...func(*ParsedConfig)) (*ParsedConfig, error) { + binary, err := popValue[[]byte](in, binaryKey) + if err != nil { + return nil, NewInvalidRequestError(err) + } + + config, err := popValue[[]byte](in, configKey) + if err != nil { + return nil, NewInvalidRequestError(err) + } + + maxMemoryMBs, err := popOptionalValue[int64](in, "maxMemoryMBs") + if err != nil { + return nil, NewInvalidRequestError(err) + } + + mc := &host.ModuleConfig{ + MaxMemoryMBs: maxMemoryMBs, + } + + timeout, err := popOptionalValue[string](in, "timeout") + if err != nil { + return nil, NewInvalidRequestError(err) + } + + var td time.Duration + if timeout != "" { + td, err = time.ParseDuration(timeout) + if err != nil { + return nil, NewInvalidRequestError(err) + } + mc.Timeout = &td + } + + tickInterval, err := popOptionalValue[string](in, "tickInterval") + if err != nil { + return nil, NewInvalidRequestError(err) + } + + var ti time.Duration + if tickInterval != "" { + ti, err = time.ParseDuration(tickInterval) + if err != nil { + return nil, NewInvalidRequestError(err) + } + mc.TickInterval = ti + } + + pc := &ParsedConfig{ + Binary: binary, + Config: config, + ModuleConfig: mc, + } + + for _, opt := range opts { + opt(pc) + } + + return pc, nil +} + +func NewTransformer() *transformer { + return &transformer{} +} + +func WithLogger(l logger.Logger) func(*ParsedConfig) { + return func(pc *ParsedConfig) { + pc.ModuleConfig.Logger = l + } +} + +func popOptionalValue[T any](m *values.Map, key string) (T, error) { + v, err := popValue[T](m, key) + if err != nil { + var nfe *NotFoundError + if errors.As(err, &nfe) { + return v, nil + } + return v, err + } + return v, nil +} + +func popValue[T any](m *values.Map, key string) (T, error) { + var empty T + + wrapped, ok := m.Underlying[key] + if !ok { + return empty, NewNotFoundError(key) + } + + delete(m.Underlying, key) + err := wrapped.UnwrapTo(&empty) + if err != nil { + return empty, fmt.Errorf("could not unwrap value: %w", err) + } + + return empty, nil +} + +type NotFoundError struct { + Key string +} + +func (e *NotFoundError) Error() string { + return fmt.Sprintf("could not find %q in map", e.Key) +} + +func NewNotFoundError(key string) *NotFoundError { + return &NotFoundError{Key: key} +} + +type InvalidRequestError struct { + Err error +} + +func (e *InvalidRequestError) Error() string { + return fmt.Sprintf("invalid request: %v", e.Err) +} + +func NewInvalidRequestError(err error) *InvalidRequestError { + return &InvalidRequestError{Err: err} +} diff --git a/core/capabilities/compute/transformer_test.go b/core/capabilities/compute/transformer_test.go new file mode 100644 index 00000000000..3da152de28b --- /dev/null +++ b/core/capabilities/compute/transformer_test.go @@ -0,0 +1,167 @@ +package compute + +import ( + "testing" + "time" + + "github.com/smartcontractkit/chainlink-common/pkg/values" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host" + "github.com/smartcontractkit/chainlink/v2/core/logger" + + "github.com/stretchr/testify/assert" +) + +func Test_NotFoundError(t *testing.T) { + nfe := NewNotFoundError("test") + assert.Equal(t, "could not find \"test\" in map", nfe.Error()) +} + +func Test_popValue(t *testing.T) { + m, err := values.NewMap( + map[string]any{ + "test": "value", + "mismatch": 42, + }, + ) + assert.NoError(t, err) + + t.Run("success", func(t *testing.T) { + var gotValue string + gotValue, err = popValue[string](m, "test") + assert.NoError(t, err) + assert.Equal(t, "value", gotValue) + }) + + t.Run("not found", func(t *testing.T) { + _, err = popValue[string](m, "foo") + var nfe *NotFoundError + assert.ErrorAs(t, err, &nfe) + }) + + t.Run("type mismatch", func(t *testing.T) { + _, err = popValue[string](m, "mismatch") + assert.Error(t, err) + assert.ErrorContains(t, err, "could not unwrap value") + }) + + assert.Len(t, m.Underlying, 0) +} + +func Test_popOptionalValue(t *testing.T) { + m, err := values.NewMap( + map[string]any{ + "test": "value", + "buzz": "fizz", + }, + ) + assert.NoError(t, err) + t.Run("found value", func(t *testing.T) { + var gotValue string + gotValue, err = popOptionalValue[string](m, "test") + assert.NoError(t, err) + assert.Equal(t, "value", gotValue) + }) + + t.Run("not found returns nil error", func(t *testing.T) { + var gotValue string + gotValue, err = popOptionalValue[string](m, "foo") + assert.NoError(t, err) + assert.Zero(t, gotValue) + }) + + t.Run("some other error fails", func(t *testing.T) { + var gotValue int + gotValue, err = popOptionalValue[int](m, "buzz") + assert.Error(t, err) + assert.Zero(t, gotValue) + }) + + assert.Len(t, m.Underlying, 0) +} + +func Test_transformer(t *testing.T) { + t.Run("success", func(t *testing.T) { + lgger := logger.TestLogger(t) + giveMap, err := values.NewMap(map[string]any{ + "maxMemoryMBs": 1024, + "timeout": "4s", + "tickInterval": "8s", + "binary": []byte{0x01, 0x02, 0x03}, + "config": []byte{0x04, 0x05, 0x06}, + }) + assert.NoError(t, err) + + wantTO := 4 * time.Second + wantConfig := &ParsedConfig{ + Binary: []byte{0x01, 0x02, 0x03}, + Config: []byte{0x04, 0x05, 0x06}, + ModuleConfig: &host.ModuleConfig{ + MaxMemoryMBs: 1024, + Timeout: &wantTO, + TickInterval: 8 * time.Second, + Logger: lgger, + }, + } + + tf := NewTransformer() + gotConfig, err := tf.Transform(giveMap, WithLogger(lgger)) + + assert.NoError(t, err) + assert.Equal(t, wantConfig, gotConfig) + }) + + t.Run("success missing optional fields", func(t *testing.T) { + lgger := logger.TestLogger(t) + giveMap, err := values.NewMap(map[string]any{ + "binary": []byte{0x01, 0x02, 0x03}, + "config": []byte{0x04, 0x05, 0x06}, + }) + assert.NoError(t, err) + + wantConfig := &ParsedConfig{ + Binary: []byte{0x01, 0x02, 0x03}, + Config: []byte{0x04, 0x05, 0x06}, + ModuleConfig: &host.ModuleConfig{ + Logger: lgger, + }, + } + + tf := NewTransformer() + gotConfig, err := tf.Transform(giveMap, WithLogger(lgger)) + + assert.NoError(t, err) + assert.Equal(t, wantConfig, gotConfig) + }) + + t.Run("fails parsing timeout", func(t *testing.T) { + lgger := logger.TestLogger(t) + giveMap, err := values.NewMap(map[string]any{ + "timeout": "not a duration", + "binary": []byte{0x01, 0x02, 0x03}, + "config": []byte{0x04, 0x05, 0x06}, + }) + assert.NoError(t, err) + + tf := NewTransformer() + _, err = tf.Transform(giveMap, WithLogger(lgger)) + + assert.Error(t, err) + assert.ErrorContains(t, err, "invalid request") + }) + + t.Run("fails parsing tick interval", func(t *testing.T) { + lgger := logger.TestLogger(t) + giveMap, err := values.NewMap(map[string]any{ + "tickInterval": "not a duration", + "binary": []byte{0x01, 0x02, 0x03}, + "config": []byte{0x04, 0x05, 0x06}, + }) + assert.NoError(t, err) + + tf := NewTransformer() + _, err = tf.Transform(giveMap, WithLogger(lgger)) + + assert.Error(t, err) + assert.ErrorContains(t, err, "invalid request") + }) +}