From 9f706a19818ae40eb5c1a927dcbd91805825727e Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 29 Nov 2022 16:11:53 -0800 Subject: [PATCH] fix (dbt): derived metrics --- .../cli/superset/sync/dbt/command.py | 2 +- .../cli/superset/sync/dbt/datasets.py | 14 ++- .../cli/superset/sync/dbt/metrics.py | 63 ++++++++++- tests/cli/superset/sync/dbt/command_test.py | 5 + tests/cli/superset/sync/dbt/metrics_test.py | 103 +++++++++++++++++- 5 files changed, 176 insertions(+), 11 deletions(-) diff --git a/src/preset_cli/cli/superset/sync/dbt/command.py b/src/preset_cli/cli/superset/sync/dbt/command.py index 92604999..aa90a06a 100644 --- a/src/preset_cli/cli/superset/sync/dbt/command.py +++ b/src/preset_cli/cli/superset/sync/dbt/command.py @@ -150,7 +150,7 @@ def dbt_core( # pylint: disable=too-many-arguments, too-many-locals for config in configs["metrics"].values(): # conform to the same schema that dbt Cloud uses for metrics config["dependsOn"] = config["depends_on"]["nodes"] - config["uniqueID"] = config["unique_id"] + config["uniqueId"] = config["unique_id"] metrics.append(metric_schema.load(config, unknown=EXCLUDE)) try: diff --git a/src/preset_cli/cli/superset/sync/dbt/datasets.py b/src/preset_cli/cli/superset/sync/dbt/datasets.py index d9485871..e8ba1c47 100644 --- a/src/preset_cli/cli/superset/sync/dbt/datasets.py +++ b/src/preset_cli/cli/superset/sync/dbt/datasets.py @@ -16,7 +16,10 @@ from preset_cli.api.clients.dbt import MetricSchema, ModelSchema from preset_cli.api.clients.superset import SupersetClient from preset_cli.api.operators import OneToMany -from preset_cli.cli.superset.sync.dbt.metrics import get_metric_expression +from preset_cli.cli.superset.sync.dbt.metrics import ( + get_metric_expression, + get_metrics_for_model, +) _logger = logging.getLogger(__name__) @@ -109,9 +112,7 @@ def sync_datasets( # pylint: disable=too-many-locals, too-many-branches, too-ma dataset_metrics = [] model_metrics = { - metric["name"]: metric - for metric in metrics - if model["unique_id"] in metric["depends_on"] + metric["name"]: metric for metric in get_metrics_for_model(model, metrics) } for name, metric in model_metrics.items(): meta = metric.get("meta", {}) @@ -120,7 +121,10 @@ def sync_datasets( # pylint: disable=too-many-locals, too-many-branches, too-ma { "expression": get_metric_expression(name, model_metrics), "metric_name": name, - "metric_type": metric["type"], + "metric_type": ( + metric.get("type") # dbt < 1.3 + or metric.get("calculation_method") # dbt >= 1.3 + ), "verbose_name": metric.get("label", name), "description": metric.get("description", ""), "extra": json.dumps(meta), diff --git a/src/preset_cli/cli/superset/sync/dbt/metrics.py b/src/preset_cli/cli/superset/sync/dbt/metrics.py index ea4304bc..da41bf13 100644 --- a/src/preset_cli/cli/superset/sync/dbt/metrics.py +++ b/src/preset_cli/cli/superset/sync/dbt/metrics.py @@ -6,12 +6,15 @@ # pylint: disable=consider-using-f-string +import logging from functools import partial from typing import Dict, List from jinja2 import Template -from preset_cli.api.clients.dbt import FilterSchema, MetricSchema +from preset_cli.api.clients.dbt import FilterSchema, MetricSchema, ModelSchema + +_logger = logging.getLogger(__name__) def get_metric_expression(metric_name: str, metrics: Dict[str, MetricSchema]) -> str: @@ -22,8 +25,16 @@ def get_metric_expression(metric_name: str, metrics: Dict[str, MetricSchema]) -> raise Exception(f"Invalid metric {metric_name}") metric = metrics[metric_name] - type_ = metric["type"] - sql = metric["sql"] + if "calculation_method" in metric: + # dbt >= 1.3 + type_ = metric["calculation_method"] + sql = metric["expression"] + expression = "derived" + else: + # dbt < 1.3 + type_ = metric["type"] + sql = metric["sql"] + expression = "expression" if metric.get("filters"): sql = apply_filters(sql, metric["filters"]) @@ -43,7 +54,7 @@ def get_metric_expression(metric_name: str, metrics: Dict[str, MetricSchema]) -> if type_ == "count_distinct": return f"COUNT(DISTINCT {sql})" - if type_ == "expression": + if type_ == expression: template = Template(sql) return template.render(metric=partial(get_metric_expression, metrics=metrics)) @@ -59,3 +70,47 @@ def apply_filters(sql: str, filters: List[FilterSchema]) -> str: "{field} {operator} {value}".format(**filter_) for filter_ in filters ) return f"CASE WHEN {condition} THEN {sql} END" + + +def is_derived(metric: MetricSchema) -> bool: + """ + Return if the metric is derived. + """ + return ( + metric.get("calculation_method") == "derived" # dbt >= 1.3 + or metric.get("type") == "expression" # dbt < 1.3 + ) + + +def get_metrics_for_model( + model: ModelSchema, + metrics: List[MetricSchema], +) -> List[MetricSchema]: + """ + Given a list of metrics, return those that are based on a given model. + """ + metric_map = {metric["unique_id"]: metric for metric in metrics} + related_metrics = [] + + for metric in metrics: + parents = set() + queue = [metric] + while queue: + node = queue.pop() + depends_on = node["depends_on"] + if is_derived(node): + queue.extend(metric_map[parent] for parent in depends_on) + else: + parents.update(depends_on) + + if len(parents) > 1: + _logger.warning( + "Metric %s cannot be calculated because it depends on multiple models", + metric["name"], + ) + break + + if model["unique_id"] == parents.pop(): + related_metrics.append(metric) + + return related_metrics diff --git a/tests/cli/superset/sync/dbt/command_test.py b/tests/cli/superset/sync/dbt/command_test.py index 0da712a4..4b70abd3 100644 --- a/tests/cli/superset/sync/dbt/command_test.py +++ b/tests/cli/superset/sync/dbt/command_test.py @@ -101,6 +101,7 @@ def test_dbt_core(mocker: MockerFixture, fs: FakeFilesystem) -> None: "name": "cnt", "sql": "*", "type": "count", + "unique_id": "metric.superset_examples.cnt", }, ] sync_datasets.assert_called_with( @@ -284,6 +285,7 @@ def test_dbt(mocker: MockerFixture, fs: FakeFilesystem) -> None: "name": "cnt", "sql": "*", "type": "count", + "unique_id": "metric.superset_examples.cnt", }, ] sync_datasets.assert_called_with( @@ -467,6 +469,7 @@ def test_dbt_cloud(mocker: MockerFixture) -> None: "name": "cnt", "sql": "*", "type": "count", + "unique_id": "metric.superset_examples.cnt", }, ] dbt_client.get_metrics.return_value = metrics @@ -534,6 +537,7 @@ def test_dbt_cloud_no_job_id(mocker: MockerFixture) -> None: "name": "cnt", "sql": "*", "type": "count", + "unique_id": "metric.superset_examples.cnt", }, ] dbt_client.get_metrics.return_value = metrics @@ -853,6 +857,7 @@ def test_dbt_cloud_exposures_only(mocker: MockerFixture, fs: FakeFilesystem) -> "name": "cnt", "sql": "*", "type": "count", + "unique_id": "metric.superset_examples.cnt", }, ] dbt_client.get_metrics.return_value = metrics diff --git a/tests/cli/superset/sync/dbt/metrics_test.py b/tests/cli/superset/sync/dbt/metrics_test.py index b875b21b..4e4e781c 100644 --- a/tests/cli/superset/sync/dbt/metrics_test.py +++ b/tests/cli/superset/sync/dbt/metrics_test.py @@ -5,9 +5,13 @@ from typing import Dict import pytest +from pytest_mock import MockerFixture from preset_cli.api.clients.dbt import MetricSchema -from preset_cli.cli.superset.sync.dbt.metrics import get_metric_expression +from preset_cli.cli.superset.sync.dbt.metrics import ( + get_metric_expression, + get_metrics_for_model, +) def test_get_metric_expression() -> None: @@ -70,3 +74,100 @@ def test_get_metric_expression() -> None: with pytest.raises(Exception) as excinfo: get_metric_expression("five", metrics) assert str(excinfo.value) == "Invalid metric five" + + +def test_get_metric_expression_new_schema() -> None: + """ + Test ``get_metric_expression`` with the dbt 1.3 schema. + + See https://docs.getdbt.com/guides/migration/versions/upgrading-to-v1.3#for-users-of-dbt-metrics + """ + metric_schema = MetricSchema() + metrics: Dict[str, MetricSchema] = { + "one": metric_schema.load( + { + "calculation_method": "count", + "expression": "user_id", + "filters": [ + {"field": "is_paying", "operator": "is", "value": "true"}, + {"field": "lifetime_value", "operator": ">=", "value": "100"}, + {"field": "company_name", "operator": "!=", "value": "'Acme, Inc'"}, + {"field": "signup_date", "operator": ">=", "value": "'2020-01-01'"}, + ], + }, + ), + } + assert get_metric_expression("one", metrics) == ( + "COUNT(CASE WHEN is_paying is true AND lifetime_value >= 100 AND " + "company_name != 'Acme, Inc' AND signup_date >= '2020-01-01' THEN user_id END)" + ) + + +def test_get_metrics_for_model(mocker: MockerFixture) -> None: + """ + Test ``get_metrics_for_model``. + """ + _logger = mocker.patch("preset_cli.cli.superset.sync.dbt.metrics._logger") + + metrics = [ + { + "unique_id": "metric.superset.a", + "depends_on": ["model.superset.table"], + "name": "a", + }, + { + "unique_id": "metric.superset.b", + "depends_on": ["model.superset.table"], + "name": "b", + }, + { + "unique_id": "metric.superset.c", + "depends_on": ["model.superset.other_table"], + "name": "c", + }, + { + "unique_id": "metric.superset.d", + "depends_on": ["metric.superset.a", "metric.superset.b"], + "name": "d", + "calculation_method": "derived", + }, + { + "unique_id": "metric.superset.e", + "depends_on": ["metric.superset.a", "metric.superset.c"], + "name": "e", + "calculation_method": "derived", + }, + ] + + model = {"unique_id": "model.superset.table"} + assert get_metrics_for_model(model, metrics) == [ # type: ignore + { + "unique_id": "metric.superset.a", + "depends_on": ["model.superset.table"], + "name": "a", + }, + { + "unique_id": "metric.superset.b", + "depends_on": ["model.superset.table"], + "name": "b", + }, + { + "unique_id": "metric.superset.d", + "depends_on": ["metric.superset.a", "metric.superset.b"], + "name": "d", + "calculation_method": "derived", + }, + ] + _logger.warning.assert_called_with( + "Metric %s cannot be calculated because it depends on multiple models", + "e", + ) + + model = {"unique_id": "model.superset.other_table"} + assert get_metrics_for_model(model, metrics) == [ # type: ignore + { + "unique_id": "metric.superset.c", + "depends_on": ["model.superset.other_table"], + "name": "c", + }, + ]