Skip to content

Commit

Permalink
add data preview for data curation app
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh committed Jul 15, 2024
1 parent 9036519 commit b8ec48e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
9 changes: 8 additions & 1 deletion wren-ai-service/eval/data_curation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime

import orjson
import pandas as pd
import streamlit as st
import tomlkit
from openai import AsyncClient
Expand Down Expand Up @@ -72,6 +73,7 @@ def on_click_generate_question_sql_pairs(llm_client: AsyncClient):
st.session_state["llm_model"],
st.session_state["mdl_json"],
st.session_state["custom_instructions_for_llm"],
st.session_state["data_source"],
)
)

Expand Down Expand Up @@ -289,9 +291,14 @@ def on_click_remove_candidate_dataset_button(i: int):
on_change=on_change_sql,
args=(i, f"sql_{i}"),
)

if st.session_state["llm_question_sql_pairs"][i]["is_valid"]:
st.success("SQL is valid")
st.dataframe(
pd.DataFrame(
question_sql_pair["data"]["data"],
columns=question_sql_pair["data"]["columns"],
)
)
else:
st.error(
f"SQL is invalid: {st.session_state["llm_question_sql_pairs"][i]["error"]}"
Expand Down
41 changes: 39 additions & 2 deletions wren-ai-service/eval/data_curation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ async def get_question_sql_pairs(
llm_model: str,
mdl_json: dict,
custom_instructions: str,
data_source: str,
num_pairs: int = 10,
) -> list[dict]:
messages = [
Expand Down Expand Up @@ -359,9 +360,12 @@ async def get_question_sql_pairs(
question_sql_pairs = await get_validated_question_sql_pairs(results)
sqls = [question_sql_pair["sql"] for question_sql_pair in question_sql_pairs]
contexts = await get_contexts_from_sqls(sqls)
sqls_data = await get_data_from_wren_engine(sqls, data_source)
return [
{**quesiton_sql_pair, "context": context}
for quesiton_sql_pair, context in zip(question_sql_pairs, contexts)
{**quesiton_sql_pair, "context": context, "data": sql_data}
for quesiton_sql_pair, context, sql_data in zip(
question_sql_pairs, contexts, sqls_data
)
]
except Exception as e:
st.error(f"Error generating question-sql-pairs: {e}")
Expand All @@ -376,6 +380,39 @@ def prettify_sql(sql: str) -> str:
)


async def get_data_from_wren_engine(sqls: List[str], data_source: str):
assert data_source in DATA_SOURCES, f"Invalid data source: {data_source}"

async def _get_data(sql: str, data_source: str):
async with aiohttp.request(
"POST",
f"{WREN_IBIS_ENDPOINT}/v2/connector/{data_source}/query",
json={
"sql": add_quotes(sql),
"manifestStr": base64.b64encode(
orjson.dumps(st.session_state["mdl_json"])
).decode(),
"connectionInfo": st.session_state["connection_info"],
},
timeout=aiohttp.ClientTimeout(total=60),
) as response:
if response.status != 200:
return {"data": [], "columns": []}

data = await response.json()
column_names = [f"{i}_{col}" for i, col in enumerate(data["columns"])]

return {"data": data["data"], "columns": column_names}

async with aiohttp.ClientSession():
tasks = []
for sql in sqls:
task = asyncio.ensure_future(_get_data(sql, data_source))
tasks.append(task)

return await asyncio.gather(*tasks)


@st.cache_data
def get_eval_dataset_in_toml_string(mdl: dict, dataset: list) -> str:
doc = tomlkit.document()
Expand Down

0 comments on commit b8ec48e

Please sign in to comment.