Skip to content

Commit

Permalink
Implement graph api endpoint to trigger embedding upload
Browse files Browse the repository at this point in the history
  • Loading branch information
shedrachokonofua committed Jul 20, 2024
1 parent 78d1ac1 commit 93ff643
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 20 deletions.
5 changes: 5 additions & 0 deletions Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,8 @@ tasks:
dir: connector/graph
cmds:
- poetry run python -m grpc_tools.protoc -Igraph/proto=../../proto --python_out=. --pyi_out=. --grpc_python_out=. ../../proto/lute.proto

"graph:notebooks":
dir: connector/graph
cmds:
- poetry run jupyter lab
28 changes: 28 additions & 0 deletions connector/graph/graph/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import uvicorn
from fastapi import FastAPI

from graph import db
from graph.logger import logger
from graph.lute import LuteClient
from graph.settings import API_PORT

Expand Down Expand Up @@ -33,6 +35,32 @@ async def root():
}


@app.post("/embeddings/albums/sync")
async def sync_album_embeddings(relationship_weights: db.AlbumRelationWeights):
embeddings = db.generate_album_embeddings("lute_graph", relationship_weights)
logger.info(
"Generated embeddings", extra={"props": {"embedding_count": len(embeddings)}}
)

cursor = 0
batch_size = 1500

async def upload_generator():
nonlocal cursor
while cursor < len(embeddings):
batch = embeddings[cursor : cursor + batch_size]
logger.info(
"Uploading embeddings batch",
extra={"props": {"batch_size": len(batch), "cursor": cursor}},
)
yield batch
cursor += batch_size

node_count = await lute_client.bulk_upload_embeddings(upload_generator())

return {"node_count": node_count}


