-
Notifications
You must be signed in to change notification settings - Fork 13.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(oauth): adding necessary changes to support bigquery oauth #30674
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,14 +22,15 @@ import { useState } from 'react'; | |
import Collapse from 'src/components/Collapse'; | ||
import { Input } from 'src/components/Input'; | ||
import { FormItem } from 'src/components/Form'; | ||
import { FieldPropTypes } from '../../types'; | ||
import { FieldPropTypes, Engines } from '../../types'; | ||
|
||
interface OAuth2ClientInfo { | ||
id: string; | ||
secret: string; | ||
authorization_request_uri: string; | ||
token_request_uri: string; | ||
scope: string; | ||
project_id?: string; | ||
} | ||
|
||
export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => { | ||
|
@@ -42,6 +43,7 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => { | |
token_request_uri: | ||
encryptedExtra.oauth2_client_info?.token_request_uri || '', | ||
scope: encryptedExtra.oauth2_client_info?.scope || '', | ||
project_id: encryptedExtra.oauth2_client_info?.project_id || '', | ||
}); | ||
|
||
if (db?.engine_information?.supports_oauth2 !== true) { | ||
|
@@ -106,6 +108,15 @@ export const OAuth2ClientField = ({ changeMethods, db }: FieldPropTypes) => { | |
onChange={handleChange('scope')} | ||
/> | ||
</FormItem> | ||
{db.engine === Engines.BigQuery && ( | ||
<FormItem label="Project ID"> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's localize the label |
||
<Input | ||
data-test="project_id" | ||
value={oauth2ClientInfo.project_id} | ||
onChange={handleChange('project_id')} | ||
/> | ||
</FormItem> | ||
)} | ||
</Collapse.Panel> | ||
</Collapse> | ||
); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,7 +93,7 @@ | |
self._context = context | ||
self._uri = uri | ||
|
||
def run(self) -> None: # pylint: disable=too-many-statements | ||
def run(self) -> None: # pylint: disable=too-many-statements,too-many-branches | ||
self.validate() | ||
ex_str = "" | ||
ssh_tunnel = self._properties.get("ssh_tunnel") | ||
|
@@ -225,6 +225,10 @@ | |
# bubble up the exception to return proper status code | ||
raise | ||
except Exception as ex: | ||
if database.is_oauth2_enabled() and database.db_engine_spec.needs_oauth2( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BigQuery raises an error when running |
||
ex | ||
): | ||
database.start_oauth2_dance() | ||
event_logger.log_with_context( | ||
action=get_log_connection_action( | ||
"test_connection_error", ssh_tunnel, ex | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# Field has been moved outside of the schemas.py file to | ||
# allow for it to be imported from outside of app_context | ||
class EncryptedField: # pylint: disable=too-few-public-methods | ||
""" | ||
A database field that should be stored in encrypted_extra. | ||
""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1691,10 +1691,13 @@ def select_star( # pylint: disable=too-many-arguments | |
return sql | ||
|
||
@classmethod | ||
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: | ||
def estimate_statement_cost( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Database is required to create the BigQuery Client when using OAuth2 |
||
cls, database: Database, statement: str, cursor: Any | ||
) -> dict[str, Any]: | ||
""" | ||
Generate a SQL query that estimates the cost of a given statement. | ||
|
||
:param database: A Database object | ||
:param statement: A single SQL statement | ||
:param cursor: Cursor instance | ||
:return: Dictionary with different costs | ||
|
@@ -1765,6 +1768,7 @@ def estimate_query_cost( # pylint: disable=too-many-arguments | |
cursor = conn.cursor() | ||
return [ | ||
cls.estimate_statement_cost( | ||
database, | ||
cls.process_statement(statement, database), | ||
cursor, | ||
) | ||
|
@@ -1793,8 +1797,9 @@ def get_url_for_impersonation( | |
return url | ||
|
||
@classmethod | ||
def update_impersonation_config( | ||
def update_impersonation_config( # pylint: disable=too-many-arguments | ||
cls, | ||
database: Database, | ||
connect_args: dict[str, Any], | ||
uri: str, | ||
username: str | None, | ||
|
@@ -1804,6 +1809,7 @@ def update_impersonation_config( | |
Update a configuration dictionary | ||
that can set the correct properties for impersonating users | ||
|
||
:param connect_args: a Database object | ||
:param connect_args: config to be updated | ||
:param uri: URI | ||
:param username: Effective username | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -351,7 +351,16 @@ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: | |
return True | ||
|
||
@classmethod | ||
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: | ||
def estimate_statement_cost( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. aligning other specs to have database, and adding missing docstring |
||
cls, database: Database, statement: str, cursor: Any | ||
) -> dict[str, Any]: | ||
""" | ||
Run a SQL query that estimates the cost of a given statement. | ||
:param database: A Database object | ||
:param statement: A single SQL statement | ||
:param cursor: Cursor instance | ||
:return: JSON response from Trino | ||
""" | ||
sql = f"EXPLAIN {statement}" | ||
cursor.execute(sql) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -365,9 +365,12 @@ def get_schema_from_engine_params( | |
return parse.unquote(database.split("/")[1]) | ||
|
||
@classmethod | ||
def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: | ||
def estimate_statement_cost( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. adding database param |
||
cls, database: Database, statement: str, cursor: Any | ||
) -> dict[str, Any]: | ||
""" | ||
Run a SQL query that estimates the cost of a given statement. | ||
:param database: A Database object | ||
:param statement: A single SQL statement | ||
:param cursor: Cursor instance | ||
:return: JSON response from Trino | ||
|
@@ -945,8 +948,9 @@ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: | |
return version is not None and Version(version) >= Version("0.319") | ||
|
||
@classmethod | ||
def update_impersonation_config( | ||
def update_impersonation_config( # pylint: disable=too-many-arguments | ||
cls, | ||
database: Database, | ||
connect_args: dict[str, Any], | ||
uri: str, | ||
username: str | None, | ||
|
@@ -955,6 +959,8 @@ def update_impersonation_config( | |
""" | ||
Update a configuration dictionary | ||
that can set the correct properties for impersonating users | ||
|
||
:param connect_args: the Database object | ||
:param connect_args: config to be updated | ||
:param uri: URI string | ||
:param username: Effective username | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ | |
from copy import deepcopy | ||
from datetime import datetime | ||
from functools import lru_cache | ||
from inspect import signature | ||
from typing import Any, Callable, cast, TYPE_CHECKING | ||
|
||
import numpy | ||
|
@@ -510,12 +511,13 @@ def _get_sqla_engine( # pylint: disable=too-many-locals | |
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url)) | ||
|
||
if self.impersonate_user: | ||
self.db_engine_spec.update_impersonation_config( | ||
connect_args, | ||
str(sqlalchemy_url), | ||
effective_username, | ||
access_token, | ||
) | ||
# PR #30674 changed the signature of the method to include database. | ||
# This ensures that the change is backwards compatible | ||
sig = signature(self.db_engine_spec.update_impersonation_config) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have sufficient test coverage for this case? |
||
args = [connect_args, str(sqlalchemy_url), effective_username, access_token] | ||
if "database" in sig.parameters: | ||
args.insert(0, self) | ||
self.db_engine_spec.update_impersonation_config(*args) | ||
|
||
if connect_args: | ||
params["connect_args"] = connect_args | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
|
||
from datetime import datetime, timedelta, timezone | ||
from typing import Any, TYPE_CHECKING | ||
from urllib.parse import unquote | ||
|
||
import backoff | ||
import jwt | ||
|
@@ -169,6 +170,8 @@ def decode_oauth2_state(encoded_state: str) -> OAuth2State: | |
""" | ||
Decode the OAuth2 state. | ||
""" | ||
# Before escaping periods, the % need to be escaped | ||
encoded_state = unquote(encoded_state) | ||
# Google OAuth2 needs periods to be escaped. | ||
encoded_state = encoded_state.replace("%2E", ".") | ||
|
||
|
@@ -192,3 +195,4 @@ class OAuth2ClientConfigSchema(Schema): | |
) | ||
authorization_request_uri = fields.String(required=True) | ||
token_request_uri = fields.String(required=True) | ||
project_id = fields.String(required=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BigQuery takes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think ideally we'd want to store the project id in the URI, to make it compatible with non-oauth2 use cases. And then later when you need you can grab it from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is BigQuery specific, it's better to not add it here, otherwise the interface loses its purpose. Ideally we want to keep only the common information that every oauth2 client has.
Can we have this in a separate attribute instead?