-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce the LLM session manager classes (#141)
Here's the proposed class for LLM session management. It support recent or full conversation storage & retrieval, as well as relevance based conversation section retrieval.
- Loading branch information
1 parent
aa05797
commit 61e7338
Showing
8 changed files
with
1,975 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,5 +17,6 @@ llmcache_03 | |
vectorizers_04 | ||
hash_vs_json_05 | ||
rerankers_06 | ||
session_manager_07 | ||
``` | ||
|
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from redisvl.extensions.session_manager.base_session import BaseSessionManager | ||
from redisvl.extensions.session_manager.semantic_session import SemanticSessionManager | ||
from redisvl.extensions.session_manager.standard_session import StandardSessionManager | ||
|
||
__all__ = ["BaseSessionManager", "StandardSessionManager", "SemanticSessionManager"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from redis import Redis | ||
|
||
|
||
class BaseSessionManager: | ||
id_field_name: str = "id_field" | ||
role_field_name: str = "role" | ||
content_field_name: str = "content" | ||
tool_field_name: str = "tool_call_id" | ||
timestamp_field_name: str = "timestamp" | ||
|
||
def __init__( | ||
self, | ||
name: str, | ||
session_tag: str, | ||
user_tag: str, | ||
): | ||
"""Initialize session memory with index | ||
Session Manager stores the current and previous user text prompts and | ||
LLM responses to allow for enriching future prompts with session | ||
context. Session history is stored in individual user or LLM prompts and | ||
responses. | ||
Args: | ||
name (str): The name of the session manager index. | ||
session_tag (str): Tag to be added to entries to link to a specific | ||
session. | ||
user_tag (str): Tag to be added to entries to link to a specific user. | ||
""" | ||
self._name = name | ||
self._user_tag = user_tag | ||
self._session_tag = session_tag | ||
|
||
def set_scope( | ||
self, | ||
session_tag: Optional[str] = None, | ||
user_tag: Optional[str] = None, | ||
) -> None: | ||
"""Set the filter to apply to querries based on the desired scope. | ||
This new scope persists until another call to set_scope is made, or if | ||
scope specified in calls to get_recent. | ||
Args: | ||
session_tag (str): Id of the specific session to filter to. Default is | ||
None. | ||
user_tag (str): Id of the specific user to filter to. Default is None. | ||
""" | ||
raise NotImplementedError | ||
|
||
def clear(self) -> None: | ||
"""Clears the chat session history.""" | ||
raise NotImplementedError | ||
|
||
def delete(self) -> None: | ||
"""Clear all conversation history and remove any search indices.""" | ||
raise NotImplementedError | ||
|
||
def drop(self, id_field: Optional[str] = None) -> None: | ||
"""Remove a specific exchange from the conversation history. | ||
Args: | ||
id_field (Optional[str]): The id_field of the entry to delete. | ||
If None then the last entry is deleted. | ||
""" | ||
raise NotImplementedError | ||
|
||
@property | ||
def messages(self) -> Union[List[str], List[Dict[str, str]]]: | ||
"""Returns the full chat history.""" | ||
raise NotImplementedError | ||
|
||
def get_recent( | ||
self, | ||
top_k: int = 5, | ||
session_tag: Optional[str] = None, | ||
user_tag: Optional[str] = None, | ||
as_text: bool = False, | ||
raw: bool = False, | ||
) -> Union[List[str], List[Dict[str, str]]]: | ||
"""Retreive the recent conversation history in sequential order. | ||
Args: | ||
top_k (int): The number of previous exchanges to return. Default is 5. | ||
Note that one exchange contains both a prompt and response. | ||
session_tag (str): Tag to be added to entries to link to a specific | ||
session. | ||
user_tag (str): Tag to be added to entries to link to a specific user. | ||
as_text (bool): Whether to return the conversation as a single string, | ||
or list of alternating prompts and responses. | ||
raw (bool): Whether to return the full Redis hash entry or just the | ||
prompt and response | ||
Returns: | ||
Union[str, List[str]]: A single string transcription of the session | ||
or list of strings if as_text is false. | ||
Raises: | ||
ValueError: If top_k is not an integer greater than or equal to 0. | ||
""" | ||
raise NotImplementedError | ||
|
||
def _format_context( | ||
self, hits: List[Dict[str, Any]], as_text: bool | ||
) -> Union[List[str], List[Dict[str, str]]]: | ||
"""Extracts the prompt and response fields from the Redis hashes and | ||
formats them as either flat dictionaries or strings. | ||
Args: | ||
hits (List): The hashes containing prompt & response pairs from | ||
recent conversation history. | ||
as_text (bool): Whether to return the conversation as a single string, | ||
or list of alternating prompts and responses. | ||
Returns: | ||
Union[str, List[str]]: A single string transcription of the session | ||
or list of strings if as_text is false. | ||
""" | ||
if as_text: | ||
text_statements = [] | ||
for hit in hits: | ||
text_statements.append(hit[self.content_field_name]) | ||
return text_statements | ||
else: | ||
statements = [] | ||
for hit in hits: | ||
statements.append( | ||
{ | ||
self.role_field_name: hit[self.role_field_name], | ||
self.content_field_name: hit[self.content_field_name], | ||
} | ||
) | ||
if ( | ||
hasattr(hit, self.tool_field_name) | ||
or isinstance(hit, dict) | ||
and self.tool_field_name in hit | ||
): | ||
statements[-1].update( | ||
{self.tool_field_name: hit[self.tool_field_name]} | ||
) | ||
return statements | ||
|
||
def store(self, prompt: str, response: str) -> None: | ||
"""Insert a prompt:response pair into the session memory. A timestamp | ||
is associated with each exchange so that they can be later sorted | ||
in sequential ordering after retrieval. | ||
Args: | ||
prompt (str): The user prompt to the LLM. | ||
response (str): The corresponding LLM response. | ||
""" | ||
raise NotImplementedError | ||
|
||
def add_messages(self, messages: List[Dict[str, str]]) -> None: | ||
"""Insert a list of prompts and responses into the session memory. | ||
A timestamp is associated with each so that they can be later sorted | ||
in sequential ordering after retrieval. | ||
Args: | ||
messages (List[Dict[str, str]]): The list of user prompts and LLM responses. | ||
""" | ||
raise NotImplementedError | ||
|
||
def add_message(self, message: Dict[str, str]) -> None: | ||
"""Insert a single prompt or response into the session memory. | ||
A timestamp is associated with it so that it can be later sorted | ||
in sequential ordering after retrieval. | ||
Args: | ||
message (Dict[str,str]): The user prompt or LLM response. | ||
""" | ||
raise NotImplementedError |
Oops, something went wrong.