From 2e08f5c0b760ad70a281883693812303b4390084 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Thu, 22 Sep 2022 18:14:35 -0700 Subject: [PATCH] chore: clean up session object --- src/preset_cli/api/clients/preset.py | 42 +++------ src/preset_cli/api/clients/superset.py | 99 +++++--------------- tests/api/clients/superset_test.py | 121 +++++++++++++------------ 3 files changed, 99 insertions(+), 163 deletions(-) diff --git a/src/preset_cli/api/clients/preset.py b/src/preset_cli/api/clients/preset.py index 8b0ac6b4..8d811a22 100644 --- a/src/preset_cli/api/clients/preset.py +++ b/src/preset_cli/api/clients/preset.py @@ -33,20 +33,17 @@ def __init__(self, baseurl: Union[str, URL], auth: Auth): # convert to URL if necessary self.baseurl = URL(baseurl) self.auth = auth - self.auth.headers.update( - { - "User-Agent": "Preset CLI", - "X-Client-Version": __version__, - }, - ) + + self.session = auth.get_session() + self.session.headers.update(auth.get_headers()) + self.session.headers["User-Agent"] = "Preset CLI" + self.session.headers["X-Client-Version"] = __version__ def get_teams(self) -> List[Any]: """ Retrieve all teams based on membership. """ - session = self.auth.get_session() - headers = self.auth.get_headers() - response = session.get(self.baseurl / "api/v1/teams/", headers=headers) + response = self.session.get(self.baseurl / "api/v1/teams/") validate_response(response) payload = response.json() @@ -58,11 +55,8 @@ def get_workspaces(self, team_name: str) -> List[Any]: """ Retrieve all workspaces for a given team. """ - session = self.auth.get_session() - headers = self.auth.get_headers() - response = session.get( + response = self.session.get( self.baseurl / "api/v1/teams" / team_name / "workspaces/", - headers=headers, ) validate_response(response) @@ -80,13 +74,9 @@ def invite_users( """ Invite users to teams. """ - session = self.auth.get_session() - headers = self.auth.get_headers() - for team in teams: - response = session.post( + response = self.session.post( self.baseurl / "api/v1/teams" / team / "invites/many", - headers=headers, json={ "invites": [ {"team_role_id": role_id, "email": email} for email in emails @@ -100,9 +90,6 @@ def export_users(self, workspace_url: URL) -> Iterator[UserType]: """ Return all users from a given workspace. """ - session = self.auth.get_session() - headers = self.auth.get_headers() - team_name: Optional[str] = None workspace_id: Optional[int] = None @@ -124,7 +111,7 @@ def export_users(self, workspace_url: URL) -> Iterator[UserType]: / str(workspace_id) / "memberships" ) - response = session.get(url, headers=headers) + response = self.session.get(url) team_members: List[UserType] = [ { "id": 0, @@ -139,7 +126,7 @@ def export_users(self, workspace_url: URL) -> Iterator[UserType]: # TODO (betodealmeida): improve this url = workspace_url / "roles/add" - response = session.get(url, headers=headers) + response = self.session.get(url) soup = BeautifulSoup(response.text, features="html.parser") select = soup.find("select", id="user") ids = { @@ -158,9 +145,6 @@ def import_users(self, teams: List[str], users: List[UserType]) -> None: """ Import users by adding them via SCIM. """ - session = self.auth.get_session() - headers = self.auth.get_headers() - for team in teams: url = self.baseurl / "api/v1/teams" / team / "scim/v2/Users" for user in users: @@ -186,7 +170,7 @@ def import_users(self, teams: List[str], users: List[UserType]) -> None: "givenName": user["first_name"], }, } - headers["Content-Type"] = "application/scim+json" - headers["Accept"] = "application/scim+json" - response = session.post(url, json=payload, headers=headers) + self.session.headers["Content-Type"] = "application/scim+json" + self.session.headers["Accept"] = "application/scim+json" + response = self.session.post(url, json=payload) validate_response(response) diff --git a/src/preset_cli/api/clients/superset.py b/src/preset_cli/api/clients/superset.py index a2c45718..40f6ad50 100644 --- a/src/preset_cli/api/clients/superset.py +++ b/src/preset_cli/api/clients/superset.py @@ -42,7 +42,6 @@ import pandas as pd import prison -import requests import yaml from bs4 import BeautifulSoup from yarl import URL @@ -221,6 +220,7 @@ def __init__(self, baseurl: Union[str, URL], auth: Auth): self.session = auth.get_session() self.session.headers.update(auth.get_headers()) self.session.headers["Referer"] = str(self.baseurl) + self.session.headers["User-Agent"] = f"Apache Superset Client ({__version__})" def run_query( self, @@ -251,13 +251,10 @@ def run_query( headers = { "Accept": "application/json", "Content-Type": "application/json", - "User-Agent": f"Apache Superset Client ({__version__})", - "Referer": str(self.baseurl), } - headers.update(self.auth.get_headers()) + self.session.headers.update(headers) - session = self.auth.get_session() - response = session.post(url, json=data, headers=headers) + response = self.session.post(url, json=data) validate_response(response) payload = response.json() @@ -362,13 +359,10 @@ def get_data( # pylint: disable=too-many-locals, too-many-arguments headers = { "Accept": "application/json", "Content-Type": "application/json", - "User-Agent": f"Apache Superset Client ({__version__})", - "Referer": str(self.baseurl), } - headers.update(self.auth.get_headers()) + self.session.headers.update(headers) - session = self.auth.get_session() - response = session.post(url, json=data, headers=headers) + response = self.session.post(url, json=data) validate_response(response) payload = response.json() @@ -381,10 +375,7 @@ def get_resource(self, resource_name: str, resource_id: int) -> Any: """ url = self.baseurl / "api/v1" / resource_name / str(resource_id) - session = self.auth.get_session() - headers = self.auth.get_headers() - headers["Referer"] = str(self.baseurl) - response = session.get(url, headers=headers) + response = self.session.get(url) validate_response(response) resource = response.json() @@ -417,10 +408,7 @@ def get_resources(self, resource_name: str, **kwargs: Any) -> List[Any]: ) url = self.baseurl / "api/v1" / resource_name / "" % {"q": query} - session = self.auth.get_session() - headers = self.auth.get_headers() - headers["Referer"] = str(self.baseurl) - response = session.get(url, headers=headers) + response = self.session.get(url) validate_response(response) payload = response.json() @@ -439,10 +427,7 @@ def create_resource(self, resource_name: str, **kwargs: Any) -> Any: """ url = self.baseurl / "api/v1" / resource_name / "" - session = self.auth.get_session() - headers = self.auth.get_headers() - headers["Referer"] = str(self.baseurl) - response = session.post(url, json=kwargs, headers=headers) + response = self.session.post(url, json=kwargs) validate_response(response) resource = response.json() @@ -463,10 +448,7 @@ def update_resource( if query_args: url %= query_args - session = self.auth.get_session() - headers = self.auth.get_headers() - headers["Referer"] = str(self.baseurl) - response = session.put(url, json=kwargs, headers=headers) + response = self.session.put(url, json=kwargs) validate_response(response) resource = response.json() @@ -567,9 +549,6 @@ def export_zip(self, resource_name: str, ids: List[int]) -> BytesIO: """ Export one or more of a resource. """ - session = self.auth.get_session() - headers = self.auth.get_headers() - headers["Referer"] = str(self.baseurl) url = self.baseurl / "api/v1" / resource_name / "export/" buf = BytesIO() @@ -577,7 +556,7 @@ def export_zip(self, resource_name: str, ids: List[int]) -> BytesIO: while ids: page, ids = ids[:MAX_IDS_IN_EXPORT], ids[MAX_IDS_IN_EXPORT:] params = {"q": prison.dumps(page)} - response = session.get(url, params=params, headers=headers) + response = self.session.get(url, params=params) validate_response(response) # write files from response to main ZIP bundle @@ -596,16 +575,13 @@ def get_uuids(self, resource_name: str) -> Dict[int, UUID]: Still method is very inneficient, but it's the only way to get the mapping between IDs and UUIDs in older versions of Superset. """ - session = self.auth.get_session() - headers = self.auth.get_headers() - headers["Referer"] = str(self.baseurl) url = self.baseurl / "api/v1" / resource_name / "export/" uuids: Dict[int, UUID] = {} for resource in self.get_resources(resource_name): id_ = resource["id"] params = {"q": prison.dumps([id_])} - response = session.get(url, params=params, headers=headers) + response = self.session.get(url, params=params) with ZipFile(BytesIO(response.content)) as export: for name in export.namelist(): @@ -627,15 +603,11 @@ def import_zip( """ url = self.baseurl / "api/v1" / resource_name / "import/" - session = self.auth.get_session() - headers = self.auth.get_headers() - headers["Referer"] = str(self.baseurl) - headers["Accept"] = "application/json" - response = session.post( + self.session.headers.update({"Accept": "application/json"}) + response = self.session.post( url, files=dict(formData=data), data=dict(overwrite=json.dumps(overwrite)), - headers=headers, ) validate_response(response) @@ -647,16 +619,12 @@ def export_users(self) -> Iterator[UserType]: """ Return all users. """ - session = self.auth.get_session() - headers = self.auth.get_headers() - headers["Referer"] = str(self.baseurl) - # For on-premise OSS Superset we can fetch the list of users by crawling the # ``/users/list/`` page. For a Preset workspace we need custom logic to talk # to Manager. - response = session.get(self.baseurl / "users/list/", headers=headers) + response = self.session.get(self.baseurl / "users/list/") if response.ok: - return self._export_users_superset(session, headers) + return self._export_users_superset() return self._export_users_preset() def _export_users_preset(self) -> Iterator[UserType]: @@ -667,20 +635,12 @@ def _export_users_preset(self) -> Iterator[UserType]: client = PresetClient("https://manage.app.preset.io/", self.auth) return client.export_users(self.baseurl) - def _export_users_superset( - self, - session: requests.Session, - headers: Dict[str, str], - ) -> Iterator[UserType]: + def _export_users_superset(self) -> Iterator[UserType]: """ Return all users from a standalone Superset instance. Since this is not exposed via an API we need to crawl the CRUD page. """ - session = self.auth.get_session() - headers = self.auth.get_headers() - headers["Referer"] = str(self.baseurl) - page = 0 while True: params = { @@ -690,7 +650,7 @@ def _export_users_superset( url = self.baseurl / "users/list/" page += 1 - response = session.get(url, params=params, headers=headers) + response = self.session.get(url, params=params) soup = BeautifulSoup(response.text, features="html.parser") table = soup.find_all("table")[1] trs = table.find_all("tr") @@ -712,10 +672,6 @@ def export_roles(self) -> Iterator[RoleType]: """ Return all roles. """ - session = self.auth.get_session() - headers = self.auth.get_headers() - headers["Referer"] = str(self.baseurl) - page = 0 while True: params = { @@ -725,7 +681,7 @@ def export_roles(self) -> Iterator[RoleType]: url = self.baseurl / "roles/list/" page += 1 - response = session.get(url, params=params, headers=headers) + response = self.session.get(url, params=params) soup = BeautifulSoup(response.text, features="html.parser") table = soup.find_all("table")[1] trs = table.find_all("tr") @@ -738,7 +694,7 @@ def export_roles(self) -> Iterator[RoleType]: role_id = int(tds[0].find("input").attrs["id"]) role_url = self.baseurl / "roles/show" / str(role_id) - response = session.get(role_url, headers=headers) + response = self.session.get(role_url) soup = BeautifulSoup(response.text, features="html.parser") table = soup.find_all("table")[-1] keys: List[Tuple[str, Callable[[Any], Any]]] = [ @@ -757,10 +713,6 @@ def export_rls(self) -> Iterator[RuleType]: """ Return all RLS rules. """ - session = self.auth.get_session() - headers = self.auth.get_headers() - headers["Referer"] = str(self.baseurl) - page = 0 while True: params = { @@ -770,7 +722,7 @@ def export_rls(self) -> Iterator[RuleType]: url = self.baseurl / "rowlevelsecurityfiltersmodelview/list/" page += 1 - response = session.get(url, params=params, headers=headers) + response = self.session.get(url, params=params) soup = BeautifulSoup(response.text, features="html.parser") try: table = soup.find_all("table")[1] @@ -792,7 +744,7 @@ def export_rls(self) -> Iterator[RuleType]: / str(rule_id) ) - response = session.get(rule_url, headers=headers) + response = self.session.get(rule_url) soup = BeautifulSoup(response.text, features="html.parser") table = soup.find("table") keys: List[Tuple[str, Callable[[Any], Any]]] = [ @@ -860,10 +812,6 @@ def import_rls(self, rls: RuleType) -> None: # pylint: disable=too-many-locals """ Import a given RLS rule. """ - session = self.auth.get_session() - headers = self.auth.get_headers() - headers["Referer"] = str(self.baseurl) - table_ids: List[int] = [] for table in rls["tables"]: if "." in table: @@ -882,7 +830,7 @@ def import_rls(self, rls: RuleType) -> None: # pylint: disable=too-many-locals for role in rls["roles"]: params = {"_flt_0_name": role} url = self.baseurl / "roles/list/" - response = session.get(url, params=params, headers=headers) + response = self.session.get(url, params=params) soup = BeautifulSoup(response.text, features="html.parser") trs = soup.find_all("table")[1].find_all("tr") if len(trs) == 1: @@ -897,7 +845,7 @@ def import_rls(self, rls: RuleType) -> None: # pylint: disable=too-many-locals role_ids.append(id_) url = self.baseurl / "rowlevelsecurityfiltersmodelview/add" - response = session.post( + response = self.session.post( url, data={ "name": rls["name"], @@ -908,7 +856,6 @@ def import_rls(self, rls: RuleType) -> None: # pylint: disable=too-many-locals "group_key": rls["group_key"], "clause": rls["clause"], }, - headers=headers, ) validate_response(response) diff --git a/tests/api/clients/superset_test.py b/tests/api/clients/superset_test.py index 799a2bc5..b059eb7a 100644 --- a/tests/api/clients/superset_test.py +++ b/tests/api/clients/superset_test.py @@ -5,7 +5,6 @@ import json from io import BytesIO -from unittest import mock from uuid import UUID from zipfile import ZipFile, is_zipfile @@ -523,10 +522,11 @@ def test_get_data(requests_mock: Mocker) -> None: } -def test_get_data_parameters(mocker: MockerFixture) -> None: +def test_get_data_parameters(mocker: MockerFixture, requests_mock: Mocker) -> None: """ Test different parameters passed to ``get_data``. """ + """ auth = mocker.MagicMock() session = auth.get_session() session.get().json.return_value = { @@ -542,8 +542,29 @@ def test_get_data_parameters(mocker: MockerFixture) -> None: }, ], } + """ + requests_mock.get( + "https://superset.example.org/api/v1/dataset/27", + json={ + "result": { + "columns": [], + "metrics": [], + }, + }, + ) + post_mock = requests_mock.post( + "https://superset.example.org/api/v1/chart/data", + json={ + "result": [ + { + "data": [{"a": 1}], + }, + ], + }, + ) mocker.patch("preset_cli.api.clients.superset.uuid4", return_value=1234) + auth = Auth() client = SupersetClient("https://superset.example.org/", auth) client.get_data( 27, @@ -554,64 +575,48 @@ def test_get_data_parameters(mocker: MockerFixture) -> None: granularity="P1M", ) - session.post.assert_has_calls( - [ - mock.call(), - mock.call( - URL("https://superset.example.org/api/v1/chart/data"), - json={ - "datasource": {"id": 27, "type": "table"}, - "force": False, - "queries": [ - { - "annotation_layers": [], - "applied_time_extras": {}, - "columns": [{"label": "name", "sqlExpression": "name"}], - "custom_form_data": {}, - "custom_params": {}, - "extras": { - "having": "", - "having_druid": [], - "time_grain_sqla": "P1M", - "where": "", - }, - "filters": [], - "granularity": "ts", - "is_timeseries": True, - "metrics": [ - { - "aggregate": None, - "column": None, - "expressionType": "SQL", - "hasCustomLabel": False, - "isNew": False, - "label": "cnt", - "optionName": "metric_1234", - "sqlExpression": "cnt", - }, - ], - "order_desc": True, - "orderby": [], - "row_limit": 10000, - "time_range": "No filter", - "timeseries_limit": 0, - "url_params": {}, - }, - ], - "result_format": "json", - "result_type": "full", - }, - headers={ - "Accept": "application/json", - "Content-Type": "application/json", - "User-Agent": f"Apache Superset Client ({__version__})", - "Referer": "https://superset.example.org/", + assert post_mock.last_request.json() == { + "datasource": {"id": 27, "type": "table"}, + "force": False, + "queries": [ + { + "annotation_layers": [], + "applied_time_extras": {}, + "columns": [{"label": "name", "sqlExpression": "name"}], + "custom_form_data": {}, + "custom_params": {}, + "extras": { + "having": "", + "having_druid": [], + "where": "", + "time_grain_sqla": "P1M", }, - ), - mock.call().ok.__bool__(), - mock.call().json(), + "filters": [], + "is_timeseries": True, + "metrics": [ + { + "aggregate": None, + "column": None, + "expressionType": "SQL", + "hasCustomLabel": False, + "isNew": False, + "label": "cnt", + "optionName": "metric_1234", + "sqlExpression": "cnt", + }, + ], + "order_desc": True, + "orderby": [], + "row_limit": 10000, + "time_range": "No filter", + "timeseries_limit": 0, + "url_params": {}, + "granularity": "ts", + }, ], - ) + "result_format": "json", + "result_type": "full", + } def test_get_data_time_column_error(requests_mock: Mocker) -> None: