diff --git a/.changeset/empty-bees-fix.md b/.changeset/empty-bees-fix.md new file mode 100644 index 00000000000..e76ee621253 --- /dev/null +++ b/.changeset/empty-bees-fix.md @@ -0,0 +1,5 @@ +--- +"chainlink": minor +--- + +#wip implement gateway handler that forwards outgoing request from http target capability. introduce gateway http client diff --git a/.mockery.yaml b/.mockery.yaml index b22875e3f9f..709134b05bd 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -212,6 +212,7 @@ packages: HttpServer: HTTPRequestHandler: WebSocketServer: + HTTPClient: github.com/smartcontractkit/chainlink/v2/core/services/job: interfaces: ServiceCtx: diff --git a/core/capabilities/webapi/target/target_test.go b/core/capabilities/webapi/target/target_test.go index 67923ef4c80..a4064c7e7fe 100644 --- a/core/capabilities/webapi/target/target_test.go +++ b/core/capabilities/webapi/target/target_test.go @@ -107,10 +107,10 @@ func gatewayResponse(t *testing.T, msgID string) *api.Message { headers := map[string]string{"Content-Type": "application/json"} body := []byte("response body") responsePayload, err := json.Marshal(webapicapabilities.TargetResponsePayload{ - StatusCode: 200, - Headers: headers, - Body: body, - Success: true, + StatusCode: 200, + Headers: headers, + Body: body, + ExecutionError: false, }) require.NoError(t, err) return &api.Message{ diff --git a/core/capabilities/webapi/target/types.go b/core/capabilities/webapi/target/types.go index 6152d1496b7..63356baa96c 100644 --- a/core/capabilities/webapi/target/types.go +++ b/core/capabilities/webapi/target/types.go @@ -22,7 +22,7 @@ type WorkflowConfig struct { DeliveryMode string `json:"deliveryMode,omitempty"` // DeliveryMode describes how request should be delivered to gateway nodes, defaults to SingleNode. } -// CapabilityConfigConfig is the configuration for the Target capability and handler +// Config is the configuration for the Target capability and handler // TODO: handle retry configurations here CM-472 // Note that workflow executions have their own internal timeouts and retries set by the user // that are separate from this configuration diff --git a/core/scripts/gateway/run_gateway.go b/core/scripts/gateway/run_gateway.go index 2daca5190a5..5dbcd02bf56 100644 --- a/core/scripts/gateway/run_gateway.go +++ b/core/scripts/gateway/run_gateway.go @@ -48,7 +48,7 @@ func main() { lggr, _ := logger.NewLogger() - handlerFactory := gateway.NewHandlerFactory(nil, nil, lggr) + handlerFactory := gateway.NewHandlerFactory(nil, nil, nil, lggr) gw, err := gateway.NewGatewayFromConfig(&cfg, handlerFactory, lggr) if err != nil { fmt.Println("error creating Gateway object:", err) diff --git a/core/services/gateway/config/config.go b/core/services/gateway/config/config.go index a4d94155c8f..02c1b44869f 100644 --- a/core/services/gateway/config/config.go +++ b/core/services/gateway/config/config.go @@ -10,7 +10,9 @@ type GatewayConfig struct { UserServerConfig gw_net.HTTPServerConfig NodeServerConfig gw_net.WebSocketServerConfig ConnectionManagerConfig ConnectionManagerConfig - Dons []DONConfig + // HTTPClientConfig is configuration for outbound HTTP calls to external endpoints + HTTPClientConfig gw_net.HTTPClientConfig + Dons []DONConfig } type ConnectionManagerConfig struct { diff --git a/core/services/gateway/delegate.go b/core/services/gateway/delegate.go index 5a30228db4c..ba059b15a35 100644 --- a/core/services/gateway/delegate.go +++ b/core/services/gateway/delegate.go @@ -12,6 +12,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/chains/legacyevm" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/network" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" ) @@ -54,7 +55,11 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) (services if err2 != nil { return nil, errors.Wrap(err2, "unmarshal gateway config") } - handlerFactory := NewHandlerFactory(d.legacyChains, d.ds, d.lggr) + httpClient, err := network.NewHTTPClient(gatewayConfig.HTTPClientConfig, d.lggr) + if err != nil { + return nil, err + } + handlerFactory := NewHandlerFactory(d.legacyChains, d.ds, httpClient, d.lggr) gateway, err := NewGatewayFromConfig(&gatewayConfig, handlerFactory, d.lggr) if err != nil { return nil, err diff --git a/core/services/gateway/gateway_test.go b/core/services/gateway/gateway_test.go index 3218c5428a2..7a5457c788c 100644 --- a/core/services/gateway/gateway_test.go +++ b/core/services/gateway/gateway_test.go @@ -57,7 +57,7 @@ Address = "0x0001020304050607080900010203040506070809" `) lggr := logger.TestLogger(t) - _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, lggr), lggr) + _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, nil, lggr), lggr) require.NoError(t, err) } @@ -75,7 +75,7 @@ HandlerName = "dummy" `) lggr := logger.TestLogger(t) - _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, lggr), lggr) + _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, nil, lggr), lggr) require.Error(t, err) } @@ -89,7 +89,7 @@ HandlerName = "no_such_handler" `) lggr := logger.TestLogger(t) - _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, lggr), lggr) + _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, nil, lggr), lggr) require.Error(t, err) } @@ -103,7 +103,7 @@ SomeOtherField = "abcd" `) lggr := logger.TestLogger(t) - _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, lggr), lggr) + _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, nil, lggr), lggr) require.Error(t, err) } @@ -121,7 +121,7 @@ Address = "0xnot_an_address" `) lggr := logger.TestLogger(t) - _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, lggr), lggr) + _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, nil, nil, lggr), lggr) require.Error(t, err) } @@ -129,7 +129,7 @@ func TestGateway_CleanStartAndClose(t *testing.T) { t.Parallel() lggr := logger.TestLogger(t) - gateway, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, buildConfig("")), gateway.NewHandlerFactory(nil, nil, lggr), lggr) + gateway, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, buildConfig("")), gateway.NewHandlerFactory(nil, nil, nil, lggr), lggr) require.NoError(t, err) servicetest.Run(t, gateway) } diff --git a/core/services/gateway/handler_factory.go b/core/services/gateway/handler_factory.go index 92ad48b5395..0c1eeaf676e 100644 --- a/core/services/gateway/handler_factory.go +++ b/core/services/gateway/handler_factory.go @@ -11,6 +11,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/webapicapabilities" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/network" ) const ( @@ -23,15 +24,17 @@ type handlerFactory struct { legacyChains legacyevm.LegacyChainContainer ds sqlutil.DataSource lggr logger.Logger + httpClient network.HTTPClient } var _ HandlerFactory = (*handlerFactory)(nil) -func NewHandlerFactory(legacyChains legacyevm.LegacyChainContainer, ds sqlutil.DataSource, lggr logger.Logger) HandlerFactory { +func NewHandlerFactory(legacyChains legacyevm.LegacyChainContainer, ds sqlutil.DataSource, httpClient network.HTTPClient, lggr logger.Logger) HandlerFactory { return &handlerFactory{ legacyChains, ds, lggr, + httpClient, } } @@ -39,10 +42,10 @@ func (hf *handlerFactory) NewHandler(handlerType HandlerType, handlerConfig json switch handlerType { case FunctionsHandlerType: return functions.NewFunctionsHandlerFromConfig(handlerConfig, donConfig, don, hf.legacyChains, hf.ds, hf.lggr) - case WebAPICapabilitiesType: - return webapicapabilities.NewWorkflowHandler(handlerConfig, donConfig, don, hf.lggr) case DummyHandlerType: return handlers.NewDummyHandler(donConfig, don, hf.lggr) + case WebAPICapabilitiesType: + return webapicapabilities.NewHandler(handlerConfig, donConfig, don, hf.httpClient, hf.lggr) default: return nil, fmt.Errorf("unsupported handler type %s", handlerType) } diff --git a/core/services/gateway/handlers/handler.go b/core/services/gateway/handlers/handler.go index 6994488707f..b9fe4234d25 100644 --- a/core/services/gateway/handlers/handler.go +++ b/core/services/gateway/handlers/handler.go @@ -31,7 +31,8 @@ type Handler interface { // 2. waits on callbackCh with a timeout HandleUserMessage(ctx context.Context, msg *api.Message, callbackCh chan<- UserCallbackPayload) error - // Handlers should not make any assumptions about goroutines calling HandleNodeMessage + // Handlers should not make any assumptions about goroutines calling HandleNodeMessage. + // should be non-blocking HandleNodeMessage(ctx context.Context, msg *api.Message, nodeAddr string) error } diff --git a/core/services/gateway/handlers/webapicapabilities/handler.go b/core/services/gateway/handlers/webapicapabilities/handler.go index d6caf067dd0..744bdc17406 100644 --- a/core/services/gateway/handlers/webapicapabilities/handler.go +++ b/core/services/gateway/handlers/webapicapabilities/handler.go @@ -13,6 +13,8 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/network" ) const ( @@ -22,17 +24,22 @@ const ( ) type handler struct { - config HandlerConfig - donConfig *config.DONConfig - don handlers.DON - savedCallbacks map[string]*savedCallback - mu sync.Mutex - lggr logger.Logger + config HandlerConfig + don handlers.DON + donConfig *config.DONConfig + savedCallbacks map[string]*savedCallback + mu sync.Mutex + lggr logger.Logger + httpClient network.HTTPClient + nodeRateLimiter *common.RateLimiter + wg sync.WaitGroup } type HandlerConfig struct { - MaxAllowedMessageAgeSec uint + NodeRateLimiter common.RateLimiterConfig `json:"nodeRateLimiter"` + MaxAllowedMessageAgeSec uint `json:"maxAllowedMessageAgeSec"` } + type savedCallback struct { id string callbackCh chan<- handlers.UserCallbackPayload @@ -40,84 +47,193 @@ type savedCallback struct { var _ handlers.Handler = (*handler)(nil) -func NewWorkflowHandler(handlerConfig json.RawMessage, donConfig *config.DONConfig, don handlers.DON, lggr logger.Logger) (*handler, error) { +func NewHandler(handlerConfig json.RawMessage, donConfig *config.DONConfig, don handlers.DON, httpClient network.HTTPClient, lggr logger.Logger) (*handler, error) { var cfg HandlerConfig err := json.Unmarshal(handlerConfig, &cfg) if err != nil { return nil, err } + nodeRateLimiter, err := common.NewRateLimiter(cfg.NodeRateLimiter) + if err != nil { + return nil, err + } return &handler{ - config: cfg, - donConfig: donConfig, - don: don, - savedCallbacks: make(map[string]*savedCallback), - lggr: lggr.Named("WorkflowHandler." + donConfig.DonId), + config: cfg, + don: don, + donConfig: donConfig, + lggr: lggr.Named("WebAPIHandler." + donConfig.DonId), + httpClient: httpClient, + nodeRateLimiter: nodeRateLimiter, + wg: sync.WaitGroup{}, + savedCallbacks: make(map[string]*savedCallback), + }, nil +} + +// sendHTTPMessageToClient is an outgoing message from the gateway to external endpoints +// returns message to be sent back to the capability node +func (h *handler) sendHTTPMessageToClient(ctx context.Context, req network.HTTPRequest, msg *api.Message) (*api.Message, error) { + var payload TargetResponsePayload + resp, err := h.httpClient.Send(ctx, req) + if err != nil { + return nil, err + } + payload = TargetResponsePayload{ + ExecutionError: false, + StatusCode: resp.StatusCode, + Headers: resp.Headers, + Body: resp.Body, + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + return &api.Message{ + Body: api.MessageBody{ + MessageId: msg.Body.MessageId, + Method: msg.Body.Method, + DonId: msg.Body.DonId, + Payload: payloadBytes, + }, }, nil } -func (d *handler) HandleUserMessage(ctx context.Context, msg *api.Message, callbackCh chan<- handlers.UserCallbackPayload) error { - d.mu.Lock() - d.savedCallbacks[msg.Body.MessageId] = &savedCallback{msg.Body.MessageId, callbackCh} - don := d.don - d.mu.Unlock() +func (h *handler) handleWebAPITargetMessage(ctx context.Context, msg *api.Message, nodeAddr string) error { + h.lggr.Debugw("handling web api target message", "messageId", msg.Body.MessageId, "nodeAddr", nodeAddr) + if !h.nodeRateLimiter.Allow(nodeAddr) { + return fmt.Errorf("rate limit exceeded for node %s", nodeAddr) + } + var targetPayload TargetRequestPayload + err := json.Unmarshal(msg.Body.Payload, &targetPayload) + if err != nil { + return err + } + // send message to target + timeout := time.Duration(targetPayload.TimeoutMs) * time.Millisecond + req := network.HTTPRequest{ + Method: targetPayload.Method, + URL: targetPayload.URL, + Headers: targetPayload.Headers, + Body: targetPayload.Body, + Timeout: timeout, + } + // this handle method must be non-blocking + // send response to node (target capability) async + // if there is a non-HTTP error (e.g. malformed request), send payload with success set to false and error messages + h.wg.Add(1) + go func() { + defer h.wg.Done() + // not cancelled when parent is cancelled to ensure the goroutine can finish + newCtx := context.WithoutCancel(ctx) + newCtx, cancel := context.WithTimeout(newCtx, timeout) + defer cancel() + l := h.lggr.With("url", targetPayload.URL, "messageId", msg.Body.MessageId, "method", targetPayload.Method) + respMsg, err := h.sendHTTPMessageToClient(newCtx, req, msg) + if err != nil { + l.Errorw("error while sending HTTP request to external endpoint", "err", err) + payload := TargetResponsePayload{ + ExecutionError: true, + ErrorMessage: err.Error(), + } + payloadBytes, err2 := json.Marshal(payload) + if err2 != nil { + // should not happen + l.Errorw("error while marshalling payload", "err", err2) + return + } + respMsg = &api.Message{ + Body: api.MessageBody{ + MessageId: msg.Body.MessageId, + Method: msg.Body.Method, + DonId: msg.Body.DonId, + Payload: payloadBytes, + }, + } + } + err = h.don.SendToNode(newCtx, nodeAddr, respMsg) + if err != nil { + l.Errorw("failed to send to node", "err", err, "to", nodeAddr) + return + } + }() + return nil +} + +func (h *handler) handleWebAPITriggerMessage(ctx context.Context, msg *api.Message, nodeAddr string) error { + h.mu.Lock() + savedCb, found := h.savedCallbacks[msg.Body.MessageId] + delete(h.savedCallbacks, msg.Body.MessageId) + h.mu.Unlock() + + if found { + // Send first response from a node back to the user, ignore any other ones. + // TODO: in practice, we should wait for at least 2F+1 nodes to respond and then return an aggregated response + // back to the user. + savedCb.callbackCh <- handlers.UserCallbackPayload{Msg: msg, ErrCode: api.NoError, ErrMsg: ""} + close(savedCb.callbackCh) + } + return nil +} + +func (h *handler) HandleNodeMessage(ctx context.Context, msg *api.Message, nodeAddr string) error { + switch msg.Body.Method { + case MethodWebAPITrigger: + return h.handleWebAPITriggerMessage(ctx, msg, nodeAddr) + case MethodWebAPITarget: + return h.handleWebAPITargetMessage(ctx, msg, nodeAddr) + default: + return fmt.Errorf("unsupported method: %s", msg.Body.Method) + } +} + +func (h *handler) Start(context.Context) error { + return nil +} + +func (h *handler) Close() error { + h.wg.Wait() + return nil +} + +func (h *handler) HandleUserMessage(ctx context.Context, msg *api.Message, callbackCh chan<- handlers.UserCallbackPayload) error { + h.mu.Lock() + h.savedCallbacks[msg.Body.MessageId] = &savedCallback{msg.Body.MessageId, callbackCh} + don := h.don + h.mu.Unlock() body := msg.Body var payload TriggerRequestPayload err := json.Unmarshal(body.Payload, &payload) if err != nil { - d.lggr.Errorw("error decoding payload", "err", err) + h.lggr.Errorw("error decoding payload", "err", err) callbackCh <- handlers.UserCallbackPayload{Msg: msg, ErrCode: api.UserMessageParseError, ErrMsg: fmt.Sprintf("error decoding payload %s", err.Error())} close(callbackCh) return nil } if payload.Timestamp == 0 { - d.lggr.Errorw("error decoding payload") + h.lggr.Errorw("error decoding payload") callbackCh <- handlers.UserCallbackPayload{Msg: msg, ErrCode: api.UserMessageParseError, ErrMsg: "error decoding payload"} close(callbackCh) return nil } - if uint(time.Now().Unix())-d.config.MaxAllowedMessageAgeSec > uint(payload.Timestamp) { + if uint(time.Now().Unix())-h.config.MaxAllowedMessageAgeSec > uint(payload.Timestamp) { callbackCh <- handlers.UserCallbackPayload{Msg: msg, ErrCode: api.HandlerError, ErrMsg: "stale message"} close(callbackCh) return nil } // TODO: apply allowlist and rate-limiting here if msg.Body.Method != MethodWebAPITrigger { - d.lggr.Errorw("unsupported method", "method", body.Method) + h.lggr.Errorw("unsupported method", "method", body.Method) callbackCh <- handlers.UserCallbackPayload{Msg: msg, ErrCode: api.HandlerError, ErrMsg: fmt.Sprintf("invalid method %s", msg.Body.Method)} close(callbackCh) return nil } // Send to all nodes. - for _, member := range d.donConfig.Members { + for _, member := range h.donConfig.Members { err = multierr.Combine(err, don.SendToNode(ctx, member.Address, msg)) } return err } - -func (d *handler) HandleNodeMessage(ctx context.Context, msg *api.Message, _ string) error { - d.mu.Lock() - savedCb, found := d.savedCallbacks[msg.Body.MessageId] - delete(d.savedCallbacks, msg.Body.MessageId) - d.mu.Unlock() - - if found { - // Send first response from a node back to the user, ignore any other ones. - // TODO: in practice, we should wait for at least 2F+1 nodes to respond and then return an aggregated response - // back to the user. - savedCb.callbackCh <- handlers.UserCallbackPayload{Msg: msg, ErrCode: api.NoError, ErrMsg: ""} - close(savedCb.callbackCh) - } - return nil -} - -func (d *handler) Start(context.Context) error { - return nil -} - -func (d *handler) Close() error { - return nil -} diff --git a/core/services/gateway/handlers/webapicapabilities/handler_test.go b/core/services/gateway/handlers/webapicapabilities/handler_test.go index ef278e40ffd..e631111ff1d 100644 --- a/core/services/gateway/handlers/webapicapabilities/handler_test.go +++ b/core/services/gateway/handlers/webapicapabilities/handler_test.go @@ -3,23 +3,28 @@ package webapicapabilities import ( "encoding/json" "fmt" - "strconv" "testing" "time" - "github.com/ethereum/go-ethereum/crypto" "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" + "strconv" + + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" gwcommon "github.com/smartcontractkit/chainlink/v2/core/services/gateway/common" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers" - + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common" handlermocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/mocks" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/network" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/network/mocks" ) const ( @@ -34,13 +39,21 @@ const ( address1 = "0x853d51d5d9935964267a5050aC53aa63ECA39bc5" ) -func setupHandler(t *testing.T) (*handler, *handlermocks.DON, []gwcommon.TestNode) { +func setupHandler(t *testing.T) (*handler, *mocks.HTTPClient, *handlermocks.DON, []gwcommon.TestNode) { lggr := logger.TestLogger(t) + httpClient := mocks.NewHTTPClient(t) don := handlermocks.NewDON(t) - + nodeRateLimiterConfig := common.RateLimiterConfig{ + GlobalRPS: 100.0, + GlobalBurst: 100, + PerSenderRPS: 100.0, + PerSenderBurst: 100, + } handlerConfig := HandlerConfig{ + NodeRateLimiter: nodeRateLimiterConfig, MaxAllowedMessageAgeSec: 30, } + cfgBytes, err := json.Marshal(handlerConfig) require.NoError(t, err) donConfig := &config.DONConfig{ @@ -54,10 +67,125 @@ func setupHandler(t *testing.T) (*handler, *handlermocks.DON, []gwcommon.TestNod Address: n.Address, }) } + handler, err := NewHandler(json.RawMessage(cfgBytes), donConfig, don, httpClient, lggr) + require.NoError(t, err) + return handler, httpClient, don, nodes +} - handler, err := NewWorkflowHandler(json.RawMessage(cfgBytes), donConfig, don, lggr) +func TestHandler_SendHTTPMessageToClient(t *testing.T) { + handler, httpClient, don, nodes := setupHandler(t) + ctx := testutils.Context(t) + nodeAddr := nodes[0].Address + payload := TargetRequestPayload{ + Method: "GET", + URL: "http://example.com", + Headers: map[string]string{}, + Body: nil, + TimeoutMs: 2000, + } + payloadBytes, err := json.Marshal(payload) require.NoError(t, err) - return handler, don, nodes + msg := &api.Message{ + Body: api.MessageBody{ + MessageId: "123", + Method: MethodWebAPITarget, + DonId: "testDonId", + Payload: json.RawMessage(payloadBytes), + }, + } + + t.Run("happy case", func(t *testing.T) { + httpClient.EXPECT().Send(mock.Anything, mock.Anything).Return(&network.HTTPResponse{ + StatusCode: 200, + Headers: map[string]string{}, + Body: []byte("response body"), + }, nil).Once() + + don.EXPECT().SendToNode(mock.Anything, nodes[0].Address, mock.MatchedBy(func(m *api.Message) bool { + var payload TargetResponsePayload + err2 := json.Unmarshal(m.Body.Payload, &payload) + if err2 != nil { + return false + } + return "123" == m.Body.MessageId && + MethodWebAPITarget == m.Body.Method && + "testDonId" == m.Body.DonId && + 200 == payload.StatusCode && + 0 == len(payload.Headers) && + string(payload.Body) == "response body" && + !payload.ExecutionError + })).Return(nil).Once() + + err = handler.HandleNodeMessage(ctx, msg, nodeAddr) + require.NoError(t, err) + + require.Eventually(t, func() bool { + // ensure all goroutines close + err2 := handler.Close() + require.NoError(t, err2) + return httpClient.AssertExpectations(t) && don.AssertExpectations(t) + }, tests.WaitTimeout(t), 100*time.Millisecond) + }) + + t.Run("http client non-HTTP error", func(t *testing.T) { + httpClient.EXPECT().Send(mock.Anything, mock.Anything).Return(&network.HTTPResponse{ + StatusCode: 404, + Headers: map[string]string{}, + Body: []byte("access denied"), + }, nil).Once() + + don.EXPECT().SendToNode(mock.Anything, nodes[0].Address, mock.MatchedBy(func(m *api.Message) bool { + var payload TargetResponsePayload + err2 := json.Unmarshal(m.Body.Payload, &payload) + if err2 != nil { + return false + } + return "123" == m.Body.MessageId && + MethodWebAPITarget == m.Body.Method && + "testDonId" == m.Body.DonId && + 404 == payload.StatusCode && + string(payload.Body) == "access denied" && + 0 == len(payload.Headers) && + !payload.ExecutionError + })).Return(nil).Once() + + err = handler.HandleNodeMessage(ctx, msg, nodeAddr) + require.NoError(t, err) + + require.Eventually(t, func() bool { + // // ensure all goroutines close + err2 := handler.Close() + require.NoError(t, err2) + return httpClient.AssertExpectations(t) && don.AssertExpectations(t) + }, tests.WaitTimeout(t), 100*time.Millisecond) + }) + + t.Run("http client non-HTTP error", func(t *testing.T) { + httpClient.EXPECT().Send(mock.Anything, mock.Anything).Return(nil, fmt.Errorf("error while marshalling")).Once() + + don.EXPECT().SendToNode(mock.Anything, nodes[0].Address, mock.MatchedBy(func(m *api.Message) bool { + var payload TargetResponsePayload + err2 := json.Unmarshal(m.Body.Payload, &payload) + if err2 != nil { + return false + } + return "123" == m.Body.MessageId && + MethodWebAPITarget == m.Body.Method && + "testDonId" == m.Body.DonId && + payload.ExecutionError && + "error while marshalling" == payload.ErrorMessage + })).Return(nil).Once() + + err = handler.HandleNodeMessage(ctx, msg, nodeAddr) + require.NoError(t, err) + + require.Eventually(t, func() bool { + // // ensure all goroutines close + err2 := handler.Close() + require.NoError(t, err2) + return httpClient.AssertExpectations(t) && don.AssertExpectations(t) + }, tests.WaitTimeout(t), 100*time.Millisecond) + }) } func triggerRequest(t *testing.T, privateKey string, topics string, methodName string, timestamp string, payload string) *api.Message { @@ -110,7 +238,7 @@ func requireNoChanMsg[T any](t *testing.T, ch <-chan T) { } func TestHandlerReceiveHTTPMessageFromClient(t *testing.T) { - handler, don, _ := setupHandler(t) + handler, _, don, _ := setupHandler(t) ctx := testutils.Context(t) msg := triggerRequest(t, privateKey1, `["daily_price_update"]`, "", "", "") diff --git a/core/services/gateway/handlers/webapicapabilities/webapi.go b/core/services/gateway/handlers/webapicapabilities/webapi.go index 97ba401881b..25f3bca6c1d 100644 --- a/core/services/gateway/handlers/webapicapabilities/webapi.go +++ b/core/services/gateway/handlers/webapicapabilities/webapi.go @@ -13,11 +13,11 @@ type TargetRequestPayload struct { } type TargetResponsePayload struct { - Success bool `json:"success"` // true if HTTP request was successful - ErrorMessage string `json:"error_message,omitempty"` // error message in case of failure - StatusCode uint8 `json:"statusCode"` // HTTP status code - Headers map[string]string `json:"headers,omitempty"` // HTTP headers - Body []byte `json:"body,omitempty"` // HTTP response body + ExecutionError bool `json:"executionError"` // true if there were non-HTTP errors. false if HTTP request was sent regardless of status (2xx, 4xx, 5xx) + ErrorMessage string `json:"errorMessage,omitempty"` // error message in case of failure + StatusCode int `json:"statusCode,omitempty"` // HTTP status code + Headers map[string]string `json:"headers,omitempty"` // HTTP headers + Body []byte `json:"body,omitempty"` // HTTP response body } // https://gateway-us-1.chain.link/web-trigger diff --git a/core/services/gateway/integration_tests/gateway_integration_test.go b/core/services/gateway/integration_tests/gateway_integration_test.go index 59418819b61..0ddf47bec04 100644 --- a/core/services/gateway/integration_tests/gateway_integration_test.go +++ b/core/services/gateway/integration_tests/gateway_integration_test.go @@ -10,6 +10,7 @@ import ( "strings" "sync/atomic" "testing" + "time" "github.com/jonboulle/clockwork" "github.com/onsi/gomega" @@ -24,6 +25,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/gateway/common" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/network" ) const gatewayConfigTemplate = ` @@ -143,7 +145,12 @@ func TestIntegration_Gateway_NoFullNodes_BasicConnectionAndMessage(t *testing.T) // Launch Gateway lggr := logger.TestLogger(t) gatewayConfig := fmt.Sprintf(gatewayConfigTemplate, nodeKeys.Address) - gateway, err := gateway.NewGatewayFromConfig(parseGatewayConfig(t, gatewayConfig), gateway.NewHandlerFactory(nil, nil, lggr), lggr) + c, err := network.NewHTTPClient(network.HTTPClientConfig{ + DefaultTimeout: 5 * time.Second, + MaxResponseBytes: 1000, + }, lggr) + require.NoError(t, err) + gateway, err := gateway.NewGatewayFromConfig(parseGatewayConfig(t, gatewayConfig), gateway.NewHandlerFactory(nil, nil, c, lggr), lggr) require.NoError(t, err) servicetest.Run(t, gateway) userPort, nodePort := gateway.GetUserPort(), gateway.GetNodePort() diff --git a/core/services/gateway/network/httpclient.go b/core/services/gateway/network/httpclient.go new file mode 100644 index 00000000000..4aecaaed3cd --- /dev/null +++ b/core/services/gateway/network/httpclient.go @@ -0,0 +1,88 @@ +package network + +import ( + "bytes" + "context" + "io" + "net/http" + "strings" + "time" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +// HTTPClient interfaces defines a method to send HTTP requests +type HTTPClient interface { + Send(ctx context.Context, req HTTPRequest) (*HTTPResponse, error) +} + +type HTTPClientConfig struct { + MaxResponseBytes uint32 + DefaultTimeout time.Duration +} + +type HTTPRequest struct { + Method string + URL string + Headers map[string]string + Body []byte + Timeout time.Duration +} +type HTTPResponse struct { + StatusCode int // HTTP status code + Headers map[string]string // HTTP headers + Body []byte // HTTP response body +} + +type httpClient struct { + client *http.Client + config HTTPClientConfig + lggr logger.Logger +} + +// NewHTTPClient creates a new NewHTTPClient +// As of now, the client does not support TLS configuration but may be extended in the future +func NewHTTPClient(config HTTPClientConfig, lggr logger.Logger) (HTTPClient, error) { + return &httpClient{ + config: config, + client: &http.Client{ + Timeout: config.DefaultTimeout, + Transport: http.DefaultTransport, + }, + lggr: lggr, + }, nil +} + +func (c *httpClient) Send(ctx context.Context, req HTTPRequest) (*HTTPResponse, error) { + timeoutCtx, cancel := context.WithTimeout(ctx, req.Timeout) + defer cancel() + r, err := http.NewRequestWithContext(timeoutCtx, req.Method, req.URL, bytes.NewBuffer(req.Body)) + if err != nil { + return nil, err + } + + resp, err := c.client.Do(r) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + reader := http.MaxBytesReader(nil, resp.Body, int64(c.config.MaxResponseBytes)) + body, err := io.ReadAll(reader) + if err != nil { + return nil, err + } + headers := make(map[string]string) + for k, v := range resp.Header { + // header values are usually an array of size 1 + // joining them to a single string in case array size is greater than 1 + headers[k] = strings.Join(v, ",") + } + c.lggr.Debugw("received HTTP response", "statusCode", resp.StatusCode, "body", string(body), "url", req.URL, "headers", headers) + + return &HTTPResponse{ + Headers: headers, + StatusCode: resp.StatusCode, + Body: body, + }, nil +} diff --git a/core/services/gateway/network/httpclient_test.go b/core/services/gateway/network/httpclient_test.go new file mode 100644 index 00000000000..2f4cc448ef5 --- /dev/null +++ b/core/services/gateway/network/httpclient_test.go @@ -0,0 +1,147 @@ +package network_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/network" +) + +func TestHTTPClient_Send(t *testing.T) { + t.Parallel() + + // Setup the test environment + lggr := logger.Test(t) + config := network.HTTPClientConfig{ + MaxResponseBytes: 1024, + DefaultTimeout: 5 * time.Second, + } + client, err := network.NewHTTPClient(config, lggr) + require.NoError(t, err) + + // Define test cases + tests := []struct { + name string + setupServer func() *httptest.Server + request network.HTTPRequest + expectedError error + expectedResp *network.HTTPResponse + }{ + { + name: "successful request", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err2 := w.Write([]byte("success")) + require.NoError(t, err2) + })) + }, + request: network.HTTPRequest{ + Method: "GET", + URL: "/", + Headers: map[string]string{}, + Body: nil, + Timeout: 2 * time.Second, + }, + expectedError: nil, + expectedResp: &network.HTTPResponse{ + StatusCode: http.StatusOK, + Headers: map[string]string{"Content-Length": "7"}, + Body: []byte("success"), + }, + }, + { + name: "request timeout", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(10 * time.Second) + w.WriteHeader(http.StatusOK) + _, err2 := w.Write([]byte("success")) + require.NoError(t, err2) + })) + }, + request: network.HTTPRequest{ + Method: "GET", + URL: "/", + Headers: map[string]string{}, + Body: nil, + Timeout: 1 * time.Second, + }, + expectedError: context.DeadlineExceeded, + expectedResp: nil, + }, + { + name: "server error", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, err2 := w.Write([]byte("error")) + require.NoError(t, err2) + })) + }, + request: network.HTTPRequest{ + Method: "GET", + URL: "/", + Headers: map[string]string{}, + Body: nil, + Timeout: 2 * time.Second, + }, + expectedError: nil, + expectedResp: &network.HTTPResponse{ + StatusCode: http.StatusInternalServerError, + Headers: map[string]string{"Content-Length": "5"}, + Body: []byte("error"), + }, + }, + { + name: "response too long", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err2 := w.Write(make([]byte, 2048)) + require.NoError(t, err2) + })) + }, + request: network.HTTPRequest{ + Method: "GET", + URL: "/", + Headers: map[string]string{}, + Body: nil, + Timeout: 2 * time.Second, + }, + expectedError: &http.MaxBytesError{}, + expectedResp: nil, + }, + } + + // Execute test cases + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := tt.setupServer() + defer server.Close() + + tt.request.URL = server.URL + tt.request.URL + + resp, err := client.Send(context.Background(), tt.request) + if tt.expectedError != nil { + require.Error(t, err) + require.ErrorContains(t, err, tt.expectedError.Error()) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedResp.StatusCode, resp.StatusCode) + for k, v := range tt.expectedResp.Headers { + value, ok := resp.Headers[k] + require.True(t, ok) + require.Equal(t, v, value) + } + require.Equal(t, tt.expectedResp.Body, resp.Body) + } + }) + } +} diff --git a/core/services/gateway/network/mocks/http_client.go b/core/services/gateway/network/mocks/http_client.go new file mode 100644 index 00000000000..8b5bff2cccf --- /dev/null +++ b/core/services/gateway/network/mocks/http_client.go @@ -0,0 +1,96 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mocks + +import ( + context "context" + + network "github.com/smartcontractkit/chainlink/v2/core/services/gateway/network" + mock "github.com/stretchr/testify/mock" +) + +// HTTPClient is an autogenerated mock type for the HTTPClient type +type HTTPClient struct { + mock.Mock +} + +type HTTPClient_Expecter struct { + mock *mock.Mock +} + +func (_m *HTTPClient) EXPECT() *HTTPClient_Expecter { + return &HTTPClient_Expecter{mock: &_m.Mock} +} + +// Send provides a mock function with given fields: ctx, req +func (_m *HTTPClient) Send(ctx context.Context, req network.HTTPRequest) (*network.HTTPResponse, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Send") + } + + var r0 *network.HTTPResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, network.HTTPRequest) (*network.HTTPResponse, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, network.HTTPRequest) *network.HTTPResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*network.HTTPResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, network.HTTPRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// HTTPClient_Send_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Send' +type HTTPClient_Send_Call struct { + *mock.Call +} + +// Send is a helper method to define mock.On call +// - ctx context.Context +// - req network.HTTPRequest +func (_e *HTTPClient_Expecter) Send(ctx interface{}, req interface{}) *HTTPClient_Send_Call { + return &HTTPClient_Send_Call{Call: _e.mock.On("Send", ctx, req)} +} + +func (_c *HTTPClient_Send_Call) Run(run func(ctx context.Context, req network.HTTPRequest)) *HTTPClient_Send_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(network.HTTPRequest)) + }) + return _c +} + +func (_c *HTTPClient_Send_Call) Return(_a0 *network.HTTPResponse, _a1 error) *HTTPClient_Send_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *HTTPClient_Send_Call) RunAndReturn(run func(context.Context, network.HTTPRequest) (*network.HTTPResponse, error)) *HTTPClient_Send_Call { + _c.Call.Return(run) + return _c +} + +// NewHTTPClient creates a new instance of HTTPClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewHTTPClient(t interface { + mock.TestingT + Cleanup(func()) +}) *HTTPClient { + mock := &HTTPClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +}