Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce the LLM session manager classes #141

Merged
merged 41 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
26c656b
wip: defining base session manager class methods
justin-cechmanek Apr 16, 2024
872e0fe
removes base session manager parent class
justin-cechmanek Apr 22, 2024
65177b2
adds proposed schema to sesssion manager description
justin-cechmanek Apr 22, 2024
ebc07d6
adds session manager init
justin-cechmanek Apr 25, 2024
9dbc928
wip: minimal working session manager
justin-cechmanek Apr 26, 2024
bb467c2
wip: adding full conversation history, and cleans up scoping
justin-cechmanek Apr 30, 2024
4ed6208
minor clean up
justin-cechmanek May 1, 2024
2c5c52a
wip: initial notebook demo on semantic session manager
justin-cechmanek May 1, 2024
4d0b9ce
wip: continues session manager work
justin-cechmanek May 2, 2024
ba2575a
cleans up first session manager notebook
justin-cechmanek May 2, 2024
5fda1ec
makes scope configurable on each call to fetch_context
justin-cechmanek May 3, 2024
37ae270
improves scope settings
justin-cechmanek May 4, 2024
4e432a4
adds notebook example of controling session access scope
justin-cechmanek May 4, 2024
3d404e7
formatting
justin-cechmanek May 6, 2024
c20d23e
mypy formatting
justin-cechmanek May 6, 2024
29b4a05
black formatting
justin-cechmanek May 6, 2024
e612195
bumps notebook number
justin-cechmanek May 6, 2024
45427ad
corrects method name
justin-cechmanek May 6, 2024
d9aacf9
sets an asymetric retrieval model as default vectorizer
justin-cechmanek May 7, 2024
3a117eb
Merge branch 'main' into jc/semantic-session-manager
justin-cechmanek May 8, 2024
f99967f
moves recency sorting into Redis query
justin-cechmanek May 8, 2024
5d4f34b
adds session manager notebook examples to index
justin-cechmanek May 8, 2024
dff924c
wip:refactor into multiple classes
justin-cechmanek May 15, 2024
c8b9325
Merge branch 'main' into jc/semantic-session-manager
justin-cechmanek May 15, 2024
6e8ca02
refactors session managers into multiple classes
justin-cechmanek May 27, 2024
90cfe59
adds tests for session managers
justin-cechmanek May 28, 2024
b2e80b5
removes redundant notebook
justin-cechmanek May 28, 2024
d898941
formatting
justin-cechmanek May 28, 2024
46fb9b5
formatting
justin-cechmanek May 28, 2024
58be5b7
fixes failing test
justin-cechmanek May 28, 2024
ce09632
changes user_id, session_id to user_tag, session_tag. Adds pydantic v…
justin-cechmanek May 29, 2024
d32b12e
makes system preamble fully optional, empty if not set
justin-cechmanek May 29, 2024
a5fb413
renames methods & properties to align with OpenAI and LangChain
justin-cechmanek Jun 12, 2024
c683532
changes session managers to match langchain chat history api
justin-cechmanek Jun 14, 2024
d666fad
adds messages property to match langchain
justin-cechmanek Jun 17, 2024
158eb52
adds test coverage for messages property
justin-cechmanek Jun 17, 2024
60a1a00
adds optional tool message type
justin-cechmanek Jun 28, 2024
85ba8d0
Merge branch 'main' into jc/semantic-session-manager
justin-cechmanek Jun 28, 2024
241199e
Bugfix in setting vectorizer. Uses index key_separator
justin-cechmanek Jul 2, 2024
5c5a385
updates doc strings
justin-cechmanek Jul 3, 2024
a0d0a21
empty arrary is returned when top_k=0
justin-cechmanek Jul 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions redisvl/extensions/session_manager/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from redisvl.extensions.session_manager.session import SessionManager
justin-cechmanek marked this conversation as resolved.
Show resolved Hide resolved

__all__ = ["SessionManager"]
293 changes: 293 additions & 0 deletions redisvl/extensions/session_manager/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
import hashlib
from typing import Any, Dict, List, Optional, Tuple, Union
from datetime import datetime

from redis import Redis

from redisvl.index import SearchIndex
from redisvl.query import FilterQuery, RangeQuery
from redisvl.query.filter import Tag, Num
from redisvl.redis.utils import array_to_buffer
from redisvl.schema.schema import IndexSchema
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer

class SessionManager:
def __init__(self,
name: str,
justin-cechmanek marked this conversation as resolved.
Show resolved Hide resolved
session_id: str,
user_id: str,
application_id: str,
scope: str = 'session',
prefix: Optional[str] = None,
vectorizer: Optional[BaseVectorizer] = None,
distance_threshold: float = 0.3,
redis_client: Optional[Redis] = None,
):
""" Initialize session memory with index

Session Manager stores the current and previous user text prompts and
LLM responses to allow for enriching future prompts with session
context. Session history is stored in prompt:response pairs referred to
as exchanges.

Args:
name str: The name of the session manager index.
session_id str: Tag to be added to entries to link to a specific
session.
user_id str: Tag to be added to entries to link to a specific user.
application_id str: Tag to be added to entries to link to a
specific application.
scope str: The level of access this session manager can retrieve
data at. Must be one of 'session', 'user', 'application'
prefix Optional[str]: Prefix for the keys for this session data.
Defaults to None and will be replaced with the index name.
vectorizer Vectorizer: The vectorizer to create embeddings with.
distance_threshold float: The maximum semantic distance to be
included in the context. Defaults to 0.3
redis_client Optional[Redis]: A Redis client instance. Defaults to
None.


The proposed schema will support a single combined vector embedding
constructed from the prompt & response in a single string.

"""
prefix = prefix or name
self._session_id = session_id
self._user_id = user_id
self._application_id = application_id
self._scope = scope

if vectorizer is None:
self._vectorizer = HFTextVectorizer(
justin-cechmanek marked this conversation as resolved.
Show resolved Hide resolved
model="sentence-transformers/all-mpnet-base-v2"
)

self.set_distance_threshold(distance_threshold)

schema = IndexSchema.from_dict({"index": {"name": name, "prefix": prefix}})

schema.add_fields(
[
{"name": "prompt", "type": "text"},
{"name": "response", "type": "text"},
{"name": "timestamp", "type": "numeric"},
{"name": "session_id", "type": "tag"},
{"name": "user_id", "type": "tag"},
{"name": "application_id", "type": "tag"},
{"name": "count", "type": "numeric"},
{"name": "token_count", "type": "numeric"},
{
"name": "combined_vector_field",
"type": "vector",
"attrs": {
"dims": self._vectorizer.dims,
"datatype": "float32",
"distance_metric": "cosine",
"algorithm": "flat",
},
},
]
)

self._index = SearchIndex(schema=schema)

if redis_client:
self._index.set_client(redis_client)
self._redis_client = redis_client
else:
self._index.connect(redis_url="redis://localhost:6379")
self._redis_client = Redis(decode_responses=True)

self._index.create(overwrite=False)

self._tag_filter = Tag("application_id") == self._application_id
if self._scope == 'user':
user_filter = Tag("user_id") == self._user_id
self._tag_filter = self._tag_filter & user_filter
if self._scope == 'session':
session_filter = Tag("session_id") == self._session_id
user_filter = Tag("user_id") == self._user_id
self._tag_filter = self._tag_filter & user_filter & session_filter


def clear(self):
""" Clears the chat session history. """
pass


def fetch_context(
justin-cechmanek marked this conversation as resolved.
Show resolved Hide resolved
self,
prompt: str,
as_text: bool = False,
top_k: int = 3
) -> Union[List[str], List[Dict[str,str]]]:
""" Searches the chat history for information semantically related to
the specified prompt.

This method uses vector similarity search with a text prompt as input.
It checks for semantically similar prompt:response pairs and fetches
the top k most relevant previous prompt:response pairs to include as
context to the next LLM call.

Args:
prompt str: The text prompt to search for in session memory
as_text bool: Whether to return the prompt:response pairs as text
or as JSON
top_k int: The number of previous exchanges to return. Default is 3

Returns:
Union[List[str], List[Dict[str,str]]: Either a list of strings, or a
list of prompts and responses in JSON containing the most relevant

Raises:
ValueError: If top_k is an invalid integer.
"""

