From 22ccd3d8662e1a1efb7714891d17e9fb002dda1d Mon Sep 17 00:00:00 2001 From: cpacker Date: Fri, 13 Sep 2024 16:56:28 -0700 Subject: [PATCH] refactor: clean up agent.step() by having a response model instead of a response tuple --- memgpt/agent.py | 119 +++++++++++++++++++--------------------- memgpt/main.py | 7 ++- memgpt/schemas/agent.py | 17 +++++- memgpt/server/server.py | 8 ++- 4 files changed, 85 insertions(+), 66 deletions(-) diff --git a/memgpt/agent.py b/memgpt/agent.py index c97daa219a..63c85d95b9 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -22,7 +22,7 @@ from memgpt.memory import ArchivalMemory, RecallMemory, summarize_messages from memgpt.metadata import MetadataStore from memgpt.persistence_manager import LocalStateManager -from memgpt.schemas.agent import AgentState +from memgpt.schemas.agent import AgentState, AgentStepResponse from memgpt.schemas.block import Block from memgpt.schemas.embedding_config import EmbeddingConfig from memgpt.schemas.enums import MessageRole, OptionState @@ -196,15 +196,10 @@ class BaseAgent(ABC): Only two interfaces are required: step and update_state. """ - # @abstractmethod - # def step(self, message: Message) -> List[Message]: - # raise NotImplementedError - - # TODO cleanup @abstractmethod def step( self, - user_message: Union[Message, str], # NOTE: should be json.dump(dict) + messages: Union[Message, List[Message], str], # TODO deprecate str inputs first_message: bool = False, first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS, skip_verify: bool = False, @@ -214,10 +209,11 @@ def step( timestamp: Optional[datetime.datetime] = None, inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT, ms: Optional[MetadataStore] = None, - ) -> Tuple[List[Union[dict, Message]], bool, bool, bool]: + ) -> AgentStepResponse: """ Top-level event message handler for the agent. """ + raise NotImplementedError @abstractmethod def update_state(self) -> AgentState: @@ -708,45 +704,21 @@ def _handle_ai_response( def step( self, - user_message: Union[Message, str], # NOTE: should be json.dump(dict) + user_message: Union[Message, None, str], # NOTE: should be json.dump(dict) first_message: bool = False, first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS, skip_verify: bool = False, - return_dicts: bool = True, # if True, return dicts, if False, return Message objects + return_dicts: bool = True, recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field stream: bool = False, # TODO move to config? timestamp: Optional[datetime.datetime] = None, inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT, ms: Optional[MetadataStore] = None, - ) -> Tuple[List[Union[dict, Message]], bool, bool, bool]: + ) -> AgentStepResponse: """Top-level event message handler for the MemGPT agent""" - def strip_name_field_from_user_message(user_message_text: str) -> Tuple[str, Optional[str]]: - """If 'name' exists in the JSON string, remove it and return the cleaned text + name value""" - try: - user_message_json = dict(json_loads(user_message_text)) - # Special handling for AutoGen messages with 'name' field - # Treat 'name' as a special field - # If it exists in the input message, elevate it to the 'message' level - name = user_message_json.pop("name", None) - clean_message = json_dumps(user_message_json) - - except Exception as e: - print(f"{CLI_WARNING_PREFIX}handling of 'name' field failed with: {e}") - - return clean_message, name - - def validate_json(user_message_text: str, raise_on_error: bool) -> str: - try: - user_message_json = dict(json_loads(user_message_text)) - user_message_json_val = json_dumps(user_message_json) - return user_message_json_val - except Exception as e: - print(f"{CLI_WARNING_PREFIX}couldn't parse user input message as JSON: {e}") - if raise_on_error: - raise e - try: + # Step 0: update core memory # only pulling latest block data if shared memory is being used # TODO: ensure we're passing in metadata store from all surfaces @@ -760,11 +732,14 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str: # once we ensure we're correctly comparing whether in-memory core # data is different than persisted core data. self.rebuild_memory(force=True, ms=ms) + # Step 1: add user message if user_message is not None: if isinstance(user_message, Message): + assert user_message.text is not None + # Validate JSON via save/load - user_message_text = validate_json(user_message.text, False) + user_message_text = validate_json(user_message.text) cleaned_user_message_text, name = strip_name_field_from_user_message(user_message_text) if name is not None: @@ -778,7 +753,7 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str: elif isinstance(user_message, str): # Validate JSON via save/load - user_message = validate_json(user_message, False) + user_message = validate_json(user_message) cleaned_user_message_text, name = strip_name_field_from_user_message(user_message) # If user_message['name'] is not None, it will be handled properly by dict_to_message @@ -799,6 +774,7 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str: self.interface.user_message(user_message.text, msg_obj=user_message) input_message_sequence = self._messages + [user_message] + # Alternatively, the requestor can send an empty user message else: input_message_sequence = self._messages @@ -806,7 +782,7 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str: if len(input_message_sequence) > 1 and input_message_sequence[-1].role != "user": printd(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue") - # Step 2: send the conversation and available functions to GPT + # Step 2: send the conversation and available functions to the LLM if not skip_verify and (first_message or self.messages_total == self.messages_total_init): printd(f"This is the first message. Running extra verifier on AI response.") counter = 0 @@ -843,17 +819,6 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str: response_message_id=response.id if stream else None, ) - # Add the extra metadata to the assistant response - # (e.g. enough metadata to enable recreating the API call) - # assert "api_response" not in all_response_messages[0] - # all_response_messages[0]["api_response"] = response_message_copy - # assert "api_args" not in all_response_messages[0] - # all_response_messages[0]["api_args"] = { - # "model": self.model, - # "messages": input_message_sequence, - # "functions": self.functions, - # } - # Step 6: extend the message history if user_message is not None: if isinstance(user_message, Message): @@ -866,6 +831,7 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str: # Check the memory pressure and potentially issue a memory pressure warning current_total_tokens = response.usage.total_tokens active_memory_warning = False + # We can't do summarize logic properly if context_window is undefined if self.agent_state.llm_config.context_window is None: # Fallback if for some reason context_window is missing, just set to the default @@ -874,14 +840,17 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str: self.agent_state.llm_config.context_window = ( LLM_MAX_TOKENS[self.model] if (self.model is not None and self.model in LLM_MAX_TOKENS) else LLM_MAX_TOKENS["DEFAULT"] ) + if current_total_tokens > MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window): printd( f"{CLI_WARNING_PREFIX}last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}" ) + # Only deliver the alert if we haven't already (this period) if not self.agent_alerted_about_memory_pressure: active_memory_warning = True self.agent_alerted_about_memory_pressure = True # it's up to the outer loop to handle this + else: printd( f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}" @@ -893,7 +862,13 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str: # update state after each step self.update_state() - return messages_to_return, heartbeat_request, function_failed, active_memory_warning, response.usage + return AgentStepResponse( + messages=messages_to_return, + heartbeat_request=heartbeat_request, + function_failed=function_failed, + in_context_memory_warning=active_memory_warning, + usage=response.usage, + ) except Exception as e: printd(f"step() failed\nuser_message = {user_message}\nerror = {e}") @@ -937,15 +912,6 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, candidate_messages_to_summarize = candidate_messages_to_summarize[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST] token_counts = token_counts[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST] - # if disallow_tool_as_first: - # # We have to make sure that a "tool" call is not sitting at the front (after system message), - # # otherwise we'll get an error from OpenAI (if using the OpenAI API) - # while len(candidate_messages_to_summarize) > 0: - # if candidate_messages_to_summarize[0]["role"] in ["tool", "function"]: - # candidate_messages_to_summarize.pop(0) - # else: - # break - printd(f"MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC={MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC}") printd(f"MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}") printd(f"token_counts={token_counts}") @@ -1386,9 +1352,11 @@ def retry_message(self) -> List[Message]: self.pop_until_user() user_message = self.pop_message(count=1)[0] - messages, _, _, _, _ = self.step(user_message=user_message.text, return_dicts=False) + step_response = self.step(user_message=user_message.text, return_dicts=False) + messages = step_response.messages - assert messages is not None and all(isinstance(msg, Message) for msg in messages), "step() returned non-Message objects" + assert messages is not None + assert all(isinstance(msg, Message) for msg in messages), "step() returned non-Message objects" return messages @@ -1430,3 +1398,30 @@ def save_agent_memory(agent: Agent, ms: MetadataStore): if block.value is None: block.value = "" ms.update_or_create_block(block) + + +def strip_name_field_from_user_message(user_message_text: str) -> Tuple[str, Optional[str]]: + """If 'name' exists in the JSON string, remove it and return the cleaned text + name value""" + try: + user_message_json = dict(json_loads(user_message_text)) + # Special handling for AutoGen messages with 'name' field + # Treat 'name' as a special field + # If it exists in the input message, elevate it to the 'message' level + name = user_message_json.pop("name", None) + clean_message = json_dumps(user_message_json) + return clean_message, name + + except Exception as e: + print(f"{CLI_WARNING_PREFIX}handling of 'name' field failed with: {e}") + raise e + + +def validate_json(user_message_text: str) -> str: + """Make sure that the user input message is valid JSON""" + try: + user_message_json = dict(json_loads(user_message_text)) + user_message_json_val = json_dumps(user_message_json) + return user_message_json_val + except Exception as e: + print(f"{CLI_WARNING_PREFIX}couldn't parse user input message as JSON: {e}") + raise e diff --git a/memgpt/main.py b/memgpt/main.py index fceaff6b23..d13839372d 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -366,7 +366,7 @@ def run_agent_loop( skip_next_user_input = False def process_agent_step(user_message, no_verify): - new_messages, heartbeat_request, function_failed, token_warning, tokens_accumulated = memgpt_agent.step( + step_response = memgpt_agent.step( user_message, first_message=False, skip_verify=no_verify, @@ -374,6 +374,11 @@ def process_agent_step(user_message, no_verify): inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, ms=ms, ) + new_messages = step_response.messages + heartbeat_request = step_response.heartbeat_request + function_failed = step_response.function_failed + token_warning = step_response.in_context_memory_warning + step_response.usage agent.save_agent(memgpt_agent, ms) skip_next_user_input = False diff --git a/memgpt/schemas/agent.py b/memgpt/schemas/agent.py index 54a9aec147..1826c78484 100644 --- a/memgpt/schemas/agent.py +++ b/memgpt/schemas/agent.py @@ -1,13 +1,15 @@ import uuid from datetime import datetime -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union -from pydantic import Field, field_validator +from pydantic import BaseModel, Field, field_validator from memgpt.schemas.embedding_config import EmbeddingConfig from memgpt.schemas.llm_config import LLMConfig from memgpt.schemas.memgpt_base import MemGPTBase from memgpt.schemas.memory import Memory +from memgpt.schemas.message import Message +from memgpt.schemas.openai.chat_completion_response import UsageStatistics class BaseAgent(MemGPTBase, validate_assignment=True): @@ -102,3 +104,14 @@ class UpdateAgentState(BaseAgent): # TODO: determine if these should be editable via this schema? message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.") memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.") + + +class AgentStepResponse(BaseModel): + # TODO remove support for list of dicts + messages: Union[List[Message], List[dict]] = Field(..., description="The messages generated during the agent's step.") + heartbeat_request: bool = Field(..., description="Whether the agent requested a heartbeat (i.e. follow-up execution).") + function_failed: bool = Field(..., description="Whether the agent step ended because a function call failed.") + in_context_memory_warning: bool = Field( + ..., description="Whether the agent step ended because the in-context memory is near its limit." + ) + usage: UsageStatistics = Field(..., description="Usage statistics of the LLM call during the agent's step.") diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 834117d046..3a25de0fa5 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -320,7 +320,7 @@ def _step( total_usage = UsageStatistics() step_count = 0 while True: - new_messages, heartbeat_request, function_failed, token_warning, usage = memgpt_agent.step( + step_response = memgpt_agent.step( next_input_message, first_message=False, skip_verify=no_verify, @@ -329,6 +329,12 @@ def _step( timestamp=timestamp, ms=self.ms, ) + step_response.messages + heartbeat_request = step_response.heartbeat_request + function_failed = step_response.function_failed + token_warning = step_response.in_context_memory_warning + usage = step_response.usage + step_count += 1 total_usage += usage counter += 1