From ad2998598f0802f81815214cc3cc0b9ee9196938 Mon Sep 17 00:00:00 2001 From: "Michael S. Molina" <70410625+michael-s-molina@users.noreply.github.com> Date: Wed, 25 Sep 2024 15:30:53 -0300 Subject: [PATCH] fix: Pre-query normalization with custom SQL (#30389) --- superset/connectors/sqla/models.py | 4 +-- .../unit_tests/connectors/sqla/models_test.py | 26 ++++++++++++++++++- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 77243d41f8d7c..9215e1545bb21 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1688,10 +1688,10 @@ def _normalize_prequery_result_type( if isinstance(value, np.generic): value = value.item() - column_ = columns_by_name[dimension] + column_ = columns_by_name.get(dimension) db_extra: dict[str, Any] = self.database.get_extra() - if column_.type and column_.is_temporal and isinstance(value, str): + if column_ and column_.type and column_.is_temporal and isinstance(value, str): sql = self.db_engine_spec.convert_dttm( column_.type, dateutil.parser.parse(value), db_extra=db_extra ) diff --git a/tests/unit_tests/connectors/sqla/models_test.py b/tests/unit_tests/connectors/sqla/models_test.py index 3fa32228ca48e..013d03e7e4cff 100644 --- a/tests/unit_tests/connectors/sqla/models_test.py +++ b/tests/unit_tests/connectors/sqla/models_test.py @@ -15,13 +15,14 @@ # specific language governing permissions and limitations # under the License. +import pandas as pd import pytest from pytest_mock import MockerFixture from sqlalchemy import create_engine from sqlalchemy.exc import IntegrityError from sqlalchemy.orm.session import Session -from superset.connectors.sqla.models import SqlaTable +from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.daos.dataset import DatasetDAO from superset.exceptions import OAuth2RedirectError from superset.models.core import Database @@ -263,3 +264,26 @@ def test_dataset_uniqueness(session: Session) -> None: database, Table("table", "schema", "some_catalog"), ) + + +def test_normalize_prequery_result_type_custom_sql() -> None: + """ + Test that the `_normalize_prequery_result_type` can hanndle custom SQL. + """ + sqla_table = SqlaTable( + table_name="my_sqla_table", + columns=[], + metrics=[], + database=Database(database_name="my_db", sqlalchemy_uri="sqlite://"), + ) + row: pd.Series = { + "custom_sql": "Car", + } + dimension: str = "custom_sql" + columns_by_name: dict[str, TableColumn] = { + "product_line": TableColumn(column_name="product_line"), + } + assert ( + sqla_table._normalize_prequery_result_type(row, dimension, columns_by_name) + == "Car" + )