Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(oauth): adding necessary changes to support bigquery oauth #30674

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ import { useState } from 'react';
import Collapse from 'src/components/Collapse';
import { Input } from 'src/components/Input';
import { FormItem } from 'src/components/Form';
import { FieldPropTypes } from '../../types';
import { FieldPropTypes, Engines } from '../../types';

interface OAuth2ClientInfo {
id: string;
secret: string;
authorization_request_uri: string;
token_request_uri: string;
scope: string;
project_id?: string;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is BigQuery specific, it's better to not add it here, otherwise the interface loses its purpose. Ideally we want to keep only the common information that every oauth2 client has.

Can we have this in a separate attribute instead?

}

export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
Expand All @@ -42,6 +43,7 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
token_request_uri:
encryptedExtra.oauth2_client_info?.token_request_uri || '',
scope: encryptedExtra.oauth2_client_info?.scope || '',
project_id: encryptedExtra.oauth2_client_info?.project_id || '',
});

if (db?.engine_information?.supports_oauth2 !== true) {
Expand Down Expand Up @@ -106,6 +108,15 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => {
onChange={handleChange('scope')}
/>
</FormItem>
{db.engine === Engines.BigQuery && (
<FormItem label="Project ID">
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's localize the label

<Input
data-test="project_id"
value={oauth2ClientInfo.project_id}
onChange={handleChange('project_id')}
/>
</FormItem>
)}
</Collapse.Panel>
</Collapse>
);
Expand Down
1 change: 1 addition & 0 deletions superset-frontend/src/features/databases/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ export enum ConfigurationMethod {

export enum Engines {
GSheet = 'gsheets',
BigQuery = 'bigquery',
Snowflake = 'snowflake',
}

Expand Down
6 changes: 5 additions & 1 deletion superset/commands/database/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
self._context = context
self._uri = uri

def run(self) -> None: # pylint: disable=too-many-statements
def run(self) -> None: # pylint: disable=too-many-statements,too-many-branches
self.validate()
ex_str = ""
ssh_tunnel = self._properties.get("ssh_tunnel")
Expand Down Expand Up @@ -225,6 +225,10 @@
# bubble up the exception to return proper status code
raise
except Exception as ex:
if database.is_oauth2_enabled() and database.db_engine_spec.needs_oauth2(
Copy link
Contributor Author

@fisjac fisjac Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BigQuery raises an error when running create_engine() necessitating another needs_oauth2(ex) check for the proper oauth2 exceptions to trigger the Oauth dance. Most DB's allow for an engine to be created without valid creds, and instead raises the exception on engine.connect()

ex
):
database.start_oauth2_dance()

Check warning on line 231 in superset/commands/database/test_connection.py

View check run for this annotation

Codecov / codecov/patch

superset/commands/database/test_connection.py#L231

Added line #L231 was not covered by tests
event_logger.log_with_context(
action=get_log_connection_action(
"test_connection_error", ssh_tunnel, ex
Expand Down
7 changes: 1 addition & 6 deletions superset/databases/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
SSHTunnelMissingCredentials,
)
from superset.constants import PASSWORD_MASK
from superset.databases.types import EncryptedField
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import get_engine_spec
from superset.exceptions import CertificateException, SupersetSecurityException
Expand Down Expand Up @@ -941,12 +942,6 @@ def validate_ssh_tunnel_credentials(
return


class EncryptedField: # pylint: disable=too-few-public-methods
"""
A database field that should be stored in encrypted_extra.
"""


class EncryptedString(EncryptedField, fields.String):
pass

Expand Down
23 changes: 23 additions & 0 deletions superset/databases/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# Field has been moved outside of the schemas.py file to
# allow for it to be imported from outside of app_context
class EncryptedField: # pylint: disable=too-few-public-methods
"""
A database field that should be stored in encrypted_extra.
"""
10 changes: 8 additions & 2 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1691,10 +1691,13 @@ def select_star( # pylint: disable=too-many-arguments
return sql

@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
def estimate_statement_cost(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Database is required to create the BigQuery Client when using OAuth2

cls, database: Database, statement: str, cursor: Any
) -> dict[str, Any]:
"""
Generate a SQL query that estimates the cost of a given statement.

:param database: A Database object
:param statement: A single SQL statement
:param cursor: Cursor instance
:return: Dictionary with different costs
Expand Down Expand Up @@ -1765,6 +1768,7 @@ def estimate_query_cost( # pylint: disable=too-many-arguments
cursor = conn.cursor()
return [
cls.estimate_statement_cost(
database,
cls.process_statement(statement, database),
cursor,
)
Expand Down Expand Up @@ -1793,8 +1797,9 @@ def get_url_for_impersonation(
return url

@classmethod
def update_impersonation_config(
def update_impersonation_config( # pylint: disable=too-many-arguments
cls,
database: Database,
connect_args: dict[str, Any],
uri: str,
username: str | None,
Expand All @@ -1804,6 +1809,7 @@ def update_impersonation_config(
Update a configuration dictionary
that can set the correct properties for impersonating users

:param connect_args: a Database object
:param connect_args: config to be updated
:param uri: URI
:param username: Effective username
Expand Down
12 changes: 8 additions & 4 deletions superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,11 @@
pandas_gbq.to_gbq(df, **to_gbq_kwargs)

@classmethod
def _get_client(cls, engine: Engine) -> bigquery.Client:
def _get_client(
cls,
engine: Engine,
database: Database, # pylint: disable=unused-argument
) -> bigquery.Client:
"""
Return the BigQuery client associated with an engine.
"""
Expand Down Expand Up @@ -453,7 +457,7 @@
catalog=catalog,
schema=schema,
) as engine:
client = cls._get_client(engine)
client = cls._get_client(engine, database)

Check warning on line 460 in superset/db_engine_specs/bigquery.py

View check run for this annotation

Codecov / codecov/patch

superset/db_engine_specs/bigquery.py#L460

Added line #L460 was not covered by tests
return [
cls.custom_estimate_statement_cost(
cls.process_statement(statement, database),
Expand All @@ -477,7 +481,7 @@
return project

with database.get_sqla_engine() as engine:
client = cls._get_client(engine)
client = cls._get_client(engine, database)
return client.project

@classmethod
Expand All @@ -493,7 +497,7 @@
"""
engine: Engine
with database.get_sqla_engine() as engine:
client = cls._get_client(engine)
client = cls._get_client(engine, database)

Check warning on line 500 in superset/db_engine_specs/bigquery.py

View check run for this annotation

Codecov / codecov/patch

superset/db_engine_specs/bigquery.py#L500

Added line #L500 was not covered by tests
projects = client.list_projects()

return {project.project_id for project in projects}
Expand Down
4 changes: 3 additions & 1 deletion superset/db_engine_specs/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,8 +537,9 @@ def get_url_for_impersonation(
return url

@classmethod
def update_impersonation_config(
def update_impersonation_config( # pylint: disable=too-many-arguments
cls,
database: Database,
connect_args: dict[str, Any],
uri: str,
username: str | None,
Expand All @@ -547,6 +548,7 @@ def update_impersonation_config(
"""
Update a configuration dictionary
that can set the correct properties for impersonating users
:param database: the Database Object
:param connect_args:
:param uri: URI string
:param impersonate_user: Flag indicating if impersonation is enabled
Expand Down
11 changes: 10 additions & 1 deletion superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,16 @@ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
return True

@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
def estimate_statement_cost(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aligning other specs to have database, and adding missing docstring

cls, database: Database, statement: str, cursor: Any
) -> dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.
:param database: A Database object
:param statement: A single SQL statement
:param cursor: Cursor instance
:return: JSON response from Trino
"""
sql = f"EXPLAIN {statement}"
cursor.execute(sql)

Expand Down
10 changes: 8 additions & 2 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,12 @@ def get_schema_from_engine_params(
return parse.unquote(database.split("/")[1])

@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
def estimate_statement_cost(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding database param

cls, database: Database, statement: str, cursor: Any
) -> dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.
:param database: A Database object
:param statement: A single SQL statement
:param cursor: Cursor instance
:return: JSON response from Trino
Expand Down Expand Up @@ -945,8 +948,9 @@ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
return version is not None and Version(version) >= Version("0.319")

@classmethod
def update_impersonation_config(
def update_impersonation_config( # pylint: disable=too-many-arguments
cls,
database: Database,
connect_args: dict[str, Any],
uri: str,
username: str | None,
Expand All @@ -955,6 +959,8 @@ def update_impersonation_config(
"""
Update a configuration dictionary
that can set the correct properties for impersonating users

:param connect_args: the Database object
:param connect_args: config to be updated
:param uri: URI string
:param username: Effective username
Expand Down
4 changes: 3 additions & 1 deletion superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ def get_extra_table_metadata(
return metadata

@classmethod
def update_impersonation_config(
def update_impersonation_config( # pylint: disable=too-many-arguments
cls,
database: Database,
connect_args: dict[str, Any],
uri: str,
username: str | None,
Expand All @@ -126,6 +127,7 @@ def update_impersonation_config(
"""
Update a configuration dictionary
that can set the correct properties for impersonating users
:param database: the Database object
:param connect_args: config to be updated
:param uri: URI string
:param username: Effective username
Expand Down
14 changes: 8 additions & 6 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from copy import deepcopy
from datetime import datetime
from functools import lru_cache
from inspect import signature
from typing import Any, Callable, cast, TYPE_CHECKING

import numpy
Expand Down Expand Up @@ -510,12 +511,13 @@ def _get_sqla_engine( # pylint: disable=too-many-locals
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))

if self.impersonate_user:
self.db_engine_spec.update_impersonation_config(
connect_args,
str(sqlalchemy_url),
effective_username,
access_token,
)
# PR #30674 changed the signature of the method to include database.
# This ensures that the change is backwards compatible
sig = signature(self.db_engine_spec.update_impersonation_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have sufficient test coverage for this case?

args = [connect_args, str(sqlalchemy_url), effective_username, access_token]
if "database" in sig.parameters:
args.insert(0, self)
self.db_engine_spec.update_impersonation_config(*args)

if connect_args:
params["connect_args"] = connect_args
Expand Down
4 changes: 4 additions & 0 deletions superset/utils/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from datetime import datetime, timedelta, timezone
from typing import Any, TYPE_CHECKING
from urllib.parse import unquote

import backoff
import jwt
Expand Down Expand Up @@ -169,6 +170,8 @@ def decode_oauth2_state(encoded_state: str) -> OAuth2State:
"""
Decode the OAuth2 state.
"""
# Before escaping periods, the % need to be escaped
encoded_state = unquote(encoded_state)
# Google OAuth2 needs periods to be escaped.
encoded_state = encoded_state.replace("%2E", ".")

Expand All @@ -192,3 +195,4 @@ class OAuth2ClientConfigSchema(Schema):
)
authorization_request_uri = fields.String(required=True)
token_request_uri = fields.String(required=True)
project_id = fields.String(required=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BigQuery takes project_id as a param in its OAuth2 parameters

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ideally we'd want to store the project id in the URI, to make it compatible with non-oauth2 use cases. And then later when you need you can grab it from database.sqlalchemy_uri.database.

6 changes: 4 additions & 2 deletions tests/integration_tests/db_engine_specs/postgres_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,13 @@ def test_estimate_statement_cost_select_star(self):
DB Eng Specs (postgres): Test estimate_statement_cost select star
"""

database = mock.Mock()
cursor = mock.Mock()
cursor.fetchone.return_value = (
"Seq Scan on birth_names (cost=0.00..1537.91 rows=75691 width=46)",
)
sql = "SELECT * FROM birth_names"
results = PostgresEngineSpec.estimate_statement_cost(sql, cursor)
results = PostgresEngineSpec.estimate_statement_cost(database, sql, cursor)
assert results == {"Start-up cost": 0.0, "Total cost": 1537.91}

def test_estimate_statement_invalid_syntax(self):
Expand All @@ -165,6 +166,7 @@ def test_estimate_statement_invalid_syntax(self):
"""
from psycopg2 import errors

database = mock.Mock()
cursor = mock.Mock()
cursor.execute.side_effect = errors.SyntaxError(
"""
Expand All @@ -175,7 +177,7 @@ def test_estimate_statement_invalid_syntax(self):
)
sql = "DROP TABLE birth_names"
with self.assertRaises(errors.SyntaxError):
PostgresEngineSpec.estimate_statement_cost(sql, cursor)
PostgresEngineSpec.estimate_statement_cost(database, sql, cursor)

def test_query_cost_formatter_example_costs(self):
"""
Expand Down
8 changes: 6 additions & 2 deletions tests/integration_tests/db_engine_specs/presto_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,22 +905,26 @@ def test_select_star_presto_expand_data(
)

def test_estimate_statement_cost(self):
mock_database = mock.MagicMock()
mock_cursor = mock.MagicMock()
estimate_json = {"a": "b"}
mock_cursor.fetchone.return_value = [
'{"a": "b"}',
]
result = PrestoEngineSpec.estimate_statement_cost(
"SELECT * FROM brth_names", mock_cursor
mock_database,
"SELECT * FROM brth_names",
mock_cursor,
)
assert result == estimate_json

def test_estimate_statement_cost_invalid_syntax(self):
mock_database = mock.MagicMock()
mock_cursor = mock.MagicMock()
mock_cursor.execute.side_effect = Exception()
with self.assertRaises(Exception):
PrestoEngineSpec.estimate_statement_cost(
"DROP TABLE brth_names", mock_cursor
mock_database, "DROP TABLE brth_names", mock_cursor
)

def test_get_create_view(self):
Expand Down
Loading