diff --git a/pkg/smokescreen/config.go b/pkg/smokescreen/config.go index adfcce10..6eb23697 100644 --- a/pkg/smokescreen/config.go +++ b/pkg/smokescreen/config.go @@ -79,9 +79,12 @@ type Config struct { // Custom Dial Timeout function to be called ProxyDialTimeout func(ctx context.Context, network, address string, timeout time.Duration) (net.Conn, error) - // Customer handler to allow clients to modify reject responses + // Custom handler to allow clients to modify reject responses RejectResponseHandler func(*http.Response) + // Custom handler to allow clients to modify accept responses + AcceptResponseHandler func(*http.Response) + // UnsafeAllowPrivateRanges inverts the default behavior, telling smokescreen to allow private IP // ranges by default (exempting loopback and unicast ranges) // This setting can be used to configure Smokescreen with a blocklist, rather than an allowlist diff --git a/pkg/smokescreen/smokescreen.go b/pkg/smokescreen/smokescreen.go index 69e9dd38..6a073c1a 100644 --- a/pkg/smokescreen/smokescreen.go +++ b/pkg/smokescreen/smokescreen.go @@ -541,10 +541,13 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer { proxy.OnResponse().DoFunc(func(resp *http.Response, pctx *goproxy.ProxyCtx) *http.Response { sctx := pctx.UserData.(*smokescreenContext) - if resp != nil && resp.Header.Get(errorHeader) != "" { - if pctx.Error == nil && sctx.decision.allow { + if resp != nil && pctx.Error == nil && sctx.decision.allow { + if resp.Header.Get(errorHeader) != "" { resp.Header.Del(errorHeader) } + if sctx.cfg.AcceptResponseHandler != nil { + sctx.cfg.AcceptResponseHandler(resp) + } } if resp == nil && pctx.Error != nil { diff --git a/pkg/smokescreen/smokescreen_test.go b/pkg/smokescreen/smokescreen_test.go index 4e01ef9a..8fe192b7 100644 --- a/pkg/smokescreen/smokescreen_test.go +++ b/pkg/smokescreen/smokescreen_test.go @@ -1064,6 +1064,49 @@ func TestRejectResponseHandler(t *testing.T) { }) } +// Test that Smokescreen calls the custom accept response handler (if defined in the Config struct) +// after every accepted request +func TestAcceptResponseHandler(t *testing.T) { + r := require.New(t) + testHeader := "TestAcceptResponseHandlerHeader" + t.Run("Testing custom accept response handler", func(t *testing.T) { + cfg, err := testConfig("test-local-srv") + + // set a custom AcceptResponseHandler that will set a header on every reject response + cfg.AcceptResponseHandler = func(resp *http.Response) { + resp.Header.Set(testHeader, "This header is added by the AcceptResponseHandler") + } + r.NoError(err) + + proxySrv := proxyServer(cfg) + r.NoError(err) + defer proxySrv.Close() + + // Create a http.Client that uses our proxy + client, err := proxyClient(proxySrv.URL) + r.NoError(err) + + // Send a request that should be allowed + resp, err := client.Get("http://example.com") + r.NoError(err) + + // The AcceptResponseHandler should set our custom header + h := resp.Header.Get(testHeader) + if h == "" { + t.Errorf("Expecting header %s to be set by AcceptResponseHandler", testHeader) + } + // Send a request that should be blocked + resp, err = client.Get("http://127.0.0.1") + r.NoError(err) + + // The header set by our custom reject response handler should not be set + h = resp.Header.Get(testHeader) + if h != "" { + t.Errorf("Expecting header %s to not be set by AcceptResponseHandler", testHeader) + } + }) +} + func TestCustomRequestHandler(t *testing.T) { r := require.New(t) testHeader := "X-Verify-Request-Header"