diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 9215e1545bb21..82980cef161d8 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1533,6 +1533,7 @@ def adhoc_metric_to_sqla( expression = self._process_sql_expression( expression=metric["sqlExpression"], database_id=self.database_id, + engine=self.database.backend, schema=self.schema, template_processor=template_processor, ) @@ -1566,6 +1567,7 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals expression = self._process_sql_expression( expression=col["sqlExpression"], database_id=self.database_id, + engine=self.database.backend, schema=self.schema, template_processor=template_processor, ) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 4085d3a0aabc7..51808f9a46b31 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -63,6 +63,7 @@ ColumnNotFoundException, QueryClauseValidationException, QueryObjectValidationError, + SupersetParseError, SupersetSecurityException, ) from superset.extensions import feature_flag_manager @@ -112,6 +113,7 @@ def validate_adhoc_subquery( sql: str, database_id: int, + engine: str, default_schema: str, ) -> str: """ @@ -126,7 +128,12 @@ def validate_adhoc_subquery( """ statements = [] for statement in sqlparse.parse(sql): - if has_table_query(statement): + try: + has_table = has_table_query(str(statement), engine) + except SupersetParseError: + has_table = True + + if has_table: if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"): raise SupersetSecurityException( SupersetError( @@ -135,7 +142,9 @@ def validate_adhoc_subquery( level=ErrorLevel.ERROR, ) ) + # TODO (betodealmeida): reimplement with sqlglot statement = insert_rls_in_predicate(statement, database_id, default_schema) + statements.append(statement) return ";\n".join(str(statement) for statement in statements) @@ -810,10 +819,11 @@ def get_sqla_row_level_filters( # for datasources of type query return [] - def _process_sql_expression( + def _process_sql_expression( # pylint: disable=too-many-arguments self, expression: Optional[str], database_id: int, + engine: str, schema: str, template_processor: Optional[BaseTemplateProcessor], ) -> Optional[str]: @@ -823,6 +833,7 @@ def _process_sql_expression( expression = validate_adhoc_subquery( expression, database_id, + engine, schema, ) try: @@ -1108,6 +1119,7 @@ def adhoc_metric_to_sqla( expression = self._process_sql_expression( expression=metric["sqlExpression"], database_id=self.database_id, + engine=self.database.backend, schema=self.schema, template_processor=template_processor, ) @@ -1551,6 +1563,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma col["sqlExpression"] = self._process_sql_expression( expression=col["sqlExpression"], database_id=self.database_id, + engine=self.database.backend, schema=self.schema, template_processor=template_processor, ) @@ -1613,6 +1626,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma selected = validate_adhoc_subquery( selected, self.database_id, + self.database.backend, self.schema, ) outer = literal_column(f"({selected})") @@ -1639,6 +1653,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma selected = validate_adhoc_subquery( _sql, self.database_id, + self.database.backend, self.schema, ) @@ -1915,6 +1930,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma where = self._process_sql_expression( expression=where, database_id=self.database_id, + engine=self.database.backend, schema=self.schema, template_processor=template_processor, ) @@ -1933,6 +1949,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma having = self._process_sql_expression( expression=having, database_id=self.database_id, + engine=self.database.backend, schema=self.schema, template_processor=template_processor, ) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 6f25a5a66058c..1702601d0f244 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -374,6 +374,7 @@ def adhoc_column_to_sqla( expression = self._process_sql_expression( expression=col["sqlExpression"], database_id=self.database_id, + engine=self.database.backend, schema=self.schema, template_processor=template_processor, ) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 1581b0c6e79e6..cb457cd4f5c51 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -64,6 +64,7 @@ extract_tables_from_statement, SQLGLOT_DIALECTS, SQLScript, + SQLStatement, Table, ) from superset.utils.backports import StrEnum @@ -570,46 +571,31 @@ class InsertRLSState(StrEnum): FOUND_TABLE = "FOUND_TABLE" -def has_table_query(token_list: TokenList) -> bool: +def has_table_query(expression: str, engine: str) -> bool: """ Return if a statement has a query reading from a table. - >>> has_table_query(sqlparse.parse("COUNT(*)")[0]) + >>> has_table_query("COUNT(*)", "postgresql") False - >>> has_table_query(sqlparse.parse("SELECT * FROM table")[0]) + >>> has_table_query("SELECT * FROM table", "postgresql") True Note that queries reading from constant values return false: - >>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0]) + >>> has_table_query("SELECT * FROM (SELECT 1)", "postgresql") False """ - state = InsertRLSState.SCANNING - for token in token_list.tokens: - # Ignore comments - if isinstance(token, sqlparse.sql.Comment): - continue - - # Recurse into child token list - if isinstance(token, TokenList) and has_table_query(token): - return True - - # Found a source keyword (FROM/JOIN) - if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]): - state = InsertRLSState.SEEN_SOURCE - - # Found identifier/keyword after FROM/JOIN - elif state == InsertRLSState.SEEN_SOURCE and ( - isinstance(token, sqlparse.sql.Identifier) or token.ttype == Keyword - ): - return True + # Remove trailing semicolon. + expression = expression.strip().rstrip(";") - # Found nothing, leaving source - elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace: - state = InsertRLSState.SCANNING + # Wrap the expression in parentheses if it's not already. + if not expression.startswith("("): + expression = f"({expression})" - return False + sql = f"SELECT {expression}" + statement = SQLStatement(sql, engine) + return any(statement.tables) def add_table_name(rls: TokenList, table: str) -> None: diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index ec45c8c57e882..ab13fc4dafb60 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -42,6 +42,7 @@ get_main_database, ) from tests.integration_tests.base_tests import db_insert_temp_object, SupersetTestCase +from tests.integration_tests.conftest import with_feature_flags from tests.integration_tests.constants import ADMIN_USERNAME from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, # noqa: F401 @@ -585,6 +586,7 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data assert "INCORRECT SQL" in rv.json.get("error") +@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True) def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_dataset): uri = ( f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table" @@ -649,6 +651,7 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset): assert rv.json["result"]["rowcount"] == 0 +@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True) def test_get_samples_with_time_filter(test_client, login_as_admin, physical_dataset): uri = ( f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table" @@ -669,6 +672,7 @@ def test_get_samples_with_time_filter(test_client, login_as_admin, physical_data assert rv.json["result"]["total_count"] == 2 +@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True) def test_get_samples_with_multiple_filters( test_client, login_as_admin, physical_dataset ): diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index 4822b690ed84c..d77523c71c9b0 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -42,7 +42,11 @@ ) from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR from tests.integration_tests.base_tests import SupersetTestCase -from tests.integration_tests.conftest import only_postgresql, only_sqlite +from tests.integration_tests.conftest import ( + only_postgresql, + only_sqlite, + with_feature_flags, +) from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, # noqa: F401 load_birth_names_data, # noqa: F401 @@ -858,6 +862,7 @@ def test_non_time_column_with_time_grain(app_context, physical_dataset): assert df["COL2 ALIAS"][0] == "a" +@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True) def test_special_chars_in_column_name(app_context, physical_dataset): qc = QueryContextFactory().create( datasource={ diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 23d51de64cdeb..44d52c7f6e8fb 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -1286,46 +1286,66 @@ def test_sqlparse_issue_652(): @pytest.mark.parametrize( - "sql,expected", + ("engine", "sql", "expected"), [ - ("SELECT * FROM table", True), - ("SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)", True), - ("(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)", True), - ("COUNT(*)", False), - ("SELECT a FROM (SELECT 1 AS a)", False), - ("SELECT a FROM (SELECT 1 AS a) JOIN table", True), - ("SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", False), - ("SELECT * FROM other_table", True), - ("extract(HOUR from from_unixtime(hour_ts)", False), - ("(SELECT * FROM table)", True), - ("(SELECT COUNT(DISTINCT name) from birth_names)", True), + ("postgresql", "extract(HOUR from from_unixtime(hour_ts))", False), + ("postgresql", "SELECT * FROM table", True), + ("postgresql", "(SELECT * FROM table)", True), ( + "postgresql", + "SELECT a FROM (SELECT 1 AS a) JOIN (SELECT * FROM table)", + True, + ), + ( + "postgresql", + "(SELECT COUNT(DISTINCT name) AS foo FROM birth_names)", + True, + ), + ("postgresql", "COUNT(*)", False), + ("postgresql", "SELECT a FROM (SELECT 1 AS a)", False), + ("postgresql", "SELECT a FROM (SELECT 1 AS a) JOIN table", True), + ( + "postgresql", + "SELECT * FROM (SELECT 1 AS foo, 2 AS bar) ORDER BY foo ASC, bar", + False, + ), + ("postgresql", "SELECT * FROM other_table", True), + ("postgresql", "(SELECT COUNT(DISTINCT name) from birth_names)", True), + ( + "postgresql", "(SELECT table_name FROM information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)", True, ), ( + "postgresql", "(SELECT table_name FROM /**/ information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)", True, ), ( + "postgresql", "SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;", True, ), ( + "postgresql", "SELECT * FROM (SELECT * FROM forbidden_table) forbidden_table", True, ), + ( + "postgresql", + "((select users.id from (select 'majorie' as a) b, users where b.a = users.name and users.name in ('majorie') limit 1) like 'U%')", + True, + ), ], ) -def test_has_table_query(sql: str, expected: bool) -> None: +def test_has_table_query(engine: str, sql: str, expected: bool) -> None: """ Test if a given statement queries a table. This is used to prevent ad-hoc metrics from querying unauthorized tables, bypassing row-level security. """ - statement = sqlparse.parse(sql)[0] - assert has_table_query(statement) == expected + assert has_table_query(sql, engine) == expected @pytest.mark.parametrize(