Skip to content

Commit

Permalink
fix (dbt): derived metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Nov 30, 2022
1 parent 1905869 commit 9f706a1
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/preset_cli/cli/superset/sync/dbt/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def dbt_core( # pylint: disable=too-many-arguments, too-many-locals
for config in configs["metrics"].values():
# conform to the same schema that dbt Cloud uses for metrics
config["dependsOn"] = config["depends_on"]["nodes"]
config["uniqueID"] = config["unique_id"]
config["uniqueId"] = config["unique_id"]
metrics.append(metric_schema.load(config, unknown=EXCLUDE))

try:
Expand Down
14 changes: 9 additions & 5 deletions src/preset_cli/cli/superset/sync/dbt/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from preset_cli.api.clients.dbt import MetricSchema, ModelSchema
from preset_cli.api.clients.superset import SupersetClient
from preset_cli.api.operators import OneToMany
from preset_cli.cli.superset.sync.dbt.metrics import get_metric_expression
from preset_cli.cli.superset.sync.dbt.metrics import (
get_metric_expression,
get_metrics_for_model,
)

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -109,9 +112,7 @@ def sync_datasets( # pylint: disable=too-many-locals, too-many-branches, too-ma

dataset_metrics = []
model_metrics = {
metric["name"]: metric
for metric in metrics
if model["unique_id"] in metric["depends_on"]
metric["name"]: metric for metric in get_metrics_for_model(model, metrics)
}
for name, metric in model_metrics.items():
meta = metric.get("meta", {})
Expand All @@ -120,7 +121,10 @@ def sync_datasets( # pylint: disable=too-many-locals, too-many-branches, too-ma
{
"expression": get_metric_expression(name, model_metrics),
"metric_name": name,
"metric_type": metric["type"],
"metric_type": (
metric.get("type") # dbt < 1.3
or metric.get("calculation_method") # dbt >= 1.3
),
"verbose_name": metric.get("label", name),
"description": metric.get("description", ""),
"extra": json.dumps(meta),
Expand Down
63 changes: 59 additions & 4 deletions src/preset_cli/cli/superset/sync/dbt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@

# pylint: disable=consider-using-f-string

import logging
from functools import partial
from typing import Dict, List

from jinja2 import Template

from preset_cli.api.clients.dbt import FilterSchema, MetricSchema
from preset_cli.api.clients.dbt import FilterSchema, MetricSchema, ModelSchema

_logger = logging.getLogger(__name__)


def get_metric_expression(metric_name: str, metrics: Dict[str, MetricSchema]) -> str:
Expand All @@ -22,8 +25,16 @@ def get_metric_expression(metric_name: str, metrics: Dict[str, MetricSchema]) ->
raise Exception(f"Invalid metric {metric_name}")

metric = metrics[metric_name]
type_ = metric["type"]
sql = metric["sql"]
if "calculation_method" in metric:
# dbt >= 1.3
type_ = metric["calculation_method"]
sql = metric["expression"]
expression = "derived"
else:
# dbt < 1.3
type_ = metric["type"]
sql = metric["sql"]
expression = "expression"

if metric.get("filters"):
sql = apply_filters(sql, metric["filters"])
Expand All @@ -43,7 +54,7 @@ def get_metric_expression(metric_name: str, metrics: Dict[str, MetricSchema]) ->
if type_ == "count_distinct":
return f"COUNT(DISTINCT {sql})"

if type_ == "expression":
if type_ == expression:
template = Template(sql)
return template.render(metric=partial(get_metric_expression, metrics=metrics))

Expand All @@ -59,3 +70,47 @@ def apply_filters(sql: str, filters: List[FilterSchema]) -> str:
"{field} {operator} {value}".format(**filter_) for filter_ in filters
)
return f"CASE WHEN {condition} THEN {sql} END"


def is_derived(metric: MetricSchema) -> bool:
"""
Return if the metric is derived.
"""
return (
metric.get("calculation_method") == "derived" # dbt >= 1.3
or metric.get("type") == "expression" # dbt < 1.3
)


def get_metrics_for_model(
model: ModelSchema,
metrics: List[MetricSchema],
) -> List[MetricSchema]:
"""
Given a list of metrics, return those that are based on a given model.
"""
metric_map = {metric["unique_id"]: metric for metric in metrics}
related_metrics = []

for metric in metrics:
parents = set()
queue = [metric]
while queue:
node = queue.pop()
depends_on = node["depends_on"]
if is_derived(node):
queue.extend(metric_map[parent] for parent in depends_on)
else:
parents.update(depends_on)

if len(parents) > 1:
_logger.warning(
"Metric %s cannot be calculated because it depends on multiple models",
metric["name"],
)
break

if model["unique_id"] == parents.pop():
related_metrics.append(metric)

return related_metrics
5 changes: 5 additions & 0 deletions tests/cli/superset/sync/dbt/command_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_dbt_core(mocker: MockerFixture, fs: FakeFilesystem) -> None:
"name": "cnt",
"sql": "*",
"type": "count",
"unique_id": "metric.superset_examples.cnt",
},
]
sync_datasets.assert_called_with(
Expand Down Expand Up @@ -284,6 +285,7 @@ def test_dbt(mocker: MockerFixture, fs: FakeFilesystem) -> None:
"name": "cnt",
"sql": "*",
"type": "count",
"unique_id": "metric.superset_examples.cnt",
},
]
sync_datasets.assert_called_with(
Expand Down Expand Up @@ -467,6 +469,7 @@ def test_dbt_cloud(mocker: MockerFixture) -> None:
"name": "cnt",
"sql": "*",
"type": "count",
"unique_id": "metric.superset_examples.cnt",
},
]
dbt_client.get_metrics.return_value = metrics
Expand Down Expand Up @@ -534,6 +537,7 @@ def test_dbt_cloud_no_job_id(mocker: MockerFixture) -> None:
"name": "cnt",
"sql": "*",
"type": "count",
"unique_id": "metric.superset_examples.cnt",
},
]
dbt_client.get_metrics.return_value = metrics
Expand Down Expand Up @@ -853,6 +857,7 @@ def test_dbt_cloud_exposures_only(mocker: MockerFixture, fs: FakeFilesystem) ->
"name": "cnt",
"sql": "*",
"type": "count",
"unique_id": "metric.superset_examples.cnt",
},
]
dbt_client.get_metrics.return_value = metrics
Expand Down
103 changes: 102 additions & 1 deletion tests/cli/superset/sync/dbt/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
from typing import Dict

