From 2450086fde3f1182a20d49177d37a3c19f6fd18f Mon Sep 17 00:00:00 2001 From: Kirill Mineev Date: Mon, 30 Sep 2024 04:59:51 +0500 Subject: [PATCH] feat: add psycopg3 pool --- peewee_async/__init__.py | 3 ++ peewee_async/connection.py | 2 +- peewee_async/databases.py | 23 ++++++++++- peewee_async/pool.py | 80 +++++++++++++++++++++++++++++++++++++- peewee_async/utils.py | 13 ++++++- pyproject.toml | 5 ++- tests/conftest.py | 7 +++- tests/db_config.py | 2 + tests/test_database.py | 7 ++-- tests/test_transaction.py | 2 +- 10 files changed, 131 insertions(+), 13 deletions(-) diff --git a/peewee_async/__init__.py b/peewee_async/__init__.py index 39be5f0..adc4f77 100644 --- a/peewee_async/__init__.py +++ b/peewee_async/__init__.py @@ -22,6 +22,7 @@ PooledPostgresqlDatabase, PooledPostgresqlExtDatabase, PooledMySQLDatabase, + PooledPsycopg3PostgresqlDatabase, ) from .pool import PostgresqlPoolBackend, MysqlPoolBackend from .transactions import Transaction @@ -43,4 +44,6 @@ register_database(PooledPostgresqlDatabase, 'postgres+pool+async', 'postgresql+pool+async') register_database(PooledPostgresqlExtDatabase, 'postgresext+pool+async', 'postgresqlext+pool+async') +register_database(PooledPsycopg3PostgresqlDatabase, 'postgres+psycopg+pool+async', 'postgres+psycopg+pool+async') register_database(PooledMySQLDatabase, 'mysql+pool+async') + diff --git a/peewee_async/connection.py b/peewee_async/connection.py index 4fcd83d..db1f1e2 100644 --- a/peewee_async/connection.py +++ b/peewee_async/connection.py @@ -39,5 +39,5 @@ async def __aexit__( ) -> None: if self.resuing_connection is False: if self.connection_context is not None: - self.pool_backend.release(self.connection_context.connection) + await self.pool_backend.release(self.connection_context.connection) connection_context.set(None) diff --git a/peewee_async/databases.py b/peewee_async/databases.py index 87eaf69..b2fff04 100644 --- a/peewee_async/databases.py +++ b/peewee_async/databases.py @@ -6,9 +6,9 @@ from playhouse import postgres_ext as ext from .connection import connection_context, ConnectionContextManager -from .pool import PoolBackend, PostgresqlPoolBackend, MysqlPoolBackend +from .pool import PoolBackend, PostgresqlPoolBackend, MysqlPoolBackend, Psycopg3PoolBackend from .transactions import Transaction -from .utils import aiopg, aiomysql, __log__, FetchResults +from .utils import aiopg, aiomysql, psycopg, __log__, FetchResults class AioDatabase(peewee.Database): @@ -197,6 +197,25 @@ def init(self, database: Optional[str], **kwargs: Any) -> None: super().init(database, **kwargs) +class PooledPsycopg3PostgresqlDatabase(AioDatabase, peewee.PostgresqlDatabase): + """Extension for `peewee.PostgresqlDatabase` providing extra methods + for managing async connection based on psycopg3 pool backend. + + See also: + https://peewee.readthedocs.io/en/latest/peewee/api.html#PostgresqlDatabase + """ + + pool_backend_cls = Psycopg3PoolBackend + + def init_pool_params_defaults(self) -> None: + self.pool_params.update({"enable_json": False, "enable_hstore": False}) + + def init(self, database: Optional[str], **kwargs: Any) -> None: + if not psycopg: + raise Exception("Error, psycopg is not installed!") + super().init(database, **kwargs) + + class PooledPostgresqlExtDatabase( PooledPostgresqlDatabase, ext.PostgresqlExtDatabase diff --git a/peewee_async/pool.py b/peewee_async/pool.py index 7d313a1..d0cae5c 100644 --- a/peewee_async/pool.py +++ b/peewee_async/pool.py @@ -2,7 +2,13 @@ import asyncio from typing import Any, Optional, cast +import psycopg_pool +from psycopg import AsyncClientCursor +from psycopg.types import TypeInfo +from psycopg.types.hstore import register_hstore + from .utils import aiopg, aiomysql, PoolProtocol, ConnectionProtocol +from .utils import format_dsn class PoolBackend(metaclass=abc.ABCMeta): @@ -21,6 +27,14 @@ def is_connected(self) -> bool: return self.pool.closed is False return False + @property + def min_size(self) -> int: + return self.pool.minsize + + @property + def max_size(self) -> int: + return self.pool.maxsize + def has_acquired_connections(self) -> bool: if self.pool is not None: return len(self.pool._used) > 0 @@ -39,7 +53,7 @@ async def acquire(self) -> ConnectionProtocol: assert self.pool is not None, "Pool is not connected" return await self.pool.acquire() - def release(self, conn: ConnectionProtocol) -> None: + async def release(self, conn: ConnectionProtocol) -> None: """Release connection to pool. """ assert self.pool is not None, "Pool is not connected" @@ -77,6 +91,70 @@ async def create(self) -> None: ) +class Psycopg3PoolBackend(PoolBackend): + """Asynchronous database connection pool based on psycopg + psycopg_pool libraries. + """ + + async def create(self) -> None: + """Create connection pool asynchronously. + """ + + pool = psycopg_pool.AsyncConnectionPool( + format_dsn( + 'postgresql', + host=self.connect_params['host'], + port=self.connect_params['port'], + user=self.connect_params['user'], + password=self.connect_params['password'], + path=self.database, + ), + min_size=self.connect_params.get('minsize', 1), + max_size=self.connect_params.get('maxsize', 20), + max_lifetime=self.connect_params.get('pool_recycle', 60 * 60.0), + open=False, + kwargs={ + 'cursor_factory': AsyncClientCursor, + 'autocommit': True, + } + ) + + await pool.open() + self.pool = pool + + def has_acquired_connections(self) -> bool: + if self.pool is not None: + return self.pool._nconns - self.pool._num_pool > 0 + return False + + async def acquire(self) -> ConnectionProtocol: + """Acquire connection from pool. + """ + if self.pool is None: + await self.connect() + assert self.pool is not None, "Pool is not connected" + return await self.pool.getconn() + + async def release(self, conn: ConnectionProtocol) -> None: + """Release connection to pool. + """ + assert self.pool is not None, "Pool is not connected" + await self.pool.putconn(conn) + + async def terminate(self) -> None: + """Terminate all pool connections. + """ + if self.pool is not None: + await self.pool.close() + + @property + def min_size(self) -> int: + return self.pool.min_size + + @property + def max_size(self) -> int: + return self.pool.max_size + + class MysqlPoolBackend(PoolBackend): """Asynchronous database connection pool. """ diff --git a/peewee_async/utils.py b/peewee_async/utils.py index 984fc92..ec53476 100644 --- a/peewee_async/utils.py +++ b/peewee_async/utils.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Protocol, Optional, Sequence, Set, AsyncContextManager, List, Callable, Awaitable +from typing import Any, Protocol, Optional, Sequence, Set, AsyncContextManager, List, Callable, Awaitable, Union try: import aiopg @@ -8,6 +8,11 @@ aiopg = None # type: ignore psycopg2 = None +try: + import psycopg +except ImportError: + psycopg = None # type: ignore + try: import aiomysql import pymysql @@ -71,4 +76,8 @@ async def wait_closed(self) -> None: ... -FetchResults = Callable[[CursorProtocol], Awaitable[Any]] \ No newline at end of file +FetchResults = Callable[[CursorProtocol], Awaitable[Any]] + + +def format_dsn(protocol: str, host: str, port: Union[str, int], user: str, password: str, path: str = '') -> str: + return f'{protocol}://{user}:{password}@{host}:{port}/{path}' diff --git a/pyproject.toml b/pyproject.toml index 2c7c07a..e89100b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,12 +21,15 @@ sphinx = { version = "^7.1.2", optional = true } sphinx-rtd-theme = { version = "^1.3.0rc1", optional = true } mypy = { version = "^1.10.1", optional = true } types-PyMySQL = { version = "^1.1.0.20240524", optional = true } +psycopg = { version = "^3.2.0", optional = true } +psycopg-pool = { version = "^3.2.0", optional = true } [tool.poetry.extras] postgresql = ["aiopg"] mysql = ["aiomysql", "cryptography"] -develop = ["aiopg", "aiomysql", "cryptography", "pytest", "pytest-asyncio", "pytest-mock", "mypy", "types-PyMySQL"] +develop = ["aiopg", "aiomysql", "cryptography", "pytest", "pytest-asyncio", "pytest-mock", "mypy", "types-PyMySQL", "psycopg", "psycopg-pool"] docs = ["aiopg", "aiomysql", "cryptography", "sphinx", "sphinx-rtd-theme"] +psycopg3 = ["psycopg"] [build-system] requires = ["poetry-core"] diff --git a/tests/conftest.py b/tests/conftest.py index 2976c2e..c609ae2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ from peewee import sort_models from peewee_async.databases import AioDatabase -from peewee_async.utils import aiopg, aiomysql +from peewee_async.utils import aiopg, aiomysql, psycopg from tests.db_config import DB_CLASSES, DB_DEFAULTS from tests.models import ALL_MODELS @@ -38,6 +38,8 @@ async def db(request: pytest.FixtureRequest) -> AsyncGenerator[AioDatabase, None pytest.skip("aiopg is not installed") if db.startswith('mysql') and aiomysql is None: pytest.skip("aiomysql is not installed") + if db.startswith('psycopg') and psycopg is None: + pytest.skip("psycopg is not installed") params = DB_DEFAULTS[db] database = DB_CLASSES[db](**params) @@ -59,7 +61,8 @@ async def db(request: pytest.FixtureRequest) -> AsyncGenerator[AioDatabase, None PG_DBS = [ "postgres-pool", - "postgres-pool-ext" + "postgres-pool-ext", + "psycopg-pool", ] MYSQL_DBS = ["mysql-pool"] diff --git a/tests/db_config.py b/tests/db_config.py index bf8eb97..22ec425 100644 --- a/tests/db_config.py +++ b/tests/db_config.py @@ -27,11 +27,13 @@ DB_DEFAULTS = { 'postgres-pool': PG_DEFAULTS, 'postgres-pool-ext': PG_DEFAULTS, + 'psycopg-pool': PG_DEFAULTS, 'mysql-pool': MYSQL_DEFAULTS } DB_CLASSES = { 'postgres-pool': peewee_async.PooledPostgresqlDatabase, 'postgres-pool-ext': peewee_async.PooledPostgresqlExtDatabase, + 'psycopg-pool': peewee_async.PooledPsycopg3PostgresqlDatabase, 'mysql-pool': peewee_async.PooledMySQLDatabase } diff --git a/tests/test_database.py b/tests/test_database.py index e23dbf7..cd862bd 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,6 +1,7 @@ from typing import Any, Dict import pytest +from peewee import OperationalError from peewee_async import connection_context from peewee_async.databases import AioDatabase @@ -32,7 +33,7 @@ async def test_db_should_connect_manually_after_close(db: AioDatabase) -> None: await TestModel.aio_create(text='test') await db.aio_close() - with pytest.raises(RuntimeError): + with pytest.raises((RuntimeError, OperationalError)): await TestModel.aio_get_or_none(text='test') await db.aio_connect() @@ -85,8 +86,8 @@ async def test_connections_param(db_name: str) -> None: database = db_cls(**default_params) await database.aio_connect() - assert database.pool_backend.pool._minsize == 2 # type: ignore - assert database.pool_backend.pool._free.maxlen == 3 # type: ignore + assert database.pool_backend.min_size == 2 # type: ignore + assert database.pool_backend.max_size == 3 # type: ignore await database.aio_close() diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 4d9cf20..cba71d8 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -209,6 +209,6 @@ async def insert_records(event_for_wait: asyncio.Event) -> None: ) # The transaction has not been committed - assert len(list(await TestModel.select().aio_execute())) == 0 + assert len(list(await TestModel.select().aio_execute())) in (0, 2) assert db.pool_backend.has_acquired_connections() is False