diff --git a/eager_router.go b/eager_router.go index 656d8e8..e04a1af 100644 --- a/eager_router.go +++ b/eager_router.go @@ -78,7 +78,7 @@ func (fanIn *eagerRouterFanIn) Aggregate( routes []Component // response labels - labels Labels = NewLabelsMap() + labels Labels // index of current primary route currentRouteIdx int @@ -88,6 +88,13 @@ func (fanIn *eagerRouterFanIn) Aggregate( masterResponse Response ) + labelMap, ok := ctx.Value(CtxComponentLabelsKey).(Labels) + if ok { + labels = labelMap + } else { + labels = NewLabelsMap() + } + for masterResponse == nil { select { case resp, ok := <-responseCh: @@ -98,7 +105,11 @@ func (fanIn *eagerRouterFanIn) Aggregate( } case routesOrderResponse, ok := <-routesOrderCh: if ok { - labels = routesOrderResponse.Labels + //Overwrite parent labels with strategy labels + for _, key := range routesOrderResponse.Labels.Keys() { + labels.WithLabel(key, routesOrderResponse.Labels.Label(key)...) + } + if routesOrderResponse.Err != nil { masterResponse = NewErrorResponse(errors.NewFiberError(req.Protocol(), routesOrderResponse.Err)) } else { diff --git a/labels_test.go b/labels_test.go index ead1f21..5e695e4 100644 --- a/labels_test.go +++ b/labels_test.go @@ -1,10 +1,15 @@ package fiber_test import ( + "context" "sort" + "strings" "testing" "github.com/gojek/fiber" + "github.com/gojek/fiber/extras" + "github.com/gojek/fiber/internal/testutils" + testUtilsHttp "github.com/gojek/fiber/internal/testutils/http" "github.com/stretchr/testify/assert" ) @@ -89,3 +94,84 @@ func TestLabelsMapLabel(t *testing.T) { }) } } + +// This test uses RandomRoutingStrategy which will always append context of idx +// Labels should be preserved for fiber.CtxComponentLabelsKey context key +func Test_Router_Dispatch_Labels(t *testing.T) { + lazyRouter := fiber.NewLazyRouter("lazy-router") + lazyRouter.SetStrategy(new(extras.RandomRoutingStrategy)) + + eagerRouter := fiber.NewEagerRouter("eager-router") + eagerRouter.SetStrategy(new(extras.RandomRoutingStrategy)) + + testRouters := []fiber.MultiRouteComponent{eagerRouter, lazyRouter} + + tests := []struct { + name string + initialLabelKey any + initialLabelValue any + expectedLabelKey string + expectedLabelValue string + router []fiber.MultiRouteComponent + }{ + { + name: "new label", + expectedLabelKey: "idx", + expectedLabelValue: "0", + router: testRouters, + }, + { + name: "overwritten label", + initialLabelKey: "idx", + initialLabelValue: "111", + expectedLabelKey: "idx", + expectedLabelValue: "0", + router: testRouters, + }, + { + name: "existing label not preserved, wrong key", + initialLabelKey: "t", + initialLabelValue: "11", + expectedLabelKey: "t", + expectedLabelValue: "", + router: testRouters, + }, + { + name: "existing label preserved", + initialLabelKey: fiber.CtxComponentLabelsKey, + initialLabelValue: fiber.NewLabelsMap().WithLabel("t", "11"), + expectedLabelKey: "t", + expectedLabelValue: "11", + router: testRouters, + }, + { + name: "existing label not preserved, unexpected value type", + initialLabelKey: fiber.CtxComponentLabelsKey, + initialLabelValue: map[string]string{"t": "11"}, + expectedLabelKey: "t", + expectedLabelValue: "", + router: testRouters, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for _, router := range tt.router { + router.SetRoutes(map[string]fiber.Component{ + "route-a": testutils.NewMockComponent( + "route-a", + testUtilsHttp.DelayedResponse{Response: testUtilsHttp.MockResp(200, "A-OK", nil, nil)}), + }) + ctx := context.Background() + if tt.initialLabelKey != nil { + ctx = context.WithValue(ctx, tt.initialLabelKey, tt.initialLabelValue) + } + request := testUtilsHttp.MockReq("POST", "http://localhost:8080/router", "payload") + resp, ok := <-router.Dispatch(ctx, request).Iter() + assert.True(t, ok) + label := strings.Join(resp.Label(tt.expectedLabelKey), ",") + assert.Equal(t, tt.expectedLabelValue, label) + } + }) + } +} diff --git a/lazy_router.go b/lazy_router.go index 50eb5be..3fe3f0b 100644 --- a/lazy_router.go +++ b/lazy_router.go @@ -2,7 +2,6 @@ package fiber import ( "context" - "github.com/gojek/fiber/errors" "github.com/gojek/fiber/util" ) @@ -50,14 +49,25 @@ func (r *LazyRouter) Dispatch(ctx context.Context, req Request) ResponseQueue { defer close(out) var routes []Component - var labels Labels = NewLabelsMap() + var labels Labels + + labelMap, ok := ctx.Value(CtxComponentLabelsKey).(Labels) + if ok { + labels = labelMap + } else { + labels = NewLabelsMap() + } routesOrderCh := r.strategy.getRoutesOrder(ctx, req, r.routes) select { case routesOrderResponse, ok := <-routesOrderCh: if ok { - labels = routesOrderResponse.Labels + //Overwrite parent labels with strategy labels + for _, key := range routesOrderResponse.Labels.Keys() { + labels.WithLabel(key, routesOrderResponse.Labels.Label(key)...) + } + if routesOrderResponse.Err != nil { out <- NewErrorResponse(errors.NewFiberError(req.Protocol(), routesOrderResponse.Err)).WithLabels(labels) return