diff --git a/gin.go b/gin.go index 6ab2be66d0..c4ef2c07e0 100644 --- a/gin.go +++ b/gin.go @@ -19,6 +19,9 @@ import ( ) const defaultMultipartMemory = 32 << 20 // 32 MB +const escapedColon = "\\:" +const colon = ":" +const backslash = "\\" var ( default404Body = []byte("404 page not found") @@ -345,6 +348,7 @@ func (engine *Engine) Run(addr ...string) (err error) { if err != nil { return err } + engine.updateRouteTrees() address := resolveAddress(addr) debugPrint("Listening and serving HTTP on %s\n", address) @@ -394,6 +398,25 @@ func (engine *Engine) parseTrustedProxies() error { return err } +// updateRouteTree do update to the route tree recursively +func updateRouteTree(n *node) { + n.path = strings.ReplaceAll(n.path, escapedColon, colon) + n.fullPath = strings.ReplaceAll(n.fullPath, escapedColon, colon) + n.indices = strings.ReplaceAll(n.indices, backslash, colon) + if n.children != nil { + for _, child := range n.children { + updateRouteTree(child) + } + } +} + +// updateRouteTrees do update to the route trees +func (engine *Engine) updateRouteTrees() { + for _, tree := range engine.trees { + updateRouteTree(tree.root) + } +} + // parseIP parse a string representation of an IP and returns a net.IP with the // minimum byte representation or nil if input is invalid. func parseIP(ip string) net.IP { diff --git a/gin_integration_test.go b/gin_integration_test.go index 094c46e871..52052c1672 100644 --- a/gin_integration_test.go +++ b/gin_integration_test.go @@ -525,3 +525,23 @@ func TestTreeRunDynamicRouting(t *testing.T) { testRequest(t, ts.URL+"/addr/dd/aa", "404 Not Found") testRequest(t, ts.URL+"/something/secondthing/121", "404 Not Found") } + +func TestEscapedColon(t *testing.T) { + router := New() + f := func(u string) { + router.GET(u, func(c *Context) { c.String(http.StatusOK, u) }) + } + f("/r/r\\:r") + f("/r/r:r") + f("/r/r/:r") + f("/r/r/\\:r") + + router.updateRouteTrees() + ts := httptest.NewServer(router) + defer ts.Close() + + testRequest(t, ts.URL+"/r/r123", "", "/r/r:r") + testRequest(t, ts.URL+"/r/r:r", "", "/r/r\\:r") + testRequest(t, ts.URL+"/r/r/r123", "", "/r/r/:r") + testRequest(t, ts.URL+"/r/r/:r", "", "/r/r/\\:r") +} diff --git a/tree.go b/tree.go index fb0a5935c2..f505689f91 100644 --- a/tree.go +++ b/tree.go @@ -261,7 +261,20 @@ walk: // Returns -1 as index, if no wildcard was found. func findWildcard(path string) (wildcard string, i int, valid bool) { // Find start + escapeColon := false for start, c := range []byte(path) { + if escapeColon { + if c == ':' { + escapeColon = false + continue + } else { + panic("invalid escaped char in " + path) + } + } + if c == '\\' { + escapeColon = true + continue + } // A wildcard starts with ':' (param) or '*' (catch-all) if c != ':' && c != '*' { continue @@ -402,6 +415,7 @@ func (n *node) getValue(path string, params *Params, unescape bool) (value nodeV var ( skippedPath string latestNode = n // Caching the latest node + latestPath = path ) walk: // Outer loop for walking the tree @@ -427,6 +441,7 @@ walk: // Outer loop for walking the tree handlers: n.handlers, fullPath: n.fullPath, } + latestPath = path } n = n.children[i] @@ -457,7 +472,7 @@ walk: // Outer loop for walking the tree // fix truncate the parameter // tree_test.go line: 204 if matched { - path = prefix + path + path = latestPath // The saved path is used after the prefix route is intercepted by matching if n.indices == "/" { path = skippedPath[1:] diff --git a/tree_test.go b/tree_test.go index cbb37340ef..c7b1e292f0 100644 --- a/tree_test.go +++ b/tree_test.go @@ -185,6 +185,7 @@ func TestTreeWildcard(t *testing.T) { "/get/abc/123abg/:param", "/get/abc/123abf/:param", "/get/abc/123abfff/:param", + "/get/abc/escaped_colon/test\\:param", } for _, route := range routes { tree.addRoute(route, fakeHandler(route)) @@ -305,6 +306,7 @@ func TestTreeWildcard(t *testing.T) { {"/get/abc/123abg/test", false, "/get/abc/123abg/:param", Params{Param{Key: "param", Value: "test"}}}, {"/get/abc/123abf/testss", false, "/get/abc/123abf/:param", Params{Param{Key: "param", Value: "testss"}}}, {"/get/abc/123abfff/te", false, "/get/abc/123abfff/:param", Params{Param{Key: "param", Value: "te"}}}, + {"/get/abc/escaped_colon/test\\:param", false, "/get/abc/escaped_colon/test\\:param", nil}, }) checkPriorities(t, tree) @@ -407,6 +409,9 @@ func TestTreeWildcardConflict(t *testing.T) { {"/user_:name", false}, {"/id:id", false}, {"/id/:id", false}, + {"/escape/test\\:d1", false}, + {"/escape/test\\:d2", false}, + {"/escape/test:param", false}, } testRoutes(t, routes) } @@ -886,3 +891,22 @@ func TestTreeWildcardConflictEx(t *testing.T) { } } } + +func TestTreeInvalidEscape(t *testing.T) { + routes := map[string]bool{ + "/r1/r": true, + "/r2/:r": true, + "/r3/\\:r": true, + "/r4/\\\\:r": false, + "/r5/\\~:r": false, + } + tree := &node{} + for route, valid := range routes { + recv := catchPanic(func() { + tree.addRoute(route, fakeHandler(route)) + }) + if recv == nil != valid { + t.Fatalf("%s should be %t but got %v", route, valid, recv) + } + } +}