From b5e850bdc5f0d35c9b2842318c6f7938388a6af8 Mon Sep 17 00:00:00 2001 From: allenyuchen Date: Wed, 9 Aug 2023 23:05:38 -0500 Subject: [PATCH] update async extension and add tests --- pyproject.toml | 3 ++ sqlmodel/ext/asyncio/__init__.py | 2 ++ sqlmodel/ext/asyncio/engine.py | 10 ++++++ sqlmodel/ext/asyncio/session.py | 4 +-- tests/test_async.py | 52 ++++++++++++++++++++++++++++++++ 5 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 sqlmodel/ext/asyncio/engine.py create mode 100644 tests/test_async.py diff --git a/pyproject.toml b/pyproject.toml index e40272715..9561cd276 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,9 @@ fastapi = "^0.68.1" requests = "^2.26.0" autoflake = "^1.4" isort = "^5.9.3" +testcontainers = "^3.7.1" +psycopg2-binary = "^2.9.7" +asyncpg = "^0.28.0" [build-system] requires = ["poetry-core"] diff --git a/sqlmodel/ext/asyncio/__init__.py b/sqlmodel/ext/asyncio/__init__.py index e69de29bb..0af81880e 100644 --- a/sqlmodel/ext/asyncio/__init__.py +++ b/sqlmodel/ext/asyncio/__init__.py @@ -0,0 +1,2 @@ +from .engine import create_async_engine as create_async_engine +from .session import AsyncSession as AsyncSession diff --git a/sqlmodel/ext/asyncio/engine.py b/sqlmodel/ext/asyncio/engine.py new file mode 100644 index 000000000..92c0dff37 --- /dev/null +++ b/sqlmodel/ext/asyncio/engine.py @@ -0,0 +1,10 @@ +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.ext.asyncio import create_async_engine as _create_async_engine + + +# create_async_engine by default already has future set to be true. +# Porting this over to sqlmodel to make it easier to use. +def create_async_engine(*args: Any, **kwargs: Any) -> AsyncEngine: + return _create_async_engine(*args, **kwargs) diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index 80267b25e..79dae568a 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -9,7 +9,7 @@ from ...engine.result import ScalarResult from ...orm.session import Session -from ...sql.expression import Select +from ...sql.expression import Select, SelectOfScalar _T = TypeVar("_T") @@ -42,7 +42,7 @@ def __init__( async def exec( self, - statement: Union[Select[_T], Executable[_T]], + statement: Union[Select[_T], SelectOfScalar[_T], Executable[_T]], params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[Any, Any] = util.EMPTY_DICT, bind_arguments: Optional[Mapping[str, Any]] = None, diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 000000000..64b7a1035 --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,52 @@ +import asyncio +from typing import Generator, Optional + +import pytest +from sqlmodel import Field, SQLModel, select +from sqlmodel.ext.asyncio import AsyncSession, create_async_engine +from testcontainers.postgres import PostgresContainer + + +# The first time this test is run, it will download the postgres image which can take +# a while. Subsequent runs will be much faster. +@pytest.fixture(scope="module") +def postgres_container_url() -> Generator[str, None, None]: + with PostgresContainer("postgres:13") as postgres: + postgres.driver = "asyncpg" + yield postgres.get_connection_url() + + +async def _test_async_create(postgres_container_url: str) -> None: + class Hero(SQLModel, table=True): + # SQLModel.metadata is a singleton and the Hero Class has already been defined. + # If I flush the metadata during this test, it will cause test_enum to fail + # because in that file, the model isn't defined within a function. For now, the + # workaround is to set extend_existing to True. In the future, test setup and + # teardown should be refactored to avoid this issue. + __table_args__ = {"extend_existing": True} + + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str + age: Optional[int] = None + + hero_create = Hero(name="Deadpond", secret_name="Dive Wilson") + + engine = create_async_engine(postgres_container_url) + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + + async with AsyncSession(engine) as session: + session.add(hero_create) + await session.commit() + await session.refresh(hero_create) + + async with AsyncSession(engine) as session: + statement = select(Hero).where(Hero.name == "Deadpond") + results = await session.exec(statement) + hero_query = results.one() + assert hero_create == hero_query + + +def test_async_create(postgres_container_url: str) -> None: + asyncio.run(_test_async_create(postgres_container_url))