Skip to content

Commit

Permalink
Fix open redirect vulnerability with AddTrailingSlashWithConfig and R…
Browse files Browse the repository at this point in the history
…emoveTrailingSlashWithConfig (#1775,#1771)

* fix open redirect vulnerability with AddTrailingSlashWithConfig and RemoveTrailingSlashWithConfig (fix #1771)
* rename trimMultipleSlashes to sanitizeURI
  • Loading branch information
aldas committed Feb 11, 2021
1 parent 932976d commit f09f2bd
Show file tree
Hide file tree
Showing 2 changed files with 273 additions and 82 deletions.
13 changes: 11 additions & 2 deletions middleware/slash.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc

// Redirect
if config.RedirectCode != 0 {
return c.Redirect(config.RedirectCode, uri)
return c.Redirect(config.RedirectCode, sanitizeURI(uri))
}

// Forward
Expand Down Expand Up @@ -108,7 +108,7 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu

// Redirect
if config.RedirectCode != 0 {
return c.Redirect(config.RedirectCode, uri)
return c.Redirect(config.RedirectCode, sanitizeURI(uri))
}

// Forward
Expand All @@ -119,3 +119,12 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu
}
}
}

func sanitizeURI(uri string) string {
// double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri
// we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash
if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') {
uri = "/" + strings.TrimLeft(uri, `/\`)
}
return uri
}
342 changes: 262 additions & 80 deletions middleware/slash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,88 +9,270 @@ import (
"github.com/stretchr/testify/assert"
)

func TestAddTrailingSlashWithConfig(t *testing.T) {
var testCases = []struct {
whenURL string
whenMethod string
expectPath string
expectLocation []string
expectStatus int
}{
{
whenURL: "/add-slash",
whenMethod: http.MethodGet,
expectPath: "/add-slash",
expectLocation: []string{`/add-slash/`},
},
{
whenURL: "/add-slash?key=value",
whenMethod: http.MethodGet,
expectPath: "/add-slash",
expectLocation: []string{`/add-slash/?key=value`},
},
{
whenURL: "/",
whenMethod: http.MethodConnect,
expectPath: "/",
expectLocation: nil,
expectStatus: http.StatusOK,
},
// cases for open redirect vulnerability
{
whenURL: "http://localhost:1323/%5Cexample.com",
expectPath: `/\example.com`,
expectLocation: []string{`/example.com/`},
},
{
whenURL: `http://localhost:1323/\example.com`,
expectPath: `/\example.com`,
expectLocation: []string{`/example.com/`},
},
{
whenURL: `http://localhost:1323/\\%5C////%5C\\\example.com`,
expectPath: `/\\\////\\\\example.com`,
expectLocation: []string{`/example.com/`},
},
{
whenURL: "http://localhost:1323//example.com",
expectPath: `//example.com`,
expectLocation: []string{`/example.com/`},
},
{
whenURL: "http://localhost:1323/%5C%5C",
expectPath: `/\\`,
expectLocation: []string{`/`},
},
}
for _, tc := range testCases {
t.Run(tc.whenURL, func(t *testing.T) {
e := echo.New()

mw := AddTrailingSlashWithConfig(TrailingSlashConfig{
RedirectCode: http.StatusMovedPermanently,
})
h := mw(func(c echo.Context) error {
return nil
})

rec := httptest.NewRecorder()
req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil)
c := e.NewContext(req, rec)

err := h(c)
assert.NoError(t, err)

assert.Equal(t, tc.expectPath, req.URL.Path)
assert.Equal(t, tc.expectLocation, rec.Header()[echo.HeaderLocation])
if tc.expectStatus == 0 {
assert.Equal(t, http.StatusMovedPermanently, rec.Code)
} else {
assert.Equal(t, tc.expectStatus, rec.Code)
}
})
}
}

func TestAddTrailingSlash(t *testing.T) {
is := assert.New(t)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/add-slash", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := AddTrailingSlash()(func(c echo.Context) error {
return nil
})
is.NoError(h(c))
is.Equal("/add-slash/", req.URL.Path)
is.Equal("/add-slash/", req.RequestURI)

// Method Connect must not fail:
req = httptest.NewRequest(http.MethodConnect, "", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h = AddTrailingSlash()(func(c echo.Context) error {
return nil
})
is.NoError(h(c))
is.Equal("/", req.URL.Path)
is.Equal("/", req.RequestURI)

// With config
req = httptest.NewRequest(http.MethodGet, "/add-slash?key=value", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h = AddTrailingSlashWithConfig(TrailingSlashConfig{
RedirectCode: http.StatusMovedPermanently,
})(func(c echo.Context) error {
return nil
})
is.NoError(h(c))
is.Equal(http.StatusMovedPermanently, rec.Code)
is.Equal("/add-slash/?key=value", rec.Header().Get(echo.HeaderLocation))
var testCases = []struct {
whenURL string
whenMethod string
expectPath string
expectLocation []string
}{
{
whenURL: "/add-slash",
whenMethod: http.MethodGet,
expectPath: "/add-slash/",
},
{
whenURL: "/add-slash?key=value",
whenMethod: http.MethodGet,
expectPath: "/add-slash/",
},
{
whenURL: "/",
whenMethod: http.MethodConnect,
expectPath: "/",
expectLocation: nil,
},
}
for _, tc := range testCases {
t.Run(tc.whenURL, func(t *testing.T) {
e := echo.New()

h := AddTrailingSlash()(func(c echo.Context) error {
return nil
})

rec := httptest.NewRecorder()
req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil)
c := e.NewContext(req, rec)

err := h(c)
assert.NoError(t, err)

assert.Equal(t, tc.expectPath, req.URL.Path)
assert.Equal(t, []string(nil), rec.Header()[echo.HeaderLocation])
assert.Equal(t, http.StatusOK, rec.Code)
})
}
}

