From 445d0d6362799ddbf89930174406d36c1847f28f Mon Sep 17 00:00:00 2001 From: Sergey Rud Date: Wed, 24 May 2023 13:19:47 +0100 Subject: [PATCH] Move the custom request handler call after the main acl check --- pkg/smokescreen/config.go | 3 ++- pkg/smokescreen/smokescreen.go | 38 ++++++++++++++--------------- pkg/smokescreen/smokescreen_test.go | 8 +++--- 3 files changed, 24 insertions(+), 25 deletions(-) diff --git a/pkg/smokescreen/config.go b/pkg/smokescreen/config.go index c24caeee..adfcce10 100644 --- a/pkg/smokescreen/config.go +++ b/pkg/smokescreen/config.go @@ -89,8 +89,9 @@ type Config struct { // Custom handler for users to allow running code per requests, users can pass in custom methods to verify requests based // on headers, code for metrics etc. + // If smokescreen denies a request, this handler is not called. // If the handler returns an error, smokescreen will deny the request. - CustomRequestHandler func(*http.Request) error + PostDecisionRequestHandler func(*http.Request) error } type missingRoleError struct { diff --git a/pkg/smokescreen/smokescreen.go b/pkg/smokescreen/smokescreen.go index b2bbe650..2c3c61d4 100644 --- a/pkg/smokescreen/smokescreen.go +++ b/pkg/smokescreen/smokescreen.go @@ -477,16 +477,6 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer { } sctx.logger.WithField("url", req.RequestURI).Debug("received HTTP proxy request") - - // Call the custom request handler if it exists - if config.CustomRequestHandler != nil { - err = config.CustomRequestHandler(req) - if err != nil { - pctx.Error = denyError{err} - return req, rejectResponse(pctx, pctx.Error) - } - } - sctx.decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, req, destination) // Returning any kind of response in this handler is goproxy's way of short circuiting @@ -499,6 +489,15 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer { return req, rejectResponse(pctx, denyError{errors.New(sctx.decision.reason)}) } + // Call the custom request handler if it exists + if config.PostDecisionRequestHandler != nil { + err = config.PostDecisionRequestHandler(req) + if err != nil { + pctx.Error = denyError{err} + return req, rejectResponse(pctx, pctx.Error) + } + } + // Proceed with proxying the request return req, nil }) @@ -621,16 +620,6 @@ func handleConnect(config *Config, pctx *goproxy.ProxyCtx) (string, error) { pctx.Error = denyError{err} return "", pctx.Error } - - // Call the custom request handler if it exists - if config.CustomRequestHandler != nil { - err = config.CustomRequestHandler(pctx.Req) - if err != nil { - pctx.Error = denyError{err} - return "", pctx.Error - } - } - sctx.decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, pctx.Req, destination) if pctx.Error != nil { return "", denyError{pctx.Error} @@ -639,6 +628,15 @@ func handleConnect(config *Config, pctx *goproxy.ProxyCtx) (string, error) { return "", denyError{errors.New(sctx.decision.reason)} } + // Call the custom request handler if it exists + if config.PostDecisionRequestHandler != nil { + err = config.PostDecisionRequestHandler(pctx.Req) + if err != nil { + pctx.Error = denyError{err} + return "", pctx.Error + } + } + return destination.String(), nil } diff --git a/pkg/smokescreen/smokescreen_test.go b/pkg/smokescreen/smokescreen_test.go index 2823a161..b92f7495 100644 --- a/pkg/smokescreen/smokescreen_test.go +++ b/pkg/smokescreen/smokescreen_test.go @@ -1070,7 +1070,7 @@ func TestCustomRequestHandler(t *testing.T) { return nil } - t.Run("CustomRequestHandler works for HTTPS", func(t *testing.T) { + t.Run("PostDecisionRequestHandler works for HTTPS", func(t *testing.T) { testCases := []struct { header http.Header expectedError bool @@ -1088,7 +1088,7 @@ func TestCustomRequestHandler(t *testing.T) { r.NoError(err) err = cfg.SetAllowAddresses([]string{"127.0.0.1"}) r.NoError(err) - cfg.CustomRequestHandler = customRequestHandler + cfg.PostDecisionRequestHandler = customRequestHandler l, err := net.Listen("tcp", "localhost:0") r.NoError(err) @@ -1119,7 +1119,7 @@ func TestCustomRequestHandler(t *testing.T) { } }) - t.Run("CustomRequestHandler works for HTTP", func(t *testing.T) { + t.Run("PostDecisionRequestHandler works for HTTP", func(t *testing.T) { testCases := []struct { header string expectedError bool @@ -1137,7 +1137,7 @@ func TestCustomRequestHandler(t *testing.T) { r.NoError(err) err = cfg.SetAllowAddresses([]string{"127.0.0.1"}) r.NoError(err) - cfg.CustomRequestHandler = customRequestHandler + cfg.PostDecisionRequestHandler = customRequestHandler l, err := net.Listen("tcp", "localhost:0") r.NoError(err)