Skip to content

Commit

Permalink
add AcceptResponseHandler to modify accepted responses (#196)
Browse files Browse the repository at this point in the history
* add AcceptResponseHandler to modify accepted responses

* customer->custom
  • Loading branch information
cmoresco-stripe authored Jul 26, 2023
1 parent 81a59fd commit 6f13b30
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 3 deletions.
5 changes: 4 additions & 1 deletion pkg/smokescreen/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,12 @@ type Config struct {
// Custom Dial Timeout function to be called
ProxyDialTimeout func(ctx context.Context, network, address string, timeout time.Duration) (net.Conn, error)

// Customer handler to allow clients to modify reject responses
// Custom handler to allow clients to modify reject responses
RejectResponseHandler func(*http.Response)

// Custom handler to allow clients to modify accept responses
AcceptResponseHandler func(*http.Response)

// UnsafeAllowPrivateRanges inverts the default behavior, telling smokescreen to allow private IP
// ranges by default (exempting loopback and unicast ranges)
// This setting can be used to configure Smokescreen with a blocklist, rather than an allowlist
Expand Down
7 changes: 5 additions & 2 deletions pkg/smokescreen/smokescreen.go
Original file line number Diff line number Diff line change
Expand Up @@ -541,10 +541,13 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer {
proxy.OnResponse().DoFunc(func(resp *http.Response, pctx *goproxy.ProxyCtx) *http.Response {
sctx := pctx.UserData.(*smokescreenContext)

if resp != nil && resp.Header.Get(errorHeader) != "" {
if pctx.Error == nil && sctx.decision.allow {
if resp != nil && pctx.Error == nil && sctx.decision.allow {
if resp.Header.Get(errorHeader) != "" {
resp.Header.Del(errorHeader)
}
if sctx.cfg.AcceptResponseHandler != nil {
sctx.cfg.AcceptResponseHandler(resp)
}
}

if resp == nil && pctx.Error != nil {
Expand Down
43 changes: 43 additions & 0 deletions pkg/smokescreen/smokescreen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,49 @@ func TestRejectResponseHandler(t *testing.T) {
})
}

// Test that Smokescreen calls the custom accept response handler (if defined in the Config struct)
// after every accepted request
func TestAcceptResponseHandler(t *testing.T) {
r := require.New(t)
testHeader := "TestAcceptResponseHandlerHeader"
t.Run("Testing custom accept response handler", func(t *testing.T) {
cfg, err := testConfig("test-local-srv")

// set a custom AcceptResponseHandler that will set a header on every reject response
cfg.AcceptResponseHandler = func(resp *http.Response) {
resp.Header.Set(testHeader, "This header is added by the AcceptResponseHandler")
}
r.NoError(err)

proxySrv := proxyServer(cfg)
r.NoError(err)
defer proxySrv.Close()

// Create a http.Client that uses our proxy
client, err := proxyClient(proxySrv.URL)
r.NoError(err)

// Send a request that should be allowed
resp, err := client.Get("http://example.com")
r.NoError(err)

// The AcceptResponseHandler should set our custom header
h := resp.Header.Get(testHeader)
if h == "" {
t.Errorf("Expecting header %s to be set by AcceptResponseHandler", testHeader)
}
// Send a request that should be blocked
resp, err = client.Get("http://127.0.0.1")
r.NoError(err)

// The header set by our custom reject response handler should not be set
h = resp.Header.Get(testHeader)
if h != "" {
t.Errorf("Expecting header %s to not be set by AcceptResponseHandler", testHeader)
}
})
}

func TestCustomRequestHandler(t *testing.T) {
r := require.New(t)
testHeader := "X-Verify-Request-Header"
Expand Down

0 comments on commit 6f13b30

Please sign in to comment.