From a33f085ece7621c8498ed7f9e6e7c5a578f07c8d Mon Sep 17 00:00:00 2001 From: Sergey Rud Date: Wed, 24 May 2023 13:19:47 +0100 Subject: [PATCH] Move the custom request handler call after the main acl check --- pkg/smokescreen/smokescreen.go | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/pkg/smokescreen/smokescreen.go b/pkg/smokescreen/smokescreen.go index b2bbe650..d654c766 100644 --- a/pkg/smokescreen/smokescreen.go +++ b/pkg/smokescreen/smokescreen.go @@ -477,16 +477,6 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer { } sctx.logger.WithField("url", req.RequestURI).Debug("received HTTP proxy request") - - // Call the custom request handler if it exists - if config.CustomRequestHandler != nil { - err = config.CustomRequestHandler(req) - if err != nil { - pctx.Error = denyError{err} - return req, rejectResponse(pctx, pctx.Error) - } - } - sctx.decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, req, destination) // Returning any kind of response in this handler is goproxy's way of short circuiting @@ -499,6 +489,15 @@ func BuildProxy(config *Config) *goproxy.ProxyHttpServer { return req, rejectResponse(pctx, denyError{errors.New(sctx.decision.reason)}) } + // Call the custom request handler if it exists + if config.CustomRequestHandler != nil { + err = config.CustomRequestHandler(req) + if err != nil { + pctx.Error = denyError{err} + return req, rejectResponse(pctx, pctx.Error) + } + } + // Proceed with proxying the request return req, nil }) @@ -621,6 +620,13 @@ func handleConnect(config *Config, pctx *goproxy.ProxyCtx) (string, error) { pctx.Error = denyError{err} return "", pctx.Error } + sctx.decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, pctx.Req, destination) + if pctx.Error != nil { + return "", denyError{pctx.Error} + } + if !sctx.decision.allow { + return "", denyError{errors.New(sctx.decision.reason)} + } // Call the custom request handler if it exists if config.CustomRequestHandler != nil { @@ -631,14 +637,6 @@ func handleConnect(config *Config, pctx *goproxy.ProxyCtx) (string, error) { } } - sctx.decision, sctx.lookupTime, pctx.Error = checkIfRequestShouldBeProxied(config, pctx.Req, destination) - if pctx.Error != nil { - return "", denyError{pctx.Error} - } - if !sctx.decision.allow { - return "", denyError{errors.New(sctx.decision.reason)} - } - return destination.String(), nil }