Skip to content

Commit

Permalink
[Assist] Generalize embeddings system (#30835)
Browse files Browse the repository at this point in the history
* generalize embeddings

* add test for mapping

* use existing mapper

* add streaming apis

* update embedding processor tests

* delete unused code

* fix tests

* Update lib/ai/embeddingprocessor.go

Co-authored-by: Zac Bergquist <zac.bergquist@goteleport.com>

* Update lib/ai/embeddingprocessor.go

Co-authored-by: Zac Bergquist <zac.bergquist@goteleport.com>

* Update lib/ai/embeddingprocessor.go

Co-authored-by: Zac Bergquist <zac.bergquist@goteleport.com>

* fix feedback

---------

Co-authored-by: Zac Bergquist <zac.bergquist@goteleport.com>
  • Loading branch information
xacrimon and zmb3 authored Aug 23, 2023
1 parent c0a459b commit 2fd8f75
Show file tree
Hide file tree
Showing 8 changed files with 402 additions and 88 deletions.
106 changes: 106 additions & 0 deletions lib/ai/embedding/serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,26 @@ import (
"github.com/gravitational/teleport/api/types"
)

// SerializeNode converts a serializable resource into text ready to be fed to an
// embedding model. The YAML serialization function was chosen over JSON and
// CSV as it provided better results.
func SerializeResource(resource types.Resource) ([]byte, error) {
switch resource.GetKind() {
case types.KindNode:
return SerializeNode(resource.(types.Server))
case types.KindKubernetesCluster:
return SerializeKubeCluster(resource.(types.KubeCluster))
case types.KindApp:
return SerializeApp(resource.(types.Application))
case types.KindDatabase:
return SerializeDatabase(resource.(types.Database))
case types.KindWindowsDesktop:
return SerializeWindowsDesktop(resource.(types.WindowsDesktop))
default:
return nil, trace.BadParameter("unknown resource kind %q", resource.GetKind())
}
}

// SerializeNode converts a type.Server into text ready to be fed to an
// embedding model. The YAML serialization function was chosen over JSON and
// CSV as it provided better results.
Expand All @@ -42,3 +62,89 @@ func SerializeNode(node types.Server) ([]byte, error) {
text, err := yaml.Marshal(&a)
return text, trace.Wrap(err)
}

// SerializeKubeCluster converts a type.KubeCluster into text ready to be fed to an
// embedding model. The YAML serialization function was chosen over JSON and
// CSV as it provided better results.
func SerializeKubeCluster(cluster types.KubeCluster) ([]byte, error) {
a := struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
SubKind string `yaml:"subkind"`
Labels map[string]string `yaml:"labels"`
}{
Name: cluster.GetName(),
Kind: types.KindKubernetesCluster,
SubKind: cluster.GetSubKind(),
Labels: cluster.GetAllLabels(),
}
text, err := yaml.Marshal(&a)
return text, trace.Wrap(err)
}

// SerializeApp converts a type.Application into text ready to be fed to an
// embedding model. The YAML serialization function was chosen over JSON and
// CSV as it provided better results.
func SerializeApp(app types.Application) ([]byte, error) {
a := struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
SubKind string `yaml:"subkind"`
Labels map[string]string `yaml:"labels"`
Description string `yaml:"description"`
}{
Name: app.GetName(),
Kind: types.KindApp,
SubKind: app.GetSubKind(),
Labels: app.GetAllLabels(),
Description: app.GetDescription(),
}
text, err := yaml.Marshal(&a)
return text, trace.Wrap(err)
}

// SerializeDatabase converts a type.Database into text ready to be fed to an
// embedding model. The YAML serialization function was chosen over JSON and
// CSV as it provided better results.
func SerializeDatabase(db types.Database) ([]byte, error) {
a := struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
SubKind string `yaml:"subkind"`
Labels map[string]string `yaml:"labels"`
Type string `yaml:"type"`
Description string `yaml:"description"`
}{
Name: db.GetName(),
Kind: types.KindDatabase,
SubKind: db.GetSubKind(),
Labels: db.GetAllLabels(),
Type: db.GetType(),
Description: db.GetDescription(),
}
text, err := yaml.Marshal(&a)
return text, trace.Wrap(err)
}

