From 6e42328d55f160738fcaf77aceecca6b9f525048 Mon Sep 17 00:00:00 2001 From: Jack Fisher Date: Tue, 22 Oct 2024 11:28:32 -0500 Subject: [PATCH 1/2] bigquery oauth --- .../OAuth2ClientField.tsx | 13 ++++++++++- .../src/features/databases/types.ts | 1 + superset/commands/database/test_connection.py | 6 ++++- superset/databases/schemas.py | 7 +----- superset/databases/types.py | 23 +++++++++++++++++++ superset/db_engine_specs/base.py | 10 ++++++-- superset/db_engine_specs/bigquery.py | 12 ++++++---- superset/db_engine_specs/hive.py | 4 +++- superset/db_engine_specs/postgres.py | 11 ++++++++- superset/db_engine_specs/presto.py | 10 ++++++-- superset/db_engine_specs/trino.py | 4 +++- superset/models/core.py | 14 ++++++----- superset/utils/oauth2.py | 1 + .../db_engine_specs/postgres_tests.py | 6 +++-- .../db_engine_specs/presto_tests.py | 8 +++++-- 15 files changed, 101 insertions(+), 29 deletions(-) create mode 100644 superset/databases/types.py diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/OAuth2ClientField.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/OAuth2ClientField.tsx index ee0ffdeb33f51..d3488a3673bce 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/OAuth2ClientField.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/OAuth2ClientField.tsx @@ -22,7 +22,7 @@ import { useState } from 'react'; import Collapse from 'src/components/Collapse'; import { Input } from 'src/components/Input'; import { FormItem } from 'src/components/Form'; -import { FieldPropTypes } from '../../types'; +import { FieldPropTypes, Engines } from '../../types'; interface OAuth2ClientInfo { id: string; @@ -30,6 +30,7 @@ interface OAuth2ClientInfo { authorization_request_uri: string; token_request_uri: string; scope: string; + project_id?: string; } export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => { @@ -42,6 +43,7 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => { token_request_uri: encryptedExtra.oauth2_client_info?.token_request_uri || '', scope: encryptedExtra.oauth2_client_info?.scope || '', + project_id: encryptedExtra.oauth2_client_info?.project_id || '', }); if (db?.engine_information?.supports_oauth2 !== true) { @@ -106,6 +108,15 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => { onChange={handleChange('scope')} /> + {db.engine === Engines.BigQuery && ( + + + + )} ); diff --git a/superset-frontend/src/features/databases/types.ts b/superset-frontend/src/features/databases/types.ts index c0ca1c5d3508e..a6600f5136cca 100644 --- a/superset-frontend/src/features/databases/types.ts +++ b/superset-frontend/src/features/databases/types.ts @@ -232,6 +232,7 @@ export enum ConfigurationMethod { export enum Engines { GSheet = 'gsheets', + BigQuery = 'bigquery', Snowflake = 'snowflake', } diff --git a/superset/commands/database/test_connection.py b/superset/commands/database/test_connection.py index 8aef6c1359b5e..7330446d47ed6 100644 --- a/superset/commands/database/test_connection.py +++ b/superset/commands/database/test_connection.py @@ -93,7 +93,7 @@ def __init__(self, data: dict[str, Any]): self._context = context self._uri = uri - def run(self) -> None: # pylint: disable=too-many-statements + def run(self) -> None: # pylint: disable=too-many-statements,too-many-branches self.validate() ex_str = "" ssh_tunnel = self._properties.get("ssh_tunnel") @@ -225,6 +225,10 @@ def ping(engine: Engine) -> bool: # bubble up the exception to return proper status code raise except Exception as ex: + if database.is_oauth2_enabled() and database.db_engine_spec.needs_oauth2( + ex + ): + database.start_oauth2_dance() event_logger.log_with_context( action=get_log_connection_action( "test_connection_error", ssh_tunnel, ex diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 27eb043eb131b..13ae0a0af1c30 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -47,6 +47,7 @@ SSHTunnelMissingCredentials, ) from superset.constants import PASSWORD_MASK +from superset.databases.types import EncryptedField from superset.databases.utils import make_url_safe from superset.db_engine_specs import get_engine_spec from superset.exceptions import CertificateException, SupersetSecurityException @@ -941,12 +942,6 @@ def validate_ssh_tunnel_credentials( return -class EncryptedField: # pylint: disable=too-few-public-methods - """ - A database field that should be stored in encrypted_extra. - """ - - class EncryptedString(EncryptedField, fields.String): pass diff --git a/superset/databases/types.py b/superset/databases/types.py new file mode 100644 index 0000000000000..91b703151655c --- /dev/null +++ b/superset/databases/types.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Field has been moved outside of the schemas.py file to +# allow for it to be imported from outside of app_context +class EncryptedField: # pylint: disable=too-few-public-methods + """ + A database field that should be stored in encrypted_extra. + """ diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index dcdfff6c3f242..8622a76a50b6d 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1691,10 +1691,13 @@ def select_star( # pylint: disable=too-many-arguments return sql @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: + def estimate_statement_cost( + cls, database: Database, statement: str, cursor: Any + ) -> dict[str, Any]: """ Generate a SQL query that estimates the cost of a given statement. + :param database: A Database object :param statement: A single SQL statement :param cursor: Cursor instance :return: Dictionary with different costs @@ -1765,6 +1768,7 @@ def estimate_query_cost( # pylint: disable=too-many-arguments cursor = conn.cursor() return [ cls.estimate_statement_cost( + database, cls.process_statement(statement, database), cursor, ) @@ -1793,8 +1797,9 @@ def get_url_for_impersonation( return url @classmethod - def update_impersonation_config( + def update_impersonation_config( # pylint: disable=too-many-arguments cls, + database: Database, connect_args: dict[str, Any], uri: str, username: str | None, @@ -1804,6 +1809,7 @@ def update_impersonation_config( Update a configuration dictionary that can set the correct properties for impersonating users + :param connect_args: a Database object :param connect_args: config to be updated :param uri: URI :param username: Effective username diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 11175d7957445..70bc4bc845390 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -409,7 +409,11 @@ def df_to_sql( pandas_gbq.to_gbq(df, **to_gbq_kwargs) @classmethod - def _get_client(cls, engine: Engine) -> bigquery.Client: + def _get_client( + cls, + engine: Engine, + database: Database, # pylint: disable=unused-argument + ) -> bigquery.Client: """ Return the BigQuery client associated with an engine. """ @@ -453,7 +457,7 @@ def estimate_query_cost( # pylint: disable=too-many-arguments catalog=catalog, schema=schema, ) as engine: - client = cls._get_client(engine) + client = cls._get_client(engine, database) return [ cls.custom_estimate_statement_cost( cls.process_statement(statement, database), @@ -477,7 +481,7 @@ def get_default_catalog(cls, database: Database) -> str | None: return project with database.get_sqla_engine() as engine: - client = cls._get_client(engine) + client = cls._get_client(engine, database) return client.project @classmethod @@ -493,7 +497,7 @@ def get_catalog_names( """ engine: Engine with database.get_sqla_engine() as engine: - client = cls._get_client(engine) + client = cls._get_client(engine, database) projects = client.list_projects() return {project.project_id for project in projects} diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index e3cf128b7a2c6..6288866db93ee 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -537,8 +537,9 @@ def get_url_for_impersonation( return url @classmethod - def update_impersonation_config( + def update_impersonation_config( # pylint: disable=too-many-arguments cls, + database: Database, connect_args: dict[str, Any], uri: str, username: str | None, @@ -547,6 +548,7 @@ def update_impersonation_config( """ Update a configuration dictionary that can set the correct properties for impersonating users + :param database: the Database Object :param connect_args: :param uri: URI string :param impersonate_user: Flag indicating if impersonation is enabled diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 70373927d521b..6281c6b3b0ff3 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -351,7 +351,16 @@ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: return True @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: + def estimate_statement_cost( + cls, database: Database, statement: str, cursor: Any + ) -> dict[str, Any]: + """ + Run a SQL query that estimates the cost of a given statement. + :param database: A Database object + :param statement: A single SQL statement + :param cursor: Cursor instance + :return: JSON response from Trino + """ sql = f"EXPLAIN {statement}" cursor.execute(sql) diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index f0664564f872c..df5e1c643fa1f 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -365,9 +365,12 @@ def get_schema_from_engine_params( return parse.unquote(database.split("/")[1]) @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: + def estimate_statement_cost( + cls, database: Database, statement: str, cursor: Any + ) -> dict[str, Any]: """ Run a SQL query that estimates the cost of a given statement. + :param database: A Database object :param statement: A single SQL statement :param cursor: Cursor instance :return: JSON response from Trino @@ -945,8 +948,9 @@ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: return version is not None and Version(version) >= Version("0.319") @classmethod - def update_impersonation_config( + def update_impersonation_config( # pylint: disable=too-many-arguments cls, + database: Database, connect_args: dict[str, Any], uri: str, username: str | None, @@ -955,6 +959,8 @@ def update_impersonation_config( """ Update a configuration dictionary that can set the correct properties for impersonating users + + :param connect_args: the Database object :param connect_args: config to be updated :param uri: URI string :param username: Effective username diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 49615c39cba52..c473528217b5e 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -116,8 +116,9 @@ def get_extra_table_metadata( return metadata @classmethod - def update_impersonation_config( + def update_impersonation_config( # pylint: disable=too-many-arguments cls, + database: Database, connect_args: dict[str, Any], uri: str, username: str | None, @@ -126,6 +127,7 @@ def update_impersonation_config( """ Update a configuration dictionary that can set the correct properties for impersonating users + :param database: the Database object :param connect_args: config to be updated :param uri: URI string :param username: Effective username diff --git a/superset/models/core.py b/superset/models/core.py index 5d3a6ea74ddab..0027e65b3919c 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -29,6 +29,7 @@ from copy import deepcopy from datetime import datetime from functools import lru_cache +from inspect import signature from typing import Any, Callable, cast, TYPE_CHECKING import numpy @@ -510,12 +511,13 @@ def _get_sqla_engine( # pylint: disable=too-many-locals logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url)) if self.impersonate_user: - self.db_engine_spec.update_impersonation_config( - connect_args, - str(sqlalchemy_url), - effective_username, - access_token, - ) + # PR #30674 changed the signature of the method to include database. + # This ensures that the change is backwards compatible + sig = signature(self.db_engine_spec.update_impersonation_config) + args = [connect_args, str(sqlalchemy_url), effective_username, access_token] + if "database" in sig.parameters: + args.insert(0, self) + self.db_engine_spec.update_impersonation_config(*args) if connect_args: params["connect_args"] = connect_args diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index b889ef83c5e75..08a081862f7d8 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -192,3 +192,4 @@ class OAuth2ClientConfigSchema(Schema): ) authorization_request_uri = fields.String(required=True) token_request_uri = fields.String(required=True) + project_id = fields.String(required=False) diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index e4f9462d63069..a5ef1cdecab59 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -151,12 +151,13 @@ def test_estimate_statement_cost_select_star(self): DB Eng Specs (postgres): Test estimate_statement_cost select star """ + database = mock.Mock() cursor = mock.Mock() cursor.fetchone.return_value = ( "Seq Scan on birth_names (cost=0.00..1537.91 rows=75691 width=46)", ) sql = "SELECT * FROM birth_names" - results = PostgresEngineSpec.estimate_statement_cost(sql, cursor) + results = PostgresEngineSpec.estimate_statement_cost(database, sql, cursor) assert results == {"Start-up cost": 0.0, "Total cost": 1537.91} def test_estimate_statement_invalid_syntax(self): @@ -165,6 +166,7 @@ def test_estimate_statement_invalid_syntax(self): """ from psycopg2 import errors + database = mock.Mock() cursor = mock.Mock() cursor.execute.side_effect = errors.SyntaxError( """ @@ -175,7 +177,7 @@ def test_estimate_statement_invalid_syntax(self): ) sql = "DROP TABLE birth_names" with self.assertRaises(errors.SyntaxError): - PostgresEngineSpec.estimate_statement_cost(sql, cursor) + PostgresEngineSpec.estimate_statement_cost(database, sql, cursor) def test_query_cost_formatter_example_costs(self): """ diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index 798e31ee431a4..94e3ea62721a4 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -905,22 +905,26 @@ def test_select_star_presto_expand_data( ) def test_estimate_statement_cost(self): + mock_database = mock.MagicMock() mock_cursor = mock.MagicMock() estimate_json = {"a": "b"} mock_cursor.fetchone.return_value = [ '{"a": "b"}', ] result = PrestoEngineSpec.estimate_statement_cost( - "SELECT * FROM brth_names", mock_cursor + mock_database, + "SELECT * FROM brth_names", + mock_cursor, ) assert result == estimate_json def test_estimate_statement_cost_invalid_syntax(self): + mock_database = mock.MagicMock() mock_cursor = mock.MagicMock() mock_cursor.execute.side_effect = Exception() with self.assertRaises(Exception): PrestoEngineSpec.estimate_statement_cost( - "DROP TABLE brth_names", mock_cursor + mock_database, "DROP TABLE brth_names", mock_cursor ) def test_get_create_view(self): From 25cdd3a57a0c2c452771c4e709752d106c7851c9 Mon Sep 17 00:00:00 2001 From: Jack Fisher Date: Thu, 24 Oct 2024 17:14:31 -0500 Subject: [PATCH 2/2] fixing double encoding bug for OAuth JWTs --- superset/utils/oauth2.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index 08a081862f7d8..43416d4681eda 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -19,6 +19,7 @@ from datetime import datetime, timedelta, timezone from typing import Any, TYPE_CHECKING +from urllib.parse import unquote import backoff import jwt @@ -169,6 +170,8 @@ def decode_oauth2_state(encoded_state: str) -> OAuth2State: """ Decode the OAuth2 state. """ + # Before escaping periods, the % need to be escaped + encoded_state = unquote(encoded_state) # Google OAuth2 needs periods to be escaped. encoded_state = encoded_state.replace("%2E", ".")