Skip to content

Commit

Permalink
chore: clean up session object
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Sep 23, 2022
1 parent a1c5a50 commit 2e08f5c
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 163 deletions.
42 changes: 13 additions & 29 deletions src/preset_cli/api/clients/preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,17 @@ def __init__(self, baseurl: Union[str, URL], auth: Auth):
# convert to URL if necessary
self.baseurl = URL(baseurl)
self.auth = auth
self.auth.headers.update(
{
"User-Agent": "Preset CLI",
"X-Client-Version": __version__,
},
)

self.session = auth.get_session()
self.session.headers.update(auth.get_headers())
self.session.headers["User-Agent"] = "Preset CLI"
self.session.headers["X-Client-Version"] = __version__

def get_teams(self) -> List[Any]:
"""
Retrieve all teams based on membership.
"""
session = self.auth.get_session()
headers = self.auth.get_headers()
response = session.get(self.baseurl / "api/v1/teams/", headers=headers)
response = self.session.get(self.baseurl / "api/v1/teams/")
validate_response(response)

payload = response.json()
Expand All @@ -58,11 +55,8 @@ def get_workspaces(self, team_name: str) -> List[Any]:
"""
Retrieve all workspaces for a given team.
"""
session = self.auth.get_session()
headers = self.auth.get_headers()
response = session.get(
response = self.session.get(
self.baseurl / "api/v1/teams" / team_name / "workspaces/",
headers=headers,
)
validate_response(response)

Expand All @@ -80,13 +74,9 @@ def invite_users(
"""
Invite users to teams.
"""
session = self.auth.get_session()
headers = self.auth.get_headers()

for team in teams:
response = session.post(
response = self.session.post(
self.baseurl / "api/v1/teams" / team / "invites/many",
headers=headers,
json={
"invites": [
{"team_role_id": role_id, "email": email} for email in emails
Expand All @@ -100,9 +90,6 @@ def export_users(self, workspace_url: URL) -> Iterator[UserType]:
"""
Return all users from a given workspace.
"""
session = self.auth.get_session()
headers = self.auth.get_headers()

team_name: Optional[str] = None
workspace_id: Optional[int] = None

Expand All @@ -124,7 +111,7 @@ def export_users(self, workspace_url: URL) -> Iterator[UserType]:
/ str(workspace_id)
/ "memberships"
)
response = session.get(url, headers=headers)
response = self.session.get(url)
team_members: List[UserType] = [
{
"id": 0,
Expand All @@ -139,7 +126,7 @@ def export_users(self, workspace_url: URL) -> Iterator[UserType]:

# TODO (betodealmeida): improve this
url = workspace_url / "roles/add"
response = session.get(url, headers=headers)
response = self.session.get(url)
soup = BeautifulSoup(response.text, features="html.parser")
select = soup.find("select", id="user")
ids = {
Expand All @@ -158,9 +145,6 @@ def import_users(self, teams: List[str], users: List[UserType]) -> None:
"""
Import users by adding them via SCIM.
"""
session = self.auth.get_session()
headers = self.auth.get_headers()

for team in teams:
url = self.baseurl / "api/v1/teams" / team / "scim/v2/Users"
for user in users:
Expand All @@ -186,7 +170,7 @@ def import_users(self, teams: List[str], users: List[UserType]) -> None:
"givenName": user["first_name"],
},
}
headers["Content-Type"] = "application/scim+json"
headers["Accept"] = "application/scim+json"
response = session.post(url, json=payload, headers=headers)
self.session.headers["Content-Type"] = "application/scim+json"
self.session.headers["Accept"] = "application/scim+json"
response = self.session.post(url, json=payload)
validate_response(response)
99 changes: 23 additions & 76 deletions src/preset_cli/api/clients/superset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

import pandas as pd
import prison
import requests
import yaml
from bs4 import BeautifulSoup
from yarl import URL
Expand Down Expand Up @@ -221,6 +220,7 @@ def __init__(self, baseurl: Union[str, URL], auth: Auth):
self.session = auth.get_session()
self.session.headers.update(auth.get_headers())
self.session.headers["Referer"] = str(self.baseurl)
self.session.headers["User-Agent"] = f"Apache Superset Client ({__version__})"

def run_query(
self,
Expand Down Expand Up @@ -251,13 +251,10 @@ def run_query(
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
"User-Agent": f"Apache Superset Client ({__version__})",
"Referer": str(self.baseurl),
}
headers.update(self.auth.get_headers())
self.session.headers.update(headers)

session = self.auth.get_session()
response = session.post(url, json=data, headers=headers)
response = self.session.post(url, json=data)
validate_response(response)

payload = response.json()
Expand Down Expand Up @@ -362,13 +359,10 @@ def get_data( # pylint: disable=too-many-locals, too-many-arguments
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
"User-Agent": f"Apache Superset Client ({__version__})",
"Referer": str(self.baseurl),
}
headers.update(self.auth.get_headers())
self.session.headers.update(headers)

session = self.auth.get_session()
response = session.post(url, json=data, headers=headers)
response = self.session.post(url, json=data)
validate_response(response)

payload = response.json()
Expand All @@ -381,10 +375,7 @@ def get_resource(self, resource_name: str, resource_id: int) -> Any:
"""
url = self.baseurl / "api/v1" / resource_name / str(resource_id)

session = self.auth.get_session()
headers = self.auth.get_headers()
headers["Referer"] = str(self.baseurl)
response = session.get(url, headers=headers)
response = self.session.get(url)
validate_response(response)

resource = response.json()
Expand Down Expand Up @@ -417,10 +408,7 @@ def get_resources(self, resource_name: str, **kwargs: Any) -> List[Any]:
)
url = self.baseurl / "api/v1" / resource_name / "" % {"q": query}

session = self.auth.get_session()
headers = self.auth.get_headers()
headers["Referer"] = str(self.baseurl)
response = session.get(url, headers=headers)
response = self.session.get(url)
validate_response(response)

payload = response.json()
Expand All @@ -439,10 +427,7 @@ def create_resource(self, resource_name: str, **kwargs: Any) -> Any:
"""
url = self.baseurl / "api/v1" / resource_name / ""

session = self.auth.get_session()
headers = self.auth.get_headers()
headers["Referer"] = str(self.baseurl)
response = session.post(url, json=kwargs, headers=headers)
response = self.session.post(url, json=kwargs)
validate_response(response)

resource = response.json()
Expand All @@ -463,10 +448,7 @@ def update_resource(
if query_args:
url %= query_args

session = self.auth.get_session()
headers = self.auth.get_headers()
headers["Referer"] = str(self.baseurl)
response = session.put(url, json=kwargs, headers=headers)
response = self.session.put(url, json=kwargs)
validate_response(response)

resource = response.json()
Expand Down Expand Up @@ -567,17 +549,14 @@ def export_zip(self, resource_name: str, ids: List[int]) -> BytesIO:
"""
Export one or more of a resource.
"""
session = self.auth.get_session()
headers = self.auth.get_headers()
headers["Referer"] = str(self.baseurl)
url = self.baseurl / "api/v1" / resource_name / "export/"

buf = BytesIO()
with ZipFile(buf, "w") as bundle:
while ids:
page, ids = ids[:MAX_IDS_IN_EXPORT], ids[MAX_IDS_IN_EXPORT:]
params = {"q": prison.dumps(page)}
response = session.get(url, params=params, headers=headers)
response = self.session.get(url, params=params)
validate_response(response)

# write files from response to main ZIP bundle
Expand All @@ -596,16 +575,13 @@ def get_uuids(self, resource_name: str) -> Dict[int, UUID]:
Still method is very inneficient, but it's the only way to get the mapping
between IDs and UUIDs in older versions of Superset.
"""
session = self.auth.get_session()
headers = self.auth.get_headers()
headers["Referer"] = str(self.baseurl)
url = self.baseurl / "api/v1" / resource_name / "export/"

uuids: Dict[int, UUID] = {}
for resource in self.get_resources(resource_name):
id_ = resource["id"]
params = {"q": prison.dumps([id_])}
response = session.get(url, params=params, headers=headers)
response = self.session.get(url, params=params)

with ZipFile(BytesIO(response.content)) as export:
for name in export.namelist():
Expand All @@ -627,15 +603,11 @@ def import_zip(
"""
url = self.baseurl / "api/v1" / resource_name / "import/"

session = self.auth.get_session()
headers = self.auth.get_headers()
headers["Referer"] = str(self.baseurl)
headers["Accept"] = "application/json"
response = session.post(
self.session.headers.update({"Accept": "application/json"})
response = self.session.post(
url,
files=dict(formData=data),
data=dict(overwrite=json.dumps(overwrite)),
headers=headers,
)
validate_response(response)

Expand All @@ -647,16 +619,12 @@ def export_users(self) -> Iterator[UserType]:
"""
Return all users.
"""
session = self.auth.get_session()
headers = self.auth.get_headers()
headers["Referer"] = str(self.baseurl)

# For on-premise OSS Superset we can fetch the list of users by crawling the
# ``/users/list/`` page. For a Preset workspace we need custom logic to talk
# to Manager.
response = session.get(self.baseurl / "users/list/", headers=headers)
response = self.session.get(self.baseurl / "users/list/")
if response.ok:
return self._export_users_superset(session, headers)
return self._export_users_superset()
return self._export_users_preset()

def _export_users_preset(self) -> Iterator[UserType]:
Expand All @@ -667,20 +635,12 @@ def _export_users_preset(self) -> Iterator[UserType]:
client = PresetClient("https://manage.app.preset.io/", self.auth)
return client.export_users(self.baseurl)

def _export_users_superset(
self,
session: requests.Session,
headers: Dict[str, str],
) -> Iterator[UserType]:
def _export_users_superset(self) -> Iterator[UserType]:
"""
Return all users from a standalone Superset instance.
Since this is not exposed via an API we need to crawl the CRUD page.
"""
session = self.auth.get_session()
headers = self.auth.get_headers()
headers["Referer"] = str(self.baseurl)

page = 0
while True:
params = {
Expand All @@ -690,7 +650,7 @@ def _export_users_superset(
url = self.baseurl / "users/list/"
page += 1

response = session.get(url, params=params, headers=headers)
response = self.session.get(url, params=params)
soup = BeautifulSoup(response.text, features="html.parser")
table = soup.find_all("table")[1]
trs = table.find_all("tr")
Expand All @@ -712,10 +672,6 @@ def export_roles(self) -> Iterator[RoleType]:
"""
Return all roles.
"""
session = self.auth.get_session()
headers = self.auth.get_headers()
headers["Referer"] = str(self.baseurl)

page = 0
while True:
params = {
Expand All @@ -725,7 +681,7 @@ def export_roles(self) -> Iterator[RoleType]:
url = self.baseurl / "roles/list/"
page += 1

response = session.get(url, params=params, headers=headers)
response = self.session.get(url, params=params)
soup = BeautifulSoup(response.text, features="html.parser")
table = soup.find_all("table")[1]
trs = table.find_all("tr")
Expand All @@ -738,7 +694,7 @@ def export_roles(self) -> Iterator[RoleType]:
role_id = int(tds[0].find("input").attrs["id"])
role_url = self.baseurl / "roles/show" / str(role_id)

response = session.get(role_url, headers=headers)
response = self.session.get(role_url)
soup = BeautifulSoup(response.text, features="html.parser")
table = soup.find_all("table")[-1]
keys: List[Tuple[str, Callable[[Any], Any]]] = [
Expand All @@ -757,10 +713,6 @@ def export_rls(self) -> Iterator[RuleType]:
"""
Return all RLS rules.
"""
session = self.auth.get_session()
headers = self.auth.get_headers()
headers["Referer"] = str(self.baseurl)

page = 0
while True:
params = {
Expand All @@ -770,7 +722,7 @@ def export_rls(self) -> Iterator[RuleType]:
url = self.baseurl / "rowlevelsecurityfiltersmodelview/list/"
page += 1

response = session.get(url, params=params, headers=headers)
response = self.session.get(url, params=params)
soup = BeautifulSoup(response.text, features="html.parser")
try:
table = soup.find_all("table")[1]
Expand All @@ -792,7 +744,7 @@ def export_rls(self) -> Iterator[RuleType]:
/ str(rule_id)
)

response = session.get(rule_url, headers=headers)
response = self.session.get(rule_url)
soup = BeautifulSoup(response.text, features="html.parser")
table = soup.find("table")
keys: List[Tuple[str, Callable[[Any], Any]]] = [
Expand Down Expand Up @@ -860,10 +812,6 @@ def import_rls(self, rls: RuleType) -> None: # pylint: disable=too-many-locals
"""
Import a given RLS rule.
"""
session = self.auth.get_session()
headers = self.auth.get_headers()
headers["Referer"] = str(self.baseurl)

table_ids: List[int] = []
for table in rls["tables"]:
if "." in table:
Expand All @@ -882,7 +830,7 @@ def import_rls(self, rls: RuleType) -> None: # pylint: disable=too-many-locals
for role in rls["roles"]:
params = {"_flt_0_name": role}
url = self.baseurl / "roles/list/"
response = session.get(url, params=params, headers=headers)
response = self.session.get(url, params=params)
soup = BeautifulSoup(response.text, features="html.parser")
trs = soup.find_all("table")[1].find_all("tr")
if len(trs) == 1:
Expand All @@ -897,7 +845,7 @@ def import_rls(self, rls: RuleType) -> None: # pylint: disable=too-many-locals
role_ids.append(id_)

url = self.baseurl / "rowlevelsecurityfiltersmodelview/add"
response = session.post(
response = self.session.post(
url,
data={
"name": rls["name"],
Expand All @@ -908,7 +856,6 @@ def import_rls(self, rls: RuleType) -> None: # pylint: disable=too-many-locals
"group_key": rls["group_key"],
"clause": rls["clause"],
},
headers=headers,
)
validate_response(response)

Expand Down
Loading

0 comments on commit 2e08f5c

Please sign in to comment.