// SerializeWindowsDesktop converts a type.WindowsDesktop into text ready to be fed to an
// embedding model. The YAML serialization function was chosen over JSON and
// CSV as it provided better results.
func SerializeWindowsDesktop(desktop types.WindowsDesktop) ([]byte, error) {
a := struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
SubKind string `yaml:"subkind"`
Labels map[string]string `yaml:"labels"`
Address string `yaml:"address"`
ADDomain string `yaml:"ad_domain"`
}{
Name: desktop.GetName(),
Kind: types.KindKubernetesCluster,
SubKind: desktop.GetSubKind(),
Labels: desktop.GetAllLabels(),
Address: desktop.GetAddr(),
ADDomain: desktop.GetDomain(),
}
text, err := yaml.Marshal(&a)
return text, trace.Wrap(err)
}
85 changes: 46 additions & 39 deletions lib/ai/embeddingprocessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ import (
"github.com/sirupsen/logrus"
"google.golang.org/protobuf/proto"

"github.com/gravitational/teleport/api/defaults"
embeddingpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/embedding/v1"
"github.com/gravitational/teleport/api/internalutils/stream"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/retryutils"
embeddinglib "github.com/gravitational/teleport/lib/ai/embedding"
"github.com/gravitational/teleport/lib/services"
streamutils "github.com/gravitational/teleport/lib/utils/stream"
)

Expand All @@ -39,18 +39,13 @@ const maxEmbeddingAPISize = 1000

// Embeddings implements the minimal interface used by the Embedding processor.
type Embeddings interface {
// GetEmbeddings returns all embeddings for a given kind.
GetEmbeddings(ctx context.Context, kind string) stream.Stream[*embeddinglib.Embedding]
// GetAllEmbeddings returns all embeddings.
GetAllEmbeddings(ctx context.Context) stream.Stream[*embeddinglib.Embedding]

// UpsertEmbedding creates or update a single ai.Embedding in the backend.
UpsertEmbedding(ctx context.Context, embedding *embeddinglib.Embedding) (*embeddinglib.Embedding, error)
}

// NodesStreamGetter is a service that gets nodes.
type NodesStreamGetter interface {
// GetNodeStream returns a list of registered servers.
GetNodeStream(ctx context.Context, namespace string) stream.Stream[types.Server]
}

// MarshalEmbedding marshals the ai.Embedding resource to binary ProtoBuf.
func MarshalEmbedding(embedding *embeddinglib.Embedding) ([]byte, error) {
data, err := proto.Marshal((*embeddingpb.Embedding)(embedding))
Expand Down Expand Up @@ -132,7 +127,7 @@ type EmbeddingProcessorConfig struct {
AIClient embeddinglib.Embedder
EmbeddingSrv Embeddings
EmbeddingsRetriever *SimpleRetriever
NodeSrv NodesStreamGetter
NodeSrv *services.UnifiedResourceCache
Log logrus.FieldLogger
Jitter retryutils.Jitter
}
Expand All @@ -143,7 +138,7 @@ type EmbeddingProcessor struct {
aiClient embeddinglib.Embedder
embeddingSrv Embeddings
embeddingsRetriever *SimpleRetriever
nodeSrv NodesStreamGetter
nodeSrv *services.UnifiedResourceCache
log logrus.FieldLogger
jitter retryutils.Jitter
}
Expand All @@ -160,15 +155,15 @@ func NewEmbeddingProcessor(cfg *EmbeddingProcessorConfig) *EmbeddingProcessor {
}
}

