diff --git a/src/preset_cli/api/clients/dbt.py b/src/preset_cli/api/clients/dbt.py index e3ecea5d..d214d45a 100644 --- a/src/preset_cli/api/clients/dbt.py +++ b/src/preset_cli/api/clients/dbt.py @@ -10,8 +10,9 @@ # pylint: disable=invalid-name, too-few-public-methods +import logging from enum import Enum -from typing import Any, Dict, List, Type, TypedDict +from typing import Any, Dict, List, Optional, Type, TypedDict from marshmallow import INCLUDE, Schema, fields from python_graphql_client import GraphqlClient @@ -20,6 +21,8 @@ from preset_cli import __version__ from preset_cli.auth.main import Auth +_logger = logging.getLogger(__name__) + REST_ENDPOINT = URL("https://cloud.getdbt.com/") GRAPHQL_ENDPOINT = URL("https://metadata.cloud.getdbt.com/graphql") @@ -575,16 +578,14 @@ class DBTClient: # pylint: disable=too-few-public-methods """ def __init__(self, auth: Auth): - self.auth = auth - self.auth.headers.update( - { - "User-Agent": "Preset CLI", - "X-Client-Version": __version__, - }, - ) self.graphql_client = GraphqlClient(endpoint=GRAPHQL_ENDPOINT) self.baseurl = REST_ENDPOINT + self.session = auth.get_session() + self.session.headers.update(auth.get_headers()) + self.session.headers["User-Agent"] = "Preset CLI" + self.session.headers["X-Client-Version"] = __version__ + def execute(self, query: str, **variables: Any) -> DataResponse: """ Run a GraphQL query. @@ -592,16 +593,16 @@ def execute(self, query: str, **variables: Any) -> DataResponse: return self.graphql_client.execute( query=query, variables=variables, - headers=self.auth.get_headers(), + headers=self.session.headers, ) def get_accounts(self) -> List[AccountSchema]: """ List all accounts. """ - session = self.auth.get_session() - headers = self.auth.get_headers() - response = session.get(self.baseurl / "api/v2/accounts/", headers=headers) + url = self.baseurl / "api/v2/accounts/" + _logger.debug("GET %s", url) + response = self.session.get(url) payload = response.json() @@ -612,12 +613,9 @@ def get_projects(self, account_id: int) -> List[ProjectSchema]: """ List all projects. """ - session = self.auth.get_session() - headers = self.auth.get_headers() - response = session.get( - self.baseurl / "api/v2/accounts" / str(account_id) / "projects/", - headers=headers, - ) + url = self.baseurl / "api/v2/accounts" / str(account_id) / "projects/" + _logger.debug("GET %s", url) + response = self.session.get(url) payload = response.json() @@ -626,16 +624,18 @@ def get_projects(self, account_id: int) -> List[ProjectSchema]: return projects - def get_jobs(self, account_id: int) -> List[JobSchema]: + def get_jobs( + self, + account_id: int, + project_id: Optional[int] = None, + ) -> List[JobSchema]: """ - List all jobs. + List all jobs, optionally for a project. """ - session = self.auth.get_session() - headers = self.auth.get_headers() - response = session.get( - self.baseurl / "api/v2/accounts" / str(account_id) / "jobs/", - headers=headers, - ) + url = self.baseurl / "api/v2/accounts" / str(account_id) / "jobs/" + params = {"project_id": project_id} if project_id is not None else {} + _logger.debug("GET %s", url % params) + response = self.session.get(url, params=params) payload = response.json() diff --git a/tests/api/clients/dbt_test.py b/tests/api/clients/dbt_test.py index 2cdf4fbc..fefb5ea9 100644 --- a/tests/api/clients/dbt_test.py +++ b/tests/api/clients/dbt_test.py @@ -2,7 +2,7 @@ Tests for the dbt client. """ -# pylint: disable=missing-class-docstring, invalid-name, line-too-long +# pylint: disable=missing-class-docstring, invalid-name, line-too-long, too-many-lines import datetime from enum import Enum @@ -60,10 +60,7 @@ def test_dbt_client_execute(mocker: MockerFixture) -> None: GraphqlClient().execute.assert_called_with( query=query, variables={"jobId": 1}, - headers={ - "User-Agent": "Preset CLI", - "X-Client-Version": __version__, - }, + headers=client.session.headers, ) @@ -768,6 +765,133 @@ def test_dbt_client_get_jobs(requests_mock: Mocker) -> None: ] +def test_dbt_client_get_jobs_for_project(requests_mock: Mocker) -> None: + """ + Test the ``get_jobs`` method when passing a project. + """ + requests_mock.get( + "https://cloud.getdbt.com/api/v2/accounts/72449/jobs/", + json={ + "status": { + "code": 200, + "is_success": True, + "user_message": "Success!", + "developer_message": "", + }, + "data": [ + { + "execution": {"timeout_seconds": 0}, + "generate_docs": True, + "run_generate_sources": False, + "id": 108380, + "account_id": 72449, + "project_id": 134905, + "environment_id": 107605, + "name": "Test job", + "dbt_version": "1.0.0", + "created_at": "2022-07-25T22:00:11.943460+00:00", + "updated_at": "2022-07-26T22:36:23.862370+00:00", + "execute_steps": ["dbt run", "dbt test"], + "state": 1, + "deactivated": False, + "run_failure_count": 0, + "deferring_job_definition_id": None, + "lifecycle_webhooks": False, + "lifecycle_webhooks_url": None, + "triggers": { + "github_webhook": False, + "git_provider_webhook": False, + "custom_branch_only": False, + "schedule": True, + }, + "settings": {"threads": 4, "target_name": "default"}, + "schedule": { + "cron": "0 * * * *", + "date": {"type": "every_day"}, + "time": {"type": "every_hour", "interval": 1}, + }, + "is_deferrable": False, + "generate_sources": False, + "cron_humanized": "Every hour", + "next_run": "2022-07-26T23:00:00+00:00", + "next_run_humanized": "2 weeks, 2 days", + }, + ], + "extra": { + "filters": {"limit": 100, "offset": 0, "account_id": 72449}, + "order_by": "id", + "pagination": {"count": 1, "total_count": 1}, + }, + }, + ) + auth = Auth() + client = DBTClient(auth) + assert client.get_jobs(72449, project_id=134905) == [ + { + "id": 108380, + "deactivated": False, + "triggers": { + "custom_branch_only": False, + "git_provider_webhook": False, + "github_webhook": False, + "schedule": True, + }, + "next_run_humanized": "2 weeks, 2 days", + "generate_sources": False, + "execution": {"timeout_seconds": 0}, + "environment_id": 107605, + "created_at": datetime.datetime( + 2022, + 7, + 25, + 22, + 0, + 11, + 943460, + tzinfo=datetime.timezone(datetime.timedelta(0), "+0000"), + ), + "account_id": 72449, + "state": 1, + "deferring_job_definition_id": None, + "generate_docs": True, + "cron_humanized": "Every hour", + "next_run": datetime.datetime( + 2022, + 7, + 26, + 23, + 0, + tzinfo=datetime.timezone(datetime.timedelta(0), "+0000"), + ), + "run_failure_count": 0, + "lifecycle_webhooks": False, + "settings": {"threads": 4, "target_name": "default"}, + "execute_steps": ["dbt run", "dbt test"], + "project_id": 134905, + "is_deferrable": False, + "schedule": { + "cron": "0 * * * *", + "time": {"interval": 1, "type": "every_hour"}, + "date": {"type": "every_day"}, + }, + "run_generate_sources": False, + "dbt_version": "1.0.0", + "updated_at": datetime.datetime( + 2022, + 7, + 26, + 22, + 36, + 23, + 862370, + tzinfo=datetime.timezone(datetime.timedelta(0), "+0000"), + ), + "lifecycle_webhooks_url": None, + "name": "Test job", + }, + ] + + def test_dbt_client_get_models(mocker: MockerFixture) -> None: """ Test the ``get_models`` method.