diff --git a/wren-ai-service/src/pipelines/sql_explanation/generation.py b/wren-ai-service/src/pipelines/sql_explanation/generation.py index dd1ba73bf..87a391b8e 100644 --- a/wren-ai-service/src/pipelines/sql_explanation/generation.py +++ b/wren-ai-service/src/pipelines/sql_explanation/generation.py @@ -135,36 +135,57 @@ def _collect_relations(relation, result, cte_names, top_level: bool = True): return results -def _compose_sql_expression_of_select_type(select_items: List[Dict]) -> Dict: +def _compose_sql_expression_of_select_type( + select_items: List[Dict], selected_data_sources: List[List[Dict]] +) -> Dict: + def _is_select_item_existed_in_selected_data_sources( + select_item, selected_data_sources + ): + for selected_data_source in selected_data_sources: + for data_source in selected_data_source: + if ( + "exprSources" in select_item + and select_item["exprSources"] + and select_item["exprSources"][0]["sourceDataset"] + == data_source["sourceDataset"] + and select_item["exprSources"][0]["sourceColumn"] + == data_source["sourceColumn"] + ): + return True + return False + result = { "withFunctionCallOrMathematicalOperation": [], "withoutFunctionCallOrMathematicalOperation": [], } for select_item in select_items: - if ( - select_item["properties"]["includeFunctionCall"] == "true" - or select_item["properties"]["includeMathematicalOperation"] == "true" + if not _is_select_item_existed_in_selected_data_sources( + select_item, selected_data_sources ): - result["withFunctionCallOrMathematicalOperation"].append( - { - "values": { - "alias": select_item["alias"], - "expression": select_item["expression"], - }, - "id": select_item.get("id", ""), - } - ) - else: - result["withoutFunctionCallOrMathematicalOperation"].append( - { - "values": { - "alias": select_item["alias"], - "expression": select_item["expression"], - }, - "id": select_item.get("id", ""), - } - ) + if ( + select_item["properties"]["includeFunctionCall"] == "true" + or select_item["properties"]["includeMathematicalOperation"] == "true" + ): + result["withFunctionCallOrMathematicalOperation"].append( + { + "values": { + "alias": select_item["alias"], + "expression": select_item["expression"], + }, + "id": select_item.get("id", ""), + } + ) + else: + result["withoutFunctionCallOrMathematicalOperation"].append( + { + "values": { + "alias": select_item["alias"], + "expression": select_item["expression"], + }, + "id": select_item.get("id", ""), + } + ) return result @@ -210,6 +231,7 @@ class SQLAnalysisPreprocessor: def run( self, cte_names: List[str], + selected_data_sources: List[List[Dict]], sql_analysis_results: List[Dict], ) -> Dict[str, List[Dict]]: preprocessed_sql_analysis_results = [] @@ -245,7 +267,7 @@ def run( preprocessed_sql_analysis_result[ "selectItems" ] = _compose_sql_expression_of_select_type( - sql_analysis_result["selectItems"] + sql_analysis_result["selectItems"], selected_data_sources ) else: preprocessed_sql_analysis_result["selectItems"] = { @@ -419,6 +441,7 @@ def run( @timer @observe(capture_input=False) def preprocess( + selected_data_sources: List[List[dict]], sql_analysis_results: List[dict], cte_names: List[str], pre_processor: SQLAnalysisPreprocessor, @@ -426,7 +449,7 @@ def preprocess( logger.debug( f"sql_analysis_results: {orjson.dumps(sql_analysis_results, option=orjson.OPT_INDENT_2).decode()}" ) - return pre_processor.run(cte_names, sql_analysis_results) + return pre_processor.run(cte_names, selected_data_sources, sql_analysis_results) @timer @@ -561,6 +584,7 @@ def visualize( self, question: str, cte_names: List[str], + selected_data_sources: List[List[dict]], step_with_analysis_results: StepWithAnalysisResults, ) -> None: destination = "outputs/pipelines/sql_explanation" @@ -573,6 +597,7 @@ def visualize( inputs={ "question": question, "cte_names": cte_names, + "selected_data_sources": selected_data_sources, "sql": step_with_analysis_results.sql, "sql_analysis_results": step_with_analysis_results.sql_analysis_results, "sql_summary": step_with_analysis_results.summary, @@ -591,6 +616,7 @@ async def run( self, question: str, cte_names: List[str], + selected_data_sources: List[List[dict]], step_with_analysis_results: StepWithAnalysisResults, ): logger.info("SQL Explanation Generation pipeline is running...") @@ -600,6 +626,7 @@ async def run( inputs={ "question": question, "cte_names": cte_names, + "selected_data_sources": selected_data_sources, "sql": step_with_analysis_results.sql, "sql_analysis_results": step_with_analysis_results.sql_analysis_results, "sql_summary": step_with_analysis_results.summary, diff --git a/wren-ai-service/src/web/v1/services/sql_explanation.py b/wren-ai-service/src/web/v1/services/sql_explanation.py index ca96fac09..862757d09 100644 --- a/wren-ai-service/src/web/v1/services/sql_explanation.py +++ b/wren-ai-service/src/web/v1/services/sql_explanation.py @@ -75,11 +75,14 @@ async def sql_explanation(self, sql_explanation_request: SQLExplanationRequest): async def _task( question: str, cte_names: List[str], + selected_data_sources: List[List[dict]], step_with_analysis_results: StepWithAnalysisResults, + i: int, ): return await self._pipelines["generation"].run( question=question, cte_names=cte_names, + selected_data_sources=selected_data_sources[:i], step_with_analysis_results=step_with_analysis_results, ) @@ -87,13 +90,26 @@ async def _task( step_with_analysis_results.cte_name for step_with_analysis_results in sql_explanation_request.steps_with_analysis_results ] + selected_data_sources = [ + [ + select_item["exprSources"][0] + for analysis_result in step_with_analysis_results.sql_analysis_results + for select_item in analysis_result.get("selectItems", []) + if select_item.get("exprSources", []) + ] + for step_with_analysis_results in sql_explanation_request.steps_with_analysis_results + ] tasks = [ _task( sql_explanation_request.question, cte_names, + selected_data_sources, step_with_analysis_results, + i, + ) + for i, step_with_analysis_results in enumerate( + sql_explanation_request.steps_with_analysis_results ) - for step_with_analysis_results in sql_explanation_request.steps_with_analysis_results ] generation_results = await asyncio.gather(*tasks)