diff --git a/middleware/cors.go b/middleware/cors.go index c263f7319..07df0e57e 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -102,6 +102,17 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { origin := req.Header.Get(echo.HeaderOrigin) allowOrigin := "" + preflight := req.Method == http.MethodOptions + res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) + + // No Origin provided + if origin == "" { + if !preflight { + return next(c) + } + return c.NoContent(http.StatusNoContent) + } + // Check allowed origins for _, o := range config.AllowOrigins { if o == "*" && config.AllowCredentials { @@ -138,9 +149,16 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } } + // Origin not allowed + if allowOrigin == "" { + if !preflight { + return next(c) + } + return c.NoContent(http.StatusNoContent) + } + // Simple request - if req.Method != http.MethodOptions { - res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) + if !preflight { res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) if config.AllowCredentials { res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") @@ -152,7 +170,6 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } // Preflight request - res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) diff --git a/middleware/cors_test.go b/middleware/cors_test.go index ca922321c..fc34694db 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -17,19 +17,31 @@ func TestCORS(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := CORS()(echo.NotFoundHandler) + req.Header.Set(echo.HeaderOrigin, "localhost") h(c) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + // Wildcard AllowedOrigin with no Origin header in request + req = httptest.NewRequest(http.MethodGet, "/", nil) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + h = CORS()(echo.NotFoundHandler) + h(c) + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) + // Allow origins req = httptest.NewRequest(http.MethodGet, "/", nil) rec = httptest.NewRecorder() c = e.NewContext(req, rec) h = CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"localhost"}, + AllowOrigins: []string{"localhost"}, + AllowCredentials: true, + MaxAge: 3600, })(echo.NotFoundHandler) req.Header.Set(echo.HeaderOrigin, "localhost") h(c) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials)) // Preflight request req = httptest.NewRequest(http.MethodOptions, "/", nil) @@ -67,6 +79,22 @@ func TestCORS(t *testing.T) { assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials)) assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge)) + // Preflight request with Access-Control-Request-Headers + req = httptest.NewRequest(http.MethodOptions, "/", nil) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + req.Header.Set(echo.HeaderOrigin, "localhost") + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + req.Header.Set(echo.HeaderAccessControlRequestHeaders, "Special-Request-Header") + cors = CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"*"}, + }) + h = cors(echo.NotFoundHandler) + h(c) + assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders)) + assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) + // Preflight request with `AllowOrigins` which allow all subdomains with * req = httptest.NewRequest(http.MethodOptions, "/", nil) rec = httptest.NewRecorder() @@ -126,7 +154,7 @@ func Test_allowOriginScheme(t *testing.T) { if tt.expected { assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) } else { - assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) } } } @@ -217,7 +245,118 @@ func Test_allowOriginSubdomain(t *testing.T) { if tt.expected { assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) } else { - assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) + } + } +} + +func TestCorsHeaders(t *testing.T) { + tests := []struct { + domain, allowedOrigin, method string + expected bool + }{ + { + domain: "", // Request does not have Origin header + allowedOrigin: "*", + method: http.MethodGet, + expected: false, + }, + { + domain: "http://example.com", + allowedOrigin: "*", + method: http.MethodGet, + expected: true, + }, + { + domain: "", // Request does not have Origin header + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: false, + }, + { + domain: "http://bar.com", + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: false, + }, + { + domain: "http://example.com", + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: true, + }, + { + domain: "", // Request does not have Origin header + allowedOrigin: "*", + method: http.MethodOptions, + expected: false, + }, + { + domain: "http://example.com", + allowedOrigin: "*", + method: http.MethodOptions, + expected: true, + }, + { + domain: "", // Request does not have Origin header + allowedOrigin: "http://example.com", + method: http.MethodOptions, + expected: false, + }, + { + domain: "http://bar.com", + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: false, + }, + { + domain: "http://example.com", + allowedOrigin: "http://example.com", + method: http.MethodOptions, + expected: true, + }, + } + + e := echo.New() + for _, tt := range tests { + req := httptest.NewRequest(tt.method, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + if tt.domain != "" { + req.Header.Set(echo.HeaderOrigin, tt.domain) + } + cors := CORSWithConfig(CORSConfig{ + AllowOrigins: []string{tt.allowedOrigin}, + //AllowCredentials: true, + //MaxAge: 3600, + }) + h := cors(echo.NotFoundHandler) + h(c) + + assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary)) + + expectedAllowOrigin := "" + if tt.allowedOrigin == "*" { + expectedAllowOrigin = "*" + } else { + expectedAllowOrigin = tt.domain + } + + switch { + case tt.expected && tt.method == http.MethodOptions: + assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods) + assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary])) + case tt.expected && tt.method == http.MethodGet: + assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin + default: + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) + assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin + } + + if tt.method == http.MethodOptions { + assert.Equal(t, http.StatusNoContent, rec.Code) } } }