Skip to content

Commit

Permalink
fix merge of options while inserting handler in the radix tree
Browse files Browse the repository at this point in the history
  • Loading branch information
gmgigi96 committed Jul 17, 2023
1 parent b7dea87 commit 3f7f582
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 15 deletions.
13 changes: 8 additions & 5 deletions pkg/rhttp/mux/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,10 @@ func (m *ServeMux) Options(path string, handler http.Handler, o ...Option) {
}

func (m *ServeMux) Walk(ctx context.Context, f WalkFunc) {
m.tree.root.walk(ctx, "", f)
m.tree.root.walk(ctx, "", &nodeOptions{}, f)
}

func (n *node) walk(ctx context.Context, prefix string, f WalkFunc) {
func (n *node) walk(ctx context.Context, prefix string, merged *nodeOptions, f WalkFunc) {
select {
case <-ctx.Done():
return
Expand All @@ -147,16 +147,19 @@ func (n *node) walk(ctx context.Context, prefix string, f WalkFunc) {
}

path := prefix + current
opts := merged.merge(&n.opts)

for method, h := range n.handlers.perMethod {
f(method, path, h, n.opts.get(method))
f(method, path, h, opts.get(method))
}

if g := n.handlers.global; g != nil {
f(MethodAll, path, g, n.opts.global)
o := n.opts.global.merge(merged.global)
f(MethodAll, path, g, o)
}

for _, c := range n.children {
c.walk(ctx, path, f)
c.walk(ctx, path, opts, f)
}
}

Expand Down
31 changes: 25 additions & 6 deletions pkg/rhttp/mux/radix.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,25 @@ func (n *nodeOptions) get(method string) *Options {
return global.merge(perMethod)
}

func (n *nodeOptions) merge(other *nodeOptions) *nodeOptions {
merged := nodeOptions{}
merged.global = n.global
merged.opts = nilMap[*Options]{}
if other.global != nil {
merged.global = merged.global.merge(other.global)
}
for method, opt := range n.opts {
merged.set(method, other.get(method).merge(opt))
}
for method, opt := range other.opts {
if _, ok := merged.opts[method]; ok {
continue
}
merged.opts.add(method, opt)
}
return &merged
}

type handlers struct {
global http.Handler
perMethod nilMap[http.Handler]
Expand Down Expand Up @@ -295,7 +314,7 @@ func (n *node) mergeOptions(method string, opts *Options) *Options {
func (n *node) insert(method, path string, handler http.Handler, opts *Options) {
if n.prefix == "" {
// the tree is empty
n.insertChild(method, path, handler, opts)
n.insertChild(method, path, handler, nil, opts)
return
}

Expand Down Expand Up @@ -324,8 +343,7 @@ walk:
}
}

opts = merged.merge(opts)
current.insertChild(method, path, handler, opts)
current.insertChild(method, path, handler, merged, opts)
return
}
}
Expand All @@ -339,13 +357,14 @@ func wildcardIndex(s string) int {
return -1
}

func (n *node) insertChild(method, path string, handler http.Handler, opts *Options) {
func (n *node) insertChild(method, path string, handler http.Handler, merged, opts *Options) {
current := n
for {
if path == "" {
if handler != nil {
if n.middlewareFactory != nil && opts != nil {
for _, mid := range n.middlewareFactory(opts) {
if n.middlewareFactory != nil {
merged = merged.merge(opts)
for _, mid := range n.middlewareFactory(merged) {
handler = mid(handler)
}
}
Expand Down
26 changes: 22 additions & 4 deletions pkg/rhttp/mux/radix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package mux

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/cs3org/reva/pkg/rhttp/middlewares"
Expand Down Expand Up @@ -611,7 +612,7 @@ func TestInsertOptions(t *testing.T) {
{
prefix: "blog",
ntype: static,
opts: nodeOptions{opts: nilMap[*Options]{"GET": &Options{Unprotected: true}}},
opts: nodeOptions{},
},
},
},
Expand All @@ -626,17 +627,34 @@ func TestInsertOptions(t *testing.T) {

func TestMultipleMiddlewaresAlongTheWay(t *testing.T) {
nop := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
m := middlewares.Middleware(func(h http.Handler) http.Handler { return nop })
var count int
m := middlewares.Middleware(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count++
h.ServeHTTP(w, r)
})
})

tree := newTree()
tree.root.middlewareFactory = func(o *Options) []middlewares.Middleware {
return o.Middlewares
}

tree.insert("GET", "/", nop, &Options{Middlewares: []middlewares.Middleware{m}})
tree.insert("POST", "/", nop, &Options{Middlewares: []middlewares.Middleware{m}})
tree.insert(MethodAll, "/test/path", nop, &Options{Middlewares: []middlewares.Middleware{m}})
tree.insert(MethodAll, "/testing", nop, &Options{Middlewares: []middlewares.Middleware{m}})
tree.insert("POST", "/test/path/other", nop, &Options{Middlewares: []middlewares.Middleware{m}})
tree.insert("POST", "/test/path/other/some/thing", nop, &Options{Middlewares: []middlewares.Middleware{m}})

n, _, ok := tree.root.lookup("/test/path/other/some/thing")
assert.Equal(t, true, ok)
assert.Equal(t, 1, len(n.opts.opts["POST"].Middlewares))

n, _, ok := tree.root.lookup("/test/path/other")
handler, ok := n.handlers.get("POST")
assert.Equal(t, true, ok)
assert.Equal(t, 3, len(n.opts.opts["POST"].Middlewares))
w := httptest.NewRecorder()
r, _ := http.NewRequest("POST", "/test/path/other/some/thing", nil)
handler.ServeHTTP(w, r)
assert.Equal(t, 4, count)
}

0 comments on commit 3f7f582

Please sign in to comment.