From 26ab188922a69f822833a460f364a06f56404513 Mon Sep 17 00:00:00 2001 From: Pierre Rousset Date: Fri, 9 Oct 2020 18:07:29 +0900 Subject: [PATCH 1/2] CORS: add an optional custom function to validate the origin --- middleware/cors.go | 71 +++++++++++++++++++++++++---------------- middleware/cors_test.go | 47 +++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 27 deletions(-) diff --git a/middleware/cors.go b/middleware/cors.go index 07df0e57e..c1e22e4e6 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -19,6 +19,13 @@ type ( // Optional. Default value []string{"*"}. AllowOrigins []string `yaml:"allow_origins"` + // AllowOriginFunc is a custom function to validate the origin. It takes the + // origin as an argument and returns true if allowed or false otherwise. If + // an error is returned, it is returned by the handler. If this option is + // set, AllowOrigins is ignored. + // Optional. + AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"` + // AllowMethods defines a list methods allowed when accessing the resource. // This is used in response to a preflight request. // Optional. Default value DefaultCORSConfig.AllowMethods. @@ -113,40 +120,50 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { return c.NoContent(http.StatusNoContent) } - // Check allowed origins - for _, o := range config.AllowOrigins { - if o == "*" && config.AllowCredentials { - allowOrigin = origin - break - } - if o == "*" || o == origin { - allowOrigin = o - break - } - if matchSubdomain(origin, o) { - allowOrigin = origin - break - } - } - - // Check allowed origin patterns - for _, re := range allowOriginPatterns { - if allowOrigin == "" { - didx := strings.Index(origin, "://") - if didx == -1 { - continue + if config.AllowOriginFunc == nil { + // Check allowed origins + for _, o := range config.AllowOrigins { + if o == "*" && config.AllowCredentials { + allowOrigin = origin + break } - domAuth := origin[didx+3:] - // to avoid regex cost by invalid long domain - if len(domAuth) > 253 { + if o == "*" || o == origin { + allowOrigin = o break } - - if match, _ := regexp.MatchString(re, origin); match { + if matchSubdomain(origin, o) { allowOrigin = origin break } } + + // Check allowed origin patterns + for _, re := range allowOriginPatterns { + if allowOrigin == "" { + didx := strings.Index(origin, "://") + if didx == -1 { + continue + } + domAuth := origin[didx+3:] + // to avoid regex cost by invalid long domain + if len(domAuth) > 253 { + break + } + + if match, _ := regexp.MatchString(re, origin); match { + allowOrigin = origin + break + } + } + } + } else { + allowed, err := config.AllowOriginFunc(origin) + if err != nil { + return err + } + if allowed { + allowOrigin = origin + } } // Origin not allowed diff --git a/middleware/cors_test.go b/middleware/cors_test.go index fc34694db..717abe498 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -1,6 +1,7 @@ package middleware import ( + "errors" "net/http" "net/http/httptest" "testing" @@ -360,3 +361,49 @@ func TestCorsHeaders(t *testing.T) { } } } + +func Test_allowOriginFunc(t *testing.T) { + returnTrue := func(origin string) (bool, error) { + return true, nil + } + returnFalse := func(origin string) (bool, error) { + return false, nil + } + returnError := func(origin string) (bool, error) { + return true, errors.New("this is a test error") + } + + allowOriginFuncs := []func(origin string) (bool, error){ + returnTrue, + returnFalse, + returnError, + } + + const origin = "http://example.com" + + e := echo.New() + for _, allowOriginFunc := range allowOriginFuncs { + req := httptest.NewRequest(http.MethodOptions, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + req.Header.Set(echo.HeaderOrigin, origin) + cors := CORSWithConfig(CORSConfig{ + AllowOriginFunc: allowOriginFunc, + }) + h := cors(echo.NotFoundHandler) + err := h(c) + + expected, expectedErr := allowOriginFunc(origin) + if expectedErr != nil { + assert.Equal(t, expectedErr, err) + assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + continue + } + + if expected { + assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + } else { + assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + } + } +} From e6f24aa8b1cb1263d462d4d4a1126827c3d8e7f7 Mon Sep 17 00:00:00 2001 From: Pierre Rousset Date: Mon, 16 Nov 2020 12:53:49 +0900 Subject: [PATCH 2/2] Addressed PR feedback --- middleware/cors.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/middleware/cors.go b/middleware/cors.go index c1e22e4e6..d6ef89644 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -120,7 +120,15 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { return c.NoContent(http.StatusNoContent) } - if config.AllowOriginFunc == nil { + if config.AllowOriginFunc != nil { + allowed, err := config.AllowOriginFunc(origin) + if err != nil { + return err + } + if allowed { + allowOrigin = origin + } + } else { // Check allowed origins for _, o := range config.AllowOrigins { if o == "*" && config.AllowCredentials { @@ -156,14 +164,6 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } } } - } else { - allowed, err := config.AllowOriginFunc(origin) - if err != nil { - return err - } - if allowed { - allowOrigin = origin - } } // Origin not allowed