diff --git a/src/preset_cli/api/clients/dbt.py b/src/preset_cli/api/clients/dbt.py index f8e87b4b..5b9b25e8 100644 --- a/src/preset_cli/api/clients/dbt.py +++ b/src/preset_cli/api/clients/dbt.py @@ -589,6 +589,7 @@ class MetricSchema(PostelSchema): # dbt >= 1.3 calculation_method = fields.String() expression = fields.String() + dialect = fields.String() class MFMetricType(str, Enum): diff --git a/src/preset_cli/cli/superset/sync/dbt/command.py b/src/preset_cli/cli/superset/sync/dbt/command.py index 16378b23..cd4d5ea0 100644 --- a/src/preset_cli/cli/superset/sync/dbt/command.py +++ b/src/preset_cli/cli/superset/sync/dbt/command.py @@ -184,7 +184,8 @@ def dbt_core( # pylint: disable=too-many-arguments, too-many-branches, too-many with open(profiles, encoding="utf-8") as input_: config = yaml.safe_load(input_) - dialect = MFSQLEngine(config[project]["outputs"][target]["type"].upper()) + dialect = config[project]["outputs"][target]["type"] + mf_dialect = MFSQLEngine(dialect.upper()) model_schema = ModelSchema() models = [] @@ -215,8 +216,9 @@ def dbt_core( # pylint: disable=too-many-arguments, too-many-branches, too-many # conform to the same schema that dbt Cloud uses for metrics config["dependsOn"] = config.pop("depends_on")["nodes"] config["uniqueId"] = config.pop("unique_id") + config["dialect"] = dialect og_metrics.append(metric_schema.load(config)) - elif sl_metric := get_sl_metric(config, model_map, dialect): + elif sl_metric := get_sl_metric(config, model_map, mf_dialect): sl_metrics.append(sl_metric) superset_metrics = get_superset_metrics_per_model(og_metrics, sl_metrics) diff --git a/src/preset_cli/cli/superset/sync/dbt/metrics.py b/src/preset_cli/cli/superset/sync/dbt/metrics.py index a9cc160f..25a2e8c8 100644 --- a/src/preset_cli/cli/superset/sync/dbt/metrics.py +++ b/src/preset_cli/cli/superset/sync/dbt/metrics.py @@ -39,14 +39,14 @@ } -def get_metric_expression(unique_id: str, metrics: Dict[str, MetricSchema]) -> str: +def get_metric_expression(metric_name: str, metrics: Dict[str, MetricSchema]) -> str: """ Return a SQL expression for a given dbt metric using sqlglot. """ - if unique_id not in metrics: - raise Exception(f"Invalid metric {unique_id}") + if metric_name not in metrics: + raise Exception(f"Invalid metric {metric_name}") - metric = metrics[unique_id] + metric = metrics[metric_name] if "calculation_method" in metric: # dbt >= 1.3 type_ = metric["calculation_method"] @@ -77,13 +77,16 @@ def get_metric_expression(unique_id: str, metrics: Dict[str, MetricSchema]) -> s return f"COUNT(DISTINCT {sql})" if type_ in {"expression", "derived"}: - expression = sqlglot.parse_one(sql) + expression = sqlglot.parse_one(sql, dialect=metric["dialect"]) tokens = expression.find_all(exp.Column) for token in tokens: if token.sql() in metrics: parent_sql = get_metric_expression(token.sql(), metrics) - parent_expression = sqlglot.parse_one(parent_sql) + parent_expression = sqlglot.parse_one( + parent_sql, + dialect=metric["dialect"], + ) token.replace(parent_expression) return expression.sql() @@ -167,23 +170,22 @@ def get_metric_models(unique_id: str, metrics: List[MetricSchema]) -> Set[str]: def get_metric_definition( - unique_id: str, + metric_name: str, metrics: List[MetricSchema], ) -> SupersetMetricDefinition: """ Build a Superset metric definition from an OG (< 1.6) dbt metric. """ - metric_map = {metric["unique_id"]: metric for metric in metrics} - metric = metric_map[unique_id] - name = metric["name"] + metric_map = {metric["name"]: metric for metric in metrics} + metric = metric_map[metric_name] meta = metric.get("meta", {}) kwargs = meta.pop("superset", {}) return { - "expression": get_metric_expression(unique_id, metric_map), - "metric_name": name, + "expression": get_metric_expression(metric_name, metric_map), + "metric_name": metric_name, "metric_type": (metric.get("type") or metric.get("calculation_method")), - "verbose_name": metric.get("label", name), + "verbose_name": metric.get("label", metric_name), "description": metric.get("description", ""), "extra": json.dumps(meta), **kwargs, # type: ignore @@ -209,7 +211,7 @@ def get_superset_metrics_per_model( continue metric_definition = get_metric_definition( - metric["unique_id"], + metric["name"], og_metrics, ) model = metric_models.pop() diff --git a/tests/cli/superset/sync/dbt/metrics_test.py b/tests/cli/superset/sync/dbt/metrics_test.py index 813b3039..0268abbc 100644 --- a/tests/cli/superset/sync/dbt/metrics_test.py +++ b/tests/cli/superset/sync/dbt/metrics_test.py @@ -38,24 +38,28 @@ def test_get_metric_expression() -> None: {"field": "company_name", "operator": "!=", "value": "'Acme, Inc'"}, {"field": "signup_date", "operator": ">=", "value": "'2020-01-01'"}, ], + "dialect": "postgres", }, ), "two": metric_schema.load( { "type": "count_distinct", "sql": "user_id", + "dialect": "postgres", }, ), "three": metric_schema.load( { "type": "expression", "sql": "one - two", + "dialect": "postgres", }, ), "four": metric_schema.load( { "type": "hllsketch", "sql": "user_id", + "dialect": "postgres", }, ), "load_fill_by_weight": metric_schema.load( @@ -72,6 +76,7 @@ def test_get_metric_expression() -> None: "sql": "load_weight_lbs / load_weight_capacity_lbs", "type": "derived", "unique_id": "metric.breakthrough_dw.load_fill_by_weight", + "dialect": "postgres", }, ), } @@ -97,7 +102,7 @@ def test_get_metric_expression() -> None: get_metric_expression("four", metrics) assert str(excinfo.value) == ( "Unable to generate metric expression from: " - "{'sql': 'user_id', 'type': 'hllsketch'}" + "{'dialect': 'postgres', 'sql': 'user_id', 'type': 'hllsketch'}" ) with pytest.raises(Exception) as excinfo: @@ -136,6 +141,107 @@ def test_get_metric_expression_new_schema() -> None: ) +def test_get_metric_expression_derived_legacy() -> None: + """ + Test ``get_metric_expression`` with derived metrics created using a legacy dbt version. + """ + metric_schema = MetricSchema() + metrics: Dict[str, MetricSchema] = { + "revenue_verbose_name_from_dbt": metric_schema.load( + { + "name": "revenue_verbose_name_from_dbt", + "expression": "price_each", + "description": "revenue.", + "calculation_method": "sum", + "unique_id": "metric.postgres.revenue_verbose_name_from_dbt", + "label": "Sales Revenue Metric and this is the dbt label", + "depends_on": ["model.postgres.vehicle_sales"], + "metrics": [], + "created_at": 1701101973.269536, + "resource_type": "metric", + "fqn": ["postgres", "revenue_verbose_name_from_dbt"], + "model": "ref('vehicle_sales')", + "path": "schema.yml", + "package_name": "postgres", + "original_file_path": "models/schema.yml", + "refs": [{"name": "vehicle_sales", "package": None, "version": None}], + "time_grains": [], + "model_unique_id": None, + "dialect": "postgres", + }, + ), + "derived_metric": metric_schema.load( + { + "name": "derived_metric", + "expression": "revenue_verbose_name_from_dbt * 1.1", + "description": "", + "calculation_method": "derived", + "unique_id": "metric.postgres.derived_metric", + "label": "Dervied Metric", + "depends_on": ["metric.postgres.revenue_verbose_name_from_dbt"], + "metrics": [["revenue_verbose_name_from_dbt"]], + "created_at": 1704299520.144628, + "resource_type": "metric", + "fqn": ["postgres", "derived_metric"], + "model": None, + "path": "schema.yml", + "package_name": "bigquery", + "original_file_path": "models/schema.yml", + "refs": [], + "time_grains": [], + "model_unique_id": None, + "config": {"enabled": True, "group": None}, + "dialect": "bigquery", + }, + ), + "another_derived_metric": metric_schema.load( + { + "name": "another_derived_metric", + "expression": """ +SAFE_DIVIDE( + SUM( + IF( + `product_line` = "Classic Cars", + price_each * 0.80, + price_each * 0.70 + ) + ), + revenue_verbose_name_from_dbt + ) +""", + "description": "", + "dialect": "bigquery", + "calculation_method": "derived", + "unique_id": "metric.postgres.another_derived_metric", + "label": "Another Dervied Metric", + "depends_on": ["metric.postgres.revenue_verbose_name_from_dbt"], + "metrics": [["revenue_verbose_name_from_dbt"]], + "created_at": 1704299520.144628, + "resource_type": "metric", + "fqn": ["postgres", "derived_metric"], + "model": None, + "path": "schema.yml", + "package_name": "postgres", + "original_file_path": "models/schema.yml", + "refs": [], + "time_grains": [], + "model_unique_id": None, + "config": {"enabled": True, "group": None}, + }, + ), + } + unique_id = "derived_metric" + result = get_metric_expression(unique_id, metrics) + assert result == "SUM(price_each) * 1.1" + + unique_id = "another_derived_metric" + result = get_metric_expression(unique_id, metrics) + assert ( + result + == "SAFE_DIVIDE(SUM(CASE WHEN \"product_line\" = 'Classic Cars' THEN price_each * 0.80 ELSE price_each * 0.70 END), SUM(price_each))" + ) + + def test_get_metrics_for_model(mocker: MockerFixture) -> None: """ Test ``get_metrics_for_model``.