Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix (dbt): derived metrics #154

Merged
merged 4 commits into from
Dec 1, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/preset_cli/cli/superset/sync/dbt/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def dbt_core( # pylint: disable=too-many-arguments, too-many-locals
for config in configs["metrics"].values():
# conform to the same schema that dbt Cloud uses for metrics
config["dependsOn"] = config["depends_on"]["nodes"]
config["uniqueID"] = config["unique_id"]
config["uniqueId"] = config["unique_id"]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a bug, but since the metric unique ID was not needed anywhere it was not caught. With the changes in this PR we use the metric ID to traverse the DAG upstream to find parent models.

metrics.append(metric_schema.load(config, unknown=EXCLUDE))

try:
Expand Down
14 changes: 9 additions & 5 deletions src/preset_cli/cli/superset/sync/dbt/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from preset_cli.api.clients.dbt import MetricSchema, ModelSchema
from preset_cli.api.clients.superset import SupersetClient
from preset_cli.api.operators import OneToMany
from preset_cli.cli.superset.sync.dbt.metrics import get_metric_expression
from preset_cli.cli.superset.sync.dbt.metrics import (
get_metric_expression,
get_metrics_for_model,
)

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -109,9 +112,7 @@ def sync_datasets( # pylint: disable=too-many-locals, too-many-branches, too-ma

dataset_metrics = []
model_metrics = {
metric["name"]: metric
for metric in metrics
if model["unique_id"] in metric["depends_on"]
Comment on lines -112 to -114
Copy link
Member Author

@betodealmeida betodealmeida Nov 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was too naive, and didn't account for how derived metrics depend on other metrics, instead of depending on models. I replaced it with a smarter function get_metrics_for_model that traverses the DAG upstream.

metric["name"]: metric for metric in get_metrics_for_model(model, metrics)
}
for name, metric in model_metrics.items():
meta = metric.get("meta", {})
Expand All @@ -120,7 +121,10 @@ def sync_datasets( # pylint: disable=too-many-locals, too-many-branches, too-ma
{
"expression": get_metric_expression(name, model_metrics),
"metric_name": name,
"metric_type": metric["type"],
"metric_type": (
metric.get("type") # dbt < 1.3
or metric.get("calculation_method") # dbt >= 1.3
),
"verbose_name": metric.get("label", name),
"description": metric.get("description", ""),
"extra": json.dumps(meta),
Expand Down
63 changes: 59 additions & 4 deletions src/preset_cli/cli/superset/sync/dbt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@

# pylint: disable=consider-using-f-string

import logging
from functools import partial
from typing import Dict, List

from jinja2 import Template

from preset_cli.api.clients.dbt import FilterSchema, MetricSchema
from preset_cli.api.clients.dbt import FilterSchema, MetricSchema, ModelSchema

_logger = logging.getLogger(__name__)


def get_metric_expression(metric_name: str, metrics: Dict[str, MetricSchema]) -> str:
Expand All @@ -22,8 +25,16 @@ def get_metric_expression(metric_name: str, metrics: Dict[str, MetricSchema]) ->
raise Exception(f"Invalid metric {metric_name}")

metric = metrics[metric_name]
type_ = metric["type"]
sql = metric["sql"]
if "calculation_method" in metric:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: have it inline

legacy = "calculation_method" in metric
type_ = metric["calculation_method"]  if legacy else metric["type"]
sql = metric["expression"] if legacy else metric["sql"]
expression = "derived" if legacy else  "expression"

# dbt >= 1.3
type_ = metric["calculation_method"]
sql = metric["expression"]
expression = "derived"
else:
# dbt < 1.3
type_ = metric["type"]
sql = metric["sql"]
expression = "expression"
Comment on lines +28 to +37

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice


if metric.get("filters"):
sql = apply_filters(sql, metric["filters"])
Expand All @@ -43,7 +54,7 @@ def get_metric_expression(metric_name: str, metrics: Dict[str, MetricSchema]) ->
if type_ == "count_distinct":
return f"COUNT(DISTINCT {sql})"

if type_ == "expression":
if type_ == expression:
template = Template(sql)
return template.render(metric=partial(get_metric_expression, metrics=metrics))

Expand All @@ -59,3 +70,47 @@ def apply_filters(sql: str, filters: List[FilterSchema]) -> str:
"{field} {operator} {value}".format(**filter_) for filter_ in filters
)
return f"CASE WHEN {condition} THEN {sql} END"


def is_derived(metric: MetricSchema) -> bool:
"""
Return if the metric is derived.
"""
return (
metric.get("calculation_method") == "derived" # dbt >= 1.3
or metric.get("type") == "expression" # dbt < 1.3
)


def get_metrics_for_model(
model: ModelSchema,
metrics: List[MetricSchema],
) -> List[MetricSchema]:
"""
Given a list of metrics, return those that are based on a given model.
"""
metric_map = {metric["unique_id"]: metric for metric in metrics}
related_metrics = []

for metric in metrics:
parents = set()
queue = [metric]
while queue:
node = queue.pop()
depends_on = node["depends_on"]
if is_derived(node):
queue.extend(metric_map[parent] for parent in depends_on)
Comment on lines +101 to +102
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the crux of the PR: for derived metrics we need to look at the upstream metrics to find the parent models.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the metric list something from superset or something from the dbt manifest?

else:
parents.update(depends_on)

if len(parents) > 1:
_logger.warning(
"Metric %s cannot be calculated because it depends on multiple models",
metric["name"],
)
break

if model["unique_id"] == parents.pop():
related_metrics.append(metric)

return related_metrics
5 changes: 5 additions & 0 deletions tests/cli/superset/sync/dbt/command_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_dbt_core(mocker: MockerFixture, fs: FakeFilesystem) -> None:
"name": "cnt",
"sql": "*",
"type": "count",
"unique_id": "metric.superset_examples.cnt",
},
]
sync_datasets.assert_called_with(
Expand Down Expand Up @@ -284,6 +285,7 @@ def test_dbt(mocker: MockerFixture, fs: FakeFilesystem) -> None:
"name": "cnt",
"sql": "*",
"type": "count",
"unique_id": "metric.superset_examples.cnt",
},
]
sync_datasets.assert_called_with(
Expand Down Expand Up @@ -467,6 +469,7 @@ def test_dbt_cloud(mocker: MockerFixture) -> None:
"name": "cnt",
"sql": "*",
"type": "count",
"unique_id": "metric.superset_examples.cnt",
},
]
dbt_client.get_metrics.return_value = metrics
Expand Down Expand Up @@ -534,6 +537,7 @@ def test_dbt_cloud_no_job_id(mocker: MockerFixture) -> None:
"name": "cnt",
"sql": "*",
"type": "count",
"unique_id": "metric.superset_examples.cnt",
},
]
dbt_client.get_metrics.return_value = metrics
Expand Down Expand Up @@ -853,6 +857,7 @@ def test_dbt_cloud_exposures_only(mocker: MockerFixture, fs: FakeFilesystem) ->
"name": "cnt",
"sql": "*",
"type": "count",
"unique_id": "metric.superset_examples.cnt",
},
]
dbt_client.get_metrics.return_value = metrics
Expand Down
103 changes: 102 additions & 1 deletion tests/cli/superset/sync/dbt/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
from typing import Dict

