Skip to content

Commit

Permalink
chore(wren-ai-service): update indexing async components and update t…
Browse files Browse the repository at this point in the history
…ests to async version (#549)

* update

* fix conflict

* fix conflict

* update

* fix conflict

* remove unused code

* fix bug

* fix conflict

* fix conflict

* fix demo ui

* update

* add sql regenerations api boilerplate

* fix conflicts

* update

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* update sql explanation api and pipeline

* update

* update sql explanation api

* refine sql explanation pipeline

* fix pipeline

* fix conflict

* fix sql formatting

* resolve conflict

* fix bug

* resolve conflict

* fix conflict

* rebase

* make sql_regeneration async

* fix broken import

* fix async await

* use logger.exception instead of logger.error

* fix bugs

* simplify pipeline

* refine prompt

* update sql_explanation api by allowing passing multiple steps of sqls

* remove redundant code

* fix bug

* update ui

* update

* update ui

* update

* fix conflict

* orjson dump and formatting for debug messages

* fix tests

* fix conflict

* fix bugs

* fix bug

* fix conflict

* update

* fix conflict

* update sql explanation results

* fix groupByKeys bug

* update

* update

* update groupByKeys

* update engine configs

* add OTHERS error code

* refine ui: use sidebar

* fix conflict

* fix conflict

* fix bug

* fix imports

* allow users to choose openai llm

* update

* update prompt

* fix bug

* fix tests

* fix bug

* fix conflicts

* update prompt and fix bugs

* update

* fix bug

* fix

* fix engine as wren_ui

* remove unused dataset

* fix sql explanation

* fix groupByKey id

* update

* change EngineConfig location and update .env.dev.example

* give defaults to EngineConfig

* update

* fix

* add async

* add async document writer

* add pytest-asyncio and update tests

* update

* revert

* fix

* fix
  • Loading branch information
cyyeh committed Jul 24, 2024
1 parent 5ab1a4d commit ca8bd54
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 115 deletions.
20 changes: 19 additions & 1 deletion wren-ai-service/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions wren-ai-service/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ ragas-haystack = "==0.1.3"
psycopg2-binary = "==2.9.9"
setuptools = "==70.0.0"
locust = "==2.28.0"
pytest-asyncio = "==0.23.8"

[tool.poetry.group.eval.dependencies]
tomlkit = "==0.13.0"
Expand Down
53 changes: 36 additions & 17 deletions wren-ai-service/src/pipelines/indexing/indexing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import json
import logging
import os
import sys
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

import orjson
from hamilton import base
Expand Down Expand Up @@ -35,14 +36,15 @@ def __init__(self, stores: List[DocumentStore]) -> None:
self._stores = stores

@component.output_types(mdl=str)
def run(self, mdl: str) -> str:
def _clear_documents(store: DocumentStore) -> None:
ids = [str(i) for i in range(store.count_documents())]
async def run(self, mdl: str) -> str:
async def _clear_documents(store: DocumentStore) -> None:
document_count = await store.count_documents()
ids = [str(i) for i in range(document_count)]
if ids:
store.delete_documents(ids)
await store.delete_documents(ids)

logger.info("Ask Indexing pipeline is clearing old documents...")
[_clear_documents(store) for store in self._stores]
await asyncio.gather(*[_clear_documents(store) for store in self._stores])
return {"mdl": mdl}


Expand Down Expand Up @@ -315,12 +317,29 @@ def _convert_metrics(self, metrics: List[Dict[str, Any]]) -> List[str]:
return ddl_commands


@component
class AsyncDocumentWriter(DocumentWriter):
@component.output_types(documents_written=int)
async def run(
self, documents: List[Document], policy: Optional[DuplicatePolicy] = None
):
if policy is None:
policy = self.policy

documents_written = await self.document_store.write_documents(
documents=documents, policy=policy
)
return {"documents_written": documents_written}


## Start of Pipeline
@timer
@async_timer
@observe(capture_input=False, capture_output=False)
def clean_document_store(mdl_str: str, cleaner: DocumentCleaner) -> Dict[str, Any]:
async def clean_document_store(
mdl_str: str, cleaner: DocumentCleaner
) -> Dict[str, Any]:
logger.debug(f"input in clean_document_store: {mdl_str}")
return cleaner.run(mdl=mdl_str)
return await cleaner.run(mdl=mdl_str)


@timer
Expand Down Expand Up @@ -357,10 +376,10 @@ async def embed_ddl(
return await ddl_embedder.run(documents=convert_to_ddl["documents"])


@timer
@async_timer
@observe(capture_input=False)
def write_ddl(embed_ddl: Dict[str, Any], ddl_writer: DocumentWriter) -> None:
return ddl_writer.run(documents=embed_ddl["documents"])
async def write_ddl(embed_ddl: Dict[str, Any], ddl_writer: DocumentWriter) -> None:
return await ddl_writer.run(documents=embed_ddl["documents"])


@timer
Expand All @@ -385,10 +404,10 @@ async def embed_view(
return await view_embedder.run(documents=convert_to_view["documents"])


@timer
@async_timer
@observe(capture_input=False)
def write_view(embed_view: Dict[str, Any], view_writer: DocumentWriter) -> None:
return view_writer.run(documents=embed_view["documents"])
async def write_view(embed_view: Dict[str, Any], view_writer: DocumentWriter) -> None:
return await view_writer.run(documents=embed_view["documents"])


## End of Pipeline
Expand All @@ -408,13 +427,13 @@ def __init__(

self.ddl_converter = DDLConverter()
self.ddl_embedder = embedder_provider.get_document_embedder()
self.ddl_writer = DocumentWriter(
self.ddl_writer = AsyncDocumentWriter(
document_store=ddl_store,
policy=DuplicatePolicy.OVERWRITE,
)
self.view_converter = ViewConverter()
self.view_embedder = embedder_provider.get_document_embedder()
self.view_writer = DocumentWriter(
self.view_writer = AsyncDocumentWriter(
document_store=view_store,
policy=DuplicatePolicy.OVERWRITE,
)
Expand Down
75 changes: 74 additions & 1 deletion wren-ai-service/src/providers/document_store/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,24 @@
import numpy as np
import qdrant_client
from haystack import Document, component
from haystack.document_stores.types import DuplicatePolicy
from haystack.utils import Secret
from haystack_integrations.components.retrievers.qdrant import QdrantEmbeddingRetriever
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore
from haystack_integrations.document_stores.qdrant import (
QdrantDocumentStore,
document_store,
)
from haystack_integrations.document_stores.qdrant.converters import (
DENSE_VECTORS_NAME,
convert_haystack_documents_to_qdrant_points,
convert_id,
convert_qdrant_point_to_haystack_document,
)
from haystack_integrations.document_stores.qdrant.filters import (
convert_filters_to_qdrant,
)
from qdrant_client.http import models as rest
from tqdm import tqdm

from src.core.provider import DocumentStoreProvider
from src.providers.loader import get_default_embedding_model_dim, provider
Expand Down Expand Up @@ -156,6 +163,72 @@ async def _query_by_embedding(
document.score = score
return results

async def delete_documents(self, ids: List[str]):
ids = [convert_id(_id) for _id in ids]
try:
await self.async_client.delete(
collection_name=self.index,
points_selector=ids,
wait=self.wait_result_from_api,
)
except KeyError:
logger.warning(
"Called QdrantDocumentStore.delete_documents() on a non-existing ID",
)

async def count_documents(self) -> int:
return (await self.async_client.count(collection_name=self.index)).count

async def write_documents(
self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL
):
for doc in documents:
if not isinstance(doc, Document):
msg = f"DocumentStore.write_documents() expects a list of Documents but got an element of {type(doc)}."
raise ValueError(msg)

self._set_up_collection(
self.index,
self.embedding_dim,
False,
self.similarity,
self.use_sparse_embeddings,
)

if len(documents) == 0:
logger.warning(
"Calling QdrantDocumentStore.write_documents() with empty list"
)
return

document_objects = self._handle_duplicate_documents(
documents=documents,
index=self.index,
policy=policy,
)

batched_documents = document_store.get_batches_from_generator(
document_objects, self.write_batch_size
)
with tqdm(
total=len(document_objects), disable=not self.progress_bar
) as progress_bar:
for document_batch in batched_documents:
batch = convert_haystack_documents_to_qdrant_points(
document_batch,
embedding_field=self.embedding_field,
use_sparse_embeddings=self.use_sparse_embeddings,
)

await self.async_client.upsert(
collection_name=self.index,
points=batch,
wait=self.wait_result_from_api,
)

progress_bar.update(self.write_batch_size)
return len(document_objects)


class AsyncQdrantEmbeddingRetriever(QdrantEmbeddingRetriever):
def __init__(
Expand Down
Loading

0 comments on commit ca8bd54

Please sign in to comment.