Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support real regex rules for rewrite and proxy middleware #1767

Merged
merged 12 commits into from
Feb 8, 2021
2 changes: 1 addition & 1 deletion middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string {
rulesRegex := map[*regexp.Regexp]string{}
for k, v := range rewrite {
k = regexp.QuoteMeta(k)
k = strings.Replace(k, `\*`, "(.*)", -1)
k = strings.Replace(k, `\*`, "(.*?)", -1)
lammel marked this conversation as resolved.
Show resolved Hide resolved
if strings.HasPrefix(k, `\^`) {
k = strings.Replace(k, `\^`, "^", -1)
}
Expand Down
22 changes: 16 additions & 6 deletions middleware/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ type (
// "/users/*/orders/*": "/user/$1/order/$2",
Rewrite map[string]string

// RegexRewrite defines rewrite rules using regexp.Rexexp with captures
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
// Example:
// "^/old/[0.9]+/": "/new",
// "^/api/.+?/(.*)": "/v2/$1",
RegexRewrite map[*regexp.Regexp]string
lammel marked this conversation as resolved.
Show resolved Hide resolved

// Context key to store selected ProxyTarget into context.
// Optional. Default value "target".
ContextKey string
Expand All @@ -46,8 +53,6 @@ type (

// ModifyResponse defines function to modify response from ProxyTarget.
ModifyResponse func(*http.Response) error

rewriteRegex map[*regexp.Regexp]string
}

// ProxyTarget defines the upstream target.
Expand Down Expand Up @@ -206,7 +211,14 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
panic("echo: proxy middleware requires balancer")
}

config.rewriteRegex = rewriteRulesRegex(config.Rewrite)
if config.Rewrite != nil {
if config.RegexRewrite == nil {
config.RegexRewrite = make(map[*regexp.Regexp]string)
}
for k, v := range rewriteRulesRegex(config.Rewrite) {
config.RegexRewrite[k] = v
}
}

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
Expand All @@ -220,7 +232,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
c.Set(config.ContextKey, tgt)

// Set rewrite path and raw path
rewritePath(config.rewriteRegex, req)
rewritePath(config.RegexRewrite, req)

// Fix header
// Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream.
Expand Down Expand Up @@ -251,5 +263,3 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
}
}
}


142 changes: 102 additions & 40 deletions middleware/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"testing"