import pytest
from pytest_mock import MockerFixture

from preset_cli.api.clients.dbt import MetricSchema
from preset_cli.cli.superset.sync.dbt.metrics import get_metric_expression
from preset_cli.cli.superset.sync.dbt.metrics import (
get_metric_expression,
get_metrics_for_model,
)


def test_get_metric_expression() -> None:
Expand Down Expand Up @@ -70,3 +74,100 @@ def test_get_metric_expression() -> None:
with pytest.raises(Exception) as excinfo:
get_metric_expression("five", metrics)
assert str(excinfo.value) == "Invalid metric five"


def test_get_metric_expression_new_schema() -> None:
"""
Test ``get_metric_expression`` with the dbt 1.3 schema.
See https://docs.getdbt.com/guides/migration/versions/upgrading-to-v1.3#for-users-of-dbt-metrics
"""
metric_schema = MetricSchema()
metrics: Dict[str, MetricSchema] = {
"one": metric_schema.load(
{
"calculation_method": "count",
"expression": "user_id",
"filters": [
{"field": "is_paying", "operator": "is", "value": "true"},
{"field": "lifetime_value", "operator": ">=", "value": "100"},
{"field": "company_name", "operator": "!=", "value": "'Acme, Inc'"},
{"field": "signup_date", "operator": ">=", "value": "'2020-01-01'"},
],
},
),
}
assert get_metric_expression("one", metrics) == (
"COUNT(CASE WHEN is_paying is true AND lifetime_value >= 100 AND "
"company_name != 'Acme, Inc' AND signup_date >= '2020-01-01' THEN user_id END)"
)


def test_get_metrics_for_model(mocker: MockerFixture) -> None:
"""
Test ``get_metrics_for_model``.
"""
_logger = mocker.patch("preset_cli.cli.superset.sync.dbt.metrics._logger")

metrics = [
{
"unique_id": "metric.superset.a",
"depends_on": ["model.superset.table"],
"name": "a",
},
{
"unique_id": "metric.superset.b",
"depends_on": ["model.superset.table"],
"name": "b",
},
{
"unique_id": "metric.superset.c",
"depends_on": ["model.superset.other_table"],
"name": "c",
},
{
"unique_id": "metric.superset.d",
"depends_on": ["metric.superset.a", "metric.superset.b"],
"name": "d",
"calculation_method": "derived",
},
{
"unique_id": "metric.superset.e",
"depends_on": ["metric.superset.a", "metric.superset.c"],
"name": "e",
"calculation_method": "derived",
},
]

model = {"unique_id": "model.superset.table"}
assert get_metrics_for_model(model, metrics) == [ # type: ignore
{
"unique_id": "metric.superset.a",
"depends_on": ["model.superset.table"],
"name": "a",
},
{
"unique_id": "metric.superset.b",
"depends_on": ["model.superset.table"],
"name": "b",
},
{
"unique_id": "metric.superset.d",
"depends_on": ["metric.superset.a", "metric.superset.b"],
"name": "d",
"calculation_method": "derived",
},
]
_logger.warning.assert_called_with(
"Metric %s cannot be calculated because it depends on multiple models",
"e",
)

model = {"unique_id": "model.superset.other_table"}
assert get_metrics_for_model(model, metrics) == [ # type: ignore
{
"unique_id": "metric.superset.c",
"depends_on": ["model.superset.other_table"],
"name": "c",
},
]

0 comments on commit 9f706a1

Please sign in to comment.