diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index 464f6cf2b9c8d..5b8ba457bafa3 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import logging import re from datetime import datetime @@ -27,6 +29,7 @@ from superset.constants import TimeGrain from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod from superset.errors import SupersetErrorType +from superset.models.sql_types.mssql_sql_types import GUID from superset.utils.core import GenericDataType logger = logging.getLogger(__name__) @@ -87,6 +90,11 @@ class MssqlEngineSpec(BaseEngineSpec): SMALLDATETIME(), GenericDataType.TEMPORAL, ), + ( + re.compile(r"^uniqueidentifier.*", re.IGNORECASE), + GUID(), + GenericDataType.STRING, + ), ) custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { diff --git a/superset/models/sql_types/mssql_sql_types.py b/superset/models/sql_types/mssql_sql_types.py new file mode 100644 index 0000000000000..add40e31006ad --- /dev/null +++ b/superset/models/sql_types/mssql_sql_types.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=abstract-method +import uuid +from typing import Any, Optional + +from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.sql.sqltypes import CHAR +from sqlalchemy.sql.visitors import Visitable +from sqlalchemy.types import TypeDecorator + +# _compiler_dispatch is defined to help with type compilation + + +class GUID(TypeDecorator): + """ + A type for SQL Server's uniqueidentifier, stored as stringified UUIDs. + """ + + impl = CHAR + + @property + def python_type(self) -> type[uuid.UUID]: + """The Python type for this SQL type is `uuid.UUID`.""" + return uuid.UUID + + @classmethod + def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str: + """Return the SQL type for the GUID type, which is CHAR(36) in SQL Server.""" + return "CHAR(36)" + + def process_bind_param(self, value: str, dialect: Dialect) -> Optional[str]: + """Prepare the UUID value for binding to the database.""" + if value is None: + return None + if not isinstance(value, uuid.UUID): + return str(uuid.UUID(value)) # Convert to string UUID if needed + return str(value) + + def process_result_value( + self, value: Optional[str], dialect: Dialect + ) -> Optional[uuid.UUID]: + """Convert the string back to a UUID when retrieving from the database.""" + if value is None: + return None + return uuid.UUID(value) diff --git a/tests/unit_tests/db_engine_specs/test_mssql.py b/tests/unit_tests/db_engine_specs/test_mssql.py index 38a5603e4ec93..0a3760a47f1fa 100644 --- a/tests/unit_tests/db_engine_specs/test_mssql.py +++ b/tests/unit_tests/db_engine_specs/test_mssql.py @@ -27,6 +27,7 @@ from sqlalchemy.types import String, TypeEngine, UnicodeText from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.models.sql_types.mssql_sql_types import GUID from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( assert_column_spec, @@ -46,6 +47,7 @@ ("NCHAR(10)", UnicodeText, None, GenericDataType.STRING, False), ("NVARCHAR(10)", UnicodeText, None, GenericDataType.STRING, False), ("NTEXT", UnicodeText, None, GenericDataType.STRING, False), + ("uniqueidentifier", GUID, None, GenericDataType.STRING, False), ], ) def test_get_column_spec(