Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support datasets in other (sub) DBs #95

Merged
merged 1 commit into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 51 additions & 4 deletions src/preset_cli/api/clients/superset.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,17 @@ def run_query(
"""
Run a SQL query, returning a Pandas dataframe.
"""
payload = self._run_query(database_id, sql, schema, limit)

return pd.DataFrame(payload["data"])

def _run_query(
self,
database_id: int,
sql: str,
schema: Optional[str] = None,
limit: int = 1000,
) -> Dict[str, Any]:
url = self.baseurl / "superset/sql_json/"
data = {
"client_id": shortid()[:10],
Expand All @@ -254,7 +265,6 @@ def run_query(
}
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
self.session.headers.update(headers)

Expand All @@ -264,7 +274,7 @@ def run_query(

payload = response.json()

return pd.DataFrame(payload["data"])
return payload

def get_data( # pylint: disable=too-many-locals, too-many-arguments
self,
Expand Down Expand Up @@ -363,7 +373,6 @@ def get_data( # pylint: disable=too-many-locals, too-many-arguments

headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
self.session.headers.update(headers)

Expand Down Expand Up @@ -504,7 +513,45 @@ def create_dataset(self, **kwargs: Any) -> Any:
"""
Create a dataset.
"""
return self.create_resource("dataset", **kwargs)
if "sql" not in kwargs:
return self.create_resource("dataset", **kwargs)

# run query to determine columns types
payload = self._run_query(
database_id=kwargs["database"],
sql=kwargs["sql"],
schema=kwargs["schema"],
limit=1,
)

# now add the virtual dataset
columns = payload["columns"]
for column in columns:
column["column_name"] = column["name"]
column["groupby"] = True
if column["is_dttm"]:
column["type_generic"] = 2
elif column["type"].lower() == "string":
column["type_generic"] = 1
else:
column["type_generic"] = 0
payload = {
"sql": kwargs["sql"],
"dbId": kwargs["database"],
"schema": kwargs["schema"],
"datasourceName": kwargs["table_name"],
"columns": columns,
}
data = {"data": json.dumps(payload)}

url = self.baseurl / "superset/sqllab_viz/"
_logger.debug("POST %s\n%s", url, json.dumps(data, indent=4))
response = self.session.post(url, data=data)
validate_response(response)

payload = response.json()

return payload["data"]

def update_dataset(
self,
Expand Down
42 changes: 36 additions & 6 deletions src/preset_cli/cli/superset/sync/dbt/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

import json
import logging
from typing import Any, List
from typing import Any, Dict, List

from sqlalchemy.engine import create_engine
from sqlalchemy.engine.url import make_url
from yarl import URL

from preset_cli.api.clients.dbt import MetricSchema, ModelSchema
Expand All @@ -18,6 +20,38 @@
_logger = logging.getLogger(__name__)


def create_dataset(
client: SupersetClient,
database: Dict[str, Any],
model: ModelSchema,
) -> Dict[str, Any]:
"""
Create a physical or virtual dataset.

Virtual datasets are created when the table database is different from the main
database, for systems that support cross-database queries (Trino, BigQuery, etc.)
"""
url = make_url(database["sqlalchemy_uri"])
if model["database"] == url.database:
kwargs = {
"database": database["id"],
"schema": model["schema"],
"table_name": model["name"],
}
else:
engine = create_engine(url)
quote = engine.dialect.identifier_preparer.quote
source = ".".join(quote(model[key]) for key in ("database", "schema", "name"))
kwargs = {
"database": database["id"],
"schema": model["schema"],
"table_name": model["name"],
"sql": f"SELECT * FROM {source}",
}

return client.create_dataset(**kwargs)


def sync_datasets( # pylint: disable=too-many-locals, too-many-branches, too-many-arguments
client: SupersetClient,
models: List[ModelSchema],
Expand Down Expand Up @@ -49,11 +83,7 @@ def sync_datasets( # pylint: disable=too-many-locals, too-many-branches, too-ma
else:
_logger.info("Creating dataset %s", model["unique_id"])
try:
dataset = client.create_dataset(
database=database["id"],
schema=model["schema"],
table_name=model["name"],
)
dataset = create_dataset(client, database, model)
except Exception: # pylint: disable=broad-except
# Superset can't add tables from different BigQuery projects
continue
Expand Down
127 changes: 126 additions & 1 deletion tests/api/clients/superset_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""
Tests for ``preset_cli.api.clients.superset``.
"""
# pylint: disable=too-many-lines, trailing-whitespace
# pylint: disable=too-many-lines, trailing-whitespace, line-too-long, use-implicit-booleaness-not-comparison

import json
from io import BytesIO
from urllib.parse import unquote_plus
from uuid import UUID
from zipfile import ZipFile, is_zipfile

Expand Down Expand Up @@ -2653,3 +2654,127 @@ def test_update_role(requests_mock: Mocker) -> None:
requests_mock.last_request.text
== "name=Old+Role&user=2&permissions=1&permissions=2"
)


def test_create_virtual_dataset(requests_mock: Mocker) -> None:
"""
Test the ``create_dataset`` method with virtual datasets.
"""
requests_mock.post(
"https://superset.example.org/superset/sql_json/",
json={
"query_id": 137,
"status": "success",
"data": [
{"ID": 20, "FIRST_NAME": "Anna", "LAST_NAME": "A.", "ds": "2022-01-01"},
],
"columns": [
{"name": "ID", "type": "INTEGER", "is_dttm": False},
{"name": "FIRST_NAME", "type": "STRING", "is_dttm": False},
{"name": "LAST_NAME", "type": "STRING", "is_dttm": False},
{"name": "ds", "type": "DATETIME", "is_dttm": True},
],
"selected_columns": [
{"name": "ID", "type": "INTEGER", "is_dttm": False},
{"name": "FIRST_NAME", "type": "STRING", "is_dttm": False},
{"name": "LAST_NAME", "type": "STRING", "is_dttm": False},
{"name": "ds", "type": "DATETIME", "is_dttm": True},
],
"expanded_columns": [],
"query": {
"changedOn": "2022-10-04T00:54:22.174889",
"changed_on": "2022-10-04T00:54:22.174889",
"dbId": 6,
"db": "jaffle_shop_dev",
"endDttm": 1664844864497.491,
"errorMessage": None,
"executedSql": "-- 6dcd92a04feb50f14bbcf07c661680ba\nSELECT * FROM `dbt-tutorial`.jaffle_shop.customers LIMIT 2\n-- 6dcd92a04feb50f14bbcf07c661680ba",
"id": "eJfI9pxnh",
"queryId": 137,
"limit": 1,
"limitingFactor": "QUERY",
"progress": 100,
"rows": 1,
"schema": "dbt_beto",
"ctas": False,
"serverId": 137,
"sql": "SELECT * FROM `dbt-tutorial`.jaffle_shop.customers LIMIT 1;",
"sqlEditorId": "8",
"startDttm": 1664844861997.288000,
"state": "success",
"tab": "Query dbt_beto.customers",
"tempSchema": None,
"tempTable": None,
"userId": 2,
"user": "Beto Ferreira De Almeida",
"resultsKey": "313ec42b-3b76-40c7-8e90-31ed549174dd",
"trackingUrl": None,
"extra": {
"progress": None,
"columns": [
{"name": "ID", "type": "INTEGER", "is_dttm": False},
{"name": "FIRST_NAME", "type": "STRING", "is_dttm": False},
{"name": "LAST_NAME", "type": "STRING", "is_dttm": False},
{"name": "ds", "type": "DATETIME", "is_dttm": True},
],
},
},
},
)
requests_mock.post(
"https://superset.example.org/superset/sqllab_viz/",
json={"data": [1, 2, 3]},
)

auth = Auth()
client = SupersetClient("https://superset.example.org/", auth)

client.create_dataset(
database=1,
schema="public",
sql="SELECT * FROM `dbt-tutorial`.jaffle_shop.customers LIMIT 1;",
table_name="test virtual",
)

assert json.loads(
unquote_plus(requests_mock.last_request.text.split("=", 1)[1]),
) == {
"sql": "SELECT * FROM `dbt-tutorial`.jaffle_shop.customers LIMIT 1;",
"dbId": 1,
"schema": "public",
"datasourceName": "test virtual",
"columns": [
{
"name": "ID",
"type": "INTEGER",
"is_dttm": False,
"column_name": "ID",
"groupby": True,
"type_generic": 0,
},
{
"name": "FIRST_NAME",
"type": "STRING",
"is_dttm": False,
"column_name": "FIRST_NAME",
"groupby": True,
"type_generic": 1,
},
{
"name": "LAST_NAME",
"type": "STRING",
"is_dttm": False,
"column_name": "LAST_NAME",
"groupby": True,
"type_generic": 1,
},
{
"name": "ds",
"type": "DATETIME",
"is_dttm": True,
"column_name": "ds",
"groupby": True,
"type_generic": 2,
},
],
}
59 changes: 55 additions & 4 deletions tests/cli/superset/sync/dbt/datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pytest_mock import MockerFixture

from preset_cli.api.clients.dbt import MetricSchema, ModelSchema
from preset_cli.cli.superset.sync.dbt.datasets import sync_datasets
from preset_cli.cli.superset.sync.dbt.datasets import create_dataset, sync_datasets

metric_schema = MetricSchema()
metrics: List[MetricSchema] = [
Expand Down Expand Up @@ -58,7 +58,7 @@ def test_sync_datasets_new(mocker: MockerFixture) -> None:
client=client,
models=models,
metrics=metrics,
database={"id": 1},
database={"id": 1, "sqlalchemy_uri": "postgresql://user@host/examples_dev"},
disallow_edits=False,
external_url_prefix="",
)
Expand Down Expand Up @@ -122,7 +122,7 @@ def test_sync_datasets_no_metrics(mocker: MockerFixture) -> None:
client=client,
models=models,
metrics=[],
database={"id": 1},
database={"id": 1, "sqlalchemy_uri": "postgresql://user@host/examples_dev"},
disallow_edits=False,
external_url_prefix="",
)
Expand Down Expand Up @@ -177,7 +177,7 @@ def test_sync_datasets_new_bq_error(mocker: MockerFixture) -> None:
client=client,
models=models,
metrics=metrics,
database={"id": 1},
database={"id": 1, "sqlalchemy_uri": "postgresql://user@host/examples_dev"},
disallow_edits=False,
external_url_prefix="",
)
Expand Down Expand Up @@ -380,3 +380,54 @@ def test_sync_datasets_no_columns(mocker: MockerFixture) -> None:
),
],
)


def test_create_dataset_physical(mocker: MockerFixture) -> None:
"""
Test ``create_dataset`` for physical datasets.
"""
client = mocker.MagicMock()

create_dataset(
client,
{
"id": 1,
"schema": "public",
"name": "Database",
"sqlalchemy_uri": "postgresql://user@host/examples_dev",
},
models[0],
)
client.create_dataset.assert_called_with(
database=1,
schema="public",
table_name="messages_channels",
)


def test_create_dataset_virtual(mocker: MockerFixture) -> None:
"""
Test ``create_dataset`` for virtual datasets.
"""
create_engine = mocker.patch(
"preset_cli.cli.superset.sync.dbt.datasets.create_engine",
)
create_engine().dialect.identifier_preparer.quote = lambda token: token
client = mocker.MagicMock()

create_dataset(
client,
{
"id": 1,
"schema": "public",
"name": "Database",
"sqlalchemy_uri": "postgresql://user@host/examples",
},
models[0],
)
client.create_dataset.assert_called_with(
database=1,
schema="public",
table_name="messages_channels",
sql="SELECT * FROM examples_dev.public.messages_channels",
)