Skip to content

Commit

Permalink
Merge pull request #1669 from ulasakdeniz/fix-incorrect-cors-headers
Browse files Browse the repository at this point in the history
Fix empty/incorrect CORS headers
  • Loading branch information
lammel committed Nov 20, 2020
2 parents ce95e12 + 871ed9c commit 90bef88
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 6 deletions.
23 changes: 20 additions & 3 deletions middleware/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,17 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
origin := req.Header.Get(echo.HeaderOrigin)
allowOrigin := ""

preflight := req.Method == http.MethodOptions
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)

// No Origin provided
if origin == "" {
if !preflight {
return next(c)
}
return c.NoContent(http.StatusNoContent)
}

// Check allowed origins
for _, o := range config.AllowOrigins {
if o == "*" && config.AllowCredentials {
Expand Down Expand Up @@ -138,9 +149,16 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
}
}

// Origin not allowed
if allowOrigin == "" {
if !preflight {
return next(c)
}
return c.NoContent(http.StatusNoContent)
}

// Simple request
if req.Method != http.MethodOptions {
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
if !preflight {
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
Expand All @@ -152,7 +170,6 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
}

// Preflight request
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
Expand Down
145 changes: 142 additions & 3 deletions middleware/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,31 @@ func TestCORS(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := CORS()(echo.NotFoundHandler)
req.Header.Set(echo.HeaderOrigin, "localhost")
h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))

// Wildcard AllowedOrigin with no Origin header in request
req = httptest.NewRequest(http.MethodGet, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h = CORS()(echo.NotFoundHandler)
h(c)
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)

// Allow origins
req = httptest.NewRequest(http.MethodGet, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"localhost"},
AllowOrigins: []string{"localhost"},
AllowCredentials: true,
MaxAge: 3600,
})(echo.NotFoundHandler)
req.Header.Set(echo.HeaderOrigin, "localhost")
h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))

// Preflight request
req = httptest.NewRequest(http.MethodOptions, "/", nil)
Expand Down Expand Up @@ -67,6 +79,22 @@ func TestCORS(t *testing.T) {
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))

// Preflight request with Access-Control-Request-Headers
req = httptest.NewRequest(http.MethodOptions, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, "localhost")
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
req.Header.Set(echo.HeaderAccessControlRequestHeaders, "Special-Request-Header")
cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"*"},
})
h = cors(echo.NotFoundHandler)
h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))

// Preflight request with `AllowOrigins` which allow all subdomains with *
req = httptest.NewRequest(http.MethodOptions, "/", nil)
rec = httptest.NewRecorder()
Expand Down Expand Up @@ -126,7 +154,7 @@ func Test_allowOriginScheme(t *testing.T) {
if tt.expected {
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else {
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
}
}
}
Expand Down Expand Up @@ -217,7 +245,118 @@ func Test_allowOriginSubdomain(t *testing.T) {
if tt.expected {
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else {
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
}
}
}

func TestCorsHeaders(t *testing.T) {
tests := []struct {
domain, allowedOrigin, method string
expected bool
}{
{
domain: "", // Request does not have Origin header
allowedOrigin: "*",
method: http.MethodGet,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "*",
method: http.MethodGet,
expected: true,
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
},
{
domain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: true,
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "*",
method: http.MethodOptions,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "*",
method: http.MethodOptions,
expected: true,
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: false,
},
{
domain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: true,
},
}

e := echo.New()
for _, tt := range tests {
req := httptest.NewRequest(tt.method, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
if tt.domain != "" {
req.Header.Set(echo.HeaderOrigin, tt.domain)
}
cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{tt.allowedOrigin},
//AllowCredentials: true,
//MaxAge: 3600,
})
h := cors(echo.NotFoundHandler)
h(c)

assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))

expectedAllowOrigin := ""
if tt.allowedOrigin == "*" {
expectedAllowOrigin = "*"
} else {
expectedAllowOrigin = tt.domain
}

switch {
case tt.expected && tt.method == http.MethodOptions:
assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
case tt.expected && tt.method == http.MethodGet:
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
default:
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
}

if tt.method == http.MethodOptions {
assert.Equal(t, http.StatusNoContent, rec.Code)
}
}
}

0 comments on commit 90bef88

Please sign in to comment.