Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support distance threshold override #202

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def check(
num_results: int = 1,
return_fields: Optional[List[str]] = None,
filter_expression: Optional[FilterExpression] = None,
distance_threshold: Optional[float] = None,
) -> List[Dict[str, Any]]:
"""Checks the semantic cache for results similar to the specified prompt
or vector.
Expand All @@ -255,6 +256,8 @@ def check(
filter_expression (Optional[FilterExpression]) : Optional filter expression
that can be used to filter cache results. Defaults to None and
the full cache will be searched.
distance_threshold (Optional[float]): The threshold for semantic
vector distance.

Returns:
List[Dict[str, Any]]: A list of dicts containing the requested
Expand All @@ -274,9 +277,12 @@ def check(
if not (prompt or vector):
raise ValueError("Either prompt or vector must be specified.")

# overrides
distance_threshold = distance_threshold or self._distance_threshold
return_fields = return_fields or self.return_fields
vector = vector or self._vectorize_prompt(prompt)

self._check_vector_dims(vector)
return_fields = return_fields or self.return_fields

if not isinstance(return_fields, list):
raise TypeError("return_fields must be a list of field names")
Expand All @@ -285,7 +291,7 @@ def check(
vector=vector,
vector_field_name=self.vector_field_name,
return_fields=self.return_fields,
distance_threshold=self._distance_threshold,
distance_threshold=distance_threshold,
num_results=num_results,
return_score=True,
filter_expression=filter_expression,
Expand Down
12 changes: 9 additions & 3 deletions redisvl/extensions/session_manager/semantic_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def get_relevant(
fall_back: bool = False,
session_tag: Optional[str] = None,
raw: bool = False,
distance_threshold: Optional[float] = None,
) -> Union[List[str], List[Dict[str, str]]]:
"""Searches the chat history for information semantically related to
the specified prompt.
Expand All @@ -151,10 +152,12 @@ def get_relevant(
as_text (bool): Whether to return the prompts and responses as text
or as JSON
top_k (int): The number of previous messages to return. Default is 5.
fall_back (bool): Whether to drop back to recent conversation history
if no relevant context is found.
session_tag (Optional[str]): Tag to be added to entries to link to a specific
session. Defaults to instance uuid.
distance_threshold (Optional[float]): The threshold for semantic
vector distance.
fall_back (bool): Whether to drop back to recent conversation history
if no relevant context is found.
raw (bool): Whether to return the full Redis hash entry or just the
message.

Expand All @@ -169,6 +172,9 @@ def get_relevant(
if top_k == 0:
return []

# override distance threshold
distance_threshold = distance_threshold or self._distance_threshold

return_fields = [
self.session_field_name,
self.role_field_name,
Expand All @@ -187,7 +193,7 @@ def get_relevant(
vector=self._vectorizer.embed(prompt),
vector_field_name=self.vector_field_name,
return_fields=return_fields,
distance_threshold=self._distance_threshold,
distance_threshold=distance_threshold,
num_results=top_k,
return_score=True,
filter_expression=session_filter,
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_llmcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_store_and_check(cache, vectorizer):
vector = vectorizer.embed(prompt)

cache.store(prompt, response, vector=vector)
check_result = cache.check(vector=vector)
check_result = cache.check(vector=vector, distance_threshold=0.4)

assert len(check_result) == 1
print(check_result, flush=True)
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/test_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,10 @@ def test_semantic_add_and_get_relevant(semantic_session):
semantic_session.set_distance_threshold(0.5)
default_context = semantic_session.get_relevant("list of fruits and vegetables")
assert len(default_context) == 5 # 2 pairs of prompt:response, and system
assert default_context == semantic_session.get_relevant(
"list of fruits and vegetables",
distance_threshold=0.5
)

# test tool calls can also be returned
context = semantic_session.get_relevant("winter sports like skiing")
Expand Down
Loading