Skip to content

Commit

Permalink
Create collections with async context manager
Browse files Browse the repository at this point in the history
With the current version of pytest-asyncio we're using,
there's an issue with using async fixtures cached at
different scopes when they need the same event loop
scope.

See: pytest-dev/pytest-asyncio#871

An API breaking change that fixes this is available in 0.24,
but fixing this with a context manager here to avoid increasing
the blast radius.
  • Loading branch information
lossyrob committed Oct 7, 2024
1 parent 515b3a3 commit 3bbfdf7
Showing 1 changed file with 68 additions and 57 deletions.
125 changes: 68 additions & 57 deletions python/tests/integration/connectors/memory/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import uuid
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Annotated, Any

import pandas as pd
Expand All @@ -10,12 +11,10 @@
from pydantic import BaseModel

from semantic_kernel.connectors.memory.postgres import PostgresStore
from semantic_kernel.connectors.memory.postgres.postgres_collection import PostgresCollection
from semantic_kernel.connectors.memory.postgres.postgres_settings import PostgresSettings
from semantic_kernel.data.const import DistanceFunction, IndexKind
from semantic_kernel.data.vector_store_model_decorator import vectorstoremodel
from semantic_kernel.data.vector_store_model_definition import VectorStoreRecordDefinition
from semantic_kernel.data.vector_store_record_collection import VectorStoreRecordCollection
from semantic_kernel.data.vector_store_record_fields import (
VectorStoreRecordDataField,
VectorStoreRecordKeyField,
Expand Down Expand Up @@ -85,14 +84,22 @@ async def vector_store() -> AsyncGenerator[PostgresStore, None]:
yield PostgresStore(connection_pool=pool)


@pytest_asyncio.fixture(scope="function")
async def simple_collection(vector_store: PostgresStore):
@asynccontextmanager
async def create_simple_collection(vector_store: PostgresStore):
"""Returns a collection with a unique name that is deleted after the context.
This can be moved to use a fixture with scope=function and loop_scope=session
after upgrade to pytest-asyncio 0.24. With the current version, the fixture
would both cache and use the event loop of the declared scope.
"""
suffix = str(uuid.uuid4()).replace("-", "")[:8]
collection_id = f"test_collection_{suffix}"
collection = vector_store.get_collection(collection_id, SimpleDataModel)
await collection.create_collection()
yield collection
await collection.delete_collection()
try:
yield collection
finally:
await collection.delete_collection()


def test_create_store(vector_store):
Expand All @@ -118,37 +125,40 @@ async def test_create_does_collection_exist_and_delete(vector_store: PostgresSto


@pytest.mark.asyncio(scope="session")
async def test_list_collection_names(vector_store, simple_collection):
simple_collection_id = simple_collection.collection_name
result = await vector_store.list_collection_names()
assert simple_collection_id in result
async def test_list_collection_names(vector_store):
async with create_simple_collection(vector_store) as simple_collection:
simple_collection_id = simple_collection.collection_name
result = await vector_store.list_collection_names()
assert simple_collection_id in result


@pytest.mark.asyncio(scope="session")
async def test_upsert_get_and_delete(simple_collection: PostgresCollection):
async def test_upsert_get_and_delete(vector_store: PostgresStore):
record = SimpleDataModel(id=1, embedding=[1.1, 2.2, 3.3], data={"key": "value"})
async with create_simple_collection(vector_store) as simple_collection:
result_before_upsert = await simple_collection.get(1)
assert result_before_upsert is None

result_before_upsert = await simple_collection.get(1)
assert result_before_upsert is None

await simple_collection.upsert(record)
result = await simple_collection.get(1)
assert result is not None
assert result.id == record.id
assert result.embedding == record.embedding
assert result.data == record.data

# Check that the table has an index
connection_pool = simple_collection.connection_pool
async with connection_pool.connection() as conn, conn.cursor() as cur:
await cur.execute("SELECT indexname FROM pg_indexes WHERE tablename = %s", (simple_collection.collection_name,))
rows = await cur.fetchall()
index_names = [index[0] for index in rows]
assert any("embedding_idx" in index_name for index_name in index_names)

await simple_collection.delete(1)
result_after_delete = await simple_collection.get(1)
assert result_after_delete is None
await simple_collection.upsert(record)
result = await simple_collection.get(1)
assert result is not None
assert result.id == record.id
assert result.embedding == record.embedding
assert result.data == record.data

# Check that the table has an index
connection_pool = simple_collection.connection_pool
async with connection_pool.connection() as conn, conn.cursor() as cur:
await cur.execute(
"SELECT indexname FROM pg_indexes WHERE tablename = %s", (simple_collection.collection_name,)
)
rows = await cur.fetchall()
index_names = [index[0] for index in rows]
assert any("embedding_idx" in index_name for index_name in index_names)

await simple_collection.delete(1)
result_after_delete = await simple_collection.get(1)
assert result_after_delete is None


@pytest.mark.asyncio(scope="session")
Expand Down Expand Up @@ -182,28 +192,29 @@ async def test_upsert_get_and_delete_pandas(vector_store):


@pytest.mark.asyncio(scope="session")
async def test_upsert_get_and_delete_batch(simple_collection: VectorStoreRecordCollection):
record1 = SimpleDataModel(id=1, embedding=[1.1, 2.2, 3.3], data={"key": "value"})
record2 = SimpleDataModel(id=2, embedding=[4.4, 5.5, 6.6], data={"key": "value"})

result_before_upsert = await simple_collection.get_batch([1, 2])
assert result_before_upsert is None

await simple_collection.upsert_batch([record1, record2])
# Test get_batch for the two existing keys and one non-existing key;
# this should return only the two existing records.
result = await simple_collection.get_batch([1, 2, 3])
assert result is not None
assert len(result) == 2
assert result[0] is not None
assert result[0].id == record1.id
assert result[0].embedding == record1.embedding
assert result[0].data == record1.data
assert result[1] is not None
assert result[1].id == record2.id
assert result[1].embedding == record2.embedding
assert result[1].data == record2.data

await simple_collection.delete_batch([1, 2])
result_after_delete = await simple_collection.get_batch([1, 2])
assert result_after_delete is None
async def test_upsert_get_and_delete_batch(vector_store: PostgresStore):
async with create_simple_collection(vector_store) as simple_collection:
record1 = SimpleDataModel(id=1, embedding=[1.1, 2.2, 3.3], data={"key": "value"})
record2 = SimpleDataModel(id=2, embedding=[4.4, 5.5, 6.6], data={"key": "value"})

result_before_upsert = await simple_collection.get_batch([1, 2])
assert result_before_upsert is None

await simple_collection.upsert_batch([record1, record2])
# Test get_batch for the two existing keys and one non-existing key;
# this should return only the two existing records.
result = await simple_collection.get_batch([1, 2, 3])
assert result is not None
assert len(result) == 2
assert result[0] is not None
assert result[0].id == record1.id
assert result[0].embedding == record1.embedding
assert result[0].data == record1.data
assert result[1] is not None
assert result[1].id == record2.id
assert result[1].embedding == record2.embedding
assert result[1].data == record2.data

await simple_collection.delete_batch([1, 2])
result_after_delete = await simple_collection.get_batch([1, 2])
assert result_after_delete is None

0 comments on commit 3bbfdf7

Please sign in to comment.