Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds filter fields to cache key name hash #224

Merged
merged 3 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions redisvl/extensions/llmcache/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Any, Dict, List, Optional

from redisvl.redis.utils import hashify


class BaseLLMCache:
def __init__(self, ttl: Optional[int] = None):
Expand Down Expand Up @@ -79,14 +77,3 @@ async def astore(
"""Async store the specified key-value pair in the cache along with
metadata."""
raise NotImplementedError

def hash_input(self, prompt: str) -> str:
"""Hashes the input prompt using SHA256.

Args:
prompt (str): Input string to be hashed.

Returns:
str: Hashed string.
"""
return hashify(prompt)
2 changes: 1 addition & 1 deletion redisvl/extensions/llmcache/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class CacheEntry(BaseModel):
def generate_id(cls, values):
# Ensure entry_id is set
if not values.get("entry_id"):
values["entry_id"] = hashify(values["prompt"])
values["entry_id"] = hashify(values["prompt"], values.get("filters"))
return values

@validator("metadata")
Expand Down
9 changes: 6 additions & 3 deletions redisvl/redis/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import hashlib
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

import numpy as np

Expand Down Expand Up @@ -40,6 +40,9 @@ def buffer_to_array(buffer: bytes, dtype: Any = np.float32) -> List[float]:
return np.frombuffer(buffer, dtype=dtype).tolist()


def hashify(content: str) -> str:
"""Create a secure hash of some arbitrary input text."""
def hashify(content: str, extras: Optional[Dict[str, Any]] = None) -> str:
"""Create a secure hash of some arbitrary input text and optional dictionary."""
if extras:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small preference here to keep hashify dumb and move the filter string processing logic for this to within the generate_id method and pass the final string to hashify.

extra_string = " ".join([str(k) + str(v) for k, v in sorted(extras.items())])
content = content + extra_string
return hashlib.sha256(content.encode("utf-8")).hexdigest()
43 changes: 43 additions & 0 deletions tests/integration/test_llmcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,3 +800,46 @@ def test_index_updating(redis_url):
filter_expression=tag_filter,
)
assert len(response) == 1


def test_no_key_collision_on_identical_prompts(redis_url):
private_cache = SemanticCache(
name="private_cache",
redis_url=redis_url,
filterable_fields=[
{"name": "user_id", "type": "tag"},
{"name": "zip_code", "type": "numeric"},
],
)

private_cache.store(
prompt="What is the phone number linked to my account?",
response="The number on file is 123-555-0000",
filters={"user_id": "gabs"},
)

private_cache.store(
prompt="What's the phone number linked in my account?",
response="The number on file is 123-555-9999",
###filters={"user_id": "cerioni"},
filters={"user_id": "cerioni", "zip_code": 90210},
)

private_cache.store(
prompt="What's the phone number linked in my account?",
response="The number on file is 123-555-1111",
filters={"user_id": "bart"},
)

results = private_cache.check(
"What's the phone number linked in my account?", num_results=5
)
assert len(results) == 3

zip_code_filter = Num("zip_code") != 90210
filtered_results = private_cache.check(
"what's the phone number linked in my account?",
num_results=5,
filter_expression=zip_code_filter,
)
assert len(filtered_results) == 2
2 changes: 1 addition & 1 deletion tests/unit/test_llmcache_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_cache_entry_to_dict():
filters={"category": "technology"},
)
result = entry.to_dict()
assert result["entry_id"] == hashify("What is AI?")
assert result["entry_id"] == hashify("What is AI?", {"category": "technology"})
assert result["metadata"] == json.dumps({"author": "John"})
assert result["prompt_vector"] == array_to_buffer([0.1, 0.2, 0.3])
assert result["category"] == "technology"
Expand Down
Loading