Skip to content

Commit

Permalink
Support distance threshold override (#202)
Browse files Browse the repository at this point in the history
At runtime, we should extend the ability to override the distance
threshold if provided. This is also how the semantic router works.
Parity between the extensions is key here (in terms of functionality)
  • Loading branch information
tylerhutcherson authored Aug 16, 2024
1 parent 38f2fe1 commit 13dcd66
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 6 deletions.
10 changes: 8 additions & 2 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def check(
num_results: int = 1,
return_fields: Optional[List[str]] = None,
filter_expression: Optional[FilterExpression] = None,
distance_threshold: Optional[float] = None,
) -> List[Dict[str, Any]]:
"""Checks the semantic cache for results similar to the specified prompt
or vector.
Expand All @@ -255,6 +256,8 @@ def check(
filter_expression (Optional[FilterExpression]) : Optional filter expression
that can be used to filter cache results. Defaults to None and
the full cache will be searched.
distance_threshold (Optional[float]): The threshold for semantic
vector distance.
Returns:
List[Dict[str, Any]]: A list of dicts containing the requested
Expand All @@ -274,9 +277,12 @@ def check(
if not (prompt or vector):
raise ValueError("Either prompt or vector must be specified.")

# overrides
distance_threshold = distance_threshold or self._distance_threshold
return_fields = return_fields or self.return_fields
vector = vector or self._vectorize_prompt(prompt)

self._check_vector_dims(vector)
return_fields = return_fields or self.return_fields

if not isinstance(return_fields, list):
raise TypeError("return_fields must be a list of field names")
Expand All @@ -285,7 +291,7 @@ def check(
vector=vector,
vector_field_name=self.vector_field_name,
return_fields=self.return_fields,
distance_threshold=self._distance_threshold,
distance_threshold=distance_threshold,
num_results=num_results,
return_score=True,
filter_expression=filter_expression,
Expand Down
12 changes: 9 additions & 3 deletions redisvl/extensions/session_manager/semantic_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def get_relevant(
fall_back: bool = False,
session_tag: Optional[str] = None,
raw: bool = False,
distance_threshold: Optional[float] = None,
) -> Union[List[str], List[Dict[str, str]]]:
"""Searches the chat history for information semantically related to
the specified prompt.
Expand All @@ -151,10 +152,12 @@ def get_relevant(
as_text (bool): Whether to return the prompts and responses as text
or as JSON
top_k (int): The number of previous messages to return. Default is 5.
fall_back (bool): Whether to drop back to recent conversation history
if no relevant context is found.
session_tag (Optional[str]): Tag to be added to entries to link to a specific
session. Defaults to instance uuid.
distance_threshold (Optional[float]): The threshold for semantic
vector distance.
fall_back (bool): Whether to drop back to recent conversation history
if no relevant context is found.
raw (bool): Whether to return the full Redis hash entry or just the
message.
Expand All @@ -169,6 +172,9 @@ def get_relevant(
if top_k == 0:
return []

# override distance threshold
distance_threshold = distance_threshold or self._distance_threshold

return_fields = [
self.session_field_name,
self.role_field_name,
Expand All @@ -187,7 +193,7 @@ def get_relevant(
vector=self._vectorizer.embed(prompt),
vector_field_name=self.vector_field_name,
return_fields=return_fields,
distance_threshold=self._distance_threshold,
distance_threshold=distance_threshold,
num_results=top_k,
return_score=True,
filter_expression=session_filter,
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_llmcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_store_and_check(cache, vectorizer):
vector = vectorizer.embed(prompt)

cache.store(prompt, response, vector=vector)
check_result = cache.check(vector=vector)
check_result = cache.check(vector=vector, distance_threshold=0.4)

assert len(check_result) == 1
print(check_result, flush=True)
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/test_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,10 @@ def test_semantic_add_and_get_relevant(semantic_session):
semantic_session.set_distance_threshold(0.5)
default_context = semantic_session.get_relevant("list of fruits and vegetables")
assert len(default_context) == 5 # 2 pairs of prompt:response, and system
assert default_context == semantic_session.get_relevant(
"list of fruits and vegetables",
distance_threshold=0.5
)

# test tool calls can also be returned
context = semantic_session.get_relevant("winter sports like skiing")
Expand Down

0 comments on commit 13dcd66

Please sign in to comment.