Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor search index to improve connection handling #192

Merged
merged 3 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/openai_qna.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@
"client = redis.Redis.from_url(\"redis://localhost:6379\")\n",
"schema = IndexSchema.from_yaml(\"wiki_schema.yaml\")\n",
"\n",
"index = AsyncSearchIndex(schema, client)\n",
"index = await AsyncSearchIndex(schema).set_client(client)\n",
"\n",
"await index.create()"
]
Expand Down
2 changes: 1 addition & 1 deletion docs/user_guide/getting_started_01.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@
"client = Redis.from_url(\"redis://localhost:6379\")\n",
"\n",
"index = AsyncSearchIndex.from_dict(schema)\n",
"index.set_client(client)"
"await index.set_client(client)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion redisvl/extensions/session_manager/standard_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
self._client = RedisConnectionFactory.get_redis_connection(
redis_url, **connection_kwargs
)
RedisConnectionFactory.validate_redis(self._client)
RedisConnectionFactory.validate_sync_redis(self._client)

self.set_scope(session_tag, user_tag)

Expand Down
168 changes: 105 additions & 63 deletions redisvl/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,24 @@ def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
result = func(self, *args, **kwargs)
RedisConnectionFactory.validate_redis(self._redis_client, self._lib_name)
RedisConnectionFactory.validate_sync_redis(
self._redis_client, self._lib_name
)
return result

return wrapper

return decorator


def setup_async_redis():
def decorator(func):
@wraps(func)
async def wrapper(self, *args, **kwargs):
result = await func(self, *args, **kwargs)
await RedisConnectionFactory.validate_async_redis(
self._redis_client, self._lib_name
)
return result

return wrapper
Expand Down Expand Up @@ -140,41 +157,10 @@ class BaseSearchIndex:
StorageType.JSON: JsonStorage,
}

def __init__(
self,
schema: IndexSchema,
redis_client: Optional[Union[redis.Redis, aredis.Redis]] = None,
redis_url: Optional[str] = None,
connection_args: Dict[str, Any] = {},
**kwargs,
):
"""Initialize the RedisVL search index with a schema, Redis client
(or URL string with other connection args), connection_args, and other
kwargs.

Args:
schema (IndexSchema): Index schema object.
redis_client(Union[redis.Redis, aredis.Redis], optional): An
instantiated redis client.
redis_url (str, optional): The URL of the Redis server to
connect to.
connection_args (Dict[str, Any], optional): Redis client connection
args.
"""
# final validation on schema object
if not isinstance(schema, IndexSchema):
raise ValueError("Must provide a valid IndexSchema object")

self.schema = schema

self._lib_name: Optional[str] = kwargs.pop("lib_name", None)
schema: IndexSchema

# set up redis connection
self._redis_client: Optional[Union[redis.Redis, aredis.Redis]] = None
if redis_client is not None:
self.set_client(redis_client)
elif redis_url is not None:
self.connect(redis_url, **connection_args)
def __init__(*args, **kwargs):
pass

@property
def _storage(self) -> BaseStorage:
Expand Down Expand Up @@ -237,8 +223,6 @@ def from_dict(cls, schema_dict: Dict[str, Any], **kwargs):

Args:
schema_dict (Dict[str, Any]): A dictionary containing the schema.
connection_args (Dict[str, Any], optional): Redis client connection
args.

Returns:
SearchIndex: A RedisVL SearchIndex object.
Expand All @@ -262,14 +246,6 @@ def from_dict(cls, schema_dict: Dict[str, Any], **kwargs):
schema = IndexSchema.from_dict(schema_dict)
return cls(schema=schema, **kwargs)

def connect(self, redis_url: Optional[str] = None, **kwargs):
"""Connect to Redis at a given URL."""
raise NotImplementedError

def set_client(self, client: Union[redis.Redis, aredis.Redis]):
"""Manually set the Redis client to use with the search index."""
raise NotImplementedError

def disconnect(self):
"""Disconnect from the Redis database."""
self._redis_client = None
Expand Down Expand Up @@ -323,6 +299,43 @@ class SearchIndex(BaseSearchIndex):

"""

def __init__(
self,
schema: IndexSchema,
redis_client: Optional[redis.Redis] = None,
redis_url: Optional[str] = None,
connection_args: Dict[str, Any] = {},
**kwargs,
):
"""Initialize the RedisVL search index with a schema, Redis client
(or URL string with other connection args), connection_args, and other
kwargs.

Args:
schema (IndexSchema): Index schema object.
redis_client(Optional[redis.Redis]): An
instantiated redis client.
redis_url (Optional[str]): The URL of the Redis server to
connect to.
connection_args (Dict[str, Any], optional): Redis client connection
args.
"""
# final validation on schema object
if not isinstance(schema, IndexSchema):
raise ValueError("Must provide a valid IndexSchema object")

self.schema = schema

self._lib_name: Optional[str] = kwargs.pop("lib_name", None)

# set up redis connection
self._redis_client: Optional[redis.Redis] = None

if redis_client is not None:
self.set_client(redis_client)
elif redis_url is not None:
self.connect(redis_url, **connection_args)

@classmethod
def from_existing(
cls,
Expand All @@ -342,7 +355,7 @@ def from_existing(
)

# Validate modules
installed_modules = RedisConnectionFactory._get_modules(redis_client)
installed_modules = RedisConnectionFactory.get_modules(redis_client)
validate_modules(installed_modules, [{"name": "search", "ver": 20810}])

# Fetch index info and convert to schema
Expand Down Expand Up @@ -380,15 +393,15 @@ def connect(self, redis_url: Optional[str] = None, **kwargs):
return self.set_client(client)

@setup_redis()
def set_client(self, client: redis.Redis, **kwargs):
def set_client(self, redis_client: redis.Redis, **kwargs):
"""Manually set the Redis client to use with the search index.

