From 2fd8f75e38eebf5c6826ef594433e5165c3bfbe1 Mon Sep 17 00:00:00 2001 From: Joel Date: Wed, 23 Aug 2023 22:01:56 +0200 Subject: [PATCH] [Assist] Generalize embeddings system (#30835) * generalize embeddings * add test for mapping * use existing mapper * add streaming apis * update embedding processor tests * delete unused code * fix tests * Update lib/ai/embeddingprocessor.go Co-authored-by: Zac Bergquist * Update lib/ai/embeddingprocessor.go Co-authored-by: Zac Bergquist * Update lib/ai/embeddingprocessor.go Co-authored-by: Zac Bergquist * fix feedback --------- Co-authored-by: Zac Bergquist --- lib/ai/embedding/serialization.go | 106 +++++++++++ lib/ai/embeddingprocessor.go | 85 +++++---- lib/ai/embeddingprocessor_test.go | 90 +++++++--- lib/auth/assist/assistv1/service.go | 169 +++++++++++++++--- lib/auth/assist/assistv1/test/service_test.go | 22 ++- lib/service/service.go | 2 +- lib/services/embeddings.go | 2 + lib/services/local/embeddings.go | 14 ++ 8 files changed, 402 insertions(+), 88 deletions(-) diff --git a/lib/ai/embedding/serialization.go b/lib/ai/embedding/serialization.go index ecf33ed7cf47..7fc78f529642 100644 --- a/lib/ai/embedding/serialization.go +++ b/lib/ai/embedding/serialization.go @@ -23,6 +23,26 @@ import ( "github.com/gravitational/teleport/api/types" ) +// SerializeNode converts a serializable resource into text ready to be fed to an +// embedding model. The YAML serialization function was chosen over JSON and +// CSV as it provided better results. +func SerializeResource(resource types.Resource) ([]byte, error) { + switch resource.GetKind() { + case types.KindNode: + return SerializeNode(resource.(types.Server)) + case types.KindKubernetesCluster: + return SerializeKubeCluster(resource.(types.KubeCluster)) + case types.KindApp: + return SerializeApp(resource.(types.Application)) + case types.KindDatabase: + return SerializeDatabase(resource.(types.Database)) + case types.KindWindowsDesktop: + return SerializeWindowsDesktop(resource.(types.WindowsDesktop)) + default: + return nil, trace.BadParameter("unknown resource kind %q", resource.GetKind()) + } +} + // SerializeNode converts a type.Server into text ready to be fed to an // embedding model. The YAML serialization function was chosen over JSON and // CSV as it provided better results. @@ -42,3 +62,89 @@ func SerializeNode(node types.Server) ([]byte, error) { text, err := yaml.Marshal(&a) return text, trace.Wrap(err) } + +// SerializeKubeCluster converts a type.KubeCluster into text ready to be fed to an +// embedding model. The YAML serialization function was chosen over JSON and +// CSV as it provided better results. +func SerializeKubeCluster(cluster types.KubeCluster) ([]byte, error) { + a := struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + SubKind string `yaml:"subkind"` + Labels map[string]string `yaml:"labels"` + }{ + Name: cluster.GetName(), + Kind: types.KindKubernetesCluster, + SubKind: cluster.GetSubKind(), + Labels: cluster.GetAllLabels(), + } + text, err := yaml.Marshal(&a) + return text, trace.Wrap(err) +} + +// SerializeApp converts a type.Application into text ready to be fed to an +// embedding model. The YAML serialization function was chosen over JSON and +// CSV as it provided better results. +func SerializeApp(app types.Application) ([]byte, error) { + a := struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + SubKind string `yaml:"subkind"` + Labels map[string]string `yaml:"labels"` + Description string `yaml:"description"` + }{ + Name: app.GetName(), + Kind: types.KindApp, + SubKind: app.GetSubKind(), + Labels: app.GetAllLabels(), + Description: app.GetDescription(), + } + text, err := yaml.Marshal(&a) + return text, trace.Wrap(err) +} + +// SerializeDatabase converts a type.Database into text ready to be fed to an +// embedding model. The YAML serialization function was chosen over JSON and +// CSV as it provided better results. +func SerializeDatabase(db types.Database) ([]byte, error) { + a := struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + SubKind string `yaml:"subkind"` + Labels map[string]string `yaml:"labels"` + Type string `yaml:"type"` + Description string `yaml:"description"` + }{ + Name: db.GetName(), + Kind: types.KindDatabase, + SubKind: db.GetSubKind(), + Labels: db.GetAllLabels(), + Type: db.GetType(), + Description: db.GetDescription(), + } + text, err := yaml.Marshal(&a) + return text, trace.Wrap(err) +} + +// SerializeWindowsDesktop converts a type.WindowsDesktop into text ready to be fed to an +// embedding model. The YAML serialization function was chosen over JSON and +// CSV as it provided better results. +func SerializeWindowsDesktop(desktop types.WindowsDesktop) ([]byte, error) { + a := struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + SubKind string `yaml:"subkind"` + Labels map[string]string `yaml:"labels"` + Address string `yaml:"address"` + ADDomain string `yaml:"ad_domain"` + }{ + Name: desktop.GetName(), + Kind: types.KindKubernetesCluster, + SubKind: desktop.GetSubKind(), + Labels: desktop.GetAllLabels(), + Address: desktop.GetAddr(), + ADDomain: desktop.GetDomain(), + } + text, err := yaml.Marshal(&a) + return text, trace.Wrap(err) +} diff --git a/lib/ai/embeddingprocessor.go b/lib/ai/embeddingprocessor.go index e02da4d7f93f..a3cf187c3e68 100644 --- a/lib/ai/embeddingprocessor.go +++ b/lib/ai/embeddingprocessor.go @@ -25,12 +25,12 @@ import ( "github.com/sirupsen/logrus" "google.golang.org/protobuf/proto" - "github.com/gravitational/teleport/api/defaults" embeddingpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/embedding/v1" "github.com/gravitational/teleport/api/internalutils/stream" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/retryutils" embeddinglib "github.com/gravitational/teleport/lib/ai/embedding" + "github.com/gravitational/teleport/lib/services" streamutils "github.com/gravitational/teleport/lib/utils/stream" ) @@ -39,18 +39,13 @@ const maxEmbeddingAPISize = 1000 // Embeddings implements the minimal interface used by the Embedding processor. type Embeddings interface { - // GetEmbeddings returns all embeddings for a given kind. - GetEmbeddings(ctx context.Context, kind string) stream.Stream[*embeddinglib.Embedding] + // GetAllEmbeddings returns all embeddings. + GetAllEmbeddings(ctx context.Context) stream.Stream[*embeddinglib.Embedding] + // UpsertEmbedding creates or update a single ai.Embedding in the backend. UpsertEmbedding(ctx context.Context, embedding *embeddinglib.Embedding) (*embeddinglib.Embedding, error) } -// NodesStreamGetter is a service that gets nodes. -type NodesStreamGetter interface { - // GetNodeStream returns a list of registered servers. - GetNodeStream(ctx context.Context, namespace string) stream.Stream[types.Server] -} - // MarshalEmbedding marshals the ai.Embedding resource to binary ProtoBuf. func MarshalEmbedding(embedding *embeddinglib.Embedding) ([]byte, error) { data, err := proto.Marshal((*embeddingpb.Embedding)(embedding)) @@ -132,7 +127,7 @@ type EmbeddingProcessorConfig struct { AIClient embeddinglib.Embedder EmbeddingSrv Embeddings EmbeddingsRetriever *SimpleRetriever - NodeSrv NodesStreamGetter + NodeSrv *services.UnifiedResourceCache Log logrus.FieldLogger Jitter retryutils.Jitter } @@ -143,7 +138,7 @@ type EmbeddingProcessor struct { aiClient embeddinglib.Embedder embeddingSrv Embeddings embeddingsRetriever *SimpleRetriever - nodeSrv NodesStreamGetter + nodeSrv *services.UnifiedResourceCache log logrus.FieldLogger jitter retryutils.Jitter } @@ -160,15 +155,15 @@ func NewEmbeddingProcessor(cfg *EmbeddingProcessorConfig) *EmbeddingProcessor { } } -// nodeStringPair is a helper struct that pairs a node with a data string. -type nodeStringPair struct { - node types.Server - data string +// resourceStringPair is a helper struct that pairs a resource with a data string. +type resourceStringPair struct { + resource types.Resource + data string } -// mapProcessFn is a helper function that maps a slice of nodeStringPair, -// compute embeddings and return them as a slice of ai.Embedding. -func (e *EmbeddingProcessor) mapProcessFn(ctx context.Context, data []*nodeStringPair) ([]*embeddinglib.Embedding, error) { +// mapProcessFn is a helper function that maps a slice of resourceStringPair, +// compute embeddings and return them as a slice. +func (e *EmbeddingProcessor) mapProcessFn(ctx context.Context, data []*resourceStringPair) ([]*embeddinglib.Embedding, error) { dataBatch := make([]string, 0, len(data)) for _, pair := range data { dataBatch = append(dataBatch, pair.data) @@ -181,8 +176,8 @@ func (e *EmbeddingProcessor) mapProcessFn(ctx context.Context, data []*nodeStrin results := make([]*embeddinglib.Embedding, 0, len(embeddings)) for i, embedding := range embeddings { - emb := embeddinglib.NewEmbedding(types.KindNode, - data[i].node.GetName(), embedding, + emb := embeddinglib.NewEmbedding(data[i].resource.GetKind(), + data[i].resource.GetName(), embedding, embeddinglib.EmbeddingHash([]byte(data[i].data)), ) results = append(results, emb) @@ -208,7 +203,7 @@ func (e *EmbeddingProcessor) Run(ctx context.Context, initialDelay, period time. } } -// process updates embeddings for all nodes once. +// process updates embeddings for all resources once. func (e *EmbeddingProcessor) process(ctx context.Context) { batch := NewBatchReducer(e.mapProcessFn, maxEmbeddingAPISize, // Max batch size allowed by OpenAI API, @@ -217,19 +212,31 @@ func (e *EmbeddingProcessor) process(ctx context.Context) { e.log.Debugf("embedding processor started") defer e.log.Debugf("embedding processor finished") - embeddingsStream := e.embeddingSrv.GetEmbeddings(ctx, types.KindNode) - nodesStream := e.nodeSrv.GetNodeStream(ctx, defaults.Namespace) + embeddingsStream := e.embeddingSrv.GetAllEmbeddings(ctx) + unifiedResources, err := e.nodeSrv.GetUnifiedResources(ctx) + if err != nil { + e.log.Debugf("embedding processor failed with error: %v", err) + return + } + + resources := make([]types.Resource, len(unifiedResources)) + for i, unifiedResource := range unifiedResources { + resources[i] = unifiedResource + unifiedResources[i] = nil + } + + resourceStream := stream.Slice(resources) s := streamutils.NewZipStreams( - nodesStream, + resourceStream, embeddingsStream, - // On new node callback. Add the node to the batch. - func(node types.Server) error { - nodeData, err := embeddinglib.SerializeNode(node) + // On new resource callback. Add the resource to the batch. + func(resource types.Resource) error { + resourceData, err := embeddinglib.SerializeResource(resource) if err != nil { return trace.Wrap(err) } - vectors, err := batch.Add(ctx, &nodeStringPair{node, string(nodeData)}) + vectors, err := batch.Add(ctx, &resourceStringPair{resource, string(resourceData)}) if err != nil { return trace.Wrap(err) } @@ -239,17 +246,17 @@ func (e *EmbeddingProcessor) process(ctx context.Context) { return nil }, - // On equal node callback. Check if the node's embedding hash matches - // the one in the backend. If not, add the node to the batch. - func(node types.Server, embedding *embeddinglib.Embedding) error { - nodeData, err := embeddinglib.SerializeNode(node) + // On equal resource callback. Check if the resource's embedding hash matches + // the one in the backend. If not, add the resource to the batch. + func(resource types.Resource, embedding *embeddinglib.Embedding) error { + resourceData, err := embeddinglib.SerializeResource(resource) if err != nil { return trace.Wrap(err) } - nodeHash := embeddinglib.EmbeddingHash(nodeData) + resourceHash := embeddinglib.EmbeddingHash(resourceData) - if !EmbeddingHashMatches(embedding, nodeHash) { - vectors, err := batch.Add(ctx, &nodeStringPair{node, string(nodeData)}) + if !EmbeddingHashMatches(embedding, resourceHash) { + vectors, err := batch.Add(ctx, &resourceStringPair{resource, string(resourceData)}) if err != nil { return trace.Wrap(err) } @@ -260,8 +267,8 @@ func (e *EmbeddingProcessor) process(ctx context.Context) { return nil }, // On compare keys callback. Compare the keys for iteration. - func(node types.Server, embeddings *embeddinglib.Embedding) int { - return strings.Compare(node.GetName(), embeddings.GetEmbeddedID()) + func(resource types.Resource, embeddings *embeddinglib.Embedding) int { + return strings.Compare(resource.GetName(), embeddings.GetEmbeddedID()) }, ) @@ -269,7 +276,7 @@ func (e *EmbeddingProcessor) process(ctx context.Context) { e.log.Warnf("Failed to generate nodes embedding: %v", err) } - // Process the remaining nodes in the batch + // Process the remaining resources in the batch vectors, err := batch.Finalize(ctx) if err != nil { e.log.Warnf("Failed to add node to batch: %v", err) @@ -290,7 +297,7 @@ func (e *EmbeddingProcessor) process(ctx context.Context) { // latest embeddings. The new index is created and then swapped with the old one. func (e *EmbeddingProcessor) updateMemIndex(ctx context.Context) error { embeddingsIndex := NewSimpleRetriever() - embeddingsStream := e.embeddingSrv.GetEmbeddings(ctx, types.KindNode) + embeddingsStream := e.embeddingSrv.GetAllEmbeddings(ctx) for embeddingsStream.Next() { embedding := embeddingsStream.Item() diff --git a/lib/ai/embeddingprocessor_test.go b/lib/ai/embeddingprocessor_test.go index 49b143c6877d..26ecc5bc058e 100644 --- a/lib/ai/embeddingprocessor_test.go +++ b/lib/ai/embeddingprocessor_test.go @@ -37,6 +37,7 @@ import ( "github.com/gravitational/teleport/lib/ai" "github.com/gravitational/teleport/lib/ai/embedding" "github.com/gravitational/teleport/lib/backend/memory" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" "github.com/gravitational/teleport/lib/utils" ) @@ -63,33 +64,59 @@ func (m *MockEmbedder) ComputeEmbeddings(_ context.Context, input []string) ([]e return result, nil } -type mockNodeStreamer struct { - mu sync.Mutex - nodes []types.Server +type mockResourceGetter struct { + mu sync.Mutex + nodes []types.Server + presence *local.PresenceService } -func (m *mockNodeStreamer) UpsertNode(_ context.Context, node types.Server) (*types.KeepAlive, error) { +func (m *mockResourceGetter) UpsertNode(ctx context.Context, node types.Server) (*types.KeepAlive, error) { m.mu.Lock() defer m.mu.Unlock() + found := false for i, n := range m.nodes { // update if n.GetName() == node.GetName() { + found = true m.nodes[i] = node - return nil, nil + break } } // insert - m.nodes = append(m.nodes, node) - return nil, nil + if !found { + m.nodes = append(m.nodes, node) + } + + return m.presence.UpsertNode(ctx, node) } -func (m *mockNodeStreamer) GetNodeStream(_ context.Context, _ string) stream.Stream[types.Server] { +func (m *mockResourceGetter) GetNodes(_ context.Context, _ string) ([]types.Server, error) { m.mu.Lock() defer m.mu.Unlock() - nodes := make([]types.Server, 0, len(m.nodes)) - nodes = append(nodes, m.nodes...) - return stream.Slice(nodes) + d := make([]types.Server, len(m.nodes)) + copy(d, m.nodes) + return d, nil +} + +func (m *mockResourceGetter) GetDatabaseServers(_ context.Context, _ string, _ ...services.MarshalOption) ([]types.DatabaseServer, error) { + return nil, nil +} + +func (m *mockResourceGetter) GetKubernetesServers(_ context.Context) ([]types.KubeServer, error) { + return nil, nil +} + +func (m *mockResourceGetter) GetApplicationServers(_ context.Context, _ string) ([]types.AppServer, error) { + return nil, nil +} + +func (m *mockResourceGetter) GetWindowsDesktops(_ context.Context, _ types.WindowsDesktopFilter) ([]types.WindowsDesktop, error) { + return nil, nil +} + +func (m *mockResourceGetter) ListSAMLIdPServiceProviders(_ context.Context, _ int, _ string) ([]types.SAMLIdPServiceProvider, string, error) { + return nil, "", nil } func TestNodeEmbeddingGeneration(t *testing.T) { @@ -111,14 +138,26 @@ func TestNodeEmbeddingGeneration(t *testing.T) { embedder := MockEmbedder{ timesCalled: make(map[string]int), } - presence := &mockNodeStreamer{} + events := local.NewEventsService(bk) + presence := &mockResourceGetter{ + presence: local.NewPresenceService(bk), + } + cache, err := services.NewUnifiedResourceCache(ctx, services.UnifiedResourceCacheConfig{ + ResourceGetter: &mockResourceGetter{}, + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "resource-watcher", + Client: events, + }, + }) + require.NoError(t, err) + embeddings := local.NewEmbeddingsService(bk) processor := ai.NewEmbeddingProcessor(&ai.EmbeddingProcessorConfig{ AIClient: &embedder, EmbeddingSrv: embeddings, EmbeddingsRetriever: ai.NewSimpleRetriever(), - NodeSrv: presence, + NodeSrv: cache, Log: utils.NewLoggerForTests(), Jitter: retryutils.NewSeventhJitter(), }) @@ -139,14 +178,17 @@ func TestNodeEmbeddingGeneration(t *testing.T) { } require.Eventually(t, func() bool { - items, err := stream.Collect(embeddings.GetEmbeddings(ctx, types.KindNode)) + items, err := stream.Collect(embeddings.GetAllEmbeddings(ctx)) assert.NoError(t, err) return len(items) == numInitialNodes - }, 7*time.Second, 200*time.Millisecond) + }, 14*time.Second, 200*time.Millisecond) + + nodesAcquired, err := presence.GetNodes(ctx, defaults.Namespace) + require.NoError(t, err) validateEmbeddings(t, - presence.GetNodeStream(ctx, defaults.Namespace), - embeddings.GetEmbeddings(ctx, types.KindNode)) + nodesAcquired, + embeddings.GetAllEmbeddings(ctx)) for k, v := range embedder.timesCalled { require.Equal(t, 1, v, "expected %v to be computed once, was %d", k, v) @@ -164,7 +206,7 @@ func TestNodeEmbeddingGeneration(t *testing.T) { // Since nodes are streamed in ascending order by names, when embeddings for node6 are calculated, // we can be sure that our recent changes have been fully processed require.Eventually(t, func() bool { - items, err := stream.Collect(embeddings.GetEmbeddings(ctx, types.KindNode)) + items, err := stream.Collect(embeddings.GetAllEmbeddings(ctx)) assert.NoError(t, err) return len(items) == numInitialNodes+1 }, 7*time.Second, 200*time.Millisecond) @@ -177,9 +219,12 @@ func TestNodeEmbeddingGeneration(t *testing.T) { require.Equal(t, expected, v, "expected embedding for %q to be computed %d times, got computed %d times", k, expected, v) } + nodesAcquired, err = presence.GetNodes(ctx, defaults.Namespace) + require.NoError(t, err) + validateEmbeddings(t, - presence.GetNodeStream(ctx, defaults.Namespace), - embeddings.GetEmbeddings(ctx, types.KindNode)) + nodesAcquired, + embeddings.GetAllEmbeddings(ctx)) } func TestMarshallUnmarshallEmbedding(t *testing.T) { @@ -210,12 +255,9 @@ func makeNode(num int) types.Server { return node } -func validateEmbeddings(t *testing.T, nodesStream stream.Stream[types.Server], embeddingsStream stream.Stream[*embedding.Embedding]) { +func validateEmbeddings(t *testing.T, nodes []types.Server, embeddingsStream stream.Stream[*embedding.Embedding]) { t.Helper() - nodes, err := stream.Collect(nodesStream) - require.NoError(t, err) - embeddings, err := stream.Collect(embeddingsStream) require.NoError(t, err) diff --git a/lib/auth/assist/assistv1/service.go b/lib/auth/assist/assistv1/service.go index 6fcd88067441..335d89dcf32c 100644 --- a/lib/auth/assist/assistv1/service.go +++ b/lib/auth/assist/assistv1/service.go @@ -49,6 +49,10 @@ type ServiceConfig struct { // Created to avoid circular dependencies. type ResourceGetter interface { GetNode(ctx context.Context, namespace, name string) (types.Server, error) + GetKubernetesCluster(ctx context.Context, name string) (types.KubeCluster, error) + GetApp(ctx context.Context, name string) (types.Application, error) + GetDatabase(ctx context.Context, name string) (types.Database, error) + GetWindowsDesktops(ctx context.Context, filter types.WindowsDesktopFilter) ([]types.WindowsDesktop, error) } // Service implements the teleport.assist.v1.AssistService RPC service. @@ -220,8 +224,13 @@ func (a *Service) IsAssistEnabled(ctx context.Context, _ *assist.IsAssistEnabled } func (a *Service) GetAssistantEmbeddings(ctx context.Context, msg *assist.GetAssistantEmbeddingsRequest) (*assist.GetAssistantEmbeddingsResponse, error) { - // TODO(jakule): The kind needs to be updated when we add more resources. - authCtx, err := authz.AuthorizeWithVerbs(ctx, a.log, a.authorizer, true, types.KindNode, types.VerbRead, types.VerbList) + switch msg.Kind { + case types.KindNode, types.KindKubernetesCluster, types.KindApp, types.KindDatabase, types.KindWindowsDesktop: + default: + return nil, trace.BadParameter("resource kind %v is not supported", msg.Kind) + } + + authCtx, err := authz.AuthorizeWithVerbs(ctx, a.log, a.authorizer, true, msg.Kind, types.VerbRead, types.VerbList) if err != nil { return nil, authz.ConvertAuthorizerError(ctx, a.log, err) } @@ -241,27 +250,89 @@ func (a *Service) GetAssistantEmbeddings(ctx context.Context, msg *assist.GetAss // Use default values for the id and content, as we only care about the embeddings. queryEmbeddings := embeddinglib.NewEmbedding(msg.Kind, "", embeddings[0], [32]byte{}) - documents := a.embeddings.GetRelevant(queryEmbeddings, int(msg.Limit), func(id string, embedding *embeddinglib.Embedding) bool { - // Run RBAC check on the embedded resource. - node, err := a.resourceGetter.GetNode(ctx, defaults.Namespace, embedding.GetEmbeddedID()) - if err != nil { - a.log.Tracef("failed to get node %q: %v", embedding.GetName(), err) - return false - } - return authCtx.Checker.CheckAccess(node, services.AccessState{MFAVerified: true}) == nil - }) + accessChecker := accessCheckerForKind(ctx, a, authCtx, msg.Kind) + documents := a.embeddings.GetRelevant(queryEmbeddings, int(msg.Limit), accessChecker) + return assembleEmbeddingResponseForKind(ctx, a, msg.Kind, documents) +} +// userHasAccess returns true if the user should have access to the resource. +func userHasAccess(authCtx *authz.Context, req interface{ GetUsername() string }) bool { + return !authz.IsCurrentUser(*authCtx, req.GetUsername()) && !authz.HasBuiltinRole(*authCtx, string(types.RoleAdmin)) +} + +func assembleEmbeddingResponseForKind(ctx context.Context, a *Service, kind string, documents []*ai.Document) (*assist.GetAssistantEmbeddingsResponse, error) { protoDocs := make([]*assist.EmbeddedDocument, 0, len(documents)) + for _, doc := range documents { - node, err := a.resourceGetter.GetNode(ctx, defaults.Namespace, doc.GetEmbeddedID()) - if err != nil { - return nil, trace.Wrap(err) + var content []byte + + switch kind { + case types.KindNode: + node, err := a.resourceGetter.GetNode(ctx, defaults.Namespace, doc.GetEmbeddedID()) + if err != nil { + return nil, trace.Wrap(err) + } + + content, err = embeddinglib.SerializeNode(node) + if err != nil { + return nil, trace.Wrap(err) + } + case types.KindKubernetesCluster: + cluster, err := a.resourceGetter.GetKubernetesCluster(ctx, doc.GetEmbeddedID()) + if err != nil { + return nil, trace.Wrap(err) + } + + content, err = embeddinglib.SerializeKubeCluster(cluster) + if err != nil { + return nil, trace.Wrap(err) + } + case types.KindApp: + app, err := a.resourceGetter.GetApp(ctx, doc.GetEmbeddedID()) + if err != nil { + return nil, trace.Wrap(err) + } + + content, err = embeddinglib.SerializeApp(app) + if err != nil { + return nil, trace.Wrap(err) + } + case types.KindDatabase: + db, err := a.resourceGetter.GetDatabase(ctx, doc.GetEmbeddedID()) + if err != nil { + return nil, trace.Wrap(err) + } + + content, err = embeddinglib.SerializeDatabase(db) + if err != nil { + return nil, trace.Wrap(err) + } + case types.KindWindowsDesktop: + desktops, err := a.resourceGetter.GetWindowsDesktops(ctx, types.WindowsDesktopFilter{ + Name: doc.GetEmbeddedID(), + }) + if err != nil { + return nil, trace.Wrap(err) + } + + var desktop types.WindowsDesktop + for _, d := range desktops { + if d.GetName() == doc.GetEmbeddedID() { + desktop = d + break + } + } + + if desktop == nil { + return nil, trace.NotFound("windows desktop %q not found", doc.GetEmbeddedID()) + } + + content, err = embeddinglib.SerializeWindowsDesktop(desktop) + if err != nil { + return nil, trace.Wrap(err) + } } - content, err := embeddinglib.SerializeNode(node) - if err != nil { - return nil, trace.Wrap(err) - } protoDocs = append(protoDocs, &assist.EmbeddedDocument{ Id: doc.GetEmbeddedID(), Content: string(content), @@ -274,7 +345,63 @@ func (a *Service) GetAssistantEmbeddings(ctx context.Context, msg *assist.GetAss }, nil } -// userHasAccess returns true if the user should have access to the resource. -func userHasAccess(authCtx *authz.Context, req interface{ GetUsername() string }) bool { - return !authz.IsCurrentUser(*authCtx, req.GetUsername()) && !authz.HasBuiltinRole(*authCtx, string(types.RoleAdmin)) +func accessCheckerForKind(ctx context.Context, a *Service, authCtx *authz.Context, kind string) func(id string, embedding *embeddinglib.Embedding) bool { + return func(id string, embedding *embeddinglib.Embedding) bool { + if embedding.EmbeddedKind != kind { + return false + } + + var resource services.AccessCheckable + var err error + + switch kind { + case types.KindNode: + resource, err = a.resourceGetter.GetNode(ctx, defaults.Namespace, embedding.GetEmbeddedID()) + if err != nil { + a.log.Tracef("failed to get node %q: %v", embedding.GetName(), err) + return false + } + + case types.KindKubernetesCluster: + resource, err = a.resourceGetter.GetKubernetesCluster(ctx, embedding.GetEmbeddedID()) + if err != nil { + a.log.Tracef("failed to get kube cluster %q: %v", embedding.GetName(), err) + return false + } + case types.KindApp: + resource, err = a.resourceGetter.GetApp(ctx, embedding.GetEmbeddedID()) + if err != nil { + a.log.Tracef("failed to get app %q: %v", embedding.GetName(), err) + return false + } + case types.KindDatabase: + resource, err = a.resourceGetter.GetDatabase(ctx, embedding.GetEmbeddedID()) + if err != nil { + a.log.Tracef("failed to get database %q: %v", embedding.GetName(), err) + return false + } + case types.KindWindowsDesktop: + desktops, err := a.resourceGetter.GetWindowsDesktops(ctx, types.WindowsDesktopFilter{ + Name: embedding.GetEmbeddedID(), + }) + if err != nil { + a.log.Tracef("failed to get windows desktop %q: %v", embedding.GetName(), err) + return false + } + + for _, d := range desktops { + if d.GetName() == embedding.GetEmbeddedID() { + resource = d + break + } + } + + if resource == nil { + a.log.Tracef("failed to find windows desktop %q: %v", embedding.GetName(), err) + return false + } + } + + return authCtx.Checker.CheckAccess(resource, services.AccessState{MFAVerified: true}) == nil + } } diff --git a/lib/auth/assist/assistv1/test/service_test.go b/lib/auth/assist/assistv1/test/service_test.go index f6d94f57d383..3fda580cd204 100644 --- a/lib/auth/assist/assistv1/test/service_test.go +++ b/lib/auth/assist/assistv1/test/service_test.go @@ -362,16 +362,32 @@ func initSvc(t *testing.T) (map[string]context.Context, *assistv1.Service) { Backend: local.NewAssistService(backend), Authorizer: authorizer, Embeddings: &ai.SimpleRetriever{}, - ResourceGetter: &nodeGetterFake{}, + ResourceGetter: &resourceGetterFake{}, }) require.NoError(t, err) return ctxs, svc } -type nodeGetterFake struct { +type resourceGetterFake struct { } -func (g *nodeGetterFake) GetNode(ctx context.Context, namespace, name string) (types.Server, error) { +func (g *resourceGetterFake) GetNode(ctx context.Context, namespace, name string) (types.Server, error) { + return nil, nil +} + +func (g *resourceGetterFake) GetKubernetesCluster(ctx context.Context, name string) (types.KubeCluster, error) { + return nil, nil +} + +func (g *resourceGetterFake) GetApp(ctx context.Context, name string) (types.Application, error) { + return nil, nil +} + +func (g *resourceGetterFake) GetDatabase(ctx context.Context, name string) (types.Database, error) { + return nil, nil +} + +func (g *resourceGetterFake) GetWindowsDesktops(ctx context.Context, _ types.WindowsDesktopFilter) ([]types.WindowsDesktop, error) { return nil, nil } diff --git a/lib/service/service.go b/lib/service/service.go index 2f612a998e9e..f35b6dc44fdf 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -1784,7 +1784,7 @@ func (process *TeleportProcess) initAuthService() error { AIClient: embedderClient, EmbeddingsRetriever: embeddingsRetriever, EmbeddingSrv: authServer, - NodeSrv: authServer, + NodeSrv: authServer.UnifiedResourceCache, Log: log, Jitter: retryutils.NewFullJitter(), }) diff --git a/lib/services/embeddings.go b/lib/services/embeddings.go index f4b8bf0dbaf9..8c9136cea3d1 100644 --- a/lib/services/embeddings.go +++ b/lib/services/embeddings.go @@ -31,6 +31,8 @@ type Embeddings interface { GetEmbedding(ctx context.Context, kind, resourceID string) (*embedding.Embedding, error) // GetEmbeddings returns all embeddings for a given kind. GetEmbeddings(ctx context.Context, kind string) stream.Stream[*embedding.Embedding] + // GetEmbeddings returns all embeddings. + GetAllEmbeddings(ctx context.Context) stream.Stream[*embedding.Embedding] // UpsertEmbedding creates or updates a single ai.Embedding in the backend. UpsertEmbedding(ctx context.Context, embedding *embedding.Embedding) (*embedding.Embedding, error) } diff --git a/lib/services/local/embeddings.go b/lib/services/local/embeddings.go index 613c99fbb089..680d300b5c61 100644 --- a/lib/services/local/embeddings.go +++ b/lib/services/local/embeddings.go @@ -51,6 +51,20 @@ func (e EmbeddingsService) GetEmbedding(ctx context.Context, kind, resourceID st return ai.UnmarshalEmbedding(result.Value) } +// GetEmbeddings returns a stream of all embeddings +func (e EmbeddingsService) GetAllEmbeddings(ctx context.Context) stream.Stream[*embedding.Embedding] { + startKey := backend.ExactKey(embeddingsPrefix) + items := backend.StreamRange(ctx, e, startKey, backend.RangeEnd(startKey), 50) + return stream.FilterMap(items, func(item backend.Item) (*embedding.Embedding, bool) { + embedding, err := ai.UnmarshalEmbedding(item.Value) + if err != nil { + e.log.Warnf("Skipping embedding at %s, failed to unmarshal: %v", item.Key, err) + return nil, false + } + return embedding, true + }) +} + // GetEmbeddings returns a stream of embeddings for a given kind. func (e EmbeddingsService) GetEmbeddings(ctx context.Context, kind string) stream.Stream[*embedding.Embedding] { startKey := backend.ExactKey(embeddingsPrefix, kind)