Skip to content

Commit

Permalink
refactor: clean up agent.step() by having a response model instead of…
Browse files Browse the repository at this point in the history
… a response tuple
  • Loading branch information
cpacker committed Sep 13, 2024
1 parent 7e70082 commit 22ccd3d
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 66 deletions.
119 changes: 57 additions & 62 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from memgpt.memory import ArchivalMemory, RecallMemory, summarize_messages
from memgpt.metadata import MetadataStore
from memgpt.persistence_manager import LocalStateManager
from memgpt.schemas.agent import AgentState
from memgpt.schemas.agent import AgentState, AgentStepResponse
from memgpt.schemas.block import Block
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.enums import MessageRole, OptionState
Expand Down Expand Up @@ -196,15 +196,10 @@ class BaseAgent(ABC):
Only two interfaces are required: step and update_state.
"""

# @abstractmethod
# def step(self, message: Message) -> List[Message]:
# raise NotImplementedError

# TODO cleanup
@abstractmethod
def step(
self,
user_message: Union[Message, str], # NOTE: should be json.dump(dict)
messages: Union[Message, List[Message], str], # TODO deprecate str inputs
first_message: bool = False,
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
skip_verify: bool = False,
Expand All @@ -214,10 +209,11 @@ def step(
timestamp: Optional[datetime.datetime] = None,
inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT,
ms: Optional[MetadataStore] = None,
) -> Tuple[List[Union[dict, Message]], bool, bool, bool]:
) -> AgentStepResponse:
"""
Top-level event message handler for the agent.
"""
raise NotImplementedError

@abstractmethod
def update_state(self) -> AgentState:
Expand Down Expand Up @@ -708,45 +704,21 @@ def _handle_ai_response(

def step(
self,
user_message: Union[Message, str], # NOTE: should be json.dump(dict)
user_message: Union[Message, None, str], # NOTE: should be json.dump(dict)
first_message: bool = False,
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
skip_verify: bool = False,
return_dicts: bool = True, # if True, return dicts, if False, return Message objects
return_dicts: bool = True,
recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
stream: bool = False, # TODO move to config?
timestamp: Optional[datetime.datetime] = None,
inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT,
ms: Optional[MetadataStore] = None,
) -> Tuple[List[Union[dict, Message]], bool, bool, bool]:
) -> AgentStepResponse:
"""Top-level event message handler for the MemGPT agent"""

def strip_name_field_from_user_message(user_message_text: str) -> Tuple[str, Optional[str]]:
"""If 'name' exists in the JSON string, remove it and return the cleaned text + name value"""
try:
user_message_json = dict(json_loads(user_message_text))
# Special handling for AutoGen messages with 'name' field
# Treat 'name' as a special field
# If it exists in the input message, elevate it to the 'message' level
name = user_message_json.pop("name", None)
clean_message = json_dumps(user_message_json)

except Exception as e:
print(f"{CLI_WARNING_PREFIX}handling of 'name' field failed with: {e}")

return clean_message, name

def validate_json(user_message_text: str, raise_on_error: bool) -> str:
try:
user_message_json = dict(json_loads(user_message_text))
user_message_json_val = json_dumps(user_message_json)
return user_message_json_val
except Exception as e:
print(f"{CLI_WARNING_PREFIX}couldn't parse user input message as JSON: {e}")
if raise_on_error:
raise e

try:

# Step 0: update core memory
# only pulling latest block data if shared memory is being used
# TODO: ensure we're passing in metadata store from all surfaces
Expand All @@ -760,11 +732,14 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str:
# once we ensure we're correctly comparing whether in-memory core
# data is different than persisted core data.
self.rebuild_memory(force=True, ms=ms)

# Step 1: add user message
if user_message is not None:
if isinstance(user_message, Message):
assert user_message.text is not None

# Validate JSON via save/load
user_message_text = validate_json(user_message.text, False)
user_message_text = validate_json(user_message.text)
cleaned_user_message_text, name = strip_name_field_from_user_message(user_message_text)

if name is not None:
Expand All @@ -778,7 +753,7 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str:

elif isinstance(user_message, str):
# Validate JSON via save/load
user_message = validate_json(user_message, False)
user_message = validate_json(user_message)
cleaned_user_message_text, name = strip_name_field_from_user_message(user_message)

# If user_message['name'] is not None, it will be handled properly by dict_to_message
Expand All @@ -799,14 +774,15 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str:
self.interface.user_message(user_message.text, msg_obj=user_message)

input_message_sequence = self._messages + [user_message]

# Alternatively, the requestor can send an empty user message
else:
input_message_sequence = self._messages

if len(input_message_sequence) > 1 and input_message_sequence[-1].role != "user":
printd(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue")

# Step 2: send the conversation and available functions to GPT
# Step 2: send the conversation and available functions to the LLM
if not skip_verify and (first_message or self.messages_total == self.messages_total_init):
printd(f"This is the first message. Running extra verifier on AI response.")
counter = 0
Expand Down Expand Up @@ -843,17 +819,6 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str:
response_message_id=response.id if stream else None,
)

# Add the extra metadata to the assistant response
# (e.g. enough metadata to enable recreating the API call)
# assert "api_response" not in all_response_messages[0]
# all_response_messages[0]["api_response"] = response_message_copy
# assert "api_args" not in all_response_messages[0]
# all_response_messages[0]["api_args"] = {
# "model": self.model,
# "messages": input_message_sequence,
# "functions": self.functions,
# }

# Step 6: extend the message history
if user_message is not None:
if isinstance(user_message, Message):
Expand All @@ -866,6 +831,7 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str:
# Check the memory pressure and potentially issue a memory pressure warning
current_total_tokens = response.usage.total_tokens
active_memory_warning = False

# We can't do summarize logic properly if context_window is undefined
if self.agent_state.llm_config.context_window is None:
# Fallback if for some reason context_window is missing, just set to the default
Expand All @@ -874,14 +840,17 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str:
self.agent_state.llm_config.context_window = (
LLM_MAX_TOKENS[self.model] if (self.model is not None and self.model in LLM_MAX_TOKENS) else LLM_MAX_TOKENS["DEFAULT"]
)

if current_total_tokens > MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window):
printd(
f"{CLI_WARNING_PREFIX}last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}"
)

# Only deliver the alert if we haven't already (this period)
if not self.agent_alerted_about_memory_pressure:
active_memory_warning = True
self.agent_alerted_about_memory_pressure = True # it's up to the outer loop to handle this

else:
printd(
f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}"
Expand All @@ -893,7 +862,13 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str:
# update state after each step
self.update_state()

return messages_to_return, heartbeat_request, function_failed, active_memory_warning, response.usage
return AgentStepResponse(
messages=messages_to_return,
heartbeat_request=heartbeat_request,
function_failed=function_failed,
in_context_memory_warning=active_memory_warning,
usage=response.usage,
)

except Exception as e:
printd(f"step() failed\nuser_message = {user_message}\nerror = {e}")
Expand Down Expand Up @@ -937,15 +912,6 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True,
candidate_messages_to_summarize = candidate_messages_to_summarize[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST]
token_counts = token_counts[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST]

# if disallow_tool_as_first:
# # We have to make sure that a "tool" call is not sitting at the front (after system message),
# # otherwise we'll get an error from OpenAI (if using the OpenAI API)
# while len(candidate_messages_to_summarize) > 0:
# if candidate_messages_to_summarize[0]["role"] in ["tool", "function"]:
# candidate_messages_to_summarize.pop(0)
# else:
# break

printd(f"MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC={MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC}")
printd(f"MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}")
printd(f"token_counts={token_counts}")
Expand Down Expand Up @@ -1386,9 +1352,11 @@ def retry_message(self) -> List[Message]:

self.pop_until_user()
user_message = self.pop_message(count=1)[0]
messages, _, _, _, _ = self.step(user_message=user_message.text, return_dicts=False)
step_response = self.step(user_message=user_message.text, return_dicts=False)
messages = step_response.messages

assert messages is not None and all(isinstance(msg, Message) for msg in messages), "step() returned non-Message objects"
assert messages is not None
assert all(isinstance(msg, Message) for msg in messages), "step() returned non-Message objects"
return messages


Expand Down Expand Up @@ -1430,3 +1398,30 @@ def save_agent_memory(agent: Agent, ms: MetadataStore):
if block.value is None:
block.value = ""
ms.update_or_create_block(block)


def strip_name_field_from_user_message(user_message_text: str) -> Tuple[str, Optional[str]]:
"""If 'name' exists in the JSON string, remove it and return the cleaned text + name value"""
try:
user_message_json = dict(json_loads(user_message_text))
# Special handling for AutoGen messages with 'name' field
# Treat 'name' as a special field
# If it exists in the input message, elevate it to the 'message' level
name = user_message_json.pop("name", None)
clean_message = json_dumps(user_message_json)
return clean_message, name

except Exception as e:
print(f"{CLI_WARNING_PREFIX}handling of 'name' field failed with: {e}")
raise e


def validate_json(user_message_text: str) -> str:
"""Make sure that the user input message is valid JSON"""
try:
user_message_json = dict(json_loads(user_message_text))
user_message_json_val = json_dumps(user_message_json)
return user_message_json_val
except Exception as e:
print(f"{CLI_WARNING_PREFIX}couldn't parse user input message as JSON: {e}")
raise e
7 changes: 6 additions & 1 deletion memgpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,19 @@ def run_agent_loop(
skip_next_user_input = False

def process_agent_step(user_message, no_verify):
new_messages, heartbeat_request, function_failed, token_warning, tokens_accumulated = memgpt_agent.step(
step_response = memgpt_agent.step(
user_message,
first_message=False,
skip_verify=no_verify,
stream=stream,
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
ms=ms,
)
new_messages = step_response.messages
heartbeat_request = step_response.heartbeat_request
function_failed = step_response.function_failed
token_warning = step_response.in_context_memory_warning
step_response.usage

agent.save_agent(memgpt_agent, ms)
skip_next_user_input = False
Expand Down
17 changes: 15 additions & 2 deletions memgpt/schemas/agent.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import uuid
from datetime import datetime
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

from pydantic import Field, field_validator
from pydantic import BaseModel, Field, field_validator

from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.llm_config import LLMConfig
from memgpt.schemas.memgpt_base import MemGPTBase
from memgpt.schemas.memory import Memory
from memgpt.schemas.message import Message
from memgpt.schemas.openai.chat_completion_response import UsageStatistics


class BaseAgent(MemGPTBase, validate_assignment=True):
Expand Down Expand Up @@ -102,3 +104,14 @@ class UpdateAgentState(BaseAgent):
# TODO: determine if these should be editable via this schema?
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")


class AgentStepResponse(BaseModel):
# TODO remove support for list of dicts
messages: Union[List[Message], List[dict]] = Field(..., description="The messages generated during the agent's step.")
heartbeat_request: bool = Field(..., description="Whether the agent requested a heartbeat (i.e. follow-up execution).")
function_failed: bool = Field(..., description="Whether the agent step ended because a function call failed.")
in_context_memory_warning: bool = Field(
..., description="Whether the agent step ended because the in-context memory is near its limit."
)
usage: UsageStatistics = Field(..., description="Usage statistics of the LLM call during the agent's step.")
8 changes: 7 additions & 1 deletion memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def _step(
total_usage = UsageStatistics()
step_count = 0
while True:
new_messages, heartbeat_request, function_failed, token_warning, usage = memgpt_agent.step(
step_response = memgpt_agent.step(
next_input_message,
first_message=False,
skip_verify=no_verify,
Expand All @@ -329,6 +329,12 @@ def _step(
timestamp=timestamp,
ms=self.ms,
)
step_response.messages
heartbeat_request = step_response.heartbeat_request
function_failed = step_response.function_failed
token_warning = step_response.in_context_memory_warning
usage = step_response.usage

step_count += 1
total_usage += usage
counter += 1
Expand Down

0 comments on commit 22ccd3d

Please sign in to comment.