Skip to content

Commit

Permalink
support dynamic distance threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerhutcherson committed Aug 16, 2024
1 parent 9147133 commit cf5a512
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 6 deletions.
10 changes: 8 additions & 2 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def check(
num_results: int = 1,
return_fields: Optional[List[str]] = None,
filter_expression: Optional[FilterExpression] = None,
distance_threshold: Optional[float] = None,
) -> List[Dict[str, Any]]:
"""Checks the semantic cache for results similar to the specified prompt
or vector.
Expand All @@ -255,6 +256,8 @@ def check(
filter_expression (Optional[FilterExpression]) : Optional filter expression
that can be used to filter cache results. Defaults to None and
the full cache will be searched.
distance_threshold (Optional[float]): The threshold for semantic
vector distance.
Returns:
List[Dict[str, Any]]: A list of dicts containing the requested
Expand All @@ -274,9 +277,12 @@ def check(
if not (prompt or vector):
raise ValueError("Either prompt or vector must be specified.")

# overrides
distance_threshold = distance_threshold or self._distance_threshold
return_fields = return_fields or self.return_fields
vector = vector or self._vectorize_prompt(prompt)

self._check_vector_dims(vector)
return_fields = return_fields or self.return_fields

if not isinstance(return_fields, list):
raise TypeError("return_fields must be a list of field names")
Expand All @@ -285,7 +291,7 @@ def check(
vector=vector,
vector_field_name=self.vector_field_name,
return_fields=self.return_fields,
distance_threshold=self._distance_threshold,
distance_threshold=distance_threshold,
num_results=num_results,
return_score=True,
filter_expression=filter_expression,
Expand Down
12 changes: 9 additions & 3 deletions redisvl/extensions/session_manager/semantic_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def get_relevant(
fall_back: bool = False,
session_tag: Optional[str] = None,
raw: bool = False,
distance_threshold: Optional[float] = None,
) -> Union[List[str], List[Dict[str, str]]]:
"""Searches the chat history for information semantically related to
the specified prompt.
Expand All @@ -151,10 +152,12 @@ def get_relevant(
as_text (bool): Whether to return the prompts and responses as text
or as JSON
top_k (int): The number of previous messages to return. Default is 5.
fall_back (bool): Whether to drop back to recent conversation history
if no relevant context is found.
session_tag (Optional[str]): Tag to be added to entries to link to a specific
session. Defaults to instance uuid.
distance_threshold (Optional[float]): The threshold for semantic
vector distance.
fall_back (bool): Whether to drop back to recent conversation history
if no relevant context is found.
raw (bool): Whether to return the full Redis hash entry or just the
message.
Expand All @@ -169,6 +172,9 @@ def get_relevant(
if top_k == 0:
return []

# override distance threshold
distance_threshold = distance_threshold or self._distance_threshold

return_fields = [
self.session_field_name,
self.role_field_name,
Expand All @@ -187,7 +193,7 @@ def get_relevant(
vector=self._vectorizer.embed(prompt),
vector_field_name=self.vector_field_name,
return_fields=return_fields,
distance_threshold=self._distance_threshold,
distance_threshold=distance_threshold,
num_results=top_k,
return_score=True,
filter_expression=session_filter,
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_llmcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_store_and_check(cache, vectorizer):
vector = vectorizer.embed(prompt)

cache.store(prompt, response, vector=vector)
check_result = cache.check(vector=vector)
check_result = cache.check(vector=vector, distance_threshold=0.4)

assert len(check_result) == 1
print(check_result, flush=True)
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/test_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,10 @@ def test_semantic_add_and_get_relevant(semantic_session):
semantic_session.set_distance_threshold(0.5)
default_context = semantic_session.get_relevant("list of fruits and vegetables")
assert len(default_context) == 5 # 2 pairs of prompt:response, and system
assert default_context == semantic_session.get_relevant(
"list of fruits and vegetables",
distance_threshold=0.5
)

# test tool calls can also be returned
context = semantic_session.get_relevant("winter sports like skiing")
Expand Down

0 comments on commit cf5a512

Please sign in to comment.