Skip to content

Commit

Permalink
Use ChatMessage pydantic model (#194)
Browse files Browse the repository at this point in the history
Uses pydantic for validation, docs, and custom serialization of chat
message objects on the way into and out of Redis.

---------

Co-authored-by: Justin Cechmanek <justin.cechmanek@redis.com>
  • Loading branch information
tylerhutcherson and justin-cechmanek authored Jul 31, 2024
1 parent 3844d57 commit 8bbd1b0
Show file tree
Hide file tree
Showing 8 changed files with 387 additions and 161 deletions.
35 changes: 35 additions & 0 deletions redisvl/extensions/router/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from pydantic.v1 import BaseModel, Field, validator

from redisvl.schema import IndexInfo, IndexSchema


class Route(BaseModel):
"""Model representing a routing path with associated metadata and thresholds."""
Expand Down Expand Up @@ -80,3 +82,36 @@ def distance_threshold_must_be_valid(cls, v):
if v <= 0 or v > 1:
raise ValueError("distance_threshold must be between 0 and 1")
return v


class SemanticRouterIndexSchema(IndexSchema):
"""Customized index schema for SemanticRouter."""

@classmethod
def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema":
"""Create an index schema based on router name and vector dimensions.
Args:
name (str): The name of the index.
vector_dims (int): The dimensions of the vectors.
Returns:
SemanticRouterIndexSchema: The constructed index schema.
"""
return cls(
index=IndexInfo(name=name, prefix=name),
fields=[ # type: ignore
{"name": "route_name", "type": "tag"},
{"name": "reference", "type": "text"},
{
"name": "vector",
"type": "vector",
"attrs": {
"algorithm": "flat",
"dims": vector_dims,
"distance_metric": "cosine",
"datatype": "float32",
},
},
],
)
35 changes: 1 addition & 34 deletions redisvl/extensions/router/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
Route,
RouteMatch,
RoutingConfig,
SemanticRouterIndexSchema,
)
from redisvl.index import SearchIndex
from redisvl.query import RangeQuery
from redisvl.redis.utils import convert_bytes, hashify, make_dict
from redisvl.schema import IndexInfo, IndexSchema
from redisvl.utils.log import get_logger
from redisvl.utils.utils import model_to_dict
from redisvl.utils.vectorize import (
Expand All @@ -29,39 +29,6 @@
logger = get_logger(__name__)


class SemanticRouterIndexSchema(IndexSchema):
"""Customized index schema for SemanticRouter."""

@classmethod
def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema":
"""Create an index schema based on router name and vector dimensions.
Args:
name (str): The name of the index.
vector_dims (int): The dimensions of the vectors.
Returns:
SemanticRouterIndexSchema: The constructed index schema.
"""
return cls(
index=IndexInfo(name=name, prefix=name),
fields=[ # type: ignore
{"name": "route_name", "type": "tag"},
{"name": "reference", "type": "text"},
{
"name": "vector",
"type": "vector",
"attrs": {
"algorithm": "flat",
"dims": vector_dims,
"distance_metric": "cosine",
"datatype": "float32",
},
},
],
)


class SemanticRouter(BaseModel):
"""Semantic Router for managing and querying route vectors."""

Expand Down
55 changes: 24 additions & 31 deletions redisvl/extensions/session_manager/base_session.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import Any, Dict, List, Optional, Union
from uuid import uuid4

from redis import Redis

from redisvl.query.filter import FilterExpression
from redisvl.extensions.session_manager.schema import ChatMessage
from redisvl.utils.utils import create_uuid


class BaseSessionManager:
Expand Down Expand Up @@ -32,7 +30,7 @@ def __init__(
session. Defaults to instance uuid.
"""
self._name = name
self._session_tag = session_tag or uuid4().hex
self._session_tag = session_tag or create_uuid()

def clear(self) -> None:
"""Clears the chat session history."""
Expand Down Expand Up @@ -85,44 +83,39 @@ def get_recent(
raise NotImplementedError

def _format_context(
self, hits: List[Dict[str, Any]], as_text: bool
self, messages: List[Dict[str, Any]], as_text: bool
) -> Union[List[str], List[Dict[str, str]]]:
"""Extracts the prompt and response fields from the Redis hashes and
formats them as either flat dictionaries or strings.
Args:
hits (List): The hashes containing prompt & response pairs from
recent conversation history.
messages (List[Dict[str, Any]]): The messages from the session index.
as_text (bool): Whether to return the conversation as a single string,
or list of alternating prompts and responses.
Returns:
Union[str, List[str]]: A single string transcription of the session
or list of strings if as_text is false.
"""
if as_text:
text_statements = []
for hit in hits:
text_statements.append(hit[self.content_field_name])
return text_statements
else:
statements = []
for hit in hits:
statements.append(
{
self.role_field_name: hit[self.role_field_name],
self.content_field_name: hit[self.content_field_name],
}
)
if (
hasattr(hit, self.tool_field_name)
or isinstance(hit, dict)
and self.tool_field_name in hit
):
statements[-1].update(
{self.tool_field_name: hit[self.tool_field_name]}
)
return statements
context = []

for message in messages:

chat_message = ChatMessage(**message)

if as_text:
context.append(chat_message.content)
else:
chat_message_dict = {
self.role_field_name: chat_message.role,
self.content_field_name: chat_message.content,
}
if chat_message.tool_call_id is not None:
chat_message_dict[self.tool_field_name] = chat_message.tool_call_id

context.append(chat_message_dict) # type: ignore

return context

def store(
self, prompt: str, response: str, session_tag: Optional[str] = None
Expand Down
94 changes: 94 additions & 0 deletions redisvl/extensions/session_manager/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import Dict, List, Optional

from pydantic.v1 import BaseModel, Field, root_validator

from redisvl.redis.utils import array_to_buffer
from redisvl.schema import IndexSchema
from redisvl.utils.utils import current_timestamp


class ChatMessage(BaseModel):
"""A single chat message exchanged between a user and an LLM."""

_id: Optional[str] = Field(default=None)
"""A unique identifier for the message."""
role: str # TODO -- do we enumify this?
"""The role of the message sender (e.g., 'user' or 'llm')."""
content: str
"""The content of the message."""
session_tag: str
"""Tag associated with the current session."""
timestamp: float = Field(default_factory=current_timestamp)
"""The time the message was sent, in UTC, rounded to milliseconds."""
tool_call_id: Optional[str] = Field(default=None)
"""An optional identifier for a tool call associated with the message."""
vector_field: Optional[List[float]] = Field(default=None)
"""The vector representation of the message content."""

class Config:
arbitrary_types_allowed = True

@root_validator(pre=False)
@classmethod
def generate_id(cls, values):
if "_id" not in values:
values["_id"] = f'{values["session_tag"]}:{values["timestamp"]}'
return values

def to_dict(self) -> Dict:
data = self.dict()

# handle optional fields
if data["vector_field"] is not None:
data["vector_field"] = array_to_buffer(data["vector_field"])
else:
del data["vector_field"]

if self.tool_call_id is None:
del data["tool_call_id"]

return data


class StandardSessionIndexSchema(IndexSchema):

@classmethod
def from_params(cls, name: str, prefix: str):

return cls(
index={"name": name, "prefix": prefix}, # type: ignore
fields=[ # type: ignore
{"name": "role", "type": "tag"},
{"name": "content", "type": "text"},
{"name": "tool_call_id", "type": "tag"},
{"name": "timestamp", "type": "numeric"},
{"name": "session_tag", "type": "tag"},
],
)


class SemanticSessionIndexSchema(IndexSchema):

@classmethod
def from_params(cls, name: str, prefix: str, vectorizer_dims: int):

return cls(
index={"name": name, "prefix": prefix}, # type: ignore
fields=[ # type: ignore
{"name": "role", "type": "tag"},
{"name": "content", "type": "text"},
{"name": "tool_call_id", "type": "tag"},
{"name": "timestamp", "type": "numeric"},
{"name": "session_tag", "type": "tag"},
{
"name": "vector_field",
"type": "vector",
"attrs": {
"dims": vectorizer_dims,
"datatype": "float32",
"distance_metric": "cosine",
"algorithm": "flat",
},
},
],
)
Loading

0 comments on commit 8bbd1b0

Please sign in to comment.