diff --git a/tenant/resolver.go b/tenant/resolver.go new file mode 100644 index 000000000..607a290e5 --- /dev/null +++ b/tenant/resolver.go @@ -0,0 +1,158 @@ +package tenant + +import ( + "context" + "errors" + "net/http" + "strings" + + "github.com/weaveworks/common/user" +) + +var defaultResolver Resolver = NewSingleResolver() + +// WithDefaultResolver updates the resolver used for the package methods. +func WithDefaultResolver(r Resolver) { + defaultResolver = r +} + +// TenantID returns exactly a single tenant ID from the context. It should be +// used when a certain endpoint should only support exactly a single +// tenant ID. It returns an error user.ErrNoOrgID if there is no tenant ID +// supplied or user.ErrTooManyOrgIDs if there are multiple tenant IDs present. +// +// ignore stutter warning +//nolint:revive +func TenantID(ctx context.Context) (string, error) { + return defaultResolver.TenantID(ctx) +} + +// TenantIDs returns all tenant IDs from the context. It should return +// normalized list of ordered and distinct tenant IDs (as produced by +// NormalizeTenantIDs). +// +// ignore stutter warning +//nolint:revive +func TenantIDs(ctx context.Context) ([]string, error) { + return defaultResolver.TenantIDs(ctx) +} + +type Resolver interface { + // TenantID returns exactly a single tenant ID from the context. It should be + // used when a certain endpoint should only support exactly a single + // tenant ID. It returns an error user.ErrNoOrgID if there is no tenant ID + // supplied or user.ErrTooManyOrgIDs if there are multiple tenant IDs present. + TenantID(context.Context) (string, error) + + // TenantIDs returns all tenant IDs from the context. It should return + // normalized list of ordered and distinct tenant IDs (as produced by + // NormalizeTenantIDs). + TenantIDs(context.Context) ([]string, error) +} + +// NewSingleResolver creates a tenant resolver, which restricts all requests to +// be using a single tenant only. This allows a wider set of characters to be +// used within the tenant ID and should not impose a breaking change. +func NewSingleResolver() *SingleResolver { + return &SingleResolver{} +} + +type SingleResolver struct { +} + +// containsUnsafePathSegments will return true if the string is a directory +// reference like `.` and `..` or if any path separator character like `/` and +// `\` can be found. +func containsUnsafePathSegments(id string) bool { + // handle the relative reference to current and parent path. + if id == "." || id == ".." { + return true + } + + return strings.ContainsAny(id, "\\/") +} + +var errInvalidTenantID = errors.New("invalid tenant ID") + +func (t *SingleResolver) TenantID(ctx context.Context) (string, error) { + //lint:ignore faillint wrapper around upstream method + id, err := user.ExtractOrgID(ctx) + if err != nil { + return "", err + } + + if containsUnsafePathSegments(id) { + return "", errInvalidTenantID + } + + return id, nil +} + +func (t *SingleResolver) TenantIDs(ctx context.Context) ([]string, error) { + orgID, err := t.TenantID(ctx) + if err != nil { + return nil, err + } + return []string{orgID}, err +} + +type MultiResolver struct { +} + +// NewMultiResolver creates a tenant resolver, which allows request to have +// multiple tenant ids submitted separated by a '|' character. This enforces +// further limits on the character set allowed within tenants as detailed here: +// https://cortexmetrics.io/docs/guides/limitations/#tenant-id-naming) +func NewMultiResolver() *MultiResolver { + return &MultiResolver{} +} + +func (t *MultiResolver) TenantID(ctx context.Context) (string, error) { + orgIDs, err := t.TenantIDs(ctx) + if err != nil { + return "", err + } + + if len(orgIDs) > 1 { + return "", user.ErrTooManyOrgIDs + } + + return orgIDs[0], nil +} + +func (t *MultiResolver) TenantIDs(ctx context.Context) ([]string, error) { + //lint:ignore faillint wrapper around upstream method + orgID, err := user.ExtractOrgID(ctx) + if err != nil { + return nil, err + } + + orgIDs := strings.Split(orgID, tenantIDsLabelSeparator) + for _, orgID := range orgIDs { + if err := ValidTenantID(orgID); err != nil { + return nil, err + } + if containsUnsafePathSegments(orgID) { + return nil, errInvalidTenantID + } + } + + return NormalizeTenantIDs(orgIDs), nil +} + +// ExtractTenantIDFromHTTPRequest extracts a single TenantID through a given +// resolver directly from a HTTP request. +func ExtractTenantIDFromHTTPRequest(req *http.Request) (string, context.Context, error) { + //lint:ignore faillint wrapper around upstream method + _, ctx, err := user.ExtractOrgIDFromHTTPRequest(req) + if err != nil { + return "", nil, err + } + + tenantID, err := defaultResolver.TenantID(ctx) + if err != nil { + return "", nil, err + } + + return tenantID, ctx, nil +} diff --git a/tenant/resolver_test.go b/tenant/resolver_test.go new file mode 100644 index 000000000..4d2da2416 --- /dev/null +++ b/tenant/resolver_test.go @@ -0,0 +1,149 @@ +package tenant + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/weaveworks/common/user" +) + +func strptr(s string) *string { + return &s +} + +type resolverTestCase struct { + name string + headerValue *string + errTenantID error + errTenantIDs error + tenantID string + tenantIDs []string +} + +func (tc *resolverTestCase) test(r Resolver) func(t *testing.T) { + return func(t *testing.T) { + + ctx := context.Background() + if tc.headerValue != nil { + ctx = user.InjectOrgID(ctx, *tc.headerValue) + } + + tenantID, err := r.TenantID(ctx) + if tc.errTenantID != nil { + assert.Equal(t, tc.errTenantID, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.tenantID, tenantID) + } + + tenantIDs, err := r.TenantIDs(ctx) + if tc.errTenantIDs != nil { + assert.Equal(t, tc.errTenantIDs, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.tenantIDs, tenantIDs) + } + } +} + +var commonResolverTestCases = []resolverTestCase{ + { + name: "no-header", + errTenantID: user.ErrNoOrgID, + errTenantIDs: user.ErrNoOrgID, + }, + { + name: "empty", + headerValue: strptr(""), + tenantIDs: []string{""}, + }, + { + name: "single-tenant", + headerValue: strptr("tenant-a"), + tenantID: "tenant-a", + tenantIDs: []string{"tenant-a"}, + }, + { + name: "parent-dir", + headerValue: strptr(".."), + errTenantID: errInvalidTenantID, + errTenantIDs: errInvalidTenantID, + }, + { + name: "current-dir", + headerValue: strptr("."), + errTenantID: errInvalidTenantID, + errTenantIDs: errInvalidTenantID, + }, +} + +func TestSingleResolver(t *testing.T) { + r := NewSingleResolver() + for _, tc := range append(commonResolverTestCases, []resolverTestCase{ + { + name: "multi-tenant", + headerValue: strptr("tenant-a|tenant-b"), + tenantID: "tenant-a|tenant-b", + tenantIDs: []string{"tenant-a|tenant-b"}, + }, + { + name: "containing-forward-slash", + headerValue: strptr("forward/slash"), + errTenantID: errInvalidTenantID, + errTenantIDs: errInvalidTenantID, + }, + { + name: "containing-backward-slash", + headerValue: strptr(`backward\slash`), + errTenantID: errInvalidTenantID, + errTenantIDs: errInvalidTenantID, + }, + }...) { + t.Run(tc.name, tc.test(r)) + } +} + +func TestMultiResolver(t *testing.T) { + r := NewMultiResolver() + for _, tc := range append(commonResolverTestCases, []resolverTestCase{ + { + name: "multi-tenant", + headerValue: strptr("tenant-a|tenant-b"), + errTenantID: user.ErrTooManyOrgIDs, + tenantIDs: []string{"tenant-a", "tenant-b"}, + }, + { + name: "multi-tenant-wrong-order", + headerValue: strptr("tenant-b|tenant-a"), + errTenantID: user.ErrTooManyOrgIDs, + tenantIDs: []string{"tenant-a", "tenant-b"}, + }, + { + name: "multi-tenant-duplicate-order", + headerValue: strptr("tenant-b|tenant-b|tenant-a"), + errTenantID: user.ErrTooManyOrgIDs, + tenantIDs: []string{"tenant-a", "tenant-b"}, + }, + { + name: "multi-tenant-with-relative-path", + headerValue: strptr("tenant-a|tenant-b|.."), + errTenantID: errInvalidTenantID, + errTenantIDs: errInvalidTenantID, + }, + { + name: "containing-forward-slash", + headerValue: strptr("forward/slash"), + errTenantID: &errTenantIDUnsupportedCharacter{pos: 7, tenantID: "forward/slash"}, + errTenantIDs: &errTenantIDUnsupportedCharacter{pos: 7, tenantID: "forward/slash"}, + }, + { + name: "containing-backward-slash", + headerValue: strptr(`backward\slash`), + errTenantID: &errTenantIDUnsupportedCharacter{pos: 8, tenantID: "backward\\slash"}, + errTenantIDs: &errTenantIDUnsupportedCharacter{pos: 8, tenantID: "backward\\slash"}, + }, + }...) { + t.Run(tc.name, tc.test(r)) + } +} diff --git a/tenant/tenant.go b/tenant/tenant.go new file mode 100644 index 000000000..99c1cc4a7 --- /dev/null +++ b/tenant/tenant.go @@ -0,0 +1,105 @@ +package tenant + +import ( + "context" + "errors" + "fmt" + "sort" + "strings" + + "github.com/weaveworks/common/user" +) + +var ( + errTenantIDTooLong = errors.New("tenant ID is too long: max 150 characters") +) + +type errTenantIDUnsupportedCharacter struct { + pos int + tenantID string +} + +func (e *errTenantIDUnsupportedCharacter) Error() string { + return fmt.Sprintf( + "tenant ID '%s' contains unsupported character '%c'", + e.tenantID, + e.tenantID[e.pos], + ) +} + +const tenantIDsLabelSeparator = "|" + +// NormalizeTenantIDs is creating a normalized form by sortiing and de-duplicating the list of tenantIDs +func NormalizeTenantIDs(tenantIDs []string) []string { + sort.Strings(tenantIDs) + + count := len(tenantIDs) + if count <= 1 { + return tenantIDs + } + + posOut := 1 + for posIn := 1; posIn < count; posIn++ { + if tenantIDs[posIn] != tenantIDs[posIn-1] { + tenantIDs[posOut] = tenantIDs[posIn] + posOut++ + } + } + + return tenantIDs[0:posOut] +} + +// ValidTenantID +func ValidTenantID(s string) error { + // check if it contains invalid runes + for pos, r := range s { + if !isSupported(r) { + return &errTenantIDUnsupportedCharacter{ + tenantID: s, + pos: pos, + } + } + } + + if len(s) > 150 { + return errTenantIDTooLong + } + + return nil +} + +func JoinTenantIDs(tenantIDs []string) string { + return strings.Join(tenantIDs, tenantIDsLabelSeparator) +} + +// this checks if a rune is supported in tenant IDs (according to +// https://cortexmetrics.io/docs/guides/limitations/#tenant-id-naming) +func isSupported(c rune) bool { + // characters + if ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') { + return true + } + + // digits + if '0' <= c && c <= '9' { + return true + } + + // special + return c == '!' || + c == '-' || + c == '_' || + c == '.' || + c == '*' || + c == '\'' || + c == '(' || + c == ')' +} + +// TenantIDsFromOrgID extracts different tenants from an orgID string value +// +// ignore stutter warning +//nolint:revive +func TenantIDsFromOrgID(orgID string) ([]string, error) { + return TenantIDs(user.InjectOrgID(context.TODO(), orgID)) +} diff --git a/tenant/tenant_test.go b/tenant/tenant_test.go new file mode 100644 index 000000000..b242fd77f --- /dev/null +++ b/tenant/tenant_test.go @@ -0,0 +1,42 @@ +package tenant + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidTenantIDs(t *testing.T) { + for _, tc := range []struct { + name string + err *string + }{ + { + name: "tenant-a", + }, + { + name: "ABCDEFGHIJKLMNOPQRSTUVWXYZ-abcdefghijklmnopqrstuvwxyz_0987654321!.*'()", + }, + { + name: "invalid|", + err: strptr("tenant ID 'invalid|' contains unsupported character '|'"), + }, + { + name: strings.Repeat("a", 150), + }, + { + name: strings.Repeat("a", 151), + err: strptr("tenant ID is too long: max 150 characters"), + }, + } { + t.Run(tc.name, func(t *testing.T) { + err := ValidTenantID(tc.name) + if tc.err == nil { + assert.Nil(t, err) + } else { + assert.EqualError(t, err, *tc.err) + } + }) + } +}