Skip to content

Commit

Permalink
bigquery oauth backend
Browse files Browse the repository at this point in the history
  • Loading branch information
fisjac committed Oct 22, 2024
1 parent bad48d0 commit ad1e536
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 12 deletions.
6 changes: 5 additions & 1 deletion superset/commands/database/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, data: dict[str, Any]):
self._context = context
self._uri = uri

def run(self) -> None: # pylint: disable=too-many-statements
def run(self) -> None: # pylint: disable=too-many-statements,too-many-branches
self.validate()
ex_str = ""
ssh_tunnel = self._properties.get("ssh_tunnel")
Expand Down Expand Up @@ -225,6 +225,10 @@ def ping(engine: Engine) -> bool:
# bubble up the exception to return proper status code
raise
except Exception as ex:
if database.is_oauth2_enabled() and database.db_engine_spec.needs_oauth2(
ex
):
database.start_oauth2_dance()
event_logger.log_with_context(
action=get_log_connection_action(
"test_connection_error", ssh_tunnel, ex
Expand Down
10 changes: 8 additions & 2 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1691,10 +1691,13 @@ def select_star( # pylint: disable=too-many-arguments
return sql

@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
def estimate_statement_cost(
cls, database: Database, statement: str, cursor: Any
) -> dict[str, Any]:
"""
Generate a SQL query that estimates the cost of a given statement.
:param database: A Database object
:param statement: A single SQL statement
:param cursor: Cursor instance
:return: Dictionary with different costs
Expand Down Expand Up @@ -1765,6 +1768,7 @@ def estimate_query_cost( # pylint: disable=too-many-arguments
cursor = conn.cursor()
return [
cls.estimate_statement_cost(
database,
cls.process_statement(statement, database),
cursor,
)
Expand Down Expand Up @@ -1793,8 +1797,9 @@ def get_url_for_impersonation(
return url

@classmethod
def update_impersonation_config(
def update_impersonation_config( # pylint: disable=too-many-arguments
cls,
database: Database,
connect_args: dict[str, Any],
uri: str,
username: str | None,
Expand All @@ -1804,6 +1809,7 @@ def update_impersonation_config(
Update a configuration dictionary
that can set the correct properties for impersonating users
:param connect_args: a Database object
:param connect_args: config to be updated
:param uri: URI
:param username: Effective username
Expand Down
11 changes: 10 additions & 1 deletion superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,16 @@ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
return True

@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
def estimate_statement_cost(
cls, database: Database, statement: str, cursor: Any
) -> dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.
:param database: A Database object
:param statement: A single SQL statement
:param cursor: Cursor instance
:return: JSON response from Trino
"""
sql = f"EXPLAIN {statement}"
cursor.execute(sql)

Expand Down
10 changes: 8 additions & 2 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,12 @@ def get_schema_from_engine_params(
return parse.unquote(database.split("/")[1])

@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
def estimate_statement_cost(
cls, database: Database, statement: str, cursor: Any
) -> dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.
:param database: A Database object
:param statement: A single SQL statement
:param cursor: Cursor instance
:return: JSON response from Trino
Expand Down Expand Up @@ -945,8 +948,9 @@ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
return version is not None and Version(version) >= Version("0.319")

@classmethod
def update_impersonation_config(
def update_impersonation_config( # pylint: disable=too-many-arguments
cls,
database: Database,
connect_args: dict[str, Any],
uri: str,
username: str | None,
Expand All @@ -955,6 +959,8 @@ def update_impersonation_config(
"""
Update a configuration dictionary
that can set the correct properties for impersonating users
:param connect_args: the Database object
:param connect_args: config to be updated
:param uri: URI string
:param username: Effective username
Expand Down
23 changes: 17 additions & 6 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from copy import deepcopy
from datetime import datetime
from functools import lru_cache
from inspect import signature
from typing import Any, Callable, cast, TYPE_CHECKING

import numpy
Expand Down Expand Up @@ -510,12 +511,22 @@ def _get_sqla_engine( # pylint: disable=too-many-locals
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))

if self.impersonate_user:
self.db_engine_spec.update_impersonation_config(
connect_args,
str(sqlalchemy_url),
effective_username,
access_token,
)
# Checking if the function signature can accept database as a param
if "database" in signature(self.db_engine_spec.update_impersonation_config):
self.db_engine_spec.update_impersonation_config(
self,
connect_args,
str(sqlalchemy_url),
effective_username,
access_token,
)
else:
self.db_engine_spec.update_impersonation_config(
connect_args,
str(sqlalchemy_url),
effective_username,
access_token,
)

if connect_args:
params["connect_args"] = connect_args
Expand Down
1 change: 1 addition & 0 deletions superset/utils/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,4 @@ class OAuth2ClientConfigSchema(Schema):
)
authorization_request_uri = fields.String(required=True)
token_request_uri = fields.String(required=True)
project_id = fields.String(required=False)

0 comments on commit ad1e536

Please sign in to comment.