Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: infer dbt models from dataset #132

Merged
merged 2 commits into from
Oct 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[flake8]
ignore = E203, E266, E501, W503, F403, F401
ignore = E203, E266, E501, W503, F403, F401, W293
max-line-length = 79
max-complexity = 18
select = B,C,E,F,W,T4,B9
3 changes: 3 additions & 0 deletions src/preset_cli/api/clients/preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,7 @@ def change_workspace_role(
self.session.put(url, json=payload)

def get_base_url(self, version: Optional[str] = "v1") -> URL:
"""
Return the base URL for API calls.
"""
return self.baseurl / version
7 changes: 5 additions & 2 deletions src/preset_cli/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def is_help() -> bool:
@click.option("--loglevel", default="INFO")
@click.version_option()
@click.pass_context
def preset_cli( # pylint: disable=too-many-branches, too-many-locals, too-many-arguments
def preset_cli( # pylint: disable=too-many-branches, too-many-locals, too-many-arguments, too-many-statements
ctx: click.core.Context,
baseurl: str,
api_token: Optional[str],
Expand Down Expand Up @@ -153,7 +153,10 @@ def preset_cli( # pylint: disable=too-many-branches, too-many-locals, too-many-
api_token = input("API token: ")
api_secret = getpass.getpass("API secret: ")
store_credentials(
api_token, api_secret, manager_api_url, credentials_path
api_token,
api_secret,
manager_api_url,
credentials_path,
)

api_token = cast(str, api_token)
Expand Down
2 changes: 1 addition & 1 deletion src/preset_cli/cli/superset/sync/dbt/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def dbt_core( # pylint: disable=too-many-arguments, too-many-locals
)
if exposures:
exposures = os.path.expanduser(exposures)
sync_exposures(client, Path(exposures), datasets)
sync_exposures(client, Path(exposures), datasets, models)


def get_account_id(client: DBTClient) -> int:
Expand Down
50 changes: 41 additions & 9 deletions src/preset_cli/cli/superset/sync/dbt/exposures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,53 @@

import json
from pathlib import Path
from typing import Any, List
from typing import Any, Dict, List, NamedTuple, Optional

import yaml

from preset_cli.api.clients.dbt import ModelSchema
from preset_cli.api.clients.superset import SupersetClient

# XXX: DashboardResponseType and DatasetResponseType


def get_chart_depends_on(client: SupersetClient, chart: Any) -> List[str]:
class ModelKey(NamedTuple):
"""
Model key, so they can be mapped from datasets.
"""

schema: Optional[str]
table: str


def get_chart_depends_on(
client: SupersetClient,
chart: Any,
model_map: Dict[ModelKey, str],
) -> List[str]:
"""
Get all the dbt dependencies for a given chart.
"""

query_context = json.loads(chart["query_context"])
dataset_id = query_context["datasource"]["id"]
dataset = client.get_dataset(dataset_id)
extra = json.loads(dataset["result"]["extra"] or "{}")
dataset = client.get_dataset(dataset_id)["result"]
extra = json.loads(dataset["extra"] or "{}")
if "depends_on" in extra:
return [extra["depends_on"]]

key = ModelKey(dataset["schema"], dataset["table_name"])
if dataset["datasource_type"] == "table" and key in model_map:
return [model_map[key]]

return []


def get_dashboard_depends_on(client: SupersetClient, dashboard: Any) -> List[str]:
def get_dashboard_depends_on(
client: SupersetClient,
dashboard: Any,
model_map: Dict[ModelKey, str],
) -> List[str]:
"""
Get all the dbt dependencies for a given dashboard.
"""
Expand All @@ -44,13 +66,17 @@ def get_dashboard_depends_on(client: SupersetClient, dashboard: Any) -> List[str

depends_on = []
for dataset in payload["result"]:
full_dataset = client.get_dataset(int(dataset["id"]))
full_dataset = client.get_dataset(int(dataset["id"]))["result"]
try:
extra = json.loads(full_dataset["result"]["extra"] or "{}")
extra = json.loads(full_dataset["extra"] or "{}")
except json.decoder.JSONDecodeError:
extra = {}

key = ModelKey(full_dataset["schema"], full_dataset["table_name"])
if "depends_on" in extra:
depends_on.append(extra["depends_on"])
elif full_dataset["datasource_type"] == "table" and key in model_map:
depends_on.append(model_map[key])

return depends_on

Expand All @@ -59,6 +85,7 @@ def sync_exposures( # pylint: disable=too-many-locals
client: SupersetClient,
exposures_path: Path,
datasets: List[Any],
models: List[ModelSchema],
) -> None:
"""
Write dashboards back to dbt as exposures.
Expand All @@ -67,6 +94,11 @@ def sync_exposures( # pylint: disable=too-many-locals
charts_ids = set()
dashboards_ids = set()

model_map = {
ModelKey(model["schema"], model["name"]): f'ref({model["name"]})'
for model in models
}

for dataset in datasets:
url = client.baseurl / "api/v1/dataset" / str(dataset["id"]) / "related_objects"

Expand Down Expand Up @@ -94,7 +126,7 @@ def sync_exposures( # pylint: disable=too-many-locals
% {"form_data": json.dumps({"slice_id": chart_id})},
),
"description": chart["description"] or "",
"depends_on": get_chart_depends_on(client, chart),
"depends_on": get_chart_depends_on(client, chart, model_map),
"owner": {
"name": first_owner["first_name"] + " " + first_owner["last_name"],
"email": first_owner.get("email", "unknown"),
Expand All @@ -113,7 +145,7 @@ def sync_exposures( # pylint: disable=too-many-locals
else "low",
"url": str(client.baseurl / dashboard["url"].lstrip("/")),
"description": "",
"depends_on": get_dashboard_depends_on(client, dashboard),
"depends_on": get_dashboard_depends_on(client, dashboard, model_map),
"owner": {
"name": first_owner["first_name"] + " " + first_owner["last_name"],
"email": first_owner.get("email", "unknown"),
Expand Down
4 changes: 3 additions & 1 deletion tests/cli/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,9 @@ def test_cmd_handling_failed_creds(
)
assert result.exit_code == 1
get_access_token.assert_called_with(
URL("https://api.app.preset.io/"), "API_TOKEN", "API_SECRET"
URL("https://api.app.preset.io/"),
"API_TOKEN",
"API_SECRET",
)


Expand Down
4 changes: 2 additions & 2 deletions tests/cli/superset/sync/dbt/command_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_dbt_core(mocker: MockerFixture, fs: FakeFilesystem) -> None:
False,
"",
)
sync_exposures.assert_called_with(client, exposures, sync_datasets())
sync_exposures.assert_called_with(client, exposures, sync_datasets(), models)


def test_dbt_core_dbt_project(mocker: MockerFixture, fs: FakeFilesystem) -> None:
Expand Down Expand Up @@ -289,7 +289,7 @@ def test_dbt(mocker: MockerFixture, fs: FakeFilesystem) -> None:
False,
"",
)
sync_exposures.assert_called_with(client, exposures, sync_datasets())
sync_exposures.assert_called_with(client, exposures, sync_datasets(), models)


def test_dbt_core_no_exposures(mocker: MockerFixture, fs: FakeFilesystem) -> None:
Expand Down
58 changes: 50 additions & 8 deletions tests/cli/superset/sync/dbt/exposures_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from yarl import URL

from preset_cli.cli.superset.sync.dbt.exposures import (
ModelKey,
get_chart_depends_on,
get_dashboard_depends_on,
sync_exposures,
Expand Down Expand Up @@ -503,7 +504,7 @@ def test_get_dashboard_depends_on(mocker: MockerFixture) -> None:
session = client.auth.get_session()
session.get().json.return_value = datasets_response

depends_on = get_dashboard_depends_on(client, dashboard_response["result"])
depends_on = get_dashboard_depends_on(client, dashboard_response["result"], {})
assert depends_on == ["ref('messages_channels')"]


Expand All @@ -518,7 +519,7 @@ def test_get_dashboard_depends_on_no_extra(mocker: MockerFixture) -> None:
session = client.auth.get_session()
session.get().json.return_value = datasets_response

depends_on = get_dashboard_depends_on(client, dashboard_response["result"])
depends_on = get_dashboard_depends_on(client, dashboard_response["result"], {})
assert not depends_on


Expand All @@ -533,7 +534,7 @@ def test_get_dashboard_depends_on_invalid_extra(mocker: MockerFixture) -> None:
session = client.auth.get_session()
session.get().json.return_value = datasets_response

depends_on = get_dashboard_depends_on(client, dashboard_response["result"])
depends_on = get_dashboard_depends_on(client, dashboard_response["result"], {})
assert not depends_on


Expand All @@ -544,7 +545,7 @@ def test_get_chart_depends_on(mocker: MockerFixture) -> None:
client = mocker.MagicMock()
client.get_dataset.return_value = dataset_response

depends_on = get_chart_depends_on(client, chart_response["result"])
depends_on = get_chart_depends_on(client, chart_response["result"], {})
assert depends_on == ["ref('messages_channels')"]


Expand All @@ -557,7 +558,7 @@ def test_get_chart_depends_on_no_extra(mocker: MockerFixture) -> None:
modified_dataset_response["result"]["extra"] = None # type: ignore
client.get_dataset.return_value = modified_dataset_response

depends_on = get_chart_depends_on(client, chart_response["result"])
depends_on = get_chart_depends_on(client, chart_response["result"], {})
assert not depends_on


Expand Down Expand Up @@ -585,7 +586,7 @@ def test_sync_exposures(mocker: MockerFixture, fs: FakeFilesystem) -> None:
)

datasets = [dataset_response["result"]]
sync_exposures(client, exposures, datasets)
sync_exposures(client, exposures, datasets, [])

with open(exposures, encoding="utf-8") as input_:
contents = yaml.load(input_, Loader=yaml.SafeLoader)
Expand Down Expand Up @@ -622,7 +623,7 @@ def test_sync_exposures_no_charts_no_dashboards(
fs: FakeFilesystem,
) -> None:
"""
Test ``sync_exposures`` when no dashboads use the datasets.
Test ``sync_exposures`` when no dashboards use the datasets.
"""
root = Path("/path/to/root")
fs.create_dir(root / "models")
Expand All @@ -637,11 +638,52 @@ def test_sync_exposures_no_charts_no_dashboards(
session.get().json.return_value = no_related_objects_response

datasets = [dataset_response["result"]]
sync_exposures(client, exposures, datasets)
sync_exposures(client, exposures, datasets, [])

with open(exposures, encoding="utf-8") as input_:
contents = yaml.load(input_, Loader=yaml.SafeLoader)
assert contents == {
"version": 2,
"exposures": [],
}


def test_get_chart_depends_on_from_dataset(mocker: MockerFixture) -> None:
"""
Test ``sync_exposures`` when datasets don't have model metadata.

This is the case when users created the datasets manually, pointing them to dbt
models, but still want to sync exposures back to dbt.
"""
client = mocker.MagicMock()
modified_dataset_response = copy.deepcopy(dataset_response)
modified_dataset_response["result"]["extra"] = None # type: ignore
client.get_dataset.return_value = modified_dataset_response

key = ModelKey("public", "messages_channels")
depends_on = get_chart_depends_on(
client,
chart_response["result"],
{key: "ref(messages_channels)"},
)
assert depends_on == ["ref(messages_channels)"]


def test_get_dashboard_depends_on_from_dataset(mocker: MockerFixture) -> None:
"""
Test ``get_dashboard_depends_on`` when dataset don't have model metadata.
"""
client = mocker.MagicMock()
modified_dataset_response = copy.deepcopy(dataset_response)
modified_dataset_response["result"]["extra"] = None # type: ignore
client.get_dataset.return_value = modified_dataset_response
session = client.auth.get_session()
session.get().json.return_value = datasets_response

key = ModelKey("public", "messages_channels")
depends_on = get_dashboard_depends_on(
client,
dashboard_response["result"],
{key: "ref(messages_channels)"},
)
assert depends_on == ["ref(messages_channels)"]