Skip to content

Commit

Permalink
bigquery oauth
Browse files Browse the repository at this point in the history
  • Loading branch information
fisjac committed Oct 24, 2024
1 parent bad48d0 commit 6e42328
Show file tree
Hide file tree
Showing 15 changed files with 101 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ import { useState } from 'react';
import Collapse from 'src/components/Collapse';
import { Input } from 'src/components/Input';
import { FormItem } from 'src/components/Form';
import { FieldPropTypes } from '../../types';
import { FieldPropTypes, Engines } from '../../types';

interface OAuth2ClientInfo {
id: string;
secret: string;
authorization_request_uri: string;
token_request_uri: string;
scope: string;
project_id?: string;
}

export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
Expand All @@ -42,6 +43,7 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
token_request_uri:
encryptedExtra.oauth2_client_info?.token_request_uri || '',
scope: encryptedExtra.oauth2_client_info?.scope || '',
project_id: encryptedExtra.oauth2_client_info?.project_id || '',
});

if (db?.engine_information?.supports_oauth2 !== true) {
Expand Down Expand Up @@ -106,6 +108,15 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
onChange={handleChange('scope')}
/>
</FormItem>
{db.engine === Engines.BigQuery && (
<FormItem label="Project ID">
<Input
data-test="project_id"
value={oauth2ClientInfo.project_id}
onChange={handleChange('project_id')}
/>
</FormItem>
)}
</Collapse.Panel>
</Collapse>
);
Expand Down
1 change: 1 addition & 0 deletions superset-frontend/src/features/databases/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ export enum ConfigurationMethod {

export enum Engines {
GSheet = 'gsheets',
BigQuery = 'bigquery',
Snowflake = 'snowflake',
}

Expand Down
6 changes: 5 additions & 1 deletion superset/commands/database/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, data: dict[str, Any]):
self._context = context
self._uri = uri

def run(self) -> None: # pylint: disable=too-many-statements
def run(self) -> None: # pylint: disable=too-many-statements,too-many-branches
self.validate()
ex_str = ""
ssh_tunnel = self._properties.get("ssh_tunnel")
Expand Down Expand Up @@ -225,6 +225,10 @@ def ping(engine: Engine) -> bool:
# bubble up the exception to return proper status code
raise
except Exception as ex:
if database.is_oauth2_enabled() and database.db_engine_spec.needs_oauth2(
ex
):
database.start_oauth2_dance()

Check warning on line 231 in superset/commands/database/test_connection.py

View check run for this annotation

Codecov / codecov/patch

superset/commands/database/test_connection.py#L231

