Skip to content

Commit

Permalink
feat(ai-service): feedback loop (#261)
Browse files Browse the repository at this point in the history
* update

* fix conflict

* fix conflict

* update

* fix conflict

* remove unused code

* fix bug

* fix conflict

* fix conflict

* fix demo ui

* update

* add sql regenerations api boilerplate

* fix conflicts

* update

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* update sql explanation api and pipeline

* update

* update sql explanation api

* refine sql explanation pipeline

* fix pipeline

* fix conflict

* fix sql formatting

* resolve conflict

* fix bug

* resolve conflict

* fix conflict

* rebase

* make sql_regeneration async

* fix broken import

* fix async await

* use logger.exception instead of logger.error

* fix bugs

* simplify pipeline

* refine prompt

* update sql_explanation api by allowing passing multiple steps of sqls

* remove redundant code

* fix bug

* update ui

* update

* update ui

* update

* fix conflict

* orjson dump and formatting for debug messages

* fix tests

* fix conflict

* fix bugs

* fix bug

* fix conflict

* update

* fix conflict

* update sql explanation results

* fix groupByKeys bug

* update

* update

* update groupByKeys

* update engine configs

* add OTHERS error code

* refine ui: use sidebar

* fix conflict

* fix conflict

* fix bug

* fix imports

* allow users to choose openai llm

* update

* update prompt

* fix bug

* fix tests

* fix bug

* fix conflicts

* update prompt and fix bugs

* update

* fix bug

* fix

* fix engine as wren_ui

* remove unused dataset

* fix sql explanation

* fix groupByKey id

* update

* change EngineConfig location and update .env.dev.example

* give defaults to EngineConfig

* update

* fix
  • Loading branch information
cyyeh committed Jul 23, 2024
1 parent 3205d73 commit 5ab1a4d
Show file tree
Hide file tree
Showing 40 changed files with 2,377 additions and 509 deletions.
5 changes: 2 additions & 3 deletions wren-ai-service/.env.dev.example
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# app related
WREN_AI_SERVICE_HOST=127.0.0.1
WREN_AI_SERVICE_PORT=5556
WREN_ENGINE_ENDPOINT=http://localhost:8080
WREN_UI_ENDPOINT=http://localhost:3000

## LLM
LLM_PROVIDER=openai_llm # openai_llm, azure_openai_llm, ollama_llm
Expand Down Expand Up @@ -53,12 +51,13 @@ WREN_UI_ENDPOINT=http://localhost:3000
WREN_IBIS_ENDPOINT=http://localhost:8000
WREN_IBIS_SOURCE=bigquery
### this is a base64 encoded string of the MDL
WREN_IBIS_MANIFEST=
WREN_IBIS_MANIFEST=
### this is a base64 encode string of the connection info
WREN_IBIS_CONNECTION_INFO=

## when using wren_engine as the engine
WREN_ENGINE_ENDPOINT=http://localhost:8080
WREN_ENGINE_MANIFEST=

# Evaluation
DATASET_NAME=book_2
Expand Down
283 changes: 156 additions & 127 deletions wren-ai-service/demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,32 @@
import streamlit as st
from utils import (
DATA_SOURCES,
LLM_MODELS,
ask,
ask_details,
get_default_llm_model,
get_mdl_json,
get_new_mdl_json,
prepare_duckdb,
prepare_semantics,
rerun_wren_engine,
save_mdl_json_file,
show_asks_details_results,
show_asks_results,
show_er_diagram,
update_llm,
)

st.set_page_config(layout="wide")
st.title("Wren AI AI Service Demo")
st.title("Wren AI LLM Service Demo")

llm_model = get_default_llm_model(LLM_MODELS)

if "chosen_llm_model" not in st.session_state:
st.session_state["chosen_llm_model"] = llm_model
if "deployment_id" not in st.session_state:
st.session_state["deployment_id"] = str(uuid.uuid4())
if "chosen_dataset" not in st.session_state:
st.session_state["chosen_dataset"] = "music"
st.session_state["chosen_dataset"] = "ecommerce"
if "dataset_type" not in st.session_state:
st.session_state["dataset_type"] = "duckdb"
if "chosen_models" not in st.session_state:
Expand All @@ -47,141 +53,164 @@
st.session_state["preview_sql"] = None
if "query_history" not in st.session_state:
st.session_state["query_history"] = None
if "sql_explanation_question" not in st.session_state:
st.session_state["sql_explanation_question"] = None
if "sql_explanation_steps_with_analysis" not in st.session_state:
st.session_state["sql_explanation_steps_with_analysis"] = None
if "sql_analysis_results" not in st.session_state:
st.session_state["sql_analysis_results"] = None
if "sql_explanation_results" not in st.session_state:
st.session_state["sql_explanation_results"] = None
if "sql_user_corrections_by_step" not in st.session_state:
st.session_state["sql_user_corrections_by_step"] = []
if "sql_regeneration_results" not in st.session_state:
st.session_state["sql_regeneration_results"] = None


def onchange_demo_dataset():
st.session_state["chosen_dataset"] = st.session_state["choose_demo_dataset"]


if __name__ == "__main__":
col1, col2 = st.columns([2, 4])
def onchange_llm_model():
if (
st.session_state["llm_model_selectbox"]
and st.session_state["chosen_llm_model"]
!= st.session_state["llm_model_selectbox"]
):
st.session_state["chosen_llm_model"] = st.session_state["llm_model_selectbox"]

with col1:
with st.expander("Deploy New Model", expanded=True):
uploaded_file = st.file_uploader(
f"Upload an MDL json file, and the file name must be [xxx]_[datasource]_mdl.json, now we support these datasources: {DATA_SOURCES}",
type="json",
)
st.markdown("or")
chosen_demo_dataset = st.selectbox(
"Select a demo dataset",
key="choose_demo_dataset",
options=["music", "nba", "ecommerce"],
index=0,
on_change=onchange_demo_dataset,
)
update_llm(st.session_state["chosen_llm_model"], st.session_state["mdl_json"])

if uploaded_file is not None:
match = re.match(
r".+_(" + "|".join(DATA_SOURCES) + r")_mdl\.json$",
uploaded_file.name,
)
if not match:
st.error(
f"the file name must be [xxx]_[datasource]_mdl.json, now we support these datasources: {DATA_SOURCES}"
)
st.stop()

data_source = match.group(1)
st.session_state["chosen_dataset"] = uploaded_file.name.split(
f"_{data_source}_mdl.json"
)[0]
st.session_state["dataset_type"] = data_source
st.session_state["mdl_json"] = orjson.loads(
uploaded_file.getvalue().decode("utf-8")
)
save_mdl_json_file(uploaded_file.name, st.session_state["mdl_json"])
elif (
chosen_demo_dataset
and st.session_state["chosen_dataset"] == chosen_demo_dataset
):
st.session_state["chosen_dataset"] = chosen_demo_dataset
st.session_state["dataset_type"] = "duckdb"
st.session_state["mdl_json"] = get_mdl_json(chosen_demo_dataset)

st.markdown("---")

chosen_models = st.multiselect(
"Select data models for AI to generate MDL metadata",
[model["name"] for model in st.session_state["mdl_json"]["models"]],
)
if chosen_models and st.session_state["chosen_models"] != chosen_models:
st.session_state["chosen_models"] = chosen_models
st.session_state["mdl_json"] = get_mdl_json(
st.session_state["chosen_dataset"]
)

ai_generate_metadata_ok = st.button(
"AI Generate MDL Metadata",
disabled=not chosen_models,

st.selectbox(
"Select an OpenAI LLM model",
LLM_MODELS,
index=LLM_MODELS.index(llm_model),
key="llm_model_selectbox",
on_change=onchange_llm_model,
)

with st.sidebar:
st.markdown("## Deploy MDL Model")
uploaded_file = st.file_uploader(
f"Upload an MDL json file, and the file name must be [xxx]_[datasource]_mdl.json, now we support these datasources: {DATA_SOURCES}",
type="json",
)
st.markdown("or")
chosen_demo_dataset = st.selectbox(
"Select a demo dataset",
key="choose_demo_dataset",
options=[
"ecommerce",
"nba",
],
index=0,
on_change=onchange_demo_dataset,
)

if uploaded_file is not None:
match = re.match(
r".+_(" + "|".join(DATA_SOURCES) + r")_mdl\.json$",
uploaded_file.name,
)
if not match:
st.error(
f"the file name must be [xxx]_[datasource]_mdl.json, now we support these datasources: {DATA_SOURCES}"
)
if ai_generate_metadata_ok:
st.session_state["mdl_json"] = get_new_mdl_json(
chosen_models=chosen_models
)

# Display the model using the selected database
st.markdown("MDL Model")
st.json(
body=st.session_state["mdl_json"],
expanded=False,
st.stop()

data_source = match.group(1)
st.session_state["chosen_dataset"] = uploaded_file.name.split(
f"_{data_source}_mdl.json"
)[0]
st.session_state["dataset_type"] = data_source
st.session_state["mdl_json"] = orjson.loads(
uploaded_file.getvalue().decode("utf-8")
)
save_mdl_json_file(uploaded_file.name, st.session_state["mdl_json"])
elif (
chosen_demo_dataset
and st.session_state["chosen_dataset"] == chosen_demo_dataset
):
st.session_state["chosen_dataset"] = chosen_demo_dataset
st.session_state["dataset_type"] = "duckdb"
st.session_state["mdl_json"] = get_mdl_json(chosen_demo_dataset)

st.markdown("---")

if st.session_state["mdl_json"]:
chosen_models = st.multiselect(
"Select data models for AI to generate MDL metadata",
[model["name"] for model in st.session_state["mdl_json"]["models"]],
)
if chosen_models and st.session_state["chosen_models"] != chosen_models:
st.session_state["chosen_models"] = chosen_models
st.session_state["mdl_json"] = get_mdl_json(
st.session_state["chosen_dataset"]
)

show_er_diagram(
st.session_state["mdl_json"]["models"],
st.session_state["mdl_json"]["relationships"],
ai_generate_metadata_ok = st.button(
"AI Generate MDL Metadata",
disabled=not chosen_models,
)
if ai_generate_metadata_ok:
st.session_state["mdl_json"] = get_new_mdl_json(chosen_models=chosen_models)

# Display the model using the selected database
st.markdown("MDL Model")
st.json(
body=st.session_state["mdl_json"],
expanded=False,
)

show_er_diagram(
st.session_state["mdl_json"]["models"],
st.session_state["mdl_json"]["relationships"],
)

deploy_ok = st.button(
"Deploy",
use_container_width=True,
type="primary",
)
# Semantics preparation
if deploy_ok:
rerun_wren_engine(
st.session_state["mdl_json"],
st.session_state["dataset_type"],
st.session_state["chosen_dataset"],
)
prepare_semantics(st.session_state["mdl_json"])

deploy_ok = st.button(
"Deploy the MDL model using the selected database",
type="primary",
)
# Semantics preparation
if deploy_ok:
if st.session_state["dataset_type"] == "duckdb":
prepare_duckdb(st.session_state["chosen_dataset"])

rerun_wren_engine(
st.session_state["mdl_json"], st.session_state["dataset_type"]
)
prepare_semantics(st.session_state["mdl_json"])

query = st.chat_input(
"Ask a question about the database",
disabled=st.session_state["semantics_preparation_status"] != "finished",
)
query = st.chat_input(
"Ask a question about the database",
disabled=st.session_state["semantics_preparation_status"] != "finished",
)

if query:
if st.session_state["asks_results"] and st.session_state["asks_details_result"]:
st.session_state["query_history"] = {
"sql": st.session_state["chosen_query_result"]["sql"],
"summary": st.session_state["chosen_query_result"]["summary"],
"steps": st.session_state["asks_details_result"]["steps"],
}
else:
st.session_state["query_history"] = None

# reset relevant session_states
# st.session_state["query"] = None
st.session_state["asks_results"] = None
st.session_state["chosen_query_result"] = None
st.session_state["asks_details_result"] = None
st.session_state["preview_data_button_index"] = None
st.session_state["preview_sql"] = None

with col2:
if query:
if (
st.session_state["asks_results"]
and st.session_state["asks_details_result"]
):
st.session_state["query_history"] = {
"sql": st.session_state["chosen_query_result"]["sql"],
"summary": st.session_state["chosen_query_result"]["summary"],
"steps": st.session_state["asks_details_result"]["steps"],
}
else:
st.session_state["query_history"] = None

# reset relevant session_states
st.session_state["query"] = None
st.session_state["asks_results"] = None
st.session_state["chosen_query_result"] = None
st.session_state["asks_details_result"] = None
st.session_state["preview_data_button_index"] = None
st.session_state["preview_sql"] = None

ask(query, st.session_state["query_history"])
if st.session_state["asks_results"]:
show_asks_results()
if (
st.session_state["asks_details_result"]
and st.session_state["chosen_query_result"]
):
show_asks_details_results()
elif st.session_state["chosen_query_result"]:
ask_details()
if st.session_state["asks_details_result"]:
show_asks_details_results()
ask(query, st.session_state["query_history"])
if st.session_state["asks_results"]:
show_asks_results()
if st.session_state["asks_details_result"] and st.session_state["chosen_query_result"]:
show_asks_details_results(st.session_state["query"])
elif st.session_state["chosen_query_result"]:
ask_details()
if st.session_state["asks_details_result"]:
show_asks_details_results(st.session_state["query"])
Loading

0 comments on commit 5ab1a4d

Please sign in to comment.