Skip to content

Commit

Permalink
update async extension and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
PookieBuns committed Aug 10, 2023
1 parent 088164e commit b5e850b
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 2 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ fastapi = "^0.68.1"
requests = "^2.26.0"
autoflake = "^1.4"
isort = "^5.9.3"
testcontainers = "^3.7.1"
psycopg2-binary = "^2.9.7"
asyncpg = "^0.28.0"

[build-system]
requires = ["poetry-core"]
Expand Down
2 changes: 2 additions & 0 deletions sqlmodel/ext/asyncio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .engine import create_async_engine as create_async_engine
from .session import AsyncSession as AsyncSession
10 changes: 10 additions & 0 deletions sqlmodel/ext/asyncio/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Any

from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.ext.asyncio import create_async_engine as _create_async_engine


# create_async_engine by default already has future set to be true.
# Porting this over to sqlmodel to make it easier to use.
def create_async_engine(*args: Any, **kwargs: Any) -> AsyncEngine:
return _create_async_engine(*args, **kwargs)
4 changes: 2 additions & 2 deletions sqlmodel/ext/asyncio/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ...engine.result import ScalarResult
from ...orm.session import Session
from ...sql.expression import Select
from ...sql.expression import Select, SelectOfScalar

_T = TypeVar("_T")

Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(

async def exec(
self,
statement: Union[Select[_T], Executable[_T]],
statement: Union[Select[_T], SelectOfScalar[_T], Executable[_T]],
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
bind_arguments: Optional[Mapping[str, Any]] = None,
Expand Down
52 changes: 52 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import asyncio
from typing import Generator, Optional

import pytest
from sqlmodel import Field, SQLModel, select
from sqlmodel.ext.asyncio import AsyncSession, create_async_engine
from testcontainers.postgres import PostgresContainer


# The first time this test is run, it will download the postgres image which can take
# a while. Subsequent runs will be much faster.
@pytest.fixture(scope="module")
def postgres_container_url() -> Generator[str, None, None]:
with PostgresContainer("postgres:13") as postgres:
postgres.driver = "asyncpg"
yield postgres.get_connection_url()


async def _test_async_create(postgres_container_url: str) -> None:
class Hero(SQLModel, table=True):
# SQLModel.metadata is a singleton and the Hero Class has already been defined.
# If I flush the metadata during this test, it will cause test_enum to fail
# because in that file, the model isn't defined within a function. For now, the
# workaround is to set extend_existing to True. In the future, test setup and
# teardown should be refactored to avoid this issue.
__table_args__ = {"extend_existing": True}

id: Optional[int] = Field(default=None, primary_key=True)
name: str
secret_name: str
age: Optional[int] = None

hero_create = Hero(name="Deadpond", secret_name="Dive Wilson")

engine = create_async_engine(postgres_container_url)
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)

async with AsyncSession(engine) as session:
session.add(hero_create)
await session.commit()
await session.refresh(hero_create)

async with AsyncSession(engine) as session:
statement = select(Hero).where(Hero.name == "Deadpond")
results = await session.exec(statement)
hero_query = results.one()
assert hero_create == hero_query


def test_async_create(postgres_container_url: str) -> None:
asyncio.run(_test_async_create(postgres_container_url))

0 comments on commit b5e850b

Please sign in to comment.