Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dbt): sync metrics from MetricFlow #256

Merged
merged 5 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 164 additions & 30 deletions src/preset_cli/api/clients/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
_logger = logging.getLogger(__name__)

REST_ENDPOINT = URL("https://cloud.getdbt.com/")
GRAPHQL_ENDPOINT = URL("https://metadata.cloud.getdbt.com/graphql")
METADATA_GRAPHQL_ENDPOINT = URL("https://metadata.cloud.getdbt.com/graphql")
SEMANTIC_LAYER_GRAPHQL_ENDPOINT = URL(
"https://semantic-layer.cloud.getdbt.com/api/graphql",
)


class PostelSchema(Schema):
Expand Down Expand Up @@ -472,7 +475,7 @@ class TimeSchema(PostelSchema):

class StringOrSchema(fields.Field):
"""
Dynamic schema constructor for fields that could have a string or another schema
Dynamic schema constructor for fields that could have a string or another schema.
"""

def __init__(self, nested_schema, *args, **kwargs):
Expand Down Expand Up @@ -587,6 +590,50 @@ class MetricSchema(PostelSchema):
expression = fields.String()


class MFMetricType(str, Enum):
"""
Type of the MetricFlow metric.
"""

SIMPLE = "SIMPLE"
RATIO = "RATIO"
CUMULATIVE = "CUMULATIVE"
DERIVED = "DERIVED"


class MFMetricSchema(PostelSchema):
"""
Schema for a MetricFlow metric.
"""

name = fields.String()
description = fields.String()
type = PostelEnumField(MFMetricType)


class MFSQLEngine(str, Enum):
"""
Databases supported by MetricFlow.
"""

BIGQUERY = "BIGQUERY"
DUCKDB = "DUCKDB"
REDSHIFT = "REDSHIFT"
POSTGRES = "POSTGRES"
SNOWFLAKE = "SNOWFLAKE"
DATABRICKS = "DATABRICKS"


class MFMetricWithSQLSchema(MFMetricSchema):
"""
MetricFlow metric with dialect and SQL, as well as model.
"""

sql = fields.String()
dialect = PostelEnumField(MFSQLEngine)
model = fields.String()


class DataResponse(TypedDict):
"""
Type for the GraphQL response.
Expand All @@ -602,7 +649,10 @@ class DBTClient: # pylint: disable=too-few-public-methods
"""

def __init__(self, auth: Auth):
self.graphql_client = GraphqlClient(endpoint=GRAPHQL_ENDPOINT)
self.metadata_graphql_client = GraphqlClient(endpoint=METADATA_GRAPHQL_ENDPOINT)
self.semantic_layer_graphql_client = GraphqlClient(
endpoint=SEMANTIC_LAYER_GRAPHQL_ENDPOINT,
)
self.baseurl = REST_ENDPOINT

self.session = auth.session
Expand All @@ -611,16 +661,6 @@ def __init__(self, auth: Auth):
self.session.headers["X-Client-Version"] = __version__
self.session.headers["X-dbt-partner-source"] = "preset"

def execute(self, query: str, **variables: Any) -> DataResponse:
"""
Run a GraphQL query.
"""
return self.graphql_client.execute(
query=query,
variables=variables,
headers=self.session.headers,
)

def get_accounts(self) -> List[AccountSchema]:
"""
List all accounts.
Expand Down Expand Up @@ -683,37 +723,46 @@ def get_models(self, job_id: int) -> List[ModelSchema]:
Fetch all available models.
"""
query = """
query ($jobId: Int!) {
models(jobId: $jobId) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is being deprecated, so I converted to the recommended query.

uniqueId
dependsOn
childrenL1
name
database
schema
description
meta
tags
columns {
query Models($jobId: BigInt!) {
job(id: $jobId) {
models {
uniqueId
dependsOn
childrenL1
name
database
schema
description
meta
tags
columns {
name
description
type
}
}
}
}
"""
payload = self.execute(query, jobId=job_id)
payload = self.metadata_graphql_client.execute(
query=query,
variables={"jobId": job_id},
headers=self.session.headers,
)

model_schema = ModelSchema()
models = [model_schema.load(model) for model in payload["data"]["models"]]
models = [
model_schema.load(model) for model in payload["data"]["job"]["models"]
]

return models

def get_og_metrics(self, job_id: int) -> List[Any]:
def get_og_metrics(self, job_id: int) -> List[MetricSchema]:
"""
Fetch all available metrics.
"""
query = """
query ($jobId: Int!) {
query GetMetrics($jobId: Int!) {
metrics(jobId: $jobId) {
uniqueId
name
Expand All @@ -731,13 +780,98 @@ def get_og_metrics(self, job_id: int) -> List[Any]:
}
}
"""
payload = self.execute(query, jobId=job_id)
payload = self.metadata_graphql_client.execute(
query=query,
variables={"jobId": job_id},
headers=self.session.headers,
)

metric_schema = MetricSchema()
metrics = [metric_schema.load(metric) for metric in payload["data"]["metrics"]]

return metrics

def get_sl_metrics(self, environment_id: int) -> List[MFMetricSchema]:
"""
Fetch all available metrics.
"""
query = """
query GetMetrics($environmentId: BigInt!) {
metrics(environmentId: $environmentId) {
name
description
type
}
}
"""
payload = self.semantic_layer_graphql_client.execute(
query=query,
variables={"environmentId": environment_id},
headers=self.session.headers,
)

metric_schema = MFMetricSchema()
metrics = [metric_schema.load(metric) for metric in payload["data"]["metrics"]]

return metrics

def get_sl_metric_sql(self, metric: str, environment_id: int) -> Optional[str]:
"""
Fetch metric SQL.

We fetch one metric at a time because if one metric fails to compile, the entire
query fails.
"""
query = """
mutation CompileSql($environmentId: BigInt!, $metricsInput: [MetricInput!]) {
compileSql(
environmentId: $environmentId
metrics: $metricsInput
groupBy: []
) {
sql
}
}
"""
payload = self.semantic_layer_graphql_client.execute(
query=query,
variables={
"environmentId": environment_id,
"metricsInput": [{"name": metric}],
},
headers=self.session.headers,
)

if payload["data"] is None:
errors = "\n\n".join(
error["message"] for error in payload.get("errors", [])
)
_logger.warning("Unable to convert metric %s: %s", metric, errors)
return None

return payload["data"]["compileSql"]["sql"]

def get_sl_dialect(self, environment_id: int) -> MFSQLEngine:
"""
Get the dialect used in the MetricFlow project.
"""
query = """
query GetEnvironmentInfo($environmentId: BigInt!) {
environmentInfo(environmentId: $environmentId) {
dialect
}
}
"""
payload = self.semantic_layer_graphql_client.execute(
query=query,
variables={"environmentId": environment_id},
headers=self.session.headers,
)

return MFSQLEngine(payload["data"]["environmentInfo"]["dialect"])

# def get_sl_metric_sql(self,

def get_database_name(self, job_id: int) -> str:
"""
Return the database name.
Expand Down
55 changes: 43 additions & 12 deletions src/preset_cli/cli/superset/sync/dbt/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,30 @@
import sys
import warnings
from pathlib import Path
from typing import Optional, Tuple
from typing import List, Optional, Tuple

import click
import yaml
from yarl import URL

from preset_cli.api.clients.dbt import DBTClient, JobSchema, MetricSchema, ModelSchema
from preset_cli.api.clients.dbt import (
DBTClient,
JobSchema,
MetricSchema,
MFMetricWithSQLSchema,
ModelSchema,
)
from preset_cli.api.clients.superset import SupersetClient
from preset_cli.auth.token import TokenAuth
from preset_cli.cli.superset.sync.dbt.databases import sync_database
from preset_cli.cli.superset.sync.dbt.datasets import sync_datasets
from preset_cli.cli.superset.sync.dbt.exposures import ModelKey, sync_exposures
from preset_cli.cli.superset.sync.dbt.lib import apply_select
from preset_cli.cli.superset.sync.dbt.metrics import get_superset_metrics_per_model
from preset_cli.cli.superset.sync.dbt.metrics import (
MultipleModelsError,
get_model_from_sql,
get_superset_metrics_per_model,
)
from preset_cli.exceptions import DatabaseNotFoundError


Expand Down Expand Up @@ -181,10 +191,7 @@ def dbt_core( # pylint: disable=too-many-arguments, too-many-branches, too-many
config["columns"] = list(config["columns"].values())
models.append(model_schema.load(config))
models = apply_select(models, select, exclude)
model_map = {
ModelKey(model["schema"], model["name"]): f"ref('{model['name']}')"
for model in models
}
model_map = {ModelKey(model["schema"], model["name"]): model for model in models}

if exposures_only:
datasets = [
Expand Down Expand Up @@ -439,13 +446,37 @@ def dbt_cloud( # pylint: disable=too-many-arguments, too-many-locals

models = dbt_client.get_models(job["id"])
models = apply_select(models, select, exclude)
model_map = {
ModelKey(model["schema"], model["name"]): f"ref('{model['name']}')"
for model in models
}
model_map = {ModelKey(model["schema"], model["name"]): model for model in models}

# original dbt <= 1.6 metrics
og_metrics = dbt_client.get_og_metrics(job["id"])
superset_metrics = get_superset_metrics_per_model(og_metrics)

# MetricFlow metrics
dialect = dbt_client.get_sl_dialect(job["environment_id"])
mf_metric_schema = MFMetricWithSQLSchema()
sl_metrics: List[MFMetricWithSQLSchema] = []
for metric in dbt_client.get_sl_metrics(job["environment_id"]):
sql = dbt_client.get_sl_metric_sql(metric["name"], job["environment_id"])
if sql is not None:
try:
model = get_model_from_sql(sql, dialect, model_map)
except MultipleModelsError:
continue

sl_metrics.append(
mf_metric_schema.load(
{
"name": metric["name"],
"type": metric["type"],
"description": metric["description"],
"sql": sql,
"dialect": dialect.value,
"model": model["unique_id"],
},
),
)

superset_metrics = get_superset_metrics_per_model(og_metrics, sl_metrics)

if exposures_only:
datasets = [
Expand Down
Loading
Loading