Skip to content

Commit

Permalink
add support for connection kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerhutcherson committed Jul 29, 2024
1 parent 521548a commit 5249360
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 22 deletions.
26 changes: 4 additions & 22 deletions redisvl/extensions/router/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
redis_client: Optional[Redis] = None,
redis_url: str = "redis://localhost:6379",
overwrite: bool = False,
connection_kwargs: Dict[str, Any] = {},
**kwargs,
):
"""Initialize the SemanticRouter.
Expand All @@ -100,7 +101,8 @@ def __init__(
redis_client (Optional[Redis], optional): Redis client for connection. Defaults to None.
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
overwrite (bool, optional): Whether to overwrite existing index. Defaults to False.
**kwargs: Additional arguments.
connection_kwargs (Dict[str, Any]): The connection arguments
for the redis client. Defaults to empty {}.
"""
# Set vectorizer default
if vectorizer is None:
Expand All @@ -115,7 +117,7 @@ def __init__(
vectorizer=vectorizer,
routing_config=routing_config,
)
self._initialize_index(redis_client, redis_url, overwrite)
self._initialize_index(redis_client, redis_url, overwrite, **connection_kwargs)

def _initialize_index(
self,
Expand Down Expand Up @@ -477,19 +479,12 @@ def clear(self) -> None:
def from_dict(
cls,
data: Dict[str, Any],
redis_client: Optional[Redis] = None,
redis_url: str = "redis://localhost:6379",
overwrite: bool = False,
**kwargs,
) -> "SemanticRouter":
"""Create a SemanticRouter from a dictionary.
Args:
data (Dict[str, Any]): The dictionary containing the semantic router data.
redis_client (Optional[Redis]): Redis client for connection.
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
overwrite (bool): Whether to overwrite existing index.
**kwargs: Additional arguments.
Returns:
SemanticRouter: The semantic router instance.
Expand Down Expand Up @@ -531,9 +526,6 @@ def from_dict(
routes=routes,
vectorizer=vectorizer,
routing_config=routing_config,
redis_client=redis_client,
redis_url=redis_url,
overwrite=overwrite,
**kwargs,
)

Expand Down Expand Up @@ -563,19 +555,12 @@ def to_dict(self) -> Dict[str, Any]:
def from_yaml(
cls,
file_path: str,
redis_client: Optional[Redis] = None,
redis_url: str = "redis://localhost:6379",
overwrite: bool = False,
**kwargs,
) -> "SemanticRouter":
"""Create a SemanticRouter from a YAML file.
Args:
file_path (str): The path to the YAML file.
redis_client (Optional[Redis]): Redis client for connection.
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
overwrite (bool): Whether to overwrite existing index.
**kwargs: Additional arguments.
Returns:
SemanticRouter: The semantic router instance.
Expand All @@ -601,9 +586,6 @@ def from_yaml(
yaml_data = yaml.safe_load(f)
return cls.from_dict(
yaml_data,
redis_client=redis_client,
redis_url=redis_url,
overwrite=overwrite,
**kwargs,
)

Expand Down
10 changes: 10 additions & 0 deletions tests/integration/test_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def test_specify_redis_client(client):
assert isinstance(session._client, type(client))


def test_specify_redis_url(client):
session = StandardSessionManager(
name="test_app",
session_tag="abc",
user_tag="123",
redis_url="redis://localhost:6379",
)
assert isinstance(session._client, type(client))


def test_standard_bad_connection_info():
with pytest.raises(ConnectionError):
StandardSessionManager(
Expand Down

0 comments on commit 5249360

Please sign in to comment.