From b8ec48e46fc0696724b2b76569eb96986b4ddda1 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Mon, 15 Jul 2024 13:28:11 +0800 Subject: [PATCH] add data preview for data curation app --- wren-ai-service/eval/data_curation/app.py | 9 ++++- wren-ai-service/eval/data_curation/utils.py | 41 ++++++++++++++++++++- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/wren-ai-service/eval/data_curation/app.py b/wren-ai-service/eval/data_curation/app.py index b4a58d19b..bda3dbc2b 100644 --- a/wren-ai-service/eval/data_curation/app.py +++ b/wren-ai-service/eval/data_curation/app.py @@ -5,6 +5,7 @@ from datetime import datetime import orjson +import pandas as pd import streamlit as st import tomlkit from openai import AsyncClient @@ -72,6 +73,7 @@ def on_click_generate_question_sql_pairs(llm_client: AsyncClient): st.session_state["llm_model"], st.session_state["mdl_json"], st.session_state["custom_instructions_for_llm"], + st.session_state["data_source"], ) ) @@ -289,9 +291,14 @@ def on_click_remove_candidate_dataset_button(i: int): on_change=on_change_sql, args=(i, f"sql_{i}"), ) - if st.session_state["llm_question_sql_pairs"][i]["is_valid"]: st.success("SQL is valid") + st.dataframe( + pd.DataFrame( + question_sql_pair["data"]["data"], + columns=question_sql_pair["data"]["columns"], + ) + ) else: st.error( f"SQL is invalid: {st.session_state["llm_question_sql_pairs"][i]["error"]}" diff --git a/wren-ai-service/eval/data_curation/utils.py b/wren-ai-service/eval/data_curation/utils.py index a794adb71..a85db03ea 100644 --- a/wren-ai-service/eval/data_curation/utils.py +++ b/wren-ai-service/eval/data_curation/utils.py @@ -305,6 +305,7 @@ async def get_question_sql_pairs( llm_model: str, mdl_json: dict, custom_instructions: str, + data_source: str, num_pairs: int = 10, ) -> list[dict]: messages = [ @@ -359,9 +360,12 @@ async def get_question_sql_pairs( question_sql_pairs = await get_validated_question_sql_pairs(results) sqls = [question_sql_pair["sql"] for question_sql_pair in question_sql_pairs] contexts = await get_contexts_from_sqls(sqls) + sqls_data = await get_data_from_wren_engine(sqls, data_source) return [ - {**quesiton_sql_pair, "context": context} - for quesiton_sql_pair, context in zip(question_sql_pairs, contexts) + {**quesiton_sql_pair, "context": context, "data": sql_data} + for quesiton_sql_pair, context, sql_data in zip( + question_sql_pairs, contexts, sqls_data + ) ] except Exception as e: st.error(f"Error generating question-sql-pairs: {e}") @@ -376,6 +380,39 @@ def prettify_sql(sql: str) -> str: ) +async def get_data_from_wren_engine(sqls: List[str], data_source: str): + assert data_source in DATA_SOURCES, f"Invalid data source: {data_source}" + + async def _get_data(sql: str, data_source: str): + async with aiohttp.request( + "POST", + f"{WREN_IBIS_ENDPOINT}/v2/connector/{data_source}/query", + json={ + "sql": add_quotes(sql), + "manifestStr": base64.b64encode( + orjson.dumps(st.session_state["mdl_json"]) + ).decode(), + "connectionInfo": st.session_state["connection_info"], + }, + timeout=aiohttp.ClientTimeout(total=60), + ) as response: + if response.status != 200: + return {"data": [], "columns": []} + + data = await response.json() + column_names = [f"{i}_{col}" for i, col in enumerate(data["columns"])] + + return {"data": data["data"], "columns": column_names} + + async with aiohttp.ClientSession(): + tasks = [] + for sql in sqls: + task = asyncio.ensure_future(_get_data(sql, data_source)) + tasks.append(task) + + return await asyncio.gather(*tasks) + + @st.cache_data def get_eval_dataset_in_toml_string(mdl: dict, dataset: list) -> str: doc = tomlkit.document()