Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use pydantic for cache entries and hits #195

Merged
merged 4 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions docs/user_guide/llmcache_03.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@
"\n",
"llmcache = SemanticCache(\n",
" name=\"llmcache\", # underlying search index name\n",
" prefix=\"llmcache\", # redis key prefix for hash entries\n",
" redis_url=\"redis://localhost:6379\", # redis connection url string\n",
" distance_threshold=0.1 # semantic cache distance threshold\n",
")"
Expand All @@ -107,13 +106,15 @@
"│ llmcache │ HASH │ ['llmcache'] │ [] │ 0 │\n",
"╰──────────────┴────────────────┴──────────────┴─────────────────┴────────────╯\n",
"Index Fields:\n",
"╭───────────────┬───────────────┬────────┬────────────────┬────────────────╮\n",
"│ Name │ Attribute │ Type │ Field Option │ Option Value │\n",
"├───────────────┼───────────────┼────────┼────────────────┼────────────────┤\n",
"│ prompt │ prompt │ TEXT │ WEIGHT │ 1 │\n",
"│ response │ response │ TEXT │ WEIGHT │ 1 │\n",
"│ prompt_vector │ prompt_vector │ VECTOR │ │ │\n",
"╰───────────────┴───────────────┴────────┴────────────────┴────────────────╯\n"
"╭───────────────┬───────────────┬─────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬─────────────────┬────────────────╮\n",
"│ Name │ Attribute │ Type │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │\n",
"├───────────────┼───────────────┼─────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼─────────────────┼────────────────┤\n",
"│ prompt │ prompt │ TEXT │ WEIGHT │ 1 │ │ │ │ │ │ │\n",
"│ response │ response │ TEXT │ WEIGHT │ 1 │ │ │ │ │ │ │\n",
"│ inserted_at │ inserted_at │ NUMERIC │ │ │ │ │ │ │ │ │\n",
"│ updated_at │ updated_at │ NUMERIC │ │ │ │ │ │ │ │ │\n",
"│ prompt_vector │ prompt_vector │ VECTOR │ algorithm │ FLAT │ data_type │ FLOAT32 │ dim │ 768 │ distance_metric │ COSINE │\n",
"╰───────────────┴───────────────┴─────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴─────────────────┴────────────────╯\n"
]
}
],
Expand Down Expand Up @@ -208,7 +209,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[{'id': 'llmcache:115049a298532be2f181edb03f766770c0db84c22aff39003fec340deaec7545', 'vector_distance': '9.53674316406e-07', 'prompt': 'What is the capital of France?', 'response': 'Paris', 'metadata': {'city': 'Paris', 'country': 'france'}}]\n"
"[{'prompt': 'What is the capital of France?', 'response': 'Paris', 'metadata': {'city': 'Paris', 'country': 'france'}, 'key': 'llmcache:115049a298532be2f181edb03f766770c0db84c22aff39003fec340deaec7545'}]\n"
]
}
],
Expand Down Expand Up @@ -384,7 +385,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -408,14 +409,14 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Without caching, a call to openAI to answer this simple question took 1.460299015045166 seconds.\n"
"Without caching, a call to openAI to answer this simple question took 0.9312698841094971 seconds.\n"
]
},
{
Expand All @@ -424,7 +425,7 @@
"'llmcache:67e0f6e28fe2a61c0022fd42bf734bb8ffe49d3e375fd69d692574295a20fc1a'"
]
},
"execution_count": 18,
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -451,8 +452,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Avg time taken with LLM cache enabled: 0.2560166358947754\n",
"Percentage of time saved: 82.47%\n"
"Avg time taken with LLM cache enabled: 0.4896167993545532\n",
"Percentage of time saved: 47.42%\n"
]
}
],
Expand Down Expand Up @@ -515,7 +516,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -540,7 +541,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.10.14"
},
"orig_nbformat": 4
},
Expand Down
1 change: 0 additions & 1 deletion redisvl/extensions/llmcache/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from typing import Any, Dict, List, Optional

