Skip to content

Commit

Permalink
mypy formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
justin-cechmanek committed May 6, 2024
1 parent 3d404e7 commit c20d23e
Showing 1 changed file with 70 additions and 82 deletions.
152 changes: 70 additions & 82 deletions redisvl/extensions/session_manager/session.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
import hashlib
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union

from redis import Redis

from redisvl.index import SearchIndex
from redisvl.query import FilterQuery, RangeQuery
from redisvl.query.filter import Tag, Num
from redisvl.query.filter import Num, Tag
from redisvl.redis.utils import array_to_buffer
from redisvl.schema.schema import IndexSchema
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer


class SessionManager:
def __init__(
self,
name: str,
session_id: str,
user_id: str,
application_id: str,
scope: str = 'session',
scope: str = "session",
prefix: Optional[str] = None,
vectorizer: Optional[BaseVectorizer] = None,
distance_threshold: float = 0.3,
redis_client: Optional[Redis] = None,
preamble: str = ''
):
""" Initialize session memory with index
preamble: str = "",
):
"""Initialize session memory with index
Session Manager stores the current and previous user text prompts and
LLM responses to allow for enriching future prompts with session
Expand Down Expand Up @@ -88,7 +91,7 @@ def __init__(
"distance_metric": "cosine",
"algorithm": "flat",
},
},
},
]
)

Expand All @@ -104,22 +107,18 @@ def __init__(
self._index.create(overwrite=False)

self._tag_filter = Tag("application_id") == self._application_id
if self._scope == 'user':
if self._scope == "user":
user_filter = Tag("user_id") == self._user_id
self._tag_filter = self._tag_filter & user_filter
if self._scope == 'session':
if self._scope == "session":
session_filter = Tag("session_id") == self._session_id
user_filter = Tag("user_id") == self._user_id
self._tag_filter = self._tag_filter & user_filter & session_filter


def set_scope(
self,
session_id: str = None,
user_id: str = None,
application_id: str = None
) -> None:
""" Set the tag filter to apply to querries based on the desired scope.
self, session_id: Optional[str] = None, user_id: Optional[str] = None, application_id: Optional[str] = None
) -> None:
"""Set the tag filter to apply to querries based on the desired scope.
This new scope persists until another call to set_scope is made, or if
scope specified in calls to fetch_recent or fetch_relevant.
Expand All @@ -135,7 +134,7 @@ def set_scope(
if not (session_id or user_id or application_id):
return

tag_filter = Tag('application_id') == []
tag_filter = Tag("application_id") == []
if application_id:
tag_filter = tag_filter & (Tag("application_id") == application_id)
if user_id:
Expand All @@ -145,32 +144,29 @@ def set_scope(

self._tag_filter = tag_filter


def clear(self) -> None:
""" Clears the chat session history. """
with self._index.client.pipeline(transaction=False) as pipe:
for key in self._index.client.scan_iter(match=f"{self._index.prefix}:*"):
"""Clears the chat session history."""
with self._index.client.pipeline(transaction=False) as pipe: # type: ignore
for key in self._index.client.scan_iter(match=f"{self._index.prefix}:*"): # type: ignore
pipe.delete(key)
pipe.execute()


def delete(self) -> None:
""" Clear all conversation keys and remove the search index. """
"""Clear all conversation keys and remove the search index."""
self._index.delete(drop=True)


def fetch_relevant(
self,
prompt: str,
as_text: bool = False,
top_k: int = 3,
fall_back: bool = False,
session_id: str = None,
user_id: str = None,
application_id: str = None,
raw: bool = False
) -> Union[List[str], List[Dict[str,str]]]:
""" Searches the chat history for information semantically related to
session_id: Optional[str] = None,
user_id: Optional[str] = None,
application_id: Optional[str] = None,
raw: bool = False,
) -> Union[List[str], List[Dict[str, str]]]:
"""Searches the chat history for information semantically related to
the specified prompt.
This method uses vector similarity search with a text prompt as input.
Expand Down Expand Up @@ -216,7 +212,7 @@ def fetch_relevant(
distance_threshold=self._distance_threshold,
num_results=top_k,
return_score=True,
filter_expression=self._tag_filter
filter_expression=self._tag_filter,
)
hits = self._index.query(query)

Expand All @@ -227,17 +223,16 @@ def fetch_relevant(
return hits
return self._format_context(hits, as_text)


def fetch_recent(
self,
as_text: bool = False,
top_k: int = 3,
session_id: str = None,
user_id: str = None,
application_id: str = None,
raw = False
) -> Union[List[str], List[Dict[str,str]]]:
""" Retreive the recent conversation history in sequential order.
session_id: Optional[str] = None,
user_id: Optional[str] = None,
application_id: Optional[str] = None,
raw: bool = False,
) -> Union[List[str], List[Dict[str, str]]]:
"""Retreive the recent conversation history in sequential order.
Args:
as_text bool: Whether to return the conversation as a single string,
Expand Down Expand Up @@ -265,27 +260,23 @@ def fetch_recent(
"timestamp",
]

count_key = ":".join([self._application_id, self._user_id, self._session_id, "count"])
count_key = ":".join(
[self._application_id, self._user_id, self._session_id, "count"]
)
count = self._redis_client.get(count_key) or 0
last_k_filter = Num("count") > int(count) - top_k
combined = self._tag_filter & last_k_filter

query = FilterQuery(
return_fields=return_fields,
filter_expression=combined
)
query = FilterQuery(return_fields=return_fields, filter_expression=combined)
hits = self._index.query(query)
if raw:
return hits
return self._format_context(hits, as_text)


def _format_context(
self,
hits: List[Dict[str, Any]],
as_text: bool
) -> Union[List[str], List[Dict[str, str]]]:
""" Extracts the prompt and response fields from the Redis hashes and
self, hits: List[Dict[str, Any]], as_text: bool
) -> Union[List[str], List[Dict[str, str]]]:
"""Extracts the prompt and response fields from the Redis hashes and
formats them as either flat dictionaries oor strings.
Args:
Expand All @@ -298,71 +289,68 @@ def _format_context(
or list of strings if as_text is false.
"""
if hits:
hits.sort(key=lambda x: x['timestamp']) # TODO move sorting to query.py
hits.sort(key=lambda x: x["timestamp"]) # TODO move sorting to query.py

if as_text:
statements = [self._preamble["_content"]]
text_statements = [self._preamble["_content"]]
for hit in hits:
statements.append(hit["prompt"])
statements.append(hit["response"])
text_statements.append(hit["prompt"])
text_statements.append(hit["response"])
return text_statements
else:
statements = [self._preamble]
for hit in hits:
statements.append({"role": "_user", "_content": hit["prompt"]})
statements.append({"role": "_llm", "_content": hit["response"]})
return statements

return statements

@property
def distance_threshold(self):
return self._distance_threshold


def set_distance_threshold(self, threshold):
self._distance_threshold = threshold


def store(self, exchange: Tuple[str, str]) -> str:
""" Insert a prompt:response pair into the session memory. A timestamp
is associated with each exchange so that they can be later sorted
in sequential ordering after retrieval.
"""Insert a prompt:response pair into the session memory. A timestamp
is associated with each exchange so that they can be later sorted
in sequential ordering after retrieval.
Args:
exchange Tuple[str, str]: The user prompt and corresponding LLM
response.
Args:
exchange Tuple[str, str]: The user prompt and corresponding LLM
response.
Returns:
str: The Redis key for the entry added to the database.
Returns:
str: The Redis key for the entry added to the database.
"""
count_key = ":".join([self._application_id, self._user_id, self._session_id, "count"])
count_key = ":".join(
[self._application_id, self._user_id, self._session_id, "count"]
)
count = self._redis_client.incr(count_key)
vector = self._vectorizer.embed(exchange[0] + exchange[1])
timestamp = int(datetime.now().timestamp())
payload = {
"id": self.hash_input(exchange[0]+str(timestamp)),
"prompt": exchange[0],
"response": exchange[1],
"timestamp": timestamp,
"session_id": self._session_id,
"user_id": self._user_id,
"application_id": self._application_id,
"count": count,
"token_count": 1, #TODO get actual token count
"combined_vector_field": array_to_buffer(vector)
"id": self.hash_input(exchange[0] + str(timestamp)),
"prompt": exchange[0],
"response": exchange[1],
"timestamp": timestamp,
"session_id": self._session_id,
"user_id": self._user_id,
"application_id": self._application_id,
"count": count,
"token_count": 1, # TODO get actual token count
"combined_vector_field": array_to_buffer(vector),
}
key = self._index.load(data=[payload])
return key

keys = self._index.load(data=[payload])
return keys[0]

def set_preamble(self, prompt: str) -> None:
""" Add a preamble statement to the the begining of each session to be
included in each subsequent LLM call.
"""Add a preamble statement to the the begining of each session to be
included in each subsequent LLM call.
"""
self._preamble = {"role": "_preamble", "_content": prompt}
# TODO store this in Redis with asigned scope?


def hash_input(self, prompt: str):
"""Hashes the input using SHA256."""
return hashlib.sha256(prompt.encode("utf-8")).hexdigest()

0 comments on commit c20d23e

Please sign in to comment.