// nodeStringPair is a helper struct that pairs a node with a data string.
type nodeStringPair struct {
node types.Server
data string
// resourceStringPair is a helper struct that pairs a resource with a data string.
type resourceStringPair struct {
resource types.Resource
data string
}

// mapProcessFn is a helper function that maps a slice of nodeStringPair,
// compute embeddings and return them as a slice of ai.Embedding.
func (e *EmbeddingProcessor) mapProcessFn(ctx context.Context, data []*nodeStringPair) ([]*embeddinglib.Embedding, error) {
// mapProcessFn is a helper function that maps a slice of resourceStringPair,
// compute embeddings and return them as a slice.
func (e *EmbeddingProcessor) mapProcessFn(ctx context.Context, data []*resourceStringPair) ([]*embeddinglib.Embedding, error) {
dataBatch := make([]string, 0, len(data))
for _, pair := range data {
dataBatch = append(dataBatch, pair.data)
Expand All @@ -181,8 +176,8 @@ func (e *EmbeddingProcessor) mapProcessFn(ctx context.Context, data []*nodeStrin

results := make([]*embeddinglib.Embedding, 0, len(embeddings))
for i, embedding := range embeddings {
emb := embeddinglib.NewEmbedding(types.KindNode,
data[i].node.GetName(), embedding,
emb := embeddinglib.NewEmbedding(data[i].resource.GetKind(),
data[i].resource.GetName(), embedding,
embeddinglib.EmbeddingHash([]byte(data[i].data)),
)
results = append(results, emb)
Expand All @@ -208,7 +203,7 @@ func (e *EmbeddingProcessor) Run(ctx context.Context, initialDelay, period time.
}
}

// process updates embeddings for all nodes once.
// process updates embeddings for all resources once.
func (e *EmbeddingProcessor) process(ctx context.Context) {
batch := NewBatchReducer(e.mapProcessFn,
maxEmbeddingAPISize, // Max batch size allowed by OpenAI API,
Expand All @@ -217,19 +212,31 @@ func (e *EmbeddingProcessor) process(ctx context.Context) {
e.log.Debugf("embedding processor started")
defer e.log.Debugf("embedding processor finished")

embeddingsStream := e.embeddingSrv.GetEmbeddings(ctx, types.KindNode)
nodesStream := e.nodeSrv.GetNodeStream(ctx, defaults.Namespace)
embeddingsStream := e.embeddingSrv.GetAllEmbeddings(ctx)
unifiedResources, err := e.nodeSrv.GetUnifiedResources(ctx)
if err != nil {
e.log.Debugf("embedding processor failed with error: %v", err)
return
}

resources := make([]types.Resource, len(unifiedResources))
for i, unifiedResource := range unifiedResources {
resources[i] = unifiedResource
unifiedResources[i] = nil
}

resourceStream := stream.Slice(resources)

s := streamutils.NewZipStreams(
nodesStream,
resourceStream,
embeddingsStream,
// On new node callback. Add the node to the batch.
func(node types.Server) error {
nodeData, err := embeddinglib.SerializeNode(node)
// On new resource callback. Add the resource to the batch.
func(resource types.Resource) error {
resourceData, err := embeddinglib.SerializeResource(resource)
if err != nil {
return trace.Wrap(err)
}
vectors, err := batch.Add(ctx, &nodeStringPair{node, string(nodeData)})
vectors, err := batch.Add(ctx, &resourceStringPair{resource, string(resourceData)})
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -239,17 +246,17 @@ func (e *EmbeddingProcessor) process(ctx context.Context) {

return nil
},
// On equal node callback. Check if the node's embedding hash matches
// the one in the backend. If not, add the node to the batch.
func(node types.Server, embedding *embeddinglib.Embedding) error {
nodeData, err := embeddinglib.SerializeNode(node)
// On equal resource callback. Check if the resource's embedding hash matches
// the one in the backend. If not, add the resource to the batch.
func(resource types.Resource, embedding *embeddinglib.Embedding) error {
resourceData, err := embeddinglib.SerializeResource(resource)
if err != nil {
return trace.Wrap(err)
}
nodeHash := embeddinglib.EmbeddingHash(nodeData)
resourceHash := embeddinglib.EmbeddingHash(resourceData)

if !EmbeddingHashMatches(embedding, nodeHash) {
vectors, err := batch.Add(ctx, &nodeStringPair{node, string(nodeData)})
if !EmbeddingHashMatches(embedding, resourceHash) {
vectors, err := batch.Add(ctx, &resourceStringPair{resource, string(resourceData)})
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -260,16 +267,16 @@ func (e *EmbeddingProcessor) process(ctx context.Context) {
return nil
},
// On compare keys callback. Compare the keys for iteration.
func(node types.Server, embeddings *embeddinglib.Embedding) int {
return strings.Compare(node.GetName(), embeddings.GetEmbeddedID())
func(resource types.Resource, embeddings *embeddinglib.Embedding) int {
return strings.Compare(resource.GetName(), embeddings.GetEmbeddedID())
},
)

if err := s.Process(); err != nil {
e.log.Warnf("Failed to generate nodes embedding: %v", err)
}

// Process the remaining nodes in the batch
// Process the remaining resources in the batch
vectors, err := batch.Finalize(ctx)
if err != nil {
e.log.Warnf("Failed to add node to batch: %v", err)
Expand All @@ -290,7 +297,7 @@ func (e *EmbeddingProcessor) process(ctx context.Context) {
// latest embeddings. The new index is created and then swapped with the old one.
func (e *EmbeddingProcessor) updateMemIndex(ctx context.Context) error {
embeddingsIndex := NewSimpleRetriever()
embeddingsStream := e.embeddingSrv.GetEmbeddings(ctx, types.KindNode)
embeddingsStream := e.embeddingSrv.GetAllEmbeddings(ctx)

for embeddingsStream.Next() {
embedding := embeddingsStream.Item()
Expand Down
Loading

0 comments on commit 2fd8f75

Please sign in to comment.