import pytest
from pytest_mock import MockerFixture

from preset_cli.api.clients.dbt import MetricSchema
from preset_cli.cli.superset.sync.dbt.metrics import get_metric_expression
from preset_cli.cli.superset.sync.dbt.metrics import (
get_metric_expression,
get_metrics_for_model,
)


def test_get_metric_expression() -> None:
Expand Down Expand Up @@ -70,3 +74,100 @@ def test_get_metric_expression() -> None:
with pytest.raises(Exception) as excinfo:
get_metric_expression("five", metrics)
assert str(excinfo.value) == "Invalid metric five"


def test_get_metric_expression_new_schema() -> None:
"""
Test ``get_metric_expression`` with the dbt 1.3 schema.

See https://docs.getdbt.com/guides/migration/versions/upgrading-to-v1.3#for-users-of-dbt-metrics
"""
metric_schema = MetricSchema()
metrics: Dict[str, MetricSchema] = {
"one": metric_schema.load(
{
"calculation_method": "count",
"expression": "user_id",
"filters": [
{"field": "is_paying", "operator": "is", "value": "true"},
{"field": "lifetime_value", "operator": ">=", "value": "100"},
{"field": "company_name", "operator": "!=", "value": "'Acme, Inc'"},
{"field": "signup_date", "operator": ">=", "value": "'2020-01-01'"},
],
},
),
}
assert get_metric_expression("one", metrics) == (
"COUNT(CASE WHEN is_paying is true AND lifetime_value >= 100 AND "
"company_name != 'Acme, Inc' AND signup_date >= '2020-01-01' THEN user_id END)"
)


def test_get_metrics_for_model(mocker: MockerFixture) -> None:
"""
Test ``get_metrics_for_model``.
"""
_logger = mocker.patch("preset_cli.cli.superset.sync.dbt.metrics._logger")

metrics = [
{
"unique_id": "metric.superset.a",
"depends_on": ["model.superset.table"],
"name": "a",
},
{
"unique_id": "metric.superset.b",
"depends_on": ["model.superset.table"],
"name": "b",
},
{
"unique_id": "metric.superset.c",
"depends_on": ["model.superset.other_table"],
"name": "c",
},
{
"unique_id": "metric.superset.d",
"depends_on": ["metric.superset.a", "metric.superset.b"],
"name": "d",
"calculation_method": "derived",
},
{
"unique_id": "metric.superset.e",
"depends_on": ["metric.superset.a", "metric.superset.c"],
"name": "e",
"calculation_method": "derived",
},
]

model = {"unique_id": "model.superset.table"}
assert get_metrics_for_model(model, metrics) == [ # type: ignore
{
"unique_id": "metric.superset.a",
"depends_on": ["model.superset.table"],
"name": "a",
},
{
"unique_id": "metric.superset.b",
"depends_on": ["model.superset.table"],
"name": "b",
},
{
"unique_id": "metric.superset.d",
"depends_on": ["metric.superset.a", "metric.superset.b"],
"name": "d",
"calculation_method": "derived",
},
]
_logger.warning.assert_called_with(
"Metric %s cannot be calculated because it depends on multiple models",
"e",
)

model = {"unique_id": "model.superset.other_table"}
assert get_metrics_for_model(model, metrics) == [ # type: ignore
{
"unique_id": "metric.superset.c",
"depends_on": ["model.superset.other_table"],
"name": "c",
},
]