diff --git a/src/preset_cli/api/clients/dbt.py b/src/preset_cli/api/clients/dbt.py index 5b9b25e8..a3b3059f 100644 --- a/src/preset_cli/api/clients/dbt.py +++ b/src/preset_cli/api/clients/dbt.py @@ -590,6 +590,7 @@ class MetricSchema(PostelSchema): calculation_method = fields.String() expression = fields.String() dialect = fields.String() + skip_parsing = fields.Boolean(allow_none=True) class MFMetricType(str, Enum): diff --git a/src/preset_cli/cli/superset/sync/dbt/metrics.py b/src/preset_cli/cli/superset/sync/dbt/metrics.py index 35ae1b0a..5b5ad269 100644 --- a/src/preset_cli/cli/superset/sync/dbt/metrics.py +++ b/src/preset_cli/cli/superset/sync/dbt/metrics.py @@ -8,11 +8,12 @@ import json import logging +import re from collections import defaultdict from typing import Dict, List, Optional, Set import sqlglot -from sqlglot import Expression, exp, parse_one +from sqlglot import Expression, ParseError, exp, parse_one from sqlglot.expressions import ( Alias, Case, @@ -49,6 +50,7 @@ } +# pylint: disable=too-many-locals def get_metric_expression(metric_name: str, metrics: Dict[str, MetricSchema]) -> str: """ Return a SQL expression for a given dbt metric using sqlglot. @@ -87,7 +89,19 @@ def get_metric_expression(metric_name: str, metrics: Dict[str, MetricSchema]) -> return f"COUNT(DISTINCT {sql})" if type_ in {"expression", "derived"}: - expression = sqlglot.parse_one(sql, dialect=metric["dialect"]) + try: + expression = sqlglot.parse_one(sql, dialect=metric["dialect"]) + except ParseError: + for parent_metric in metric["depends_on"]: + parent_metric_name = parent_metric.split(".")[-1] + pattern = r"\b" + re.escape(parent_metric_name) + r"\b" + parent_metric_syntax = get_metric_expression( + parent_metric_name, + metrics, + ) + sql = re.sub(pattern, parent_metric_syntax, sql) + return sql + tokens = expression.find_all(exp.Column) for token in tokens: @@ -192,7 +206,11 @@ def get_metric_definition( kwargs = meta.pop("superset", {}) return { - "expression": get_metric_expression(metric_name, metric_map), + "expression": ( + get_metric_expression(metric_name, metric_map) + if not metric.get("skip_parsing") + else metric.get("expression") or metric.get("sql") + ), "metric_name": metric_name, "metric_type": (metric.get("type") or metric.get("calculation_method")), "verbose_name": metric.get("label", metric_name), @@ -212,6 +230,20 @@ def get_superset_metrics_per_model( superset_metrics = defaultdict(list) for metric in og_metrics: metric_models = get_metric_models(metric["unique_id"], og_metrics) + + # dbt supports creating derived metrics with raw syntax + if len(metric_models) == 0: + try: + metric_models.add(metric["meta"]["superset"].pop("model")) + metric["skip_parsing"] = True + except KeyError: + _logger.warning( + "Metric %s cannot be calculated because it's not associated with any model." + " Please specify the model under metric.meta.superset.model.", + metric["name"], + ) + continue + if len(metric_models) != 1: _logger.warning( "Metric %s cannot be calculated because it depends on multiple models: %s", diff --git a/tests/cli/superset/sync/dbt/metrics_test.py b/tests/cli/superset/sync/dbt/metrics_test.py index fc91abab..9cc602a0 100644 --- a/tests/cli/superset/sync/dbt/metrics_test.py +++ b/tests/cli/superset/sync/dbt/metrics_test.py @@ -2,7 +2,7 @@ Tests for metrics. """ -# pylint: disable=line-too-long +# pylint: disable=line-too-long, too-many-lines from typing import Dict @@ -908,3 +908,134 @@ def test_get_superset_metrics_per_model() -> None: }, ], } + + +def test_get_superset_metrics_per_model_og_derived( + caplog: pytest.CaptureFixture[str], +) -> None: + """ + Tests for the ``get_superset_metrics_per_model`` function + with derived OG metrics. + """ + og_metric_schema = MetricSchema() + + og_metrics = [ + og_metric_schema.load( + { + "name": "sales", + "unique_id": "sales", + "depends_on": ["orders"], + "calculation_method": "sum", + "expression": "1", + }, + ), + og_metric_schema.load( + { + "name": "derived_metric_missing_model_info", + "unique_id": "derived_metric_missing_model_info", + "depends_on": [], + "calculation_method": "derived", + "expression": "price_each * 1.2", + }, + ), + og_metric_schema.load( + { + "name": "derived_metric_model_from_meta", + "unique_id": "derived_metric_model_from_meta", + "depends_on": [], + "calculation_method": "derived", + "expression": "(SUM(price_each)) * 1.2", + "meta": {"superset": {"model": "customers"}}, + }, + ), + og_metric_schema.load( + { + "name": "derived_metric_with_jinja", + "unique_id": "derived_metric_with_jinja", + "depends_on": [], + "calculation_method": "derived", + "expression": """ +SUM( + {% for x in filter_values('x_values') %} + {{ + x_values }} + {% endfor %} +) +""", + "meta": {"superset": {"model": "customers"}}, + }, + ), + og_metric_schema.load( + { + "name": "derived_metric_with_jinja_and_other_metric", + "unique_id": "derived_metric_with_jinja_and_other_metric", + "depends_on": ["sales"], + "dialect": "postgres", + "calculation_method": "derived", + "expression": """ +SUM( + {% for x in filter_values('x_values') %} + {{ my_sales + sales }} + {% endfor %} +) +""", + }, + ), + ] + + result = get_superset_metrics_per_model(og_metrics, []) + output_content = caplog.text + assert ( + "Metric derived_metric_missing_model_info cannot be calculated because it's not associated with any model" + in output_content + ) + + assert result == { + "customers": [ + { + "expression": "(SUM(price_each)) * 1.2", + "metric_name": "derived_metric_model_from_meta", + "metric_type": "derived", + "verbose_name": "derived_metric_model_from_meta", + "description": "", + "extra": "{}", + }, + { + "expression": """ +SUM( + {% for x in filter_values('x_values') %} + {{ + x_values }} + {% endfor %} +) +""", + "metric_name": "derived_metric_with_jinja", + "metric_type": "derived", + "verbose_name": "derived_metric_with_jinja", + "description": "", + "extra": "{}", + }, + ], + "orders": [ + { + "description": "", + "expression": "SUM(1)", + "extra": "{}", + "metric_name": "sales", + "metric_type": "sum", + "verbose_name": "sales", + }, + { + "expression": """ +SUM( + {% for x in filter_values('x_values') %} + {{ my_sales + SUM(1) }} + {% endfor %} +) +""", + "metric_name": "derived_metric_with_jinja_and_other_metric", + "metric_type": "derived", + "verbose_name": "derived_metric_with_jinja_and_other_metric", + "description": "", + "extra": "{}", + }, + ], + }