Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: CTE queries with non-SELECT statements #25014

Merged
merged 5 commits into from
Aug 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,53 @@ def tables(self) -> set[Table]:
def limit(self) -> Optional[int]:
return self._limit

def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]:
if "with" not in parsed:
return []
return parsed["with"].get("cte_tables", [])

def _check_cte_is_select(self, oxide_parse: list[dict[str, Any]]) -> bool:
"""
Check if a oxide parsed CTE contains only SELECT statements

:param oxide_parse: parsed CTE
:return: True if CTE is a SELECT statement
"""
for query in oxide_parse:
parsed_query = query["Query"]
cte_tables = self._get_cte_tables(parsed_query)
for cte_table in cte_tables:
is_select = all(
key == "Select" for key in cte_table["query"]["body"].keys()
)
if not is_select:
return False
return True

def is_select(self) -> bool:
# make sure we strip comments; prevents a bug with comments in the CTE
parsed = sqlparse.parse(self.strip_comments())

# Check if this is a CTE
if parsed[0].is_group and parsed[0][0].ttype == Keyword.CTE:
if sqloxide_parse is not None:
try:
if not self._check_cte_is_select(
sqloxide_parse(self.strip_comments(), dialect="ansi")
):
return False
except ValueError:
# sqloxide was not able to parse the query, so let's continue with
# sqlparse
pass
inner_cte = self.get_inner_cte_expression(parsed[0].tokens) or []
# Check if the inner CTE is a not a SELECT
if any(token.ttype == DDL for token in inner_cte) or any(
token.ttype == DML and token.normalized != "SELECT"
for token in inner_cte
):
return False

if parsed[0].get_type() == "SELECT":
return True

Expand All @@ -241,6 +285,17 @@ def is_select(self) -> bool:
token.ttype == DML and token.normalized == "SELECT" for token in parsed[0]
)

def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]:
for token in tokens:
if self._is_identifier(token):
for identifier_token in token.tokens:
if (
isinstance(identifier_token, Parenthesis)
and identifier_token.is_group
):
return identifier_token.tokens
return None

def is_valid_ctas(self) -> bool:
parsed = sqlparse.parse(self.strip_comments())
return parsed[-1].get_type() == "SELECT"
Expand Down
81 changes: 81 additions & 0 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,87 @@ def test_cte_is_select_lowercase() -> None:
assert sql.is_select()


def test_cte_insert_is_not_select() -> None:
"""
Some CTEs with lowercase select are not correctly identified as SELECTS.
"""
sql = ParsedQuery(
"""WITH foo AS(
INSERT INTO foo (id) VALUES (1) RETURNING 1
) select * FROM foo f"""
)
assert sql.is_select() is False


def test_cte_delete_is_not_select() -> None:
"""
Some CTEs with lowercase select are not correctly identified as SELECTS.
"""
sql = ParsedQuery(
"""WITH foo AS(
DELETE FROM foo RETURNING *
) select * FROM foo f"""
)
assert sql.is_select() is False


def test_cte_is_not_select_lowercase() -> None:
"""
Some CTEs with lowercase select are not correctly identified as SELECTS.
"""
sql = ParsedQuery(
"""WITH foo AS(
insert into foo (id) values (1) RETURNING 1
) select * FROM foo f"""
)
assert sql.is_select() is False


def test_cte_with_multiple_selects() -> None:
sql = ParsedQuery(
"WITH a AS ( select * from foo1 ), b as (select * from foo2) SELECT * FROM a;"
)
assert sql.is_select()


def test_cte_with_multiple_with_non_select() -> None:
sql = ParsedQuery(
"""WITH a AS (
select * from foo1
), b as (
update foo2 set id=2
) SELECT * FROM a"""
)
assert sql.is_select() is False
sql = ParsedQuery(
"""WITH a AS (
update foo2 set name=2
),
b as (
select * from foo1
) SELECT * FROM a"""
)
assert sql.is_select() is False
sql = ParsedQuery(
"""WITH a AS (
update foo2 set name=2
),
b as (
update foo1 set name=2
) SELECT * FROM a"""
)
assert sql.is_select() is False
sql = ParsedQuery(
"""WITH a AS (
INSERT INTO foo (id) VALUES (1)
),
b as (
select 1
) SELECT * FROM a"""
)
assert sql.is_select() is False


def test_unknown_select() -> None:
"""
Test that `is_select` works when sqlparse fails to identify the type.
Expand Down
Loading