Skip to content

Commit

Permalink
fix(dbt): Support syncing derived metrics on legacy dbt versions (#270)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vitor-Avila authored Mar 18, 2024
1 parent 73db090 commit 1a0e0ac
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 17 deletions.
1 change: 1 addition & 0 deletions src/preset_cli/api/clients/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ class MetricSchema(PostelSchema):
# dbt >= 1.3
calculation_method = fields.String()
expression = fields.String()
dialect = fields.String()


class MFMetricType(str, Enum):
Expand Down
6 changes: 4 additions & 2 deletions src/preset_cli/cli/superset/sync/dbt/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ def dbt_core( # pylint: disable=too-many-arguments, too-many-branches, too-many

with open(profiles, encoding="utf-8") as input_:
config = yaml.safe_load(input_)
dialect = MFSQLEngine(config[project]["outputs"][target]["type"].upper())
dialect = config[project]["outputs"][target]["type"]
mf_dialect = MFSQLEngine(dialect.upper())

model_schema = ModelSchema()
models = []
Expand Down Expand Up @@ -215,8 +216,9 @@ def dbt_core( # pylint: disable=too-many-arguments, too-many-branches, too-many
# conform to the same schema that dbt Cloud uses for metrics
config["dependsOn"] = config.pop("depends_on")["nodes"]
config["uniqueId"] = config.pop("unique_id")
config["dialect"] = dialect
og_metrics.append(metric_schema.load(config))
elif sl_metric := get_sl_metric(config, model_map, dialect):
elif sl_metric := get_sl_metric(config, model_map, mf_dialect):
sl_metrics.append(sl_metric)

superset_metrics = get_superset_metrics_per_model(og_metrics, sl_metrics)
Expand Down
30 changes: 16 additions & 14 deletions src/preset_cli/cli/superset/sync/dbt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@
}


def get_metric_expression(unique_id: str, metrics: Dict[str, MetricSchema]) -> str:
def get_metric_expression(metric_name: str, metrics: Dict[str, MetricSchema]) -> str:
"""
Return a SQL expression for a given dbt metric using sqlglot.
"""
if unique_id not in metrics:
raise Exception(f"Invalid metric {unique_id}")
if metric_name not in metrics:
raise Exception(f"Invalid metric {metric_name}")

metric = metrics[unique_id]
metric = metrics[metric_name]
if "calculation_method" in metric:
# dbt >= 1.3
type_ = metric["calculation_method"]
Expand Down Expand Up @@ -77,13 +77,16 @@ def get_metric_expression(unique_id: str, metrics: Dict[str, MetricSchema]) -> s
return f"COUNT(DISTINCT {sql})"

if type_ in {"expression", "derived"}:
expression = sqlglot.parse_one(sql)
expression = sqlglot.parse_one(sql, dialect=metric["dialect"])
tokens = expression.find_all(exp.Column)

for token in tokens:
if token.sql() in metrics:
parent_sql = get_metric_expression(token.sql(), metrics)
parent_expression = sqlglot.parse_one(parent_sql)
parent_expression = sqlglot.parse_one(
parent_sql,
dialect=metric["dialect"],
)
token.replace(parent_expression)

return expression.sql()
Expand Down Expand Up @@ -167,23 +170,22 @@ def get_metric_models(unique_id: str, metrics: List[MetricSchema]) -> Set[str]:


def get_metric_definition(
unique_id: str,
metric_name: str,
metrics: List[MetricSchema],
) -> SupersetMetricDefinition:
"""
Build a Superset metric definition from an OG (< 1.6) dbt metric.
"""
metric_map = {metric["unique_id"]: metric for metric in metrics}
metric = metric_map[unique_id]
name = metric["name"]
metric_map = {metric["name"]: metric for metric in metrics}
metric = metric_map[metric_name]
meta = metric.get("meta", {})
kwargs = meta.pop("superset", {})

return {
"expression": get_metric_expression(unique_id, metric_map),
"metric_name": name,
"expression": get_metric_expression(metric_name, metric_map),
"metric_name": metric_name,
"metric_type": (metric.get("type") or metric.get("calculation_method")),
"verbose_name": metric.get("label", name),
"verbose_name": metric.get("label", metric_name),
"description": metric.get("description", ""),
"extra": json.dumps(meta),
**kwargs, # type: ignore
Expand All @@ -209,7 +211,7 @@ def get_superset_metrics_per_model(
continue

metric_definition = get_metric_definition(
metric["unique_id"],
metric["name"],
og_metrics,
)
model = metric_models.pop()
Expand Down
108 changes: 107 additions & 1 deletion tests/cli/superset/sync/dbt/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,28 @@ def test_get_metric_expression() -> None:
{"field": "company_name", "operator": "!=", "value": "'Acme, Inc'"},
{"field": "signup_date", "operator": ">=", "value": "'2020-01-01'"},
],
"dialect": "postgres",
},
),
"two": metric_schema.load(
{
"type": "count_distinct",
"sql": "user_id",
"dialect": "postgres",
},
),
"three": metric_schema.load(
{
"type": "expression",
"sql": "one - two",
"dialect": "postgres",
},
),
"four": metric_schema.load(
{
"type": "hllsketch",
"sql": "user_id",
"dialect": "postgres",
},
),
"load_fill_by_weight": metric_schema.load(
Expand All @@ -72,6 +76,7 @@ def test_get_metric_expression() -> None:
"sql": "load_weight_lbs / load_weight_capacity_lbs",
"type": "derived",
"unique_id": "metric.breakthrough_dw.load_fill_by_weight",
"dialect": "postgres",
},
),
}
Expand All @@ -97,7 +102,7 @@ def test_get_metric_expression() -> None:
get_metric_expression("four", metrics)
assert str(excinfo.value) == (
"Unable to generate metric expression from: "
"{'sql': 'user_id', 'type': 'hllsketch'}"
"{'dialect': 'postgres', 'sql': 'user_id', 'type': 'hllsketch'}"
)

with pytest.raises(Exception) as excinfo:
Expand Down Expand Up @@ -136,6 +141,107 @@ def test_get_metric_expression_new_schema() -> None:
)


def test_get_metric_expression_derived_legacy() -> None:
"""
Test ``get_metric_expression`` with derived metrics created using a legacy dbt version.
"""
metric_schema = MetricSchema()
metrics: Dict[str, MetricSchema] = {
"revenue_verbose_name_from_dbt": metric_schema.load(
{
"name": "revenue_verbose_name_from_dbt",
"expression": "price_each",
"description": "revenue.",
"calculation_method": "sum",
"unique_id": "metric.postgres.revenue_verbose_name_from_dbt",
"label": "Sales Revenue Metric and this is the dbt label",
"depends_on": ["model.postgres.vehicle_sales"],
"metrics": [],
"created_at": 1701101973.269536,
"resource_type": "metric",
"fqn": ["postgres", "revenue_verbose_name_from_dbt"],
"model": "ref('vehicle_sales')",
"path": "schema.yml",
"package_name": "postgres",
"original_file_path": "models/schema.yml",
"refs": [{"name": "vehicle_sales", "package": None, "version": None}],
"time_grains": [],
"model_unique_id": None,
"dialect": "postgres",
},
),
"derived_metric": metric_schema.load(
{
"name": "derived_metric",
"expression": "revenue_verbose_name_from_dbt * 1.1",
"description": "",
"calculation_method": "derived",
"unique_id": "metric.postgres.derived_metric",
"label": "Dervied Metric",
"depends_on": ["metric.postgres.revenue_verbose_name_from_dbt"],
"metrics": [["revenue_verbose_name_from_dbt"]],
"created_at": 1704299520.144628,
"resource_type": "metric",
"fqn": ["postgres", "derived_metric"],
"model": None,
"path": "schema.yml",
"package_name": "bigquery",
"original_file_path": "models/schema.yml",
"refs": [],
"time_grains": [],
"model_unique_id": None,
"config": {"enabled": True, "group": None},
"dialect": "bigquery",
},
),
"another_derived_metric": metric_schema.load(
{
"name": "another_derived_metric",
"expression": """
SAFE_DIVIDE(
SUM(
IF(
`product_line` = "Classic Cars",
price_each * 0.80,
price_each * 0.70
)
),
revenue_verbose_name_from_dbt
)
""",
"description": "",
"dialect": "bigquery",
"calculation_method": "derived",
"unique_id": "metric.postgres.another_derived_metric",
"label": "Another Dervied Metric",
"depends_on": ["metric.postgres.revenue_verbose_name_from_dbt"],
"metrics": [["revenue_verbose_name_from_dbt"]],
"created_at": 1704299520.144628,
"resource_type": "metric",
"fqn": ["postgres", "derived_metric"],
"model": None,
"path": "schema.yml",
"package_name": "postgres",
"original_file_path": "models/schema.yml",
"refs": [],
"time_grains": [],
"model_unique_id": None,
"config": {"enabled": True, "group": None},
},
),
}
unique_id = "derived_metric"
result = get_metric_expression(unique_id, metrics)
assert result == "SUM(price_each) * 1.1"

unique_id = "another_derived_metric"
result = get_metric_expression(unique_id, metrics)
assert (
result
== "SAFE_DIVIDE(SUM(CASE WHEN \"product_line\" = 'Classic Cars' THEN price_each * 0.80 ELSE price_each * 0.70 END), SUM(price_each))"
)


def test_get_metrics_for_model(mocker: MockerFixture) -> None:
"""
Test ``get_metrics_for_model``.
Expand Down

0 comments on commit 1a0e0ac

Please sign in to comment.