Added line #L231 was not covered by tests
event_logger.log_with_context(
action=get_log_connection_action(
"test_connection_error", ssh_tunnel, ex
Expand Down
7 changes: 1 addition & 6 deletions superset/databases/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
SSHTunnelMissingCredentials,
)
from superset.constants import PASSWORD_MASK
from superset.databases.types import EncryptedField
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import get_engine_spec
from superset.exceptions import CertificateException, SupersetSecurityException
Expand Down Expand Up @@ -941,12 +942,6 @@ def validate_ssh_tunnel_credentials(
return


class EncryptedField: # pylint: disable=too-few-public-methods
"""
A database field that should be stored in encrypted_extra.
"""


class EncryptedString(EncryptedField, fields.String):
pass

Expand Down
23 changes: 23 additions & 0 deletions superset/databases/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# Field has been moved outside of the schemas.py file to
# allow for it to be imported from outside of app_context
class EncryptedField: # pylint: disable=too-few-public-methods
"""
A database field that should be stored in encrypted_extra.
"""
10 changes: 8 additions & 2 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1691,10 +1691,13 @@ def select_star( # pylint: disable=too-many-arguments
return sql

@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
def estimate_statement_cost(
cls, database: Database, statement: str, cursor: Any
) -> dict[str, Any]:
"""
Generate a SQL query that estimates the cost of a given statement.
:param database: A Database object
:param statement: A single SQL statement
:param cursor: Cursor instance
:return: Dictionary with different costs
Expand Down Expand Up @@ -1765,6 +1768,7 @@ def estimate_query_cost( # pylint: disable=too-many-arguments
cursor = conn.cursor()
return [
cls.estimate_statement_cost(
database,
cls.process_statement(statement, database),
cursor,
)
Expand Down Expand Up @@ -1793,8 +1797,9 @@ def get_url_for_impersonation(
return url

@classmethod
def update_impersonation_config(
def update_impersonation_config( # pylint: disable=too-many-arguments
cls,
database: Database,
connect_args: dict[str, Any],
uri: str,
username: str | None,
Expand All @@ -1804,6 +1809,7 @@ def update_impersonation_config(
Update a configuration dictionary
that can set the correct properties for impersonating users
:param connect_args: a Database object
:param connect_args: config to be updated
:param uri: URI
:param username: Effective username
Expand Down
12 changes: 8 additions & 4 deletions superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,11 @@ def df_to_sql(
pandas_gbq.to_gbq(df, **to_gbq_kwargs)

@classmethod
def _get_client(cls, engine: Engine) -> bigquery.Client:
def _get_client(
cls,
engine: Engine,
database: Database, # pylint: disable=unused-argument
) -> bigquery.Client:
"""
Return the BigQuery client associated with an engine.
"""
Expand Down Expand Up @@ -453,7 +457,7 @@ def estimate_query_cost( # pylint: disable=too-many-arguments
catalog=catalog,
schema=schema,
) as engine:
client = cls._get_client(engine)
client = cls._get_client(engine, database)

Check warning on line 460 in superset/db_engine_specs/bigquery.py

View check run for this annotation

Codecov / codecov/patch

superset/db_engine_specs/bigquery.py#L460

Added line #L460 was not covered by tests
return [
cls.custom_estimate_statement_cost(
cls.process_statement(statement, database),
Expand All @@ -477,7 +481,7 @@ def get_default_catalog(cls, database: Database) -> str | None:
return project

with database.get_sqla_engine() as engine:
client = cls._get_client(engine)
client = cls._get_client(engine, database)
return client.project

@classmethod
Expand All @@ -493,7 +497,7 @@ def get_catalog_names(
"""
engine: Engine
with database.get_sqla_engine() as engine:
client = cls._get_client(engine)
client = cls._get_client(engine, database)

Check warning on line 500 in superset/db_engine_specs/bigquery.py

View check run for this annotation

Codecov / codecov/patch

superset/db_engine_specs/bigquery.py#L500

Added line #L500 was not covered by tests
projects = client.list_projects()

return {project.project_id for project in projects}
Expand Down
4 changes: 3 additions & 1 deletion superset/db_engine_specs/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,8 +537,9 @@ def get_url_for_impersonation(
return url

@classmethod
def update_impersonation_config(
def update_impersonation_config( # pylint: disable=too-many-arguments
cls,
database: Database,
connect_args: dict[str, Any],
uri: str,
username: str | None,
Expand All @@ -547,6 +548,7 @@ def update_impersonation_config(
"""
Update a configuration dictionary
that can set the correct properties for impersonating users
:param database: the Database Object
:param connect_args:
:param uri: URI string
:param impersonate_user: Flag indicating if impersonation is enabled
Expand Down
11 changes: 10 additions & 1 deletion superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,16 @@ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
return True

@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
def estimate_statement_cost(
cls, database: Database, statement: str, cursor: Any
) -> dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.
:param database: A Database object
:param statement: A single SQL statement
:param cursor: Cursor instance
:return: JSON response from Trino
"""
sql = f"EXPLAIN {statement}"
cursor.execute(sql)

Expand Down
10 changes: 8 additions & 2 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,12 @@ def get_schema_from_engine_params(
return parse.unquote(database.split("/")[1])

@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
def estimate_statement_cost(
cls, database: Database, statement: str, cursor: Any
) -> dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.
:param database: A Database object
:param statement: A single SQL statement
:param cursor: Cursor instance
:return: JSON response from Trino
Expand Down Expand Up @@ -945,8 +948,9 @@ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
return version is not None and Version(version) >= Version("0.319")

@classmethod
def update_impersonation_config(
def update_impersonation_config( # pylint: disable=too-many-arguments
cls,
database: Database,
connect_args: dict[str, Any],
uri: str,
username: str | None,
Expand All @@ -955,6 +959,8 @@ def update_impersonation_config(
"""
Update a configuration dictionary
that can set the correct properties for impersonating users
:param connect_args: the Database object
:param connect_args: config to be updated
:param uri: URI string
:param username: Effective username
Expand Down
4 changes: 3 additions & 1 deletion superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ def get_extra_table_metadata(
return metadata

@classmethod
def update_impersonation_config(
def update_impersonation_config( # pylint: disable=too-many-arguments
cls,
database: Database,
connect_args: dict[str, Any],
uri: str,
username: str | None,
Expand All @@ -126,6 +127,7 @@ def update_impersonation_config(
"""
Update a configuration dictionary
that can set the correct properties for impersonating users
:param database: the Database object
:param connect_args: config to be updated
:param uri: URI string
:param username: Effective username
Expand Down
14 changes: 8 additions & 6 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from copy import deepcopy
from datetime import datetime
from functools import lru_cache
from inspect import signature
from typing import Any, Callable, cast, TYPE_CHECKING

import numpy
Expand Down Expand Up @@ -510,12 +511,13 @@ def _get_sqla_engine( # pylint: disable=too-many-locals
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))

if self.impersonate_user:
self.db_engine_spec.update_impersonation_config(
connect_args,
str(sqlalchemy_url),
effective_username,
access_token,
)
# PR #30674 changed the signature of the method to include database.
# This ensures that the change is backwards compatible
sig = signature(self.db_engine_spec.update_impersonation_config)
args = [connect_args, str(sqlalchemy_url), effective_username, access_token]
if "database" in sig.parameters:
args.insert(0, self)
self.db_engine_spec.update_impersonation_config(*args)

if connect_args:
params["connect_args"] = connect_args
Expand Down
1 change: 1 addition & 0 deletions superset/utils/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,4 @@ class OAuth2ClientConfigSchema(Schema):
)
authorization_request_uri = fields.String(required=True)
token_request_uri = fields.String(required=True)
project_id = fields.String(required=False)
6 changes: 4 additions & 2 deletions tests/integration_tests/db_engine_specs/postgres_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,13 @@ def test_estimate_statement_cost_select_star(self):
DB Eng Specs (postgres): Test estimate_statement_cost select star
"""

database = mock.Mock()
cursor = mock.Mock()
cursor.fetchone.return_value = (
"Seq Scan on birth_names (cost=0.00..1537.91 rows=75691 width=46)",
)
sql = "SELECT * FROM birth_names"
results = PostgresEngineSpec.estimate_statement_cost(sql, cursor)
results = PostgresEngineSpec.estimate_statement_cost(database, sql, cursor)
assert results == {"Start-up cost": 0.0, "Total cost": 1537.91}

def test_estimate_statement_invalid_syntax(self):
Expand All @@ -165,6 +166,7 @@ def test_estimate_statement_invalid_syntax(self):
"""
from psycopg2 import errors

database = mock.Mock()
cursor = mock.Mock()
cursor.execute.side_effect = errors.SyntaxError(
"""
Expand All @@ -175,7 +177,7 @@ def test_estimate_statement_invalid_syntax(self):
)
sql = "DROP TABLE birth_names"
with self.assertRaises(errors.SyntaxError):
PostgresEngineSpec.estimate_statement_cost(sql, cursor)
PostgresEngineSpec.estimate_statement_cost(database, sql, cursor)

def test_query_cost_formatter_example_costs(self):
"""
Expand Down
8 changes: 6 additions & 2 deletions tests/integration_tests/db_engine_specs/presto_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,22 +905,26 @@ def test_select_star_presto_expand_data(
)

def test_estimate_statement_cost(self):
mock_database = mock.MagicMock()
mock_cursor = mock.MagicMock()
estimate_json = {"a": "b"}
mock_cursor.fetchone.return_value = [
'{"a": "b"}',
]
result = PrestoEngineSpec.estimate_statement_cost(
"SELECT * FROM brth_names", mock_cursor
mock_database,
"SELECT * FROM brth_names",
mock_cursor,
)
assert result == estimate_json

def test_estimate_statement_cost_invalid_syntax(self):
mock_database = mock.MagicMock()
mock_cursor = mock.MagicMock()
mock_cursor.execute.side_effect = Exception()
with self.assertRaises(Exception):
PrestoEngineSpec.estimate_statement_cost(
"DROP TABLE brth_names", mock_cursor
mock_database, "DROP TABLE brth_names", mock_cursor
)

def test_get_create_view(self):
Expand Down

0 comments on commit 6e42328

Please sign in to comment.