"github.com/labstack/echo/v4"
Expand Down Expand Up @@ -83,46 +84,6 @@ func TestProxy(t *testing.T) {
body = rec.Body.String()
assert.Equal(t, "target 2", body)

// Rewrite
e = echo.New()
e.Use(ProxyWithConfig(ProxyConfig{
Balancer: rrb,
Rewrite: map[string]string{
"/old": "/new",
"/api/*": "/$1",
"/js/*": "/public/javascripts/$1",
"/users/*/orders/*": "/user/$1/order/$2",
},
}))
req.URL, _ = url.Parse("/api/users")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/users", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL, _ = url.Parse( "/js/main.js")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL, _ = url.Parse("/old")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/new", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL, _ = url.Parse( "/users/jack/orders/1")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL, _ = url.Parse("/api/new users")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/new%20users", req.URL.EscapedPath())
// ModifyResponse
e = echo.New()
e.Use(ProxyWithConfig(ProxyConfig{
Expand Down Expand Up @@ -196,3 +157,104 @@ func TestProxyRealIPHeader(t *testing.T) {
assert.Equal(t, tt.extectedXRealIP, req.Header.Get(echo.HeaderXRealIP), "hasRealIPheader: %t / hasIPExtractor: %t", tt.hasRealIPheader, tt.hasIPExtractor)
}
}

func TestProxyRewrite(t *testing.T) {
// Setup
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer upstream.Close()
url, _ := url.Parse(upstream.URL)
rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

// Rewrite
e := echo.New()
e.Use(ProxyWithConfig(ProxyConfig{
Balancer: rrb,
Rewrite: map[string]string{
"/old": "/new",
"/api/*": "/$1",
"/js/*": "/public/javascripts/$1",
"/users/*/orders/*": "/user/$1/order/$2",
},
}))
req.URL, _ = url.Parse("/api/users")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/users", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL, _ = url.Parse("/js/main.js")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL, _ = url.Parse("/old")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/new", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL, _ = url.Parse("/users/jack/orders/1")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
req.URL, _ = url.Parse("/api/new users")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/new%20users", req.URL.EscapedPath())
}

func TestProxyRewriteRegex(t *testing.T) {
// Setup
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer upstream.Close()
url, _ := url.Parse(upstream.URL)
rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()

// Rewrite
e := echo.New()
e.Use(ProxyWithConfig(ProxyConfig{
Balancer: rrb,
Rewrite: map[string]string{
"^/a/*": "/v1/$1",
"^/b/*/c/*": "/v2/$2/$1",
"^/c/*/*": "/v3/$2",
},
RegexRewrite: map[*regexp.Regexp]string{
regexp.MustCompile("^/x/.+?/(.*)"): "/v4/$1",
regexp.MustCompile("^/y/(.+?)/(.*)"): "/v5/$2/$1",
},
}))

testCases := []struct {
requestPath string
statusCode int
expectPath string
}{
{"/unmatched", http.StatusOK, "/unmatched"},
{"/a/test", http.StatusOK, "/v1/test"},
{"/b/foo/c/bar/baz", http.StatusOK, "/v2/bar/baz/foo"},
{"/c/ignore/test", http.StatusOK, "/v3/test"},
{"/c/ignore1/test/this", http.StatusOK, "/v3/test/this"},
{"/x/ignore/test", http.StatusOK, "/v4/test"},
{"/y/foo/bar", http.StatusOK, "/v5/bar/foo"},
}


for _, tc := range testCases {
aldas marked this conversation as resolved.
Show resolved Hide resolved
t.Run(tc.requestPath, func(t *testing.T) {
req.URL, _ = url.Parse(tc.requestPath)
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectPath, req.URL.EscapedPath())
assert.Equal(t, tc.statusCode, rec.Code)
})
}
}
24 changes: 18 additions & 6 deletions middleware/rewrite.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package middleware

import (
"github.com/labstack/echo/v4"
"regexp"

"github.com/labstack/echo/v4"
)

type (
Expand All @@ -21,7 +22,12 @@ type (
// Required.
Rules map[string]string `yaml:"rules"`

rulesRegex map[*regexp.Regexp]string
// RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
// Example:
// "^/old/[0.9]+/": "/new",
// "^/api/.+?/(.*)": "/v2/$1",
RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"`
}
)

Expand All @@ -45,14 +51,20 @@ func Rewrite(rules map[string]string) echo.MiddlewareFunc {
// See: `Rewrite()`.
func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
// Defaults
if config.Rules == nil {
panic("echo: rewrite middleware requires url path rewrite rules")
if config.Rules == nil && config.RegexRules == nil {
panic("echo: rewrite middleware requires url path rewrite rules or regex rules")
}

if config.Skipper == nil {
config.Skipper = DefaultBodyDumpConfig.Skipper
}

config.rulesRegex = rewriteRulesRegex(config.Rules)
if config.RegexRules == nil {
config.RegexRules = make(map[*regexp.Regexp]string)
}
for k, v := range rewriteRulesRegex(config.Rules) {
config.RegexRules[k] = v
}

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
Expand All @@ -62,7 +74,7 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {

req := c.Request()
// Set rewrite path and raw path
rewritePath(config.rulesRegex, req)
rewritePath(config.RegexRules, req)
return next(c)
}
}
Expand Down
47 changes: 45 additions & 2 deletions middleware/rewrite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"testing"

"github.com/labstack/echo/v4"
Expand Down Expand Up @@ -55,8 +56,8 @@ func TestEchoRewritePreMiddleware(t *testing.T) {

// Rewrite old url to new one
e.Pre(Rewrite(map[string]string{
"/old": "/new",
},
"/old": "/new",
},
))

// Route
Expand Down Expand Up @@ -129,3 +130,45 @@ func TestEchoRewriteWithCaret(t *testing.T) {
e.ServeHTTP(rec, req)
assert.Equal(t, "/v2/abc/test", req.URL.Path)
}

// Verify regex used with rewrite
func TestEchoRewriteWithRegexRules(t *testing.T) {
e := echo.New()

e.Pre(RewriteWithConfig(RewriteConfig{
Rules: map[string]string{
"^/a/*": "/v1/$1",
"^/b/*/c/*": "/v2/$2/$1",
"^/c/*/*": "/v3/$2",
},
RegexRules: map[*regexp.Regexp]string{
regexp.MustCompile("^/x/.+?/(.*)"): "/v4/$1",
regexp.MustCompile("^/y/(.+?)/(.*)"): "/v5/$2/$1",
},
}))

var rec *httptest.ResponseRecorder
var req *http.Request

testCases := []struct {
requestPath string
expectPath string
}{
{"/unmatched", "/unmatched"},
{"/a/test", "/v1/test"},
{"/b/foo/c/bar/baz", "/v2/bar/baz/foo"},
{"/c/ignore/test", "/v3/test"},
{"/c/ignore1/test/this", "/v3/test/this"},
{"/x/ignore/test", "/v4/test"},
{"/y/foo/bar", "/v5/bar/foo"},
}

for _, tc := range testCases {
t.Run(tc.requestPath, func(t *testing.T) {
req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil)
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectPath, req.URL.EscapedPath())
})
}
}