from redisvl.redis.utils import hashify
Expand Down
128 changes: 128 additions & 0 deletions redisvl/extensions/llmcache/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from typing import Any, Dict, List, Optional

from pydantic.v1 import BaseModel, Field, root_validator, validator

from redisvl.redis.utils import array_to_buffer, hashify
from redisvl.schema import IndexSchema
from redisvl.utils.utils import current_timestamp, deserialize, serialize


class CacheEntry(BaseModel):
"""A single cache entry in Redis"""

entry_id: Optional[str] = Field(default=None)
"""Cache entry identifier"""
prompt: str
"""Input prompt or question cached in Redis"""
response: str
"""Response or answer to the question, cached in Redis"""
prompt_vector: List[float]
"""Text embedding representation of the prompt"""
inserted_at: float = Field(default_factory=current_timestamp)
"""Timestamp of when the entry was added to the cache"""
updated_at: float = Field(default_factory=current_timestamp)
"""Timestamp of when the entry was updated in the cache"""
metadata: Optional[Dict[str, Any]] = Field(default=None)
"""Optional metadata stored on the cache entry"""
filters: Optional[Dict[str, Any]] = Field(default=None)
"""Optional filter data stored on the cache entry for customizing retrieval"""

@root_validator(pre=True)
@classmethod
def generate_id(cls, values):
# Ensure entry_id is set
if not values.get("entry_id"):
values["entry_id"] = hashify(values["prompt"])
return values

@validator("metadata")
def non_empty_metadata(cls, v):
if v is not None and not isinstance(v, dict):
raise TypeError("Metadata must be a dictionary.")
return v

def to_dict(self) -> Dict:
data = self.dict(exclude_none=True)
data["prompt_vector"] = array_to_buffer(self.prompt_vector)
if self.metadata:
data["metadata"] = serialize(self.metadata)
if self.filters:
data.update(self.filters)
del data["filters"]
return data


class CacheHit(BaseModel):
"""A cache hit based on some input query"""

entry_id: str
"""Cache entry identifier"""
prompt: str
"""Input prompt or question cached in Redis"""
response: str
"""Response or answer to the question, cached in Redis"""
vector_distance: float
"""The semantic distance between the query vector and the stored prompt vector"""
inserted_at: float
"""Timestamp of when the entry was added to the cache"""
updated_at: float
"""Timestamp of when the entry was updated in the cache"""
metadata: Optional[Dict[str, Any]] = Field(default=None)
"""Optional metadata stored on the cache entry"""
filters: Optional[Dict[str, Any]] = Field(default=None)
"""Optional filter data stored on the cache entry for customizing retrieval"""

@root_validator(pre=True)
@classmethod
def validate_cache_hit(cls, values):
# Deserialize metadata if necessary
if "metadata" in values and isinstance(values["metadata"], str):
values["metadata"] = deserialize(values["metadata"])

# Separate filters from other fields
known_fields = set(cls.__fields__.keys())
filters = {k: v for k, v in values.items() if k not in known_fields}

# Add filters to values
if filters:
values["filters"] = filters

# Remove filter fields from the main values
for k in filters:
values.pop(k)

return values

def to_dict(self) -> Dict:
data = self.dict(exclude_none=True)
if self.filters:
data.update(self.filters)
del data["filters"]

return data


class SemanticCacheIndexSchema(IndexSchema):

@classmethod
def from_params(cls, name: str, prefix: str, vector_dims: int):

return cls(
index={"name": name, "prefix": prefix}, # type: ignore
fields=[ # type: ignore
{"name": "prompt", "type": "text"},
{"name": "response", "type": "text"},
{"name": "inserted_at", "type": "numeric"},
{"name": "updated_at", "type": "numeric"},
{
"name": "prompt_vector",
"type": "vector",
"attrs": {
"dims": vector_dims,
"datatype": "float32",
"distance_metric": "cosine",
"algorithm": "flat",
},
},
],
)
Loading
Loading