async def run():
config = uvicorn.Config(
app,
Expand Down
111 changes: 103 additions & 8 deletions connector/graph/graph/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
from graphdatascience import GraphDataScience

from graph.logger import logger
from graph.models import AlbumRelationWeights, EmbeddingDocument
from graph.proto import lute_pb2
from graph.settings import NEO4J_URL

gds = GraphDataScience(NEO4J_URL)

def setup_indexes(gds: GraphDataScience):

def setup_indexes():
statements = [
"""
CREATE CONSTRAINT album_file_name IF NOT EXISTS FOR (a:Album)
Expand Down Expand Up @@ -38,14 +42,16 @@ def setup_indexes(gds: GraphDataScience):
gds.run_cypher(statement)


def update_graph(gds: GraphDataScience, albums: list[tuple[str, lute_pb2.ParsedAlbum]]):
def update_graph(albums: list[tuple[str, lute_pb2.ParsedAlbum]]):
start = time()
relationship_count = 0

logger.info(
"Building graph update",
extra={
"album_count": len(albums),
"props": {
"album_count": len(albums),
}
},
)
artists = {}
Expand Down Expand Up @@ -171,7 +177,7 @@ def update_graph(gds: GraphDataScience, albums: list[tuple[str, lute_pb2.ParsedA
UNWIND $album_genres AS album_genre
MATCH (album:Album {file_name: album_genre.album_file_name})
MATCH (genre:Genre {name: album_genre.genre})
MERGE (album)-[:GENRE {weight: 2}]->(genre)
MERGE (album)-[:GENRE {weight: 3}]->(genre)
""",
{
"album_genres": [
Expand Down Expand Up @@ -236,9 +242,98 @@ def update_graph(gds: GraphDataScience, albums: list[tuple[str, lute_pb2.ParsedA
logger.info(
"Graph updated",
extra={
"album_count": len(albums),
"duration": time() - start,
"node_count": node_count,
"relationship_count": relationship_count,
"props": {
"album_count": len(albums),
"duration": time() - start,
"node_count": node_count,
"relationship_count": relationship_count,
}
},
)


def generate_album_embeddings(
embedding_key: str,
weights: AlbumRelationWeights,
) -> list[EmbeddingDocument]:
node_projection = ["Album", "Genre", "Artist", "Descriptor", "Language"]
relationship_projection = {
"GENRE": {"orientation": "UNDIRECTED", "properties": "weight"},
"DESCRIPTOR": {
"orientation": "UNDIRECTED",
"properties": {"weight": {"defaultValue": weights.descriptor}},
},
"LANGUAGE": {
"orientation": "UNDIRECTED",
"properties": {"weight": {"defaultValue": weights.language}},
},
"ALBUM_ARTIST": {
"orientation": "UNDIRECTED",
"properties": {"weight": {"defaultValue": weights.album_artist}},
},
"CREDITED": {
"orientation": "UNDIRECTED",
"properties": {"weight": {"defaultValue": weights.credited}},
},
}
projection_id = f"p_{int(time())}"
projection, output = gds.graph.project(
projection_id, node_projection, relationship_projection
)

logger.info(
"Created graph projection, generating embeddings",
extra={
"props": {
"projection_id": projection_id,
"node_count": int(projection.node_count()),
"relationship_count": int(projection.relationship_count()),
"duration_ms": int(output["projectMillis"]),
}
},
)

start = time()
result = gds.run_cypher(
"""
CALL gds.fastRP.stream($projection_id, {
embeddingDimension: 512,
randomSeed: 42,
relationshipWeightProperty: 'weight'
})
YIELD nodeId, embedding
WITH nodeId, embedding, gds.util.asNode(nodeId) AS node
WHERE node:Album
RETURN nodeId, node.file_name AS fileName, embedding
""",
{
"projection_id": projection_id,
},
)

logger.info(
"Generated embeddings",
extra={
"props": {
"duration": time() - start,
"embedding_count": result.shape[0],
}
},
)

embedding_documents = [
EmbeddingDocument(
file_name=row["fileName"],
embedding=row["embedding"],
embedding_key=embedding_key,
)
for row in result.to_dict("records")
]

projection.drop()

return embedding_documents


def disconnect():
gds.close()
24 changes: 24 additions & 0 deletions connector/graph/graph/lute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import graph.proto.lute_pb2 as lute_pb2
import graph.proto.lute_pb2_grpc as lute_pb2_grpc
from graph.models import EmbeddingDocument
from graph.settings import LUTE_EVENT_SUBSCRIBER_PREFIX, LUTE_URL

MAX_MESSAGE_LENGTH = 50 * 1024 * 1024
Expand Down Expand Up @@ -83,3 +84,26 @@ async def request_generator():
async for reply in self.event_service.Stream(request_generator()):
yield reply.items
await queue.put(reply.cursor)

async def bulk_upload_embeddings(
self, embedding_iter: AsyncIterator[list[EmbeddingDocument]]
) -> int:
if self.album_service is None:
raise ValueError("Client not initialized")

async def request_generator():
async for batch in embedding_iter:
yield lute_pb2.BulkUploadAlbumEmbeddingsRequest(
items=[
lute_pb2.BulkUploadAlbumEmbeddingsRequestItem(
file_name=doc.file_name,
embedding=doc.embedding,
embedding_key=doc.embedding_key,
)
for doc in batch
]
)

reply = await self.album_service.BulkUploadAlbumEmbeddings(request_generator())
print(reply)
return reply.count
16 changes: 4 additions & 12 deletions connector/graph/graph/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import asyncio

from graphdatascience import GraphDataScience

from graph import api, db
from graph.logger import logger
from graph.lute import LuteClient
from graph.proto import lute_pb2
from graph.settings import NEO4J_URL


def is_album_parsed_event(item: lute_pb2.EventStreamItem) -> bool:
Expand All @@ -20,11 +17,9 @@ def is_album_parsed_event(item: lute_pb2.EventStreamItem) -> bool:


async def run_graph_sync():
gds = GraphDataScience(NEO4J_URL)
db.setup_indexes(gds)
async with LuteClient() as client:
async for items in client.stream_events("parser", "build", 500):
logger.info("Received events", extra={"event_count": len(items)})
logger.info("Received events", extra={"props": {"event_count": len(items)}})
parsed_albums = [
(
item.payload.event.file_parsed.file_name,
Expand All @@ -35,17 +30,14 @@ async def run_graph_sync():
]

if parsed_albums:
db.update_graph(gds, parsed_albums)
gds.close()
db.update_graph(parsed_albums)


async def run():
db.setup_indexes()
await asyncio.gather(api.run(), run_graph_sync())
db.disconnect()


def main():
asyncio.run(run())


# if __name__ == "__main__":
# main()
17 changes: 17 additions & 0 deletions connector/graph/graph/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from dataclasses import dataclass

from pydantic import BaseModel, PositiveInt


class AlbumRelationWeights(BaseModel):
album_artist: PositiveInt = 4
credited: PositiveInt = 2
descriptor: PositiveInt = 1
language: PositiveInt = 1


@dataclass
class EmbeddingDocument:
file_name: str
embedding: list[float]
embedding_key: str
1 change: 1 addition & 0 deletions core/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ pub struct EmbeddingProviderSettings {
pub openai: Option<OpenAISettings>,
pub voyageai: Option<VoyageAISettings>,
pub ollama: Option<OllamaSettings>,
pub
}

#[derive(Debug, Clone, Default, Deserialize, PartialEq, Eq)]
Expand Down

0 comments on commit 93ff643

Please sign in to comment.