Skip to content

Commit

Permalink
Expose aggregation API from SearchIndex (#230)
Browse files Browse the repository at this point in the history
In order to support more advanced queries, we expose the `aggregate`
method to pass through to the core Redis FT.AGGREGATE API. This PR also
simplifies and standardizes error handling for Redis
searches/aggregations on the index.
  • Loading branch information
tylerhutcherson authored Oct 8, 2024
1 parent 3c74dee commit 19dedcb
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 69 deletions.
4 changes: 4 additions & 0 deletions redisvl/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ class RedisVLException(Exception):

class RedisModuleVersionError(RedisVLException):
"""Invalid module versions installed"""


class RedisSearchError(RedisVLException):
"""Error while performing a search or aggregate request"""
12 changes: 6 additions & 6 deletions redisvl/extensions/router/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,9 @@ def _classify_route(
)

try:
aggregation_result: AggregateResult = self._index.client.ft( # type: ignore
self._index.name
).aggregate(aggregate_request, vector_range_query.params)
aggregation_result: AggregateResult = self._index.aggregate(
aggregate_request, vector_range_query.params
)
except ResponseError as e:
if "VSS is not yet supported on FT.AGGREGATE" in str(e):
raise RuntimeError(
Expand Down Expand Up @@ -308,9 +308,9 @@ def _classify_multi_route(
)

try:
aggregation_result: AggregateResult = self._index.client.ft( # type: ignore
self._index.name
).aggregate(aggregate_request, vector_range_query.params)
aggregation_result: AggregateResult = self._index.aggregate(
aggregate_request, vector_range_query.params
)
except ResponseError as e:
if "VSS is not yet supported on FT.AGGREGATE" in str(e):
raise RuntimeError(
Expand Down
112 changes: 55 additions & 57 deletions redisvl/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)

if TYPE_CHECKING:
from redis.commands.search.aggregation import AggregateResult
from redis.commands.search.document import Document
from redis.commands.search.result import Result
from redisvl.query.query import BaseQuery
Expand All @@ -25,7 +26,7 @@
import redis.asyncio as aredis
from redis.commands.search.indexDefinition import IndexDefinition

from redisvl.exceptions import RedisModuleVersionError
from redisvl.exceptions import RedisModuleVersionError, RedisSearchError
from redisvl.index.storage import BaseStorage, HashStorage, JsonStorage
from redisvl.query import BaseQuery, CountQuery, FilterQuery
from redisvl.query.filter import FilterExpression
Expand Down Expand Up @@ -123,36 +124,6 @@ async def wrapper(self, *args, **kwargs):
return decorator


def check_index_exists():
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self.exists():
raise RuntimeError(
f"Index has not been created. Must be created before calling {func.__name__}"
)
return func(self, *args, **kwargs)

return wrapper

return decorator


def check_async_index_exists():
def decorator(func):
@wraps(func)
async def wrapper(self, *args, **kwargs):
if not await self.exists():
raise ValueError(
f"Index has not been created. Must be created before calling {func.__name__}"
)
return await func(self, *args, **kwargs)

return wrapper

return decorator


class BaseSearchIndex:
"""Base search engine class"""

Expand Down Expand Up @@ -486,7 +457,6 @@ def create(self, overwrite: bool = False, drop: bool = False) -> None:
logger.exception("Error while trying to create the index")
raise

@check_index_exists()
def delete(self, drop: bool = True):
"""Delete the search index while optionally dropping all keys associated
with the index.
Expand All @@ -502,8 +472,8 @@ def delete(self, drop: bool = True):
self._redis_client.ft(self.schema.index.name).dropindex( # type: ignore
delete_documents=drop
)
except:
logger.exception("Error while deleting index")
except Exception as e:
raise RedisSearchError(f"Error while deleting index: {str(e)}") from e

def clear(self) -> int:
"""Clear all keys in Redis associated with the index, leaving the index
Expand Down Expand Up @@ -629,13 +599,29 @@ def fetch(self, id: str) -> Optional[Dict[str, Any]]:
return convert_bytes(obj[0])
return None

@check_index_exists()
def aggregate(self, *args, **kwargs) -> "AggregateResult":
"""Perform an aggregation operation against the index.
Wrapper around the aggregation API that adds the index name
to the query and passes along the rest of the arguments
to the redis-py ft().aggregate() method.
Returns:
Result: Raw Redis aggregation results.
"""
try:
return self._redis_client.ft(self.schema.index.name).aggregate( # type: ignore
*args, **kwargs
)
except Exception as e:
raise RedisSearchError(f"Error while aggregating: {str(e)}") from e

def search(self, *args, **kwargs) -> "Result":
"""Perform a search against the index.
Wrapper around redis.search.Search that adds the index name
to the search query and passes along the rest of the arguments
to the redis-py ft.search() method.
Wrapper around the search API that adds the index name
to the query and passes along the rest of the arguments
to the redis-py ft().search() method.
Returns:
Result: Raw Redis search results.
Expand All @@ -644,9 +630,8 @@ def search(self, *args, **kwargs) -> "Result":
return self._redis_client.ft(self.schema.index.name).search( # type: ignore
*args, **kwargs
)
except:
logger.exception("Error while searching")
raise
except Exception as e:
raise RedisSearchError(f"Error while searching: {str(e)}") from e

def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
"""Execute a query and process results."""
Expand Down Expand Up @@ -752,11 +737,11 @@ def _info(name: str, redis_client: redis.Redis) -> Dict[str, Any]:
"""Run FT.INFO to fetch information about the index."""
try:
return convert_bytes(redis_client.ft(name).info()) # type: ignore
except:
logger.exception(f"Error while fetching {name} index info")
raise
except Exception as e:
raise RedisSearchError(
f"Error while fetching {name} index info: {str(e)}"
) from e

@check_index_exists()
def info(self, name: Optional[str] = None) -> Dict[str, Any]:
"""Get information about the index.
Expand Down Expand Up @@ -1010,7 +995,6 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None:
logger.exception("Error while trying to create the index")
raise

@check_async_index_exists()
async def delete(self, drop: bool = True):
"""Delete the search index.
Expand All @@ -1025,9 +1009,8 @@ async def delete(self, drop: bool = True):
await self._redis_client.ft(self.schema.index.name).dropindex( # type: ignore
delete_documents=drop
)
except:
logger.exception("Error while deleting index")
raise
except Exception as e:
raise RedisSearchError(f"Error while deleting index: {str(e)}") from e

async def clear(self) -> int:
"""Clear all keys in Redis associated with the index, leaving the index
Expand Down Expand Up @@ -1152,7 +1135,23 @@ async def fetch(self, id: str) -> Optional[Dict[str, Any]]:
return convert_bytes(obj[0])
return None

@check_async_index_exists()
async def aggregate(self, *args, **kwargs) -> "AggregateResult":
"""Perform an aggregation operation against the index.
Wrapper around the aggregation API that adds the index name
to the query and passes along the rest of the arguments
to the redis-py ft().aggregate() method.
Returns:
Result: Raw Redis aggregation results.
"""
try:
return await self._redis_client.ft(self.schema.index.name).aggregate( # type: ignore
*args, **kwargs
)
except Exception as e:
raise RedisSearchError(f"Error while aggregating: {str(e)}") from e

async def search(self, *args, **kwargs) -> "Result":
"""Perform a search on this index.
Expand All @@ -1167,9 +1166,8 @@ async def search(self, *args, **kwargs) -> "Result":
return await self._redis_client.ft(self.schema.index.name).search( # type: ignore
*args, **kwargs
)
except:
logger.exception("Error while searching")
raise
except Exception as e:
raise RedisSearchError(f"Error while searching: {str(e)}") from e

async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
"""Asynchronously execute a query and process results."""
Expand Down Expand Up @@ -1275,11 +1273,11 @@ async def exists(self) -> bool:
async def _info(name: str, redis_client: aredis.Redis) -> Dict[str, Any]:
try:
return convert_bytes(await redis_client.ft(name).info()) # type: ignore
except:
logger.exception(f"Error while fetching {name} index info")
raise
except Exception as e:
raise RedisSearchError(
f"Error while fetching {name} index info: {str(e)}"
) from e

@check_async_index_exists()
async def info(self, name: Optional[str] = None) -> Dict[str, Any]:
"""Get information about the index.
Expand Down
7 changes: 4 additions & 3 deletions tests/integration/test_async_search_index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from redisvl.exceptions import RedisSearchError
from redisvl.index import AsyncSearchIndex
from redisvl.query import VectorQuery
from redisvl.redis.utils import convert_bytes
Expand Down Expand Up @@ -291,7 +292,7 @@ async def test_check_index_exists_before_delete(async_client, async_index):
await async_index.set_client(async_client)
await async_index.create(overwrite=True, drop=True)
await async_index.delete(drop=True)
with pytest.raises(ValueError):
with pytest.raises(RedisSearchError):
await async_index.delete()


Expand All @@ -307,7 +308,7 @@ async def test_check_index_exists_before_search(async_client, async_index):
return_fields=["user", "credit_score", "age", "job", "location"],
num_results=7,
)
with pytest.raises(ValueError):
with pytest.raises(RedisSearchError):
await async_index.search(query.query, query_params=query.params)


Expand All @@ -317,5 +318,5 @@ async def test_check_index_exists_before_info(async_client, async_index):
await async_index.create(overwrite=True, drop=True)
await async_index.delete(drop=True)

with pytest.raises(ValueError):
with pytest.raises(RedisSearchError):
await async_index.info()
7 changes: 4 additions & 3 deletions tests/integration/test_search_index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from redisvl.exceptions import RedisSearchError
from redisvl.index import SearchIndex
from redisvl.query import VectorQuery
from redisvl.redis.connection import RedisConnectionFactory, validate_modules
Expand Down Expand Up @@ -251,7 +252,7 @@ def test_check_index_exists_before_delete(client, index):
index.set_client(client)
index.create(overwrite=True, drop=True)
index.delete(drop=True)
with pytest.raises(RuntimeError):
with pytest.raises(RedisSearchError):
index.delete()


Expand All @@ -266,7 +267,7 @@ def test_check_index_exists_before_search(client, index):
return_fields=["user", "credit_score", "age", "job", "location"],
num_results=7,
)
with pytest.raises(RuntimeError):
with pytest.raises(RedisSearchError):
index.search(query.query, query_params=query.params)


Expand All @@ -275,7 +276,7 @@ def test_check_index_exists_before_info(client, index):
index.create(overwrite=True, drop=True)
index.delete(drop=True)

with pytest.raises(RuntimeError):
with pytest.raises(RedisSearchError):
index.info()


Expand Down

0 comments on commit 19dedcb

Please sign in to comment.