Skip to content

Commit

Permalink
Standardize redis init in extensions (#188)
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerhutcherson authored Jul 30, 2024
1 parent 877f4f2 commit cb61457
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 67 deletions.
17 changes: 8 additions & 9 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
vectorizer: Optional[BaseVectorizer] = None,
redis_client: Optional[Redis] = None,
redis_url: str = "redis://localhost:6379",
connection_args: Dict[str, Any] = {},
connection_kwargs: Dict[str, Any] = {},
**kwargs,
):
"""Semantic Cache for Large Language Models.
Expand All @@ -43,14 +43,13 @@ def __init__(
cache. Defaults to 0.1.
ttl (Optional[int], optional): The time-to-live for records cached
in Redis. Defaults to None.
vectorizer (BaseVectorizer, optional): The vectorizer for the cache.
vectorizer (Optional[BaseVectorizer], optional): The vectorizer for the cache.
Defaults to HFTextVectorizer.
redis_client(Redis, optional): A redis client connection instance.
redis_client(Optional[Redis], optional): A redis client connection instance.
Defaults to None.
redis_url (str, optional): The redis url. Defaults to
"redis://localhost:6379".
connection_args (Dict[str, Any], optional): The connection arguments
for the redis client. Defaults to None.
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
connection_kwargs (Dict[str, Any]): The connection arguments
for the redis client. Defaults to empty {}.
Raises:
TypeError: If an invalid vectorizer is provided.
Expand Down Expand Up @@ -96,8 +95,8 @@ def __init__(
# handle redis connection
if redis_client:
self._index.set_client(redis_client)
else:
self._index.connect(redis_url=redis_url, **connection_args)
elif redis_url:
self._index.connect(redis_url=redis_url, **connection_kwargs)

# initialize other components
self.default_return_fields = [
Expand Down
34 changes: 7 additions & 27 deletions redisvl/extensions/router/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ def __init__(
vectorizer: Optional[BaseVectorizer] = None,
routing_config: Optional[RoutingConfig] = None,
redis_client: Optional[Redis] = None,
redis_url: Optional[str] = None,
redis_url: str = "redis://localhost:6379",
overwrite: bool = False,
connection_kwargs: Dict[str, Any] = {},
**kwargs,
):
"""Initialize the SemanticRouter.
Expand All @@ -98,9 +99,10 @@ def __init__(
vectorizer (BaseVectorizer, optional): The vectorizer used to embed route references. Defaults to default HFTextVectorizer.
routing_config (RoutingConfig, optional): Configuration for routing behavior. Defaults to the default RoutingConfig.
redis_client (Optional[Redis], optional): Redis client for connection. Defaults to None.
redis_url (Optional[str], optional): Redis URL for connection. Defaults to None.
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
overwrite (bool, optional): Whether to overwrite existing index. Defaults to False.
**kwargs: Additional arguments.
connection_kwargs (Dict[str, Any]): The connection arguments
for the redis client. Defaults to empty {}.
"""
# Set vectorizer default
if vectorizer is None:
Expand All @@ -115,12 +117,12 @@ def __init__(
vectorizer=vectorizer,
routing_config=routing_config,
)
self._initialize_index(redis_client, redis_url, overwrite)
self._initialize_index(redis_client, redis_url, overwrite, **connection_kwargs)

def _initialize_index(
self,
redis_client: Optional[Redis] = None,
redis_url: Optional[str] = None,
redis_url: str = "redis://localhost:6379",
overwrite: bool = False,
**connection_kwargs,
):
Expand All @@ -132,8 +134,6 @@ def _initialize_index(
self._index.set_client(redis_client)
elif redis_url:
self._index.connect(redis_url=redis_url, **connection_kwargs)
else:
raise ValueError("Must provide either a redis client or redis url string.")

existed = self._index.exists()
self._index.create(overwrite=overwrite)
Expand Down Expand Up @@ -479,19 +479,12 @@ def clear(self) -> None:
def from_dict(
cls,
data: Dict[str, Any],
redis_client: Optional[Redis] = None,
redis_url: Optional[str] = None,
overwrite: bool = False,
**kwargs,
) -> "SemanticRouter":
"""Create a SemanticRouter from a dictionary.
Args:
data (Dict[str, Any]): The dictionary containing the semantic router data.
redis_client (Optional[Redis]): Redis client for connection.
redis_url (Optional[str]): Redis URL for connection.
overwrite (bool): Whether to overwrite existing index.
**kwargs: Additional arguments.
Returns:
SemanticRouter: The semantic router instance.
Expand Down Expand Up @@ -533,9 +526,6 @@ def from_dict(
routes=routes,
vectorizer=vectorizer,
routing_config=routing_config,
redis_client=redis_client,
redis_url=redis_url,
overwrite=overwrite,
**kwargs,
)

Expand Down Expand Up @@ -565,19 +555,12 @@ def to_dict(self) -> Dict[str, Any]:
def from_yaml(
cls,
file_path: str,
redis_client: Optional[Redis] = None,
redis_url: Optional[str] = None,
overwrite: bool = False,
**kwargs,
) -> "SemanticRouter":
"""Create a SemanticRouter from a YAML file.
Args:
file_path (str): The path to the YAML file.
redis_client (Optional[Redis]): Redis client for connection.
redis_url (Optional[str]): Redis URL for connection.
overwrite (bool): Whether to overwrite existing index.
**kwargs: Additional arguments.
Returns:
SemanticRouter: The semantic router instance.
Expand All @@ -603,9 +586,6 @@ def from_yaml(
yaml_data = yaml.safe_load(f)
return cls.from_dict(
yaml_data,
redis_client=redis_client,
redis_url=redis_url,
overwrite=overwrite,
**kwargs,
)

Expand Down
15 changes: 10 additions & 5 deletions redisvl/extensions/session_manager/semantic_session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from time import time
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

from redis import Redis

Expand Down Expand Up @@ -27,6 +27,8 @@ def __init__(
distance_threshold: float = 0.3,
redis_client: Optional[Redis] = None,
redis_url: str = "redis://localhost:6379",
connection_kwargs: Dict[str, Any] = {},
**kwargs,
):
"""Initialize session memory with index
Expand All @@ -43,12 +45,14 @@ def __init__(
user_tag (str): Tag to be added to entries to link to a specific user.
prefix (Optional[str]): Prefix for the keys for this session data.
Defaults to None and will be replaced with the index name.
vectorizer (Vectorizer): The vectorizer to create embeddings with.
vectorizer (Optional[BaseVectorizer]): The vectorizer used to create embeddings.
distance_threshold (float): The maximum semantic distance to be
included in the context. Defaults to 0.3.
redis_client (Optional[Redis]): A Redis client instance. Defaults to
None.
redis_url (str): The URL of the Redis instance. Defaults to 'redis://localhost:6379'.
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
connection_kwargs (Dict[str, Any]): The connection arguments
for the redis client. Defaults to empty {}.
The proposed schema will support a single vector embedding constructed
from either the prompt or response in a single string.
Expand Down Expand Up @@ -89,10 +93,11 @@ def __init__(

self._index = SearchIndex(schema=schema)

# handle redis connection
if redis_client:
self._index.set_client(redis_client)
else:
self._index.connect(redis_url=redis_url)
elif redis_url:
self._index.connect(redis_url=redis_url, **connection_kwargs)

self._index.create(overwrite=False)

Expand Down
19 changes: 14 additions & 5 deletions redisvl/extensions/session_manager/standard_session.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import json
from time import time
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

from redis import Redis

from redisvl.extensions.session_manager import BaseSessionManager
from redisvl.redis.connection import RedisConnectionFactory


class StandardSessionManager(BaseSessionManager):
Expand All @@ -16,6 +17,8 @@ def __init__(
user_tag: str,
redis_client: Optional[Redis] = None,
redis_url: str = "redis://localhost:6379",
connection_kwargs: Dict[str, Any] = {},
**kwargs,
):
"""Initialize session memory
Expand All @@ -31,18 +34,24 @@ def __init__(
user_tag (str): Tag to be added to entries to link to a specific user.
redis_client (Optional[Redis]): A Redis client instance. Defaults to
None.
redis_url (str): The URL of the Redis instance. Defaults to 'redis://localhost:6379'.
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
connection_kwargs (Dict[str, Any]): The connection arguments
for the redis client. Defaults to empty {}.
The proposed schema will support a single combined vector embedding
constructed from the prompt & response in a single string.
"""
super().__init__(name, session_tag, user_tag)

# handle redis connection
if redis_client:
self._client = redis_client
else:
self._client = Redis.from_url(redis_url)
elif redis_url:
self._client = RedisConnectionFactory.get_redis_connection(
redis_url, **connection_kwargs
)
RedisConnectionFactory.validate_redis(self._client)

self.set_scope(session_tag, user_tag)

Expand All @@ -51,7 +60,7 @@ def set_scope(
session_tag: Optional[str] = None,
user_tag: Optional[str] = None,
) -> None:
"""Set the filter to apply to querries based on the desired scope.
"""Set the filter to apply to queries based on the desired scope.
This new scope persists until another call to set_scope is made, or if
scope is specified in calls to get_recent.
Expand Down
26 changes: 11 additions & 15 deletions tests/integration/test_llmcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from time import sleep

import pytest
from redis.exceptions import ConnectionError

from redisvl.extensions.llmcache import SemanticCache
from redisvl.index.index import SearchIndex
Expand Down Expand Up @@ -40,19 +41,17 @@ def cache_with_ttl(vectorizer, redis_url):


@pytest.fixture
def cache_with_redis_client(vectorizer, client, redis_url):
def cache_with_redis_client(vectorizer, client):
cache_instance = SemanticCache(
vectorizer=vectorizer,
redis_client=client,
distance_threshold=0.2,
redis_url=redis_url,
)
yield cache_instance
cache_instance.clear() # Clear cache after each test
cache_instance._index.delete(True) # Clean up index


# # Test handling invalid input for check method
def test_bad_ttl(cache):
with pytest.raises(ValueError):
cache.set_ttl(2.5)
Expand All @@ -76,7 +75,6 @@ def test_reset_ttl(cache):
assert cache.ttl is None


# Test basic store and check functionality
def test_store_and_check(cache, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
Expand All @@ -91,7 +89,6 @@ def test_store_and_check(cache, vectorizer):
assert "metadata" not in check_result[0]


# Test clearing the cache
def test_clear(cache, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
Expand Down Expand Up @@ -139,7 +136,6 @@ def test_check_no_match(cache, vectorizer):
assert len(check_result) == 0


# Test handling invalid input for check method
def test_check_invalid_input(cache):
with pytest.raises(ValueError):
cache.check()
Expand All @@ -148,7 +144,15 @@ def test_check_invalid_input(cache):
cache.check(prompt="test", return_fields="bad value")


# Test storing with metadata
def test_bad_connection_info(vectorizer):
with pytest.raises(ConnectionError):
SemanticCache(
vectorizer=vectorizer,
distance_threshold=0.2,
redis_url="redis://localhost:6389",
)


def test_store_with_metadata(cache, vectorizer):
prompt = "This is another test prompt."
response = "This is another test response."
Expand All @@ -165,7 +169,6 @@ def test_store_with_metadata(cache, vectorizer):
assert check_result[0]["prompt"] == prompt


# Test storing with invalid metadata
def test_store_with_invalid_metadata(cache, vectorizer):
prompt = "This is another test prompt."
response = "This is another test response."
Expand All @@ -179,7 +182,6 @@ def test_store_with_invalid_metadata(cache, vectorizer):
cache.store(prompt, response, vector=vector, metadata=metadata)


# Test setting and getting the distance threshold
def test_distance_threshold(cache):
initial_threshold = cache.distance_threshold
new_threshold = 0.1
Expand All @@ -189,14 +191,12 @@ def test_distance_threshold(cache):
assert cache.distance_threshold != initial_threshold


# Test out of range distance threshold
def test_distance_threshold_out_of_range(cache):
out_of_range_threshold = -1
with pytest.raises(ValueError):
cache.set_threshold(out_of_range_threshold)


# Test storing and retrieving multiple items
def test_multiple_items(cache, vectorizer):
prompts_responses = {
"prompt1": "response1",
Expand All @@ -217,12 +217,10 @@ def test_multiple_items(cache, vectorizer):
assert "metadata" not in check_result[0]


# Test retrieving underlying SearchIndex for the cache.
def test_get_index(cache):
assert isinstance(cache.index, SearchIndex)


# Test basic functionality with cache created with user-provided Redis client
def test_store_and_check_with_provided_client(cache_with_redis_client, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
Expand All @@ -237,13 +235,11 @@ def test_store_and_check_with_provided_client(cache_with_redis_client, vectorize
assert "metadata" not in check_result[0]


# Test deleting the cache
def test_delete(cache_no_cleanup):
cache_no_cleanup.delete()
assert not cache_no_cleanup.index.exists()


# Test we can only store and check vectors of correct dimensions
def test_vector_size(cache, vectorizer):
prompt = "This is test prompt."
response = "This is a test response."
Expand Down
Loading

0 comments on commit cb61457

Please sign in to comment.