This method configures the search index to use a specific Redis or
Async Redis client. It is useful for cases where an external,
custom-configured client is preferred instead of creating a new one.

Args:
client (redis.Redis): A Redis or Async Redis
redis_client (redis.Redis): A Redis or Async Redis
client instance to be used for the connection.

Raises:
Expand All @@ -404,10 +417,10 @@ def set_client(self, client: redis.Redis, **kwargs):
index.set_client(client)

"""
if not isinstance(client, redis.Redis):
if not isinstance(redis_client, redis.Redis):
raise TypeError("Invalid Redis client instance")

self._redis_client = client
self._redis_client = redis_client

return self

Expand Down Expand Up @@ -759,7 +772,7 @@ class AsyncSearchIndex(BaseSearchIndex):

# initialize the index object with schema from file
index = AsyncSearchIndex.from_yaml("schemas/schema.yaml")
index.connect(redis_url="redis://localhost:6379")
await index.connect(redis_url="redis://localhost:6379")

# create the index
await index.create(overwrite=True)
Expand All @@ -772,6 +785,34 @@ class AsyncSearchIndex(BaseSearchIndex):

"""

def __init__(
self,
schema: IndexSchema,
**kwargs,
):
"""Initialize the RedisVL async search index with a schema.

Args:
schema (IndexSchema): Index schema object.
connection_args (Dict[str, Any], optional): Redis client connection
args.
"""
# final validation on schema object
if not isinstance(schema, IndexSchema):
raise ValueError("Must provide a valid IndexSchema object")

self.schema = schema

self._lib_name: Optional[str] = kwargs.pop("lib_name", None)

# set up empty redis connection
self._redis_client: Optional[aredis.Redis] = None

if "redis_client" in kwargs or "redis_url" in kwargs:
logger.warning(
"Must use set_client() or connect() methods to provide a Redis connection to AsyncSearchIndex"
)

@classmethod
async def from_existing(
cls,
Expand All @@ -791,18 +832,18 @@ async def from_existing(
)

# Validate modules
installed_modules = await RedisConnectionFactory._get_modules_async(
redis_client
)
installed_modules = await RedisConnectionFactory.get_modules_async(redis_client)
validate_modules(installed_modules, [{"name": "search", "ver": 20810}])

# Fetch index info and convert to schema
index_info = await cls._info(name, redis_client)
schema_dict = convert_index_info_to_schema(index_info)
schema = IndexSchema.from_dict(schema_dict)
return cls(schema, redis_client, **kwargs)
index = cls(schema, **kwargs)
await index.set_client(redis_client)
return index

def connect(self, redis_url: Optional[str] = None, **kwargs):
async def connect(self, redis_url: Optional[str] = None, **kwargs):
"""Connect to a Redis instance using the provided `redis_url`, falling
back to the `REDIS_URL` environment variable (if available).

Expand All @@ -828,18 +869,18 @@ def connect(self, redis_url: Optional[str] = None, **kwargs):
client = RedisConnectionFactory.connect(
redis_url=redis_url, use_async=True, **kwargs
)
return self.set_client(client)
return await self.set_client(client)

@setup_redis()
def set_client(self, client: aredis.Redis):
@setup_async_redis()
async def set_client(self, redis_client: aredis.Redis):
"""Manually set the Redis client to use with the search index.

This method configures the search index to use a specific
Async Redis client. It is useful for cases where an external,
custom-configured client is preferred instead of creating a new one.

Args:
client (aredis.Redis): An Async Redis
redis_client (aredis.Redis): An Async Redis
client instance to be used for the connection.

Raises:
Expand All @@ -853,13 +894,13 @@ def set_client(self, client: aredis.Redis):
# async Redis client and index
client = aredis.Redis.from_url("redis://localhost:6379")
index = AsyncSearchIndex.from_yaml("schemas/schema.yaml")
index.set_client(client)
await index.set_client(client)

"""
if not isinstance(client, aredis.Redis):
if not isinstance(redis_client, aredis.Redis):
raise TypeError("Invalid Redis client instance")

self._redis_client = client
self._redis_client = redis_client

return self

Expand Down Expand Up @@ -889,6 +930,7 @@ async def create(self, overwrite: bool = False, drop: bool = False) -> None:
await index.create(overwrite=True, drop=True)
"""
redis_fields = self.schema.redis_fields

if not redis_fields:
raise ValueError("No fields defined for index")
if not isinstance(overwrite, bool):
Expand Down
Loading
Loading