Skip to content

Commit

Permalink
Write for pydantic v1 and v2 compat
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonDeMeester committed Nov 13, 2023
1 parent f590548 commit 7ecbc38
Show file tree
Hide file tree
Showing 8 changed files with 478 additions and 146 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
pydantic-version: ["pydantic-v1", "pydantic-v2"]
fail-fast: false

steps:
Expand Down Expand Up @@ -51,6 +52,12 @@ jobs:
- name: Install Dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: python -m poetry install
- name: Install Pydantic v1
if: matrix.pydantic-version == 'pydantic-v1'
run: pip install "pydantic>=1.10.0,<2.0.0"
- name: Install Pydantic v2
if: matrix.pydantic-version == 'pydantic-v2'
run: pip install "pydantic>=2.0.2,<3.0.0"
- name: Lint
run: python -m poetry run bash scripts/lint.sh
- run: mkdir coverage
Expand Down
2 changes: 1 addition & 1 deletion sqlmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from sqlalchemy.sql import (
LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL,
)
from sqlalchemy.sql import Subquery as Subquery
from sqlalchemy.sql import alias as alias
from sqlalchemy.sql import all_ as all_
from sqlalchemy.sql import and_ as and_
Expand Down Expand Up @@ -70,7 +71,6 @@
from sqlalchemy.sql import outerjoin as outerjoin
from sqlalchemy.sql import outparam as outparam
from sqlalchemy.sql import over as over
from sqlalchemy.sql import Subquery as Subquery
from sqlalchemy.sql import table as table
from sqlalchemy.sql import tablesample as tablesample
from sqlalchemy.sql import text as text
Expand Down
169 changes: 169 additions & 0 deletions sqlmodel/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
ForwardRef,
Optional,
Type,
TypeVar,
Union,
get_args,
get_origin,
)

from pydantic import VERSION as PYDANTIC_VERSION

IS_PYDANTIC_V2 = int(PYDANTIC_VERSION.split(".")[0]) >= 2


if IS_PYDANTIC_V2:
from pydantic import ConfigDict
from pydantic_core import PydanticUndefined as PydanticUndefined, PydanticUndefinedType as PydanticUndefinedType # noqa
else:
from pydantic import BaseConfig # noqa
from pydantic.fields import ModelField # noqa
from pydantic.fields import Undefined as PydanticUndefined, UndefinedType as PydanticUndefinedType # noqa

if TYPE_CHECKING:
from .main import FieldInfo, RelationshipInfo, SQLModel, SQLModelMetaclass


NoArgAnyCallable = Callable[[], Any]
T = TypeVar("T")
InstanceOrType = Union[T, Type[T]]

if IS_PYDANTIC_V2:

class SQLModelConfig(ConfigDict, total=False):
table: Optional[bool]
read_from_attributes: Optional[bool]
registry: Optional[Any]

else:

class SQLModelConfig(BaseConfig):
table: Optional[bool] = None
read_from_attributes: Optional[bool] = None
registry: Optional[Any] = None

def __getitem__(self, item: str) -> Any:
return self.__getattr__(item)

def __setitem__(self, item: str, value: Any) -> None:
return self.__setattr__(item, value)


# Inspired from https://github.com/roman-right/beanie/blob/main/beanie/odm/utils/pydantic.py
def get_model_config(model: type) -> Optional[SQLModelConfig]:
if IS_PYDANTIC_V2:
return getattr(model, "model_config", None)
else:
return getattr(model, "Config", None)


def get_config_value(
model: InstanceOrType["SQLModel"], parameter: str, default: Any = None
) -> Any:
if IS_PYDANTIC_V2:
return model.model_config.get(parameter, default)
else:
return getattr(model.Config, parameter, default)


def set_config_value(
model: InstanceOrType["SQLModel"], parameter: str, value: Any, v1_parameter: str = None
) -> None:
if IS_PYDANTIC_V2:
model.model_config[parameter] = value # type: ignore
else:
model.Config[v1_parameter or parameter] = value # type: ignore


def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]:
if IS_PYDANTIC_V2:
return model.model_fields # type: ignore
else:
return model.__fields__ # type: ignore


def get_fields_set(model: InstanceOrType["SQLModel"]) -> set[str]:
if IS_PYDANTIC_V2:
return model.__pydantic_fields_set__
else:
return model.__fields_set__ # type: ignore


def set_fields_set(
new_object: InstanceOrType["SQLModel"], fields: set["FieldInfo"]
) -> None:
if IS_PYDANTIC_V2:
object.__setattr__(new_object, "__pydantic_fields_set__", fields)
else:
object.__setattr__(new_object, "__fields_set__", fields)


def set_attribute_mode(cls: Type["SQLModelMetaclass"]) -> None:
if IS_PYDANTIC_V2:
cls.model_config["read_from_attributes"] = True
else:
cls.__config__.read_with_orm_mode = True # type: ignore


def get_relationship_to(
name: str,
rel_info: "RelationshipInfo",
annotation: Any,
) -> Any:
if IS_PYDANTIC_V2:
relationship_to = get_origin(annotation)
# Direct relationships (e.g. 'Team' or Team) have None as an origin
if relationship_to is None:
relationship_to = annotation
# If Union (e.g. Optional), get the real field
elif relationship_to is Union:
relationship_to = get_args(annotation)[0]
# If a list, then also get the real field
elif relationship_to is list:
relationship_to = get_args(annotation)[0]
if isinstance(relationship_to, ForwardRef):
relationship_to = relationship_to.__forward_arg__
return relationship_to
else:
temp_field = ModelField.infer(
name=name,
value=rel_info,
annotation=annotation,
class_validators=None,
config=SQLModelConfig,
)
relationship_to = temp_field.type_
if isinstance(temp_field.type_, ForwardRef):
relationship_to = temp_field.type_.__forward_arg__
return relationship_to


def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) -> None:
"""
Pydantic v2 without required fields with no optionals cannot do empty initialisations.
This means we cannot do Table() and set fields later.
We go around this by adding a default to everything, being None
Args:
annotations: Dict[str, Any]: The annotations to provide to pydantic
class_dict: Dict[str, Any]: The class dict for the defaults
"""
if IS_PYDANTIC_V2:
from .main import FieldInfo
# Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything
for key in annotations.keys():
value = class_dict.get(key, PydanticUndefined)
if value is PydanticUndefined:
class_dict[key] = None
elif isinstance(value, FieldInfo):
if (
value.default in (PydanticUndefined, Ellipsis)
) and value.default_factory is None:
# So we can check for nullable
value.original_default = value.default
value.default = None
Loading

0 comments on commit 7ecbc38

Please sign in to comment.