diff --git a/pkg/rhttp/mux/mux.go b/pkg/rhttp/mux/mux.go index 577e0508f7d..3ec998d48b6 100644 --- a/pkg/rhttp/mux/mux.go +++ b/pkg/rhttp/mux/mux.go @@ -121,10 +121,10 @@ func (m *ServeMux) Options(path string, handler http.Handler, o ...Option) { } func (m *ServeMux) Walk(ctx context.Context, f WalkFunc) { - m.tree.root.walk(ctx, "", f) + m.tree.root.walk(ctx, "", &nodeOptions{}, f) } -func (n *node) walk(ctx context.Context, prefix string, f WalkFunc) { +func (n *node) walk(ctx context.Context, prefix string, merged *nodeOptions, f WalkFunc) { select { case <-ctx.Done(): return @@ -147,16 +147,19 @@ func (n *node) walk(ctx context.Context, prefix string, f WalkFunc) { } path := prefix + current + opts := merged.merge(&n.opts) + for method, h := range n.handlers.perMethod { - f(method, path, h, n.opts.get(method)) + f(method, path, h, opts.get(method)) } if g := n.handlers.global; g != nil { - f(MethodAll, path, g, n.opts.global) + o := n.opts.global.merge(merged.global) + f(MethodAll, path, g, o) } for _, c := range n.children { - c.walk(ctx, path, f) + c.walk(ctx, path, opts, f) } } diff --git a/pkg/rhttp/mux/radix.go b/pkg/rhttp/mux/radix.go index 1ed972c0eee..f9abab54118 100644 --- a/pkg/rhttp/mux/radix.go +++ b/pkg/rhttp/mux/radix.go @@ -74,6 +74,25 @@ func (n *nodeOptions) get(method string) *Options { return global.merge(perMethod) } +func (n *nodeOptions) merge(other *nodeOptions) *nodeOptions { + merged := nodeOptions{} + merged.global = n.global + merged.opts = nilMap[*Options]{} + if other.global != nil { + merged.global = merged.global.merge(other.global) + } + for method, opt := range n.opts { + merged.set(method, other.get(method).merge(opt)) + } + for method, opt := range other.opts { + if _, ok := merged.opts[method]; ok { + continue + } + merged.opts.add(method, opt) + } + return &merged +} + type handlers struct { global http.Handler perMethod nilMap[http.Handler] @@ -295,7 +314,7 @@ func (n *node) mergeOptions(method string, opts *Options) *Options { func (n *node) insert(method, path string, handler http.Handler, opts *Options) { if n.prefix == "" { // the tree is empty - n.insertChild(method, path, handler, opts) + n.insertChild(method, path, handler, nil, opts) return } @@ -324,8 +343,7 @@ walk: } } - opts = merged.merge(opts) - current.insertChild(method, path, handler, opts) + current.insertChild(method, path, handler, merged, opts) return } } @@ -339,13 +357,14 @@ func wildcardIndex(s string) int { return -1 } -func (n *node) insertChild(method, path string, handler http.Handler, opts *Options) { +func (n *node) insertChild(method, path string, handler http.Handler, merged, opts *Options) { current := n for { if path == "" { if handler != nil { - if n.middlewareFactory != nil && opts != nil { - for _, mid := range n.middlewareFactory(opts) { + if n.middlewareFactory != nil { + merged = merged.merge(opts) + for _, mid := range n.middlewareFactory(merged) { handler = mid(handler) } } diff --git a/pkg/rhttp/mux/radix_test.go b/pkg/rhttp/mux/radix_test.go index beb0b7c3b63..360be50ea5d 100644 --- a/pkg/rhttp/mux/radix_test.go +++ b/pkg/rhttp/mux/radix_test.go @@ -20,6 +20,7 @@ package mux import ( "net/http" + "net/http/httptest" "testing" "github.com/cs3org/reva/pkg/rhttp/middlewares" @@ -611,7 +612,7 @@ func TestInsertOptions(t *testing.T) { { prefix: "blog", ntype: static, - opts: nodeOptions{opts: nilMap[*Options]{"GET": &Options{Unprotected: true}}}, + opts: nodeOptions{}, }, }, }, @@ -626,17 +627,34 @@ func TestInsertOptions(t *testing.T) { func TestMultipleMiddlewaresAlongTheWay(t *testing.T) { nop := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - m := middlewares.Middleware(func(h http.Handler) http.Handler { return nop }) + var count int + m := middlewares.Middleware(func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count++ + h.ServeHTTP(w, r) + }) + }) tree := newTree() + tree.root.middlewareFactory = func(o *Options) []middlewares.Middleware { + return o.Middlewares + } tree.insert("GET", "/", nop, &Options{Middlewares: []middlewares.Middleware{m}}) tree.insert("POST", "/", nop, &Options{Middlewares: []middlewares.Middleware{m}}) tree.insert(MethodAll, "/test/path", nop, &Options{Middlewares: []middlewares.Middleware{m}}) tree.insert(MethodAll, "/testing", nop, &Options{Middlewares: []middlewares.Middleware{m}}) tree.insert("POST", "/test/path/other", nop, &Options{Middlewares: []middlewares.Middleware{m}}) + tree.insert("POST", "/test/path/other/some/thing", nop, &Options{Middlewares: []middlewares.Middleware{m}}) + + n, _, ok := tree.root.lookup("/test/path/other/some/thing") + assert.Equal(t, true, ok) + assert.Equal(t, 1, len(n.opts.opts["POST"].Middlewares)) - n, _, ok := tree.root.lookup("/test/path/other") + handler, ok := n.handlers.get("POST") assert.Equal(t, true, ok) - assert.Equal(t, 3, len(n.opts.opts["POST"].Middlewares)) + w := httptest.NewRecorder() + r, _ := http.NewRequest("POST", "/test/path/other/some/thing", nil) + handler.ServeHTTP(w, r) + assert.Equal(t, 4, count) }