return_fields = [
"session_id",
"user_id",
"application_id",
"count",
"prompt",
"response",
"timestamp",
"combined_vector_field",
]

query = RangeQuery(
vector=self._vectorizer.embed(prompt),
vector_field_name="combined_vector_field",
return_fields=return_fields,
distance_threshold=self._distance_threshold,
num_results=top_k,
return_score=True,
filter_expression=self._tag_filter
)
hits = self._index.query(query)
# if we don't find semantic matches fallback to returning recent context
if not hits:
hits = self.conversation_history()

hits.sort(key=lambda x: x['timestamp']) # TODO move sorting to query.py

if as_text:
statements = [self._preamble["_content"]]
for hit in hits:
statements.append(hit["prompt"])
statements.append(hit["response"])
else:
statements = [self._preamble]
for hit in hits:
statements.append({"role": "_user", "_content": hit["prompt"]})
statements.append({"role": "_llm", "_content": hit["response"]})

return statements


def conversation_history(
self,
as_text: bool = False,
top_k: int = 3
) -> Union[List[str], List[Dict[str,str]]]:
""" Retreive the conversation history in sequential order.

Args:
as_text bool: Whether to return the conversation as a single string,
or list of alternating prompts and responses.
top_k int: The number of previous exchanges to return. Default is 3

Returns:
Union[str, List[str]]: A single string transcription of the session
or list of strings if as_text is false.
"""
return_fields = [
"session_id",
"user_id",
"application_id",
"count",
"prompt",
"response",
"timestamp",
]

count_key = ":".join([self._application_id, self._user_id, self._session_id, "count"])
count = self._redis_client.get(count_key) or 0
last_k_filter = Num("count") > int(count) - top_k
combined = self._tag_filter & last_k_filter

query = FilterQuery(
return_fields=return_fields,
justin-cechmanek marked this conversation as resolved.
Show resolved Hide resolved
filter_expression=combined
)
hits = self._index.query(query)
return hits


@property
def distance_threshold(self):
return self._distance_threshold


def set_distance_threshold(self, threshold):
self._distance_threshold = threshold


def store(self, exchange: Tuple[str, str]) -> str:
""" Insert a prompt:response pair into the session memory. A timestamp
is associated with each exchange so that they can be later sorted
in sequential ordering after retrieval.

Args:
exchange Tuple[str, str]: The user prompt and corresponding LLM
response.

Returns:
str: The Redis key for the entry added to the database.
"""
count_key = ":".join([self._application_id, self._user_id, self._session_id, "count"])
count = self._redis_client.incr(count_key)
justin-cechmanek marked this conversation as resolved.
Show resolved Hide resolved
vector = self._vectorizer.embed(exchange[0] + exchange[1])
timestamp = int(datetime.now().timestamp())
payload = {
justin-cechmanek marked this conversation as resolved.
Show resolved Hide resolved
"id": self.hash_input(exchange[0]+str(timestamp)),
"prompt": exchange[0],
"response": exchange[1],
"timestamp": timestamp,
"session_id": self._session_id,
"user_id": self._user_id,
"application_id": self._application_id,
"count": count,
"token_count": 1, #TODO get actual token count
"combined_vector_field": array_to_buffer(vector)
}
key = self._index.load(data=[payload])
return key


def set_preamble(self, prompt: str) -> None:
""" Add a preamble statement to the the begining of each session to be
included in each subsequent LLM call.
"""
self._preamble = {"role": "_preamble", "_content": prompt}
# TODO store this in Redis with asigned scope?


def timstamp_to_int(self, timestamp: datetime.timestamp) -> int:
""" Converts a datetime object into integer for storage as numeric field
in hash.
"""
pass


def int_to_timestamp(self, epoch_time: int) -> datetime.timestamp:
""" Converts a numeric date expressed in epoch time into datetime
object.
"""
pass


def hash_input(self, prompt: str):
"""Hashes the input using SHA256."""
return hashlib.sha256(prompt.encode("utf-8")).hexdigest()

Loading