diff --git a/internal/gateway/gateway.go b/internal/gateway/gateway.go index c3698b4bc..80e96714d 100644 --- a/internal/gateway/gateway.go +++ b/internal/gateway/gateway.go @@ -26,6 +26,9 @@ import ( "github.com/cloudevents/sdk-go/v2/client" "github.com/cloudevents/sdk-go/v2/protocol" cehttp "github.com/cloudevents/sdk-go/v2/protocol/http" + "go.opentelemetry.io/otel/trace" + "google.golang.org/grpc/credentials/insecure" + "github.com/vanus-labs/vanus/internal/gateway/proxy" "github.com/vanus-labs/vanus/internal/primitive" "github.com/vanus-labs/vanus/internal/primitive/vanus" @@ -35,8 +38,6 @@ import ( "github.com/vanus-labs/vanus/proto/pkg/cloudevents" "github.com/vanus-labs/vanus/proto/pkg/codec" proxypb "github.com/vanus-labs/vanus/proto/pkg/proxy" - "go.opentelemetry.io/otel/trace" - "google.golang.org/grpc/credentials/insecure" ) var requestDataFromContext = cehttp.RequestDataFromContext @@ -124,31 +125,48 @@ func (ga *ceGateway) receive(ctx context.Context, event v2.Event) (re *v2.Event, } const ( - httpRequestPrefix = "/gateway" + httpRequestPrefix = "gateway" ) func (ga *ceGateway) getEventbusFromPath(ctx context.Context, reqData *cehttp.RequestData) (vanus.ID, error) { // TODO validate reqPathStr := reqData.URL.String() + reqPathStr = reqPathStr[1:] var ( ns string name string ) - if strings.HasPrefix(reqPathStr, httpRequestPrefix) { // Deprecated, just for compatibility of older than v0.7.0 - ns = primitive.DefaultNamespace - name = strings.TrimLeft(reqPathStr[len(httpRequestPrefix):], "/") - } else { - // namespaces/:namespace_name/eventbus/:eventbus_name/events - path := strings.TrimLeft(reqData.URL.String(), "/") - strs := strings.Split(path, "/") - if len(strs) != 5 { + paths := strings.Split(reqPathStr, "/") + switch len(paths) { + case 2: + if paths[0] == httpRequestPrefix { // Deprecated, just for compatibility of older than v0.7.0 + // gateway/eb_name + ns = primitive.DefaultNamespace + name = paths[1] + } else if paths[1] == "events" { + // eb_id/events + eventbusID, err := vanus.NewIDFromString(paths[0]) + if err != nil { + return 0, err + } + // check eb exist + _, err = ga.ctrl.EventbusService().GetEventbus(ctx, eventbusID.Uint64()) + if err != nil { + return 0, err + } + return eventbusID, nil + } else { return 0, errors.New("invalid request path") } - if strs[0] != "namespaces" && strs[2] != "eventbus" && strs[4] != "events" { + case 5: + // namespaces/:namespace_name/eventbus/:eventbus_name/events + if paths[0] != "namespaces" && paths[2] != "eventbus" && paths[4] != "events" { return 0, errors.New("invalid request path") } - ns = strs[1] - name = strs[3] + ns = paths[1] + name = paths[3] + default: + return 0, errors.New("invalid request path") } if ns == "" { diff --git a/internal/gateway/gateway_test.go b/internal/gateway/gateway_test.go index 3b51c8cfc..e4c585587 100644 --- a/internal/gateway/gateway_test.go +++ b/internal/gateway/gateway_test.go @@ -157,23 +157,41 @@ func TestGateway_getEventbusFromPath(t *testing.T) { ebSvc := cluster.NewMockEventbusService(ctrl) cctrl.EXPECT().EventbusService().AnyTimes().Return(ebSvc) - ebID := vanus.NewTestID().Uint64() + ebID := vanus.NewTestID() ebSvc.EXPECT().GetEventbusByName(Any(), "default", "test").Times(1).Return( &metapb.Eventbus{ Name: "test", LogNumber: 1, - Id: ebID, + Id: ebID.Uint64(), Description: "desc", NamespaceId: vanus.NewTestID().Uint64(), }, nil) + ebSvc.EXPECT().GetEventbus(Any(), ebID.Uint64()).Times(1).Return(&metapb.Eventbus{ + Name: "test", + LogNumber: 1, + Id: ebID.Uint64(), + Description: "desc", + NamespaceId: vanus.NewTestID().Uint64(), + }, nil) ga.ctrl = cctrl - Convey("test get eventbus from path return path ", t, func() { + Convey("test get eventbus from path name return path ", t, func() { reqData := &cehttp.RequestData{ URL: &url.URL{ Opaque: "/namespaces/default/eventbus/test/events", }, } + id, err := ga.getEventbusFromPath(context.Background(), reqData) + So(err, ShouldBeNil) + So(id, ShouldEqual, ebID) + }) + Convey("test get eventbus from path id return path ", t, func() { + reqData := &cehttp.RequestData{ + URL: &url.URL{ + Opaque: fmt.Sprintf("/%s/events", ebID.String()), + }, + } + id, err := ga.getEventbusFromPath(context.Background(), reqData) So(err, ShouldBeNil) So(id, ShouldEqual, ebID)