Skip to content

Commit

Permalink
Merge pull request #126 from Yiannis128/fcm_message_history
Browse files Browse the repository at this point in the history
Fix Code Mode: Add message history customization, history can be shown in the following way:
* Latest state only, this means that the LLM will forget about previous iterations
* Reversed, this will reverse all messages excluding system messages
  • Loading branch information
Yiannis128 committed Apr 20, 2024
2 parents 86a590b + 6617c99 commit dc92445
Show file tree
Hide file tree
Showing 8 changed files with 356 additions and 64 deletions.
1 change: 1 addition & 0 deletions config.json
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"generate_solution": {
"max_attempts": 5,
"temperature": 1.3,
"message_history": "normal",
"scenarios": {
"division by zero": {
"system": [
Expand Down
82 changes: 53 additions & 29 deletions esbmc_ai/commands/fix_code_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import sys
from typing import Any, Tuple
from typing_extensions import override
from langchain.schema import AIMessage, HumanMessage

from esbmc_ai.chat_response import FinishReason
from esbmc_ai.latest_state_solution_generator import LatestStateSolutionGenerator
from esbmc_ai.reverse_order_solution_generator import ReverseOrderSolutionGenerator

from .chat_command import ChatCommand
from .. import config
Expand All @@ -18,8 +19,6 @@
from ..solution_generator import (
ESBMCTimedOutException,
SolutionGenerator,
SourceCodeParseError,
get_esbmc_output_formatted,
)
from ..logging import print_horizontal_line, printv, printvv

Expand Down Expand Up @@ -61,21 +60,58 @@ def print_raw_conversation() -> None:
else "Using generic prompt..."
)

match config.fix_code_message_history:
case "normal":
solution_generator = SolutionGenerator(
ai_model_agent=config.chat_prompt_generator_mode,
ai_model=config.ai_model,
llm=config.ai_model.create_llm(
api_keys=config.api_keys,
temperature=config.chat_prompt_generator_mode.temperature,
requests_max_tries=config.requests_max_tries,
requests_timeout=config.requests_timeout,
),
scenario=scenario,
source_code_format=config.source_code_format,
esbmc_output_type=config.esbmc_output_type,
)
case "latest_only":
solution_generator = LatestStateSolutionGenerator(
ai_model_agent=config.chat_prompt_generator_mode,
ai_model=config.ai_model,
llm=config.ai_model.create_llm(
api_keys=config.api_keys,
temperature=config.chat_prompt_generator_mode.temperature,
requests_max_tries=config.requests_max_tries,
requests_timeout=config.requests_timeout,
),
scenario=scenario,
source_code_format=config.source_code_format,
esbmc_output_type=config.esbmc_output_type,
)
case "reverse":
solution_generator = ReverseOrderSolutionGenerator(
ai_model_agent=config.chat_prompt_generator_mode,
ai_model=config.ai_model,
llm=config.ai_model.create_llm(
api_keys=config.api_keys,
temperature=config.chat_prompt_generator_mode.temperature,
requests_max_tries=config.requests_max_tries,
requests_timeout=config.requests_timeout,
),
scenario=scenario,
source_code_format=config.source_code_format,
esbmc_output_type=config.esbmc_output_type,
)
case _:
raise NotImplementedError(
f"error: {config.fix_code_message_history} has not been implemented in the Fix Code Command"
)

try:
solution_generator = SolutionGenerator(
ai_model_agent=config.chat_prompt_generator_mode,
solution_generator.update_state(
source_code=source_code,
esbmc_output=esbmc_output,
ai_model=config.ai_model,
llm=config.ai_model.create_llm(
api_keys=config.api_keys,
temperature=config.chat_prompt_generator_mode.temperature,
requests_max_tries=config.requests_max_tries,
requests_timeout=config.requests_timeout,
),
scenario=scenario,
source_code_format=config.source_code_format,
esbmc_output_type=config.esbmc_output_type,
)
except ESBMCTimedOutException:
print("error: ESBMC has timed out...")
Expand All @@ -93,9 +129,7 @@ def print_raw_conversation() -> None:
llm_solution, finish_reason = solution_generator.generate_solution()
self.anim.stop()
if finish_reason == FinishReason.length:
self.anim.start("Compressing message stack... Please Wait")
solution_generator.compress_message_stack()
self.anim.stop()
else:
source_code = llm_solution
break
Expand Down Expand Up @@ -135,26 +169,16 @@ def print_raw_conversation() -> None:

return False, source_code

# TODO Move this process into Solution Generator since have (beginning) is done
# inside, and the other half is done here.
# Get formatted ESBMC output
try:
esbmc_output = get_esbmc_output_formatted(
esbmc_output_type=config.esbmc_output_type,
esbmc_output=esbmc_output,
)
except SourceCodeParseError:
pass
# Update state
solution_generator.update_state(source_code, esbmc_output)
except ESBMCTimedOutException:
print("error: ESBMC has timed out...")
sys.exit(1)

# Failure case
print(f"ESBMC-AI Notice: Failure {idx+1}/{max_retries}: Retrying...")

# Update state
solution_generator.update_state(source_code, esbmc_output)

if config.raw_conversation:
print_raw_conversation()

Expand Down
13 changes: 13 additions & 0 deletions esbmc_ai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
source_code_format: str = "full"

fix_code_max_attempts: int = 5
fix_code_message_history: str = ""

requests_max_tries: int = 5
requests_timeout: float = 60
Expand All @@ -57,6 +58,7 @@
cfg_path: str


# TODO Get rid of this class as soon as ConfigTool with the pyautoconfig
class AIAgentConversation(NamedTuple):
"""Immutable class describing the conversation definition for an AI agent. The
class represents the system messages of the AI agent defined and contains a load
Expand Down Expand Up @@ -384,6 +386,17 @@ def load_config(file_path: str) -> None:
f"ESBMC output type in the config is not valid: {esbmc_output_type}"
)

global fix_code_message_history
fix_code_message_history, _ = _load_config_value(
config_file=config_file["chat_modes"]["generate_solution"],
name="message_history",
)

if fix_code_message_history not in ["normal", "latest_only", "reverse"]:
raise ValueError(
f"error: fix code mode message history not valid: {fix_code_message_history}"
)

global requests_max_tries
requests_max_tries = int(
_load_config_real_number(
Expand Down
27 changes: 27 additions & 0 deletions esbmc_ai/latest_state_solution_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Author: Yiannis Charalambous

from typing_extensions import override
from langchain_core.messages import BaseMessage
from esbmc_ai.solution_generator import SolutionGenerator
from esbmc_ai.chat_response import FinishReason

# TODO Test me


class LatestStateSolutionGenerator(SolutionGenerator):
"""SolutionGenerator that only shows the latest source code and verifier
output state."""

@override
def generate_solution(self) -> tuple[str, FinishReason]:
# Backup message stack and clear before sending base message. We want
# to keep the message stack intact because we will print it with
# print_raw_conversation.
messages: list[BaseMessage] = self.messages
self.messages: list[BaseMessage] = []
solution, finish_reason = super().generate_solution()
# Append last messages to the messages stack
messages.extend(self.messages)
# Restore
self.messages = messages
return solution, finish_reason
34 changes: 34 additions & 0 deletions esbmc_ai/reverse_order_solution_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Author: Yiannis Charalambous

from langchain.schema import BaseMessage, HumanMessage
from typing_extensions import override, Optional
from esbmc_ai.solution_generator import (
SolutionGenerator,
get_source_code_formatted,
get_source_code_err_line_idx,
get_clang_err_line_index,
apply_line_patch,
)
from esbmc_ai.chat_response import FinishReason, ChatResponse

# TODO Test me


class ReverseOrderSolutionGenerator(SolutionGenerator):
"""SolutionGenerator that shows the source code and verifier output state in
reverse order."""

@override
def send_message(self, message: Optional[str] = None) -> ChatResponse:
# Reverse the messages
messages: list[BaseMessage] = self.messages.copy()
self.messages.reverse()

response: ChatResponse = super().send_message(message)

# Add to the reversed message the new message received by the LLM.
messages.append(self.messages[-1])
# Restore
self.messages = messages

return response
79 changes: 44 additions & 35 deletions esbmc_ai/solution_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,50 +82,43 @@ def get_esbmc_output_formatted(esbmc_output_type: str, esbmc_output: str) -> str
class SolutionGenerator(BaseChatInterface):
def __init__(
self,
ai_model_agent: DynamicAIModelAgent,
ai_model_agent: DynamicAIModelAgent | ChatPromptSettings,
llm: BaseLanguageModel,
source_code: str,
esbmc_output: str,
ai_model: AIModel,
scenario: str = "",
source_code_format: str = "full",
esbmc_output_type: str = "full",
) -> None:
# Convert to chat prompt
chat_prompt: ChatPromptSettings = DynamicAIModelAgent.to_chat_prompt_settings(
ai_model_agent=ai_model_agent, scenario=scenario
)
"""Initializes the solution generator. This ModelChat provides Dynamic
Prompting. Will get the correct scenario from the DynamicAIModelAgent
supplied and create a ChatPrompt."""

chat_prompt: ChatPromptSettings = ai_model_agent
if isinstance(ai_model_agent, DynamicAIModelAgent):
# Convert to chat prompt
chat_prompt = DynamicAIModelAgent.to_chat_prompt_settings(
ai_model_agent=ai_model_agent, scenario=scenario
)

super().__init__(
ai_model_agent=chat_prompt,
ai_model=ai_model,
llm=llm,
)

self.initial_prompt = ai_model_agent.initial_prompt

self.esbmc_output_type: str = esbmc_output_type
self.source_code_format: str = source_code_format
self.source_code_raw: str = source_code
# Used for resetting state.
self._original_source_code: str = source_code

# Format ESBMC output
try:
self.esbmc_output = get_esbmc_output_formatted(
esbmc_output_type=self.esbmc_output_type,
esbmc_output=esbmc_output,
)
except SourceCodeParseError:
# When clang output is displayed, show it entirely as it doesn't get very
# big.
self.esbmc_output = esbmc_output
self.source_code_raw: Optional[str] = None
self.source_code_formatted: Optional[str] = None
self.esbmc_output: Optional[str] = None

@override
def compress_message_stack(self) -> None:
# Resets the conversation - cannot summarize code
# If generate_solution is called after this point, it will start new
# with the currently set state.
self.messages: list[BaseMessage] = []
self.source_code_raw = self._original_source_code

@classmethod
def get_code_from_solution(cls, solution: str) -> str:
Expand Down Expand Up @@ -153,27 +146,43 @@ def get_code_from_solution(cls, solution: str) -> str:
pass
return solution

def update_state(
self, source_code: Optional[str] = None, esbmc_output: Optional[str] = None
) -> None:
if source_code:
self.source_code_raw = source_code
if esbmc_output:
self.esbmc_output = esbmc_output
def update_state(self, source_code: str, esbmc_output: str) -> None:
"""Updates the latest state of the code and ESBMC output. This should be
called before generate_solution."""
self.source_code_raw = source_code

def generate_solution(self) -> tuple[str, FinishReason]:
self.push_to_message_stack(HumanMessage(content=self.initial_prompt))
# Format ESBMC output
try:
self.esbmc_output = get_esbmc_output_formatted(
esbmc_output_type=self.esbmc_output_type,
esbmc_output=esbmc_output,
)
except SourceCodeParseError:
# When clang output is displayed, show it entirely as it doesn't get very
# big.
self.esbmc_output = esbmc_output

# Format source code
source_code_formatted: str = get_source_code_formatted(
self.source_code_formatted = get_source_code_formatted(
source_code_format=self.source_code_format,
source_code=self.source_code_raw,
source_code=source_code,
esbmc_output=self.esbmc_output,
)

def generate_solution(self) -> tuple[str, FinishReason]:
assert (
self.source_code_raw is not None
and self.source_code_formatted is not None
and self.esbmc_output is not None
), "Call update_state before calling generate_solution."

self.push_to_message_stack(
HumanMessage(content=self.ai_model_agent.initial_prompt)
)

# Apply template substitution to message stack
self.apply_template_value(
source_code=source_code_formatted,
source_code=self.source_code_formatted,
esbmc_output=self.esbmc_output,
)

Expand Down
Loading

0 comments on commit dc92445

Please sign in to comment.