diff --git a/.gitignore b/.gitignore index 38a82dbea7c..7c0d863b3c6 100644 --- a/.gitignore +++ b/.gitignore @@ -33,6 +33,7 @@ tools/clroot/db.sqlite3-wal debug.env *.txt operator_ui/install +.devenv # codeship *.aes diff --git a/core/capabilities/compute/compute.go b/core/capabilities/compute/compute.go index 7dedfb80d17..5a43b7bf40b 100644 --- a/core/capabilities/compute/compute.go +++ b/core/capabilities/compute/compute.go @@ -16,7 +16,6 @@ import ( capabilitiespb "github.com/smartcontractkit/chainlink-common/pkg/capabilities/pb" "github.com/smartcontractkit/chainlink-common/pkg/logger" coretypes "github.com/smartcontractkit/chainlink-common/pkg/types/core" - "github.com/smartcontractkit/chainlink-common/pkg/values" "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host" wasmpb "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/pb" ) @@ -24,8 +23,11 @@ import ( const ( CapabilityIDCompute = "custom_compute@1.0.0" - binaryKey = "binary" - configKey = "config" + binaryKey = "binary" + configKey = "config" + maxMemoryMBsKey = "maxMemoryMBs" + timeoutKey = "timeout" + tickIntervalKey = "tickInterval" ) var ( @@ -65,6 +67,8 @@ type Compute struct { log logger.Logger registry coretypes.CapabilitiesRegistry modules *moduleCache + + transformer ConfigTransformer } func (c *Compute) RegisterToWorkflow(ctx context.Context, request capabilities.RegisterToWorkflowRequest) error { @@ -91,21 +95,16 @@ func copyRequest(req capabilities.CapabilityRequest) capabilities.CapabilityRequ func (c *Compute) Execute(ctx context.Context, request capabilities.CapabilityRequest) (capabilities.CapabilityResponse, error) { copied := copyRequest(request) - binary, err := c.popBytesValue(copied.Config, binaryKey) - if err != nil { - return capabilities.CapabilityResponse{}, fmt.Errorf("invalid request: %w", err) - } - - config, err := c.popBytesValue(copied.Config, configKey) + cfg, err := c.transformer.Transform(copied.Config, WithLogger(c.log)) if err != nil { - return capabilities.CapabilityResponse{}, fmt.Errorf("invalid request: %w", err) + return capabilities.CapabilityResponse{}, fmt.Errorf("invalid request: could not transform config: %w", err) } - id := generateID(binary) + id := generateID(cfg.Binary) m, ok := c.modules.get(id) if !ok { - mod, err := c.initModule(id, binary, request.Metadata.WorkflowID, request.Metadata.ReferenceID) + mod, err := c.initModule(id, cfg.ModuleConfig, cfg.Binary, request.Metadata.WorkflowID, request.Metadata.ReferenceID) if err != nil { return capabilities.CapabilityResponse{}, err } @@ -113,12 +112,12 @@ func (c *Compute) Execute(ctx context.Context, request capabilities.CapabilityRe m = mod } - return c.executeWithModule(m.module, config, request) + return c.executeWithModule(m.module, cfg.Config, request) } -func (c *Compute) initModule(id string, binary []byte, workflowID, referenceID string) (*module, error) { +func (c *Compute) initModule(id string, cfg *host.ModuleConfig, binary []byte, workflowID, referenceID string) (*module, error) { initStart := time.Now() - mod, err := host.NewModule(&host.ModuleConfig{Logger: c.log}, binary) + mod, err := host.NewModule(cfg, binary) if err != nil { return nil, fmt.Errorf("failed to instantiate WASM module: %w", err) } @@ -133,21 +132,6 @@ func (c *Compute) initModule(id string, binary []byte, workflowID, referenceID s return m, nil } -func (c *Compute) popBytesValue(m *values.Map, key string) ([]byte, error) { - v, ok := m.Underlying[key] - if !ok { - return nil, fmt.Errorf("could not find %q in map", key) - } - - vb, ok := v.(*values.Bytes) - if !ok { - return nil, fmt.Errorf("value is not bytes: %q", key) - } - - delete(m.Underlying, key) - return vb.Underlying, nil -} - func (c *Compute) executeWithModule(module *host.Module, config []byte, req capabilities.CapabilityRequest) (capabilities.CapabilityResponse, error) { executeStart := time.Now() capReq := capabilitiespb.CapabilityRequestToProto(req) @@ -204,9 +188,10 @@ func (c *Compute) Close() error { func NewAction(log logger.Logger, registry coretypes.CapabilitiesRegistry) *Compute { compute := &Compute{ - log: logger.Named(log, "CustomCompute"), - registry: registry, - modules: newModuleCache(clockwork.NewRealClock(), 1*time.Minute, 10*time.Minute, 3), + log: logger.Named(log, "CustomCompute"), + registry: registry, + modules: newModuleCache(clockwork.NewRealClock(), 1*time.Minute, 10*time.Minute, 3), + transformer: NewTransformer(), } return compute } diff --git a/core/capabilities/compute/transformer.go b/core/capabilities/compute/transformer.go new file mode 100644 index 00000000000..7eca7b7c72f --- /dev/null +++ b/core/capabilities/compute/transformer.go @@ -0,0 +1,149 @@ +package compute + +import ( + "errors" + "fmt" + "time" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/values" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host" +) + +type Transformer[T any, U any] interface { + Transform(T, ...func(*U)) (*U, error) +} + +type ConfigTransformer = Transformer[*values.Map, ParsedConfig] + +type ParsedConfig struct { + Binary []byte + Config []byte + ModuleConfig *host.ModuleConfig +} + +type transformer struct{} + +func (t *transformer) Transform(in *values.Map, opts ...func(*ParsedConfig)) (*ParsedConfig, error) { + binary, err := popValue[[]byte](in, binaryKey) + if err != nil { + return nil, NewInvalidRequestError(err) + } + + config, err := popValue[[]byte](in, configKey) + if err != nil { + return nil, NewInvalidRequestError(err) + } + + maxMemoryMBs, err := popOptionalValue[int64](in, maxMemoryMBsKey) + if err != nil { + return nil, NewInvalidRequestError(err) + } + + mc := &host.ModuleConfig{ + MaxMemoryMBs: maxMemoryMBs, + } + + timeout, err := popOptionalValue[string](in, timeoutKey) + if err != nil { + return nil, NewInvalidRequestError(err) + } + + var td time.Duration + if timeout != "" { + td, err = time.ParseDuration(timeout) + if err != nil { + return nil, NewInvalidRequestError(err) + } + mc.Timeout = &td + } + + tickInterval, err := popOptionalValue[string](in, tickIntervalKey) + if err != nil { + return nil, NewInvalidRequestError(err) + } + + var ti time.Duration + if tickInterval != "" { + ti, err = time.ParseDuration(tickInterval) + if err != nil { + return nil, NewInvalidRequestError(err) + } + mc.TickInterval = ti + } + + pc := &ParsedConfig{ + Binary: binary, + Config: config, + ModuleConfig: mc, + } + + for _, opt := range opts { + opt(pc) + } + + return pc, nil +} + +func NewTransformer() *transformer { + return &transformer{} +} + +func WithLogger(l logger.Logger) func(*ParsedConfig) { + return func(pc *ParsedConfig) { + pc.ModuleConfig.Logger = l + } +} + +func popOptionalValue[T any](m *values.Map, key string) (T, error) { + v, err := popValue[T](m, key) + if err != nil { + var nfe *NotFoundError + if errors.As(err, &nfe) { + return v, nil + } + return v, err + } + return v, nil +} + +func popValue[T any](m *values.Map, key string) (T, error) { + var empty T + + wrapped, ok := m.Underlying[key] + if !ok { + return empty, NewNotFoundError(key) + } + + delete(m.Underlying, key) + err := wrapped.UnwrapTo(&empty) + if err != nil { + return empty, fmt.Errorf("could not unwrap value: %w", err) + } + + return empty, nil +} + +type NotFoundError struct { + Key string +} + +func (e *NotFoundError) Error() string { + return fmt.Sprintf("could not find %q in map", e.Key) +} + +func NewNotFoundError(key string) *NotFoundError { + return &NotFoundError{Key: key} +} + +type InvalidRequestError struct { + Err error +} + +func (e *InvalidRequestError) Error() string { + return fmt.Sprintf("invalid request: %v", e.Err) +} + +func NewInvalidRequestError(err error) *InvalidRequestError { + return &InvalidRequestError{Err: err} +} diff --git a/core/capabilities/compute/transformer_test.go b/core/capabilities/compute/transformer_test.go new file mode 100644 index 00000000000..3da152de28b --- /dev/null +++ b/core/capabilities/compute/transformer_test.go @@ -0,0 +1,167 @@ +package compute + +import ( + "testing" + "time" + + "github.com/smartcontractkit/chainlink-common/pkg/values" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host" + "github.com/smartcontractkit/chainlink/v2/core/logger" + + "github.com/stretchr/testify/assert" +) + +func Test_NotFoundError(t *testing.T) { + nfe := NewNotFoundError("test") + assert.Equal(t, "could not find \"test\" in map", nfe.Error()) +} + +func Test_popValue(t *testing.T) { + m, err := values.NewMap( + map[string]any{ + "test": "value", + "mismatch": 42, + }, + ) + assert.NoError(t, err) + + t.Run("success", func(t *testing.T) { + var gotValue string + gotValue, err = popValue[string](m, "test") + assert.NoError(t, err) + assert.Equal(t, "value", gotValue) + }) + + t.Run("not found", func(t *testing.T) { + _, err = popValue[string](m, "foo") + var nfe *NotFoundError + assert.ErrorAs(t, err, &nfe) + }) + + t.Run("type mismatch", func(t *testing.T) { + _, err = popValue[string](m, "mismatch") + assert.Error(t, err) + assert.ErrorContains(t, err, "could not unwrap value") + }) + + assert.Len(t, m.Underlying, 0) +} + +func Test_popOptionalValue(t *testing.T) { + m, err := values.NewMap( + map[string]any{ + "test": "value", + "buzz": "fizz", + }, + ) + assert.NoError(t, err) + t.Run("found value", func(t *testing.T) { + var gotValue string + gotValue, err = popOptionalValue[string](m, "test") + assert.NoError(t, err) + assert.Equal(t, "value", gotValue) + }) + + t.Run("not found returns nil error", func(t *testing.T) { + var gotValue string + gotValue, err = popOptionalValue[string](m, "foo") + assert.NoError(t, err) + assert.Zero(t, gotValue) + }) + + t.Run("some other error fails", func(t *testing.T) { + var gotValue int + gotValue, err = popOptionalValue[int](m, "buzz") + assert.Error(t, err) + assert.Zero(t, gotValue) + }) + + assert.Len(t, m.Underlying, 0) +} + +func Test_transformer(t *testing.T) { + t.Run("success", func(t *testing.T) { + lgger := logger.TestLogger(t) + giveMap, err := values.NewMap(map[string]any{ + "maxMemoryMBs": 1024, + "timeout": "4s", + "tickInterval": "8s", + "binary": []byte{0x01, 0x02, 0x03}, + "config": []byte{0x04, 0x05, 0x06}, + }) + assert.NoError(t, err) + + wantTO := 4 * time.Second + wantConfig := &ParsedConfig{ + Binary: []byte{0x01, 0x02, 0x03}, + Config: []byte{0x04, 0x05, 0x06}, + ModuleConfig: &host.ModuleConfig{ + MaxMemoryMBs: 1024, + Timeout: &wantTO, + TickInterval: 8 * time.Second, + Logger: lgger, + }, + } + + tf := NewTransformer() + gotConfig, err := tf.Transform(giveMap, WithLogger(lgger)) + + assert.NoError(t, err) + assert.Equal(t, wantConfig, gotConfig) + }) + + t.Run("success missing optional fields", func(t *testing.T) { + lgger := logger.TestLogger(t) + giveMap, err := values.NewMap(map[string]any{ + "binary": []byte{0x01, 0x02, 0x03}, + "config": []byte{0x04, 0x05, 0x06}, + }) + assert.NoError(t, err) + + wantConfig := &ParsedConfig{ + Binary: []byte{0x01, 0x02, 0x03}, + Config: []byte{0x04, 0x05, 0x06}, + ModuleConfig: &host.ModuleConfig{ + Logger: lgger, + }, + } + + tf := NewTransformer() + gotConfig, err := tf.Transform(giveMap, WithLogger(lgger)) + + assert.NoError(t, err) + assert.Equal(t, wantConfig, gotConfig) + }) + + t.Run("fails parsing timeout", func(t *testing.T) { + lgger := logger.TestLogger(t) + giveMap, err := values.NewMap(map[string]any{ + "timeout": "not a duration", + "binary": []byte{0x01, 0x02, 0x03}, + "config": []byte{0x04, 0x05, 0x06}, + }) + assert.NoError(t, err) + + tf := NewTransformer() + _, err = tf.Transform(giveMap, WithLogger(lgger)) + + assert.Error(t, err) + assert.ErrorContains(t, err, "invalid request") + }) + + t.Run("fails parsing tick interval", func(t *testing.T) { + lgger := logger.TestLogger(t) + giveMap, err := values.NewMap(map[string]any{ + "tickInterval": "not a duration", + "binary": []byte{0x01, 0x02, 0x03}, + "config": []byte{0x04, 0x05, 0x06}, + }) + assert.NoError(t, err) + + tf := NewTransformer() + _, err = tf.Transform(giveMap, WithLogger(lgger)) + + assert.Error(t, err) + assert.ErrorContains(t, err, "invalid request") + }) +} diff --git a/core/scripts/keystone/src/05_deploy_initialize_capabilities_registry.go b/core/scripts/keystone/src/05_deploy_initialize_capabilities_registry.go index 2fb3065f7fe..f4e394b7da5 100644 --- a/core/scripts/keystone/src/05_deploy_initialize_capabilities_registry.go +++ b/core/scripts/keystone/src/05_deploy_initialize_capabilities_registry.go @@ -156,6 +156,12 @@ var ( EncryptionPublicKey: "0x87cf298dd236a307ea887cd5d81eb0b708e3dd48c984c0700bb26c072e427942", }, } + + defaultComputeModuleConfig = map[string]any{ + "defaultTickInterval": "100ms", + "defaultTimeout": "300ms", + "defaultMaxMemoryMBs": int64(64), + } ) type deployAndInitializeCapabilitiesRegistryCommand struct{} @@ -223,9 +229,30 @@ func peerToNode(nopID uint32, p peer) (kcr.CapabilitiesRegistryNodeParams, error }, nil } -func newCapabilityConfig() *capabilitiespb.CapabilityConfig { +// newCapabilityConfig returns a new capability config with the default config set as empty. +// Override the empty default config with functional options. +func newCapabilityConfig(opts ...func(*values.Map)) *capabilitiespb.CapabilityConfig { + dc := values.EmptyMap() + for _, opt := range opts { + opt(dc) + } + return &capabilitiespb.CapabilityConfig{ - DefaultConfig: values.Proto(values.EmptyMap()).GetMapValue(), + DefaultConfig: values.ProtoMap(dc), + } +} + +// withDefaultConfig returns a function that sets the default config for a capability by merging +// the provided map with the existing default config. This is a shallow merge. +func withDefaultConfig(m map[string]any) func(*values.Map) { + return func(dc *values.Map) { + overrides, err := values.NewMap(m) + if err != nil { + panic(err) + } + for k, v := range overrides.Underlying { + dc.Underlying[k] = v + } } } @@ -292,6 +319,16 @@ func (c *deployAndInitializeCapabilitiesRegistryCommand) Run(args []string) { panic(err) } + computeAction := kcr.CapabilitiesRegistryCapability{ + LabelledName: "custom-compute", + Version: "1.0.0", + CapabilityType: uint8(1), // action + } + aid, err := reg.GetHashedCapabilityId(&bind.CallOpts{}, computeAction.LabelledName, computeAction.Version) + if err != nil { + panic(err) + } + writeChain := kcr.CapabilitiesRegistryCapability{ LabelledName: "write_ethereum-testnet-sepolia", Version: "1.0.0", @@ -328,6 +365,7 @@ func (c *deployAndInitializeCapabilitiesRegistryCommand) Run(args []string) { aptosWriteChain, ocr, cronTrigger, + computeAction, }) if err != nil { log.Printf("failed to call AddCapabilities: %s", err) @@ -413,6 +451,12 @@ func (c *deployAndInitializeCapabilitiesRegistryCommand) Run(args []string) { panic(err) } + computeCfg := newCapabilityConfig(withDefaultConfig(defaultComputeModuleConfig)) + ccfgb, err := proto.Marshal(computeCfg) + if err != nil { + panic(err) + } + cfgs := []kcr.CapabilitiesRegistryCapabilityConfiguration{ { CapabilityId: ocrid, @@ -422,6 +466,10 @@ func (c *deployAndInitializeCapabilitiesRegistryCommand) Run(args []string) { CapabilityId: ctid, Config: ccb, }, + { + CapabilityId: aid, + Config: ccfgb, + }, } _, err = reg.AddDON(env.Owner, ps, cfgs, true, true, 2) if err != nil { diff --git a/core/services/workflows/engine.go b/core/services/workflows/engine.go index 313fb05014f..ffe7da643ad 100644 --- a/core/services/workflows/engine.go +++ b/core/services/workflows/engine.go @@ -776,11 +776,10 @@ func (e *Engine) configForStep(ctx context.Context, executionID string, step *st return step.config, nil } - // Merge the configs for now; note that this means that a workflow can override - // all of the config set by the capability. This is probably not desirable in - // the long-term, but we don't know much about those use cases so stick to a simpler - // implementation for now. - return merge(capConfig.DefaultConfig, step.config), nil + // Merge the configs with registry config overriding the step config. This is because + // some config fields are sensitive and could affect the safe running of the capability, + // so we avoid user provided values by overriding them with config from the capabilities registry. + return merge(step.config, capConfig.DefaultConfig), nil } // executeStep executes the referenced capability within a step and returns the result. diff --git a/core/services/workflows/engine_test.go b/core/services/workflows/engine_test.go index 048c353c747..382662afeb1 100644 --- a/core/services/workflows/engine_test.go +++ b/core/services/workflows/engine_test.go @@ -1024,16 +1024,23 @@ func TestEngine_Error(t *testing.T) { } func TestEngine_MergesWorkflowConfigAndCRConfig(t *testing.T) { - ctx := testutils.Context(t) - reg := coreCap.NewRegistry(logger.TestLogger(t)) - - trigger, _ := mockTrigger(t) + var ( + ctx = testutils.Context(t) + writeID = "write_polygon-testnet-mumbai@1.0.0" + gotConfig = values.EmptyMap() + wantConfigKeys = []string{"deltaStage", "schedule", "address", "params", "abi"} + ) - require.NoError(t, reg.Add(ctx, trigger)) - require.NoError(t, reg.Add(ctx, mockConsensus(""))) - writeID := "write_polygon-testnet-mumbai@1.0.0" + giveRegistryConfig, err := values.WrapMap(map[string]any{ + "deltaStage": "1s", + "schedule": "allAtOnce", + }) + assert.NoError(t, err, "failed to wrap map of registry config") - gotConfig := values.EmptyMap() + // Mock the capabilities of the simple workflow. + reg := coreCap.NewRegistry(logger.TestLogger(t)) + trigger, _ := mockTrigger(t) + consensus := mockConsensus("") target := newMockCapability( // Create a remote capability so we don't use the local transmission protocol. capabilities.MustNewRemoteCapabilityInfo( @@ -1043,6 +1050,7 @@ func TestEngine_MergesWorkflowConfigAndCRConfig(t *testing.T) { &capabilities.DON{ID: 1}, ), func(req capabilities.CapabilityRequest) (capabilities.CapabilityResponse, error) { + // Replace the empty config with the write target config. gotConfig = req.Config return capabilities.CapabilityResponse{ @@ -1050,6 +1058,9 @@ func TestEngine_MergesWorkflowConfigAndCRConfig(t *testing.T) { }, nil }, ) + + require.NoError(t, reg.Add(ctx, trigger)) + require.NoError(t, reg.Add(ctx, consensus)) require.NoError(t, reg.Add(ctx, target)) eng, testHooks := newTestEngineWithYAMLSpec( @@ -1063,16 +1074,149 @@ func TestEngine_MergesWorkflowConfigAndCRConfig(t *testing.T) { return registrysyncer.CapabilityConfiguration{}, nil } - cm, err := values.WrapMap(map[string]any{ - "deltaStage": "1s", - "schedule": "allAtOnce", + var cb []byte + cb, err = proto.Marshal(&capabilitiespb.CapabilityConfig{ + DefaultConfig: values.ProtoMap(giveRegistryConfig), }) - if err != nil { - return registrysyncer.CapabilityConfiguration{}, err + return registrysyncer.CapabilityConfiguration{ + Config: cb, + }, err + }, + }) + + servicetest.Run(t, eng) + + eid := getExecutionId(t, eng, testHooks) + + state, err := eng.executionStates.Get(ctx, eid) + require.NoError(t, err) + + assert.Equal(t, state.Status, store.StatusCompleted) + + // Assert that the config from the CR is merged with the default config from the registry. + m, err := values.Unwrap(gotConfig) + require.NoError(t, err) + assert.Equal(t, m.(map[string]any)["deltaStage"], "1s") + assert.Equal(t, m.(map[string]any)["schedule"], "allAtOnce") + + for _, k := range wantConfigKeys { + assert.Contains(t, m.(map[string]any), k) + } +} + +const customComputeWorkflow = ` +triggers: + - id: "mercury-trigger@1.0.0" + config: + feedlist: + - "0x1111111111111111111100000000000000000000000000000000000000000000" # ETHUSD + - "0x2222222222222222222200000000000000000000000000000000000000000000" # LINKUSD + - "0x3333333333333333333300000000000000000000000000000000000000000000" # BTCUSD + +actions: + - id: custom_compute@1.0.0 + ref: custom_compute + config: + maxMemoryMBs: 128 + tickInterval: 100ms + timeout: 300ms + inputs: + action: + - $(trigger.outputs) + +consensus: + - id: "offchain_reporting@1.0.0" + ref: "evm_median" + inputs: + observations: + - "$(trigger.outputs)" + config: + aggregation_method: "data_feeds_2_0" + aggregation_config: + "0x1111111111111111111100000000000000000000000000000000000000000000": + deviation: "0.001" + heartbeat: 3600 + "0x2222222222222222222200000000000000000000000000000000000000000000": + deviation: "0.001" + heartbeat: 3600 + "0x3333333333333333333300000000000000000000000000000000000000000000": + deviation: "0.001" + heartbeat: 3600 + encoder: "EVM" + encoder_config: + abi: "mercury_reports bytes[]" + +targets: + - id: "write_ethereum-testnet-sepolia@1.0.0" + inputs: "$(evm_median.outputs)" + config: + address: "0x54e220867af6683aE6DcBF535B4f952cB5116510" + params: ["$(report)"] + abi: "receive(report bytes)" +` + +// TestEngine_MergesWorkflowConfigAndCRConfig_CRConfigPrecedence tests that the engine merges the +// workflow config with the CR config, with the CR config taking precedence. +func TestEngine_MergesWorkflowConfigAndCRConfig_CRConfigPrecedence(t *testing.T) { + var ( + ctx = testutils.Context(t) + actionID = "custom_compute@1.0.0" + giveTimeout = 300 * time.Millisecond + giveTickInterval = 100 * time.Millisecond + registryConfig = map[string]any{ + "maxMemoryMBs": int64(64), + "timeout": giveTimeout.String(), + "tickInterval": giveTickInterval.String(), + } + gotConfig = values.EmptyMap() + ) + + giveRegistryConfig, err := values.WrapMap(registryConfig) + assert.NoError(t, err, "failed to wrap map of registry config") + + // Mock the capabilities of the simple workflow. + reg := coreCap.NewRegistry(logger.TestLogger(t)) + trigger, _ := mockTrigger(t) + target := mockTarget("write_ethereum-testnet-sepolia@1.0.0") + action := newMockCapability( + // Create a remote capability so we don't use the local transmission protocol. + capabilities.MustNewRemoteCapabilityInfo( + actionID, + capabilities.CapabilityTypeAction, + "a custom compute action with custom config", + &capabilities.DON{ID: 1}, + ), + func(req capabilities.CapabilityRequest) (capabilities.CapabilityResponse, error) { + // Replace the empty config with the write target config. + gotConfig = req.Config + + return capabilities.CapabilityResponse{ + Value: req.Inputs, + }, nil + }, + ) + + consensus := mockConsensus("") + + require.NoError(t, reg.Add(ctx, trigger)) + require.NoError(t, reg.Add(ctx, action)) + require.NoError(t, reg.Add(ctx, target)) + require.NoError(t, reg.Add(ctx, consensus)) + + eng, testHooks := newTestEngineWithYAMLSpec( + t, + reg, + customComputeWorkflow, + ) + reg.SetLocalRegistry(testConfigProvider{ + configForCapability: func(ctx context.Context, capabilityID string, donID uint32) (registrysyncer.CapabilityConfiguration, error) { + if capabilityID != actionID { + return registrysyncer.CapabilityConfiguration{}, nil } - cb, err := proto.Marshal(&capabilitiespb.CapabilityConfig{ - DefaultConfig: values.ProtoMap(cm), + var cb []byte + cb, err = proto.Marshal(&capabilitiespb.CapabilityConfig{ + DefaultConfig: values.ProtoMap(giveRegistryConfig), }) return registrysyncer.CapabilityConfiguration{ Config: cb, @@ -1089,10 +1233,13 @@ func TestEngine_MergesWorkflowConfigAndCRConfig(t *testing.T) { assert.Equal(t, state.Status, store.StatusCompleted) + // Assert that the config from the CR is merged with the default config from the registry. With + // the CR config taking precedence. m, err := values.Unwrap(gotConfig) require.NoError(t, err) - assert.Equal(t, m.(map[string]any)["deltaStage"], "1s") - assert.Equal(t, m.(map[string]any)["schedule"], "allAtOnce") + assert.Equalf(t, registryConfig["maxMemoryMBs"], m.(map[string]any)["maxMemoryMBs"], "maxMemoryMBs should be %d", registryConfig["maxMemoryMBs"]) + assert.Equalf(t, registryConfig["timeout"], m.(map[string]any)["timeout"], "timeout should be %s", registryConfig["timeout"]) + assert.Equalf(t, registryConfig["tickInterval"], m.(map[string]any)["tickInterval"], "tickInterval should be %s", registryConfig["tickInterval"]) } func TestEngine_HandlesNilConfigOnchain(t *testing.T) {