func TestRemoveTrailingSlashWithConfig(t *testing.T) {
var testCases = []struct {
whenURL string
whenMethod string
expectPath string
expectLocation []string
expectStatus int
}{
{
whenURL: "/remove-slash/",
whenMethod: http.MethodGet,
expectPath: "/remove-slash/",
expectLocation: []string{`/remove-slash`},
},
{
whenURL: "/remove-slash/?key=value",
whenMethod: http.MethodGet,
expectPath: "/remove-slash/",
expectLocation: []string{`/remove-slash?key=value`},
},
{
whenURL: "/",
whenMethod: http.MethodConnect,
expectPath: "/",
expectLocation: nil,
expectStatus: http.StatusOK,
},
{
whenURL: "http://localhost",
whenMethod: http.MethodGet,
expectPath: "",
expectLocation: nil,
expectStatus: http.StatusOK,
},
// cases for open redirect vulnerability
{
whenURL: "http://localhost:1323/%5Cexample.com/",
expectPath: `/\example.com/`,
expectLocation: []string{`/example.com`},
},
{
whenURL: `http://localhost:1323/\example.com/`,
expectPath: `/\example.com/`,
expectLocation: []string{`/example.com`},
},
{
whenURL: `http://localhost:1323/\\%5C////%5C\\\example.com/`,
expectPath: `/\\\////\\\\example.com/`,
expectLocation: []string{`/example.com`},
},
{
whenURL: "http://localhost:1323//example.com/",
expectPath: `//example.com/`,
expectLocation: []string{`/example.com`},
},
{
whenURL: "http://localhost:1323/%5C%5C/",
expectPath: `/\\/`,
expectLocation: []string{`/`},
},
}
for _, tc := range testCases {
t.Run(tc.whenURL, func(t *testing.T) {
e := echo.New()

mw := RemoveTrailingSlashWithConfig(TrailingSlashConfig{
RedirectCode: http.StatusMovedPermanently,
})
h := mw(func(c echo.Context) error {
return nil
})

rec := httptest.NewRecorder()
req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil)
c := e.NewContext(req, rec)

err := h(c)
assert.NoError(t, err)

assert.Equal(t, tc.expectPath, req.URL.Path)
assert.Equal(t, tc.expectLocation, rec.Header()[echo.HeaderLocation])
if tc.expectStatus == 0 {
assert.Equal(t, http.StatusMovedPermanently, rec.Code)
} else {
assert.Equal(t, tc.expectStatus, rec.Code)
}
})
}
}

func TestRemoveTrailingSlash(t *testing.T) {
is := assert.New(t)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/remove-slash/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := RemoveTrailingSlash()(func(c echo.Context) error {
return nil
})
is.NoError(h(c))
is.Equal("/remove-slash", req.URL.Path)
is.Equal("/remove-slash", req.RequestURI)

// Method Connect must not fail:
req = httptest.NewRequest(http.MethodConnect, "", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h = RemoveTrailingSlash()(func(c echo.Context) error {
return nil
})
is.NoError(h(c))
is.Equal("", req.URL.Path)
is.Equal("", req.RequestURI)

// With config
req = httptest.NewRequest(http.MethodGet, "/remove-slash/?key=value", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h = RemoveTrailingSlashWithConfig(TrailingSlashConfig{
RedirectCode: http.StatusMovedPermanently,
})(func(c echo.Context) error {
return nil
})
is.NoError(h(c))
is.Equal(http.StatusMovedPermanently, rec.Code)
is.Equal("/remove-slash?key=value", rec.Header().Get(echo.HeaderLocation))

// With bare URL
req = httptest.NewRequest(http.MethodGet, "http://localhost", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h = RemoveTrailingSlash()(func(c echo.Context) error {
return nil
})
is.NoError(h(c))
is.Equal("", req.URL.Path)
var testCases = []struct {
whenURL string
whenMethod string
expectPath string
}{
{
whenURL: "/remove-slash/",
whenMethod: http.MethodGet,
expectPath: "/remove-slash",
},
{
whenURL: "/remove-slash/?key=value",
whenMethod: http.MethodGet,
expectPath: "/remove-slash",
},
{
whenURL: "/",
whenMethod: http.MethodConnect,
expectPath: "/",
},
{
whenURL: "http://localhost",
whenMethod: http.MethodGet,
expectPath: "",
},
}
for _, tc := range testCases {
t.Run(tc.whenURL, func(t *testing.T) {
e := echo.New()

h := RemoveTrailingSlash()(func(c echo.Context) error {
return nil
})

rec := httptest.NewRecorder()
req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil)
c := e.NewContext(req, rec)

err := h(c)
assert.NoError(t, err)

assert.Equal(t, tc.expectPath, req.URL.Path)
assert.Equal(t, []string(nil), rec.Header()[echo.HeaderLocation])
assert.Equal(t, http.StatusOK, rec.Code)
})
}
}

0 comments on commit f09f2bd

Please sign in to comment.