Skip to content

Commit

Permalink
documentation cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerhutcherson committed Aug 1, 2024
1 parent 704038b commit 28c7484
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 24 deletions.
20 changes: 20 additions & 0 deletions redisvl/extensions/llmcache/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,24 @@


class CacheEntry(BaseModel):
"""A single cache entry in Redis"""

entry_id: Optional[str] = Field(default=None)
"""Cache entry identifier"""
prompt: str
"""Input prompt or question cached in Redis"""
response: str
"""Response or answer to the question, cached in Redis"""
prompt_vector: List[float]
"""Text embedding representation of the prompt"""
inserted_at: float = Field(default_factory=current_timestamp)
"""Timestamp of when the entry was added to the cache"""
updated_at: float = Field(default_factory=current_timestamp)
"""Timestamp of when the entry was updated in the cache"""
metadata: Optional[Dict[str, Any]] = Field(default=None)
"""Optional metadata stored on the cache entry"""
filters: Optional[Dict[str, Any]] = Field(default=None)
"""Optional filter data stored on the cache entry for customizing retrieval"""

@root_validator(pre=True)
@classmethod
Expand Down Expand Up @@ -43,14 +53,24 @@ def to_dict(self) -> Dict:


class CacheHit(BaseModel):
"""A cache hit based on some input query"""

entry_id: str
"""Cache entry identifier"""
prompt: str
"""Input prompt or question cached in Redis"""
response: str
"""Response or answer to the question, cached in Redis"""
vector_distance: float
"""The semantic distance between the query vector and the stored prompt vector"""
inserted_at: float
"""Timestamp of when the entry was added to the cache"""
updated_at: float
"""Timestamp of when the entry was updated in the cache"""
metadata: Optional[Dict[str, Any]] = Field(default=None)
"""Optional metadata stored on the cache entry"""
filters: Optional[Dict[str, Any]] = Field(default=None)
"""Optional filter data stored on the cache entry for customizing retrieval"""

@root_validator(pre=True)
@classmethod
Expand Down
57 changes: 33 additions & 24 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional

from redis import Redis

Expand All @@ -10,20 +10,15 @@
)
from redisvl.index import SearchIndex
from redisvl.query import RangeQuery
from redisvl.query.filter import FilterExpression, Tag
from redisvl.redis.utils import array_to_buffer
from redisvl.utils.utils import (
current_timestamp,
deserialize,
serialize,
validate_vector_dims,
)
from redisvl.query.filter import FilterExpression
from redisvl.utils.utils import current_timestamp, serialize, validate_vector_dims
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer


class SemanticCache(BaseLLMCache):
"""Semantic Cache for Large Language Models."""

redis_key_field_name: str = "key"
entry_id_field_name: str = "entry_id"
prompt_field_name: str = "prompt"
response_field_name: str = "response"
Expand Down Expand Up @@ -55,6 +50,8 @@ def __init__(
in Redis. Defaults to None.
vectorizer (Optional[BaseVectorizer], optional): The vectorizer for the cache.
Defaults to HFTextVectorizer.
filterable_fields (Optional[List[Dict[str, Any]]]): An optional list of RedisVL fields
that can be used to customize cache retrieval with filters.
redis_client(Optional[Redis], optional): A redis client connection instance.
Defaults to None.
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
Expand All @@ -81,9 +78,6 @@ def __init__(
model="sentence-transformers/all-mpnet-base-v2"
)

# Create semantic cache schema
schema = SemanticCacheIndexSchema.from_params(name, prefix, vectorizer.dims)

# Process fields
self.return_fields = [
self.entry_id_field_name,
Expand All @@ -94,18 +88,9 @@ def __init__(
self.metadata_field_name,
]

if filterable_fields is not None:
for filter_field in filterable_fields:
if (
filter_field["name"] in self.return_fields
or filter_field["name"] == "key"
):
raise ValueError(
f'{filter_field["name"]} is a reserved field name for the semantic cache schema'
)
schema.add_field(filter_field)
# Add to return fields too
self.return_fields.append(filter_field["name"])
# Create semantic cache schema and index
schema = SemanticCacheIndexSchema.from_params(name, prefix, vectorizer.dims)
schema = self._modify_schema(schema, filterable_fields)

self._index = SearchIndex(schema=schema)

Expand All @@ -120,6 +105,30 @@ def __init__(
self.set_threshold(distance_threshold)
self._index.create(overwrite=False)

def _modify_schema(
self,
schema: SemanticCacheIndexSchema,
filterable_fields: Optional[List[Dict[str, Any]]] = None,
) -> SemanticCacheIndexSchema:
"""Modify the base cache schema using the provided filterable fields"""

if filterable_fields is not None:
protected_field_names = set(
self.return_fields + [self.redis_key_field_name]
)
for filter_field in filterable_fields:
field_name = filter_field["name"]
if field_name in protected_field_names:
raise ValueError(
f"{field_name} is a reserved field name for the semantic cache schema"
)
# Add to schema
schema.add_field(filter_field)
# Add to return fields too
self.return_fields.append(field_name)

return schema

@property
def index(self) -> SearchIndex:
"""The underlying SearchIndex for the cache.
Expand Down

0 comments on commit 28c7484

Please sign in to comment.