Skip to content

Commit

Permalink
chore(wren-ai-service): minor refactor (#643)
Browse files Browse the repository at this point in the history
* fix conflict

* update

* add sql summary

* fix bug

* fix conflict

* remove pandasai

* update

* fix bug

* refactor pipeline

* refactor

* update

* refactor

* fix langfuse metadata

* update

* add pandas
  • Loading branch information
cyyeh committed Sep 4, 2024
1 parent 2d13ddd commit 322c34a
Show file tree
Hide file tree
Showing 14 changed files with 488 additions and 437 deletions.
8 changes: 8 additions & 0 deletions docker/.env.ai.example
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,11 @@ EMBEDDER_OLLAMA_URL=http://host.docker.internal:11434
DOCUMENT_STORE_PROVIDER=qdrant

QDRANT_HOST=qdrant


## Langfuse: https://langfuse.com/
# empty means disabled
LANGFUSE_ENABLE=
LANGFUSE_SECRET_KEY=
LANGFUSE_PUBLIC_KEY=
LANGFUSE_HOST=https://cloud.langfuse.com
13 changes: 9 additions & 4 deletions wren-ai-service/.env.dev.example
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ TABLE_RETRIEVAL_SIZE=10
TABLE_COLUMN_RETRIEVAL_SIZE=1000

## LLM
LLM_PROVIDER=openai_llm # openai_llm, azure_openai_llm, ollama_llm
GENERATION_MODEL=gpt-4o-mini # gpt-4o-mini, gpt-4o, gpt-4-turbo, gpt-3.5-turbo
# openai_llm, azure_openai_llm, ollama_llm
LLM_PROVIDER=openai_llm
# gpt-4o-mini, gpt-4o, gpt-4-turbo, gpt-3.5-turbo
GENERATION_MODEL=gpt-4o-mini

# openai or openai-api-compatible
LLM_OPENAI_API_KEY=sk-1234567890
Expand All @@ -24,7 +26,8 @@ LLM_OLLAMA_URL=http://localhost:11434


## EMBEDDER
EMBEDDER_PROVIDER=openai_embedder # openai_embedder, azure_openai_embedder, ollama_embedder
# openai_embedder, azure_openai_embedder, ollama_embedder
EMBEDDER_PROVIDER=openai_embedder
EMBEDDING_MODEL=text-embedding-3-large
EMBEDDING_MODEL_DIMENSION=3072

Expand All @@ -47,7 +50,8 @@ DOCUMENT_STORE_PROVIDER=qdrant
QDRANT_HOST=http://localhost:6333
QDRANT_API_KEY=

ENGINE=wren_ui # wren_ui, wren_ibis, wren_engine
# wren_ui, wren_ibis, wren_engine
ENGINE=wren_ui

## when using wren_ui as the engine
WREN_UI_ENDPOINT=http://localhost:3000
Expand All @@ -67,6 +71,7 @@ WREN_ENGINE_MANIFEST=
# Evaluation
DATASET_NAME=book_2

# empty means disabled
LANGFUSE_ENABLE=
LANGFUSE_SECRET_KEY=
LANGFUSE_PUBLIC_KEY=
Expand Down
809 changes: 412 additions & 397 deletions wren-ai-service/poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions wren-ai-service/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,26 @@ backoff = "==2.2.1"
tqdm = "==4.66.4"
numpy = "==1.26.4"
sqlparse = "==0.5.0"
sqlglot = "==22.5.0"
orjson = "==3.10.3"
sf-hamilton = {version = "==1.69.0", extras = ["visualization"]}
aiohttp = {extras = ["speedups"], version = "==3.10.2"}
ollama-haystack = "==0.0.6"
langfuse = "==2.43.3"
ollama = "==0.2.1"
toml = "==0.10.2"
sqlglot = "==25.18.0"

[tool.poetry.group.dev.dependencies]
pre-commit = "==3.7.1"
streamlit = "==1.37.0"
watchdog = "==4.0.0"
pandas = "==2.2.2"

[tool.poetry.group.eval.dependencies]
tomlkit = "==0.13.0"
deepeval = "==1.0.6"
streamlit-tags = "==1.2.8"
gitpython = "==3.1.43"
pandas = "==2.2.2"

[tool.poetry.group.demo.dependencies]
requests = "==2.32.2"
Expand Down
12 changes: 8 additions & 4 deletions wren-ai-service/src/pipelines/ask/followup_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,22 +140,24 @@ def prompt(

@async_timer
@observe(as_type="generation", capture_input=False)
async def generate(prompt: dict, generator: Any) -> dict:
async def generate_sql_in_followup(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {orjson.dumps(prompt, option=orjson.OPT_INDENT_2).decode()}")
return await generator.run(prompt=prompt.get("prompt"))


@async_timer
@observe(capture_input=False)
async def post_process(
generate: dict,
generate_sql_in_followup: dict,
post_processor: GenerationPostProcessor,
project_id: str | None = None,
) -> dict:
logger.debug(
f"generate: {orjson.dumps(generate, option=orjson.OPT_INDENT_2).decode()}"
f"generate_sql_in_followup: {orjson.dumps(generate_sql_in_followup, option=orjson.OPT_INDENT_2).decode()}"
)
return await post_processor.run(
generate_sql_in_followup.get("replies"), project_id=project_id
)
return await post_processor.run(generate.get("replies"), project_id=project_id)


## End of Pipeline
Expand Down Expand Up @@ -184,6 +186,7 @@ def visualize(
query: str,
contexts: List[str],
history: AskRequest.AskResponseDetails,
project_id: str | None = None,
) -> None:
destination = "outputs/pipelines/ask"
if not Path(destination).exists():
Expand All @@ -200,6 +203,7 @@ def visualize(
"documents": contexts,
"history": history,
"alert": TEXT_TO_SQL_RULES,
"project_id": project_id,
},
show_legend=True,
orient="LR",
Expand Down
10 changes: 6 additions & 4 deletions wren-ai-service/src/pipelines/ask/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,22 +112,22 @@ def prompt(

@async_timer
@observe(as_type="generation", capture_input=False)
async def generate(prompt: dict, generator: Any) -> dict:
async def generate_sql(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {orjson.dumps(prompt, option=orjson.OPT_INDENT_2).decode()}")
return await generator.run(prompt=prompt.get("prompt"))


@async_timer
@observe(capture_input=False)
async def post_process(
generate: dict,
generate_sql: dict,
post_processor: GenerationPostProcessor,
project_id: str | None = None,
) -> dict:
logger.debug(
f"generate: {orjson.dumps(generate, option=orjson.OPT_INDENT_2).decode()}"
f"generate_sql: {orjson.dumps(generate_sql, option=orjson.OPT_INDENT_2).decode()}"
)
return await post_processor.run(generate.get("replies"), project_id=project_id)
return await post_processor.run(generate_sql.get("replies"), project_id=project_id)


## End of Pipeline
Expand All @@ -154,6 +154,7 @@ def visualize(
query: str,
contexts: List[str],
exclude: List[Dict],
project_id: str | None = None,
) -> None:
destination = "outputs/pipelines/ask"
if not Path(destination).exists():
Expand All @@ -170,6 +171,7 @@ def visualize(
"generator": self.generator,
"prompt_builder": self.prompt_builder,
"post_processor": self.post_processor,
"project_id": project_id,
},
show_legend=True,
orient="LR",
Expand Down
12 changes: 8 additions & 4 deletions wren-ai-service/src/pipelines/ask/sql_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,24 @@ def prompt(

@async_timer
@observe(as_type="generation", capture_input=False)
async def generate(prompt: dict, generator: Any) -> dict:
async def generate_sql_correction(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {orjson.dumps(prompt, option=orjson.OPT_INDENT_2).decode()}")
return await generator.run(prompt=prompt.get("prompt"))


@async_timer
@observe(capture_input=False)
async def post_process(
generate: dict,
generate_sql_correction: dict,
post_processor: GenerationPostProcessor,
project_id: str | None = None,
) -> dict:
logger.debug(
f"generate: {orjson.dumps(generate, option=orjson.OPT_INDENT_2).decode()}"
f"generate_sql_correction: {orjson.dumps(generate_sql_correction, option=orjson.OPT_INDENT_2).decode()}"
)
return await post_processor.run(
generate_sql_correction.get("replies"), project_id=project_id
)
return await post_processor.run(generate.get("replies"), project_id=project_id)


## End of Pipeline
Expand Down Expand Up @@ -125,6 +127,7 @@ def visualize(
self,
contexts: List[Document],
invalid_generation_results: List[Dict[str, str]],
project_id: str | None = None,
) -> None:
destination = "outputs/pipelines/ask"
if not Path(destination).exists():
Expand All @@ -140,6 +143,7 @@ def visualize(
"generator": self.generator,
"prompt_builder": self.prompt_builder,
"post_processor": self.post_processor,
"project_id": project_id,
},
show_legend=True,
orient="LR",
Expand Down
8 changes: 4 additions & 4 deletions wren-ai-service/src/pipelines/ask/sql_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,21 @@ def prompt(

@async_timer
@observe(as_type="generation", capture_input=False)
async def generate(prompt: dict, generator: Any) -> dict:
async def generate_sql_summary(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {orjson.dumps(prompt, option=orjson.OPT_INDENT_2).decode()}")
return await generator.run(prompt=prompt.get("prompt"))


@timer
def post_process(
generate: dict,
generate_sql_summary: dict,
sqls: List[str],
post_processor: SQLSummaryPostProcessor,
) -> dict:
logger.debug(
f"generate: {orjson.dumps(generate, option=orjson.OPT_INDENT_2).decode()}"
f"generate_sql_summary: {orjson.dumps(generate_sql_summary, option=orjson.OPT_INDENT_2).decode()}"
)
return post_processor.run(sqls, generate.get("replies"))
return post_processor.run(sqls, generate_sql_summary.get("replies"))


## End of Pipeline
Expand Down
13 changes: 8 additions & 5 deletions wren-ai-service/src/pipelines/ask_details/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,24 @@ def prompt(sql: str, prompt_builder: PromptBuilder) -> dict:

@async_timer
@observe(as_type="generation", capture_input=False)
async def generate(prompt: dict, generator: Any) -> dict:
async def generate_sql_details(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {orjson.dumps(prompt, option=orjson.OPT_INDENT_2).decode()}")
return await generator.run(prompt=prompt.get("prompt"))


@async_timer
@observe(capture_input=False)
async def post_process(
generate: dict,
generate_sql_details: dict,
post_processor: GenerationPostProcessor,
project_id: str | None = None,
) -> dict:
logger.debug(
f"generate: {orjson.dumps(generate, option=orjson.OPT_INDENT_2).decode()}"
f"generate_sql_details: {orjson.dumps(generate_sql_details, option=orjson.OPT_INDENT_2).decode()}"
)
return await post_processor.run(
generate_sql_details.get("replies"), project_id=project_id
)
return await post_processor.run(generate.get("replies"), project_id=project_id)


## End of Pipeline
Expand All @@ -93,7 +95,7 @@ def __init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)

def visualize(self, sql: str) -> None:
def visualize(self, sql: str, project_id: str | None = None) -> None:
destination = "outputs/pipelines/ask_details"
if not Path(destination).exists():
Path(destination).mkdir(parents=True, exist_ok=True)
Expand All @@ -106,6 +108,7 @@ def visualize(self, sql: str) -> None:
"generator": self.generator,
"prompt_builder": self.prompt_builder,
"post_processor": self.post_processor,
"project_id": project_id,
},
show_legend=True,
orient="LR",
Expand Down
8 changes: 4 additions & 4 deletions wren-ai-service/src/pipelines/sql_explanation/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def prompts(

@async_timer
@observe(as_type="generation", capture_input=False)
async def generates(prompts: List[dict], generator: Any) -> List[dict]:
async def generate_sql_explanation(prompts: List[dict], generator: Any) -> List[dict]:
logger.debug(
f"prompts: {orjson.dumps(prompts, option=orjson.OPT_INDENT_2).decode()}"
)
Expand All @@ -484,19 +484,19 @@ async def _task(prompt: str, generator: Any):
@timer
@observe(capture_input=False)
def post_process(
generates: List[dict],
generate_sql_explanation: List[dict],
preprocess: dict,
post_processor: GenerationPostProcessor,
) -> dict:
logger.debug(
f"generates: {orjson.dumps(generates, option=orjson.OPT_INDENT_2).decode()}"
f"generate_sql_explanation: {orjson.dumps(generate_sql_explanation, option=orjson.OPT_INDENT_2).decode()}"
)
logger.debug(
f"preprocess: {orjson.dumps(preprocess, option=orjson.OPT_INDENT_2).decode()}"
)

return post_processor.run(
generates,
generate_sql_explanation,
preprocess["preprocessed_sql_analysis_results"],
)

Expand Down
10 changes: 6 additions & 4 deletions wren-ai-service/src/pipelines/sql_regeneration/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def sql_regeneration_prompt(

@async_timer
@observe(as_type="generation", capture_input=False)
async def sql_regeneration_generate(
async def generate_sql_regeneration(
sql_regeneration_prompt: dict,
sql_regeneration_generator: Any,
) -> dict:
Expand All @@ -93,15 +93,15 @@ async def sql_regeneration_generate(
@async_timer
@observe(capture_input=False)
async def sql_regeneration_post_process(
sql_regeneration_generate: dict,
generate_sql_regeneration: dict,
sql_regeneration_post_processor: GenerationPostProcessor,
project_id: str | None = None,
) -> dict:
logger.debug(
f"sql_regeneration_generate: {orjson.dumps(sql_regeneration_generate, option=orjson.OPT_INDENT_2).decode()}"
f"generate_sql_regeneration: {orjson.dumps(generate_sql_regeneration, option=orjson.OPT_INDENT_2).decode()}"
)
return await sql_regeneration_post_processor.run(
replies=sql_regeneration_generate.get("replies"),
replies=generate_sql_regeneration.get("replies"),
project_id=project_id,
)

Expand Down Expand Up @@ -132,6 +132,7 @@ def visualize(
self,
description: str,
steps: List[SQLExplanationWithUserCorrections],
project_id: str | None = None,
) -> None:
destination = "outputs/pipelines/sql_regeneration"
if not Path(destination).exists():
Expand All @@ -147,6 +148,7 @@ def visualize(
"sql_regeneration_prompt_builder": self.sql_regeneration_prompt_builder,
"sql_regeneration_generator": self.sql_regeneration_generator,
"sql_regeneration_post_processor": self.sql_regeneration_post_processor,
"project_id": project_id,
},
show_legend=True,
orient="LR",
Expand Down
4 changes: 3 additions & 1 deletion wren-ai-service/src/web/v1/services/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ async def ask(
"ask_result": {},
"metadata": {
"error_type": "",
"error_message": "",
},
}

Expand Down Expand Up @@ -305,7 +306,7 @@ async def ask(
except Exception as e:
logger.exception(f"ask pipeline - OTHERS: {e}")

self._ask_results[query_id] = AskResultResponse(
self._ask_results[ask_request.query_id] = AskResultResponse(
status="failed",
error=AskResultResponse.AskError(
code="OTHERS",
Expand All @@ -314,6 +315,7 @@ async def ask(
)

results["metadata"]["error_type"] = "OTHERS"
results["metadata"]["error_message"] = str(e)
return results

def stop_ask(
Expand Down
Loading

0 comments on commit 322c34a

Please sign in to comment.