Skip to content

Commit

Permalink
Merge branch 'develop' into converting-whitelistsharing-reacthookform
Browse files Browse the repository at this point in the history
  • Loading branch information
tomlynchRNA authored Sep 20, 2024
2 parents 6d74847 + 06f63f1 commit 71676c0
Show file tree
Hide file tree
Showing 48 changed files with 1,101 additions and 637 deletions.
26 changes: 25 additions & 1 deletion agent-backend/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions agent-backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ langgraph = "^0.2.16"
langchain-ollama = "^0.1.1"
pinecone-client = "^5.0.1"
langchain-google-genai = "^1.0.10"
motor = "^3.5.1"
minio = "^7.2.8"


Expand Down
12 changes: 10 additions & 2 deletions agent-backend/src/chat/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def send_to_socket(self, text='', event=SocketEvents.MESSAGE, first=True, chunk_
if timestamp is None:
timestamp = int(datetime.now().timestamp() * 1000)

if len(text.rstrip()) == 0 and event == SocketEvents.MESSAGE:
if len(text) == 0 and event == SocketEvents.MESSAGE:
return # Don't send empty first messages

# send the message
Expand Down Expand Up @@ -95,7 +95,7 @@ async def stream_execute(self):
config = {"configurable": {"thread_id": self.session_id}}

while True:
past_messages = self.graph.get_state(config).values.get("messages")
past_messages = (await self.graph.aget_state(config)).values.get("messages")

if past_messages:
if self._max_messages_limit_reached(past_messages):
Expand Down Expand Up @@ -148,6 +148,7 @@ async def stream_execute(self):
case "on_parser_stream":
self.logger.debug(f"Parser chunk ({kind}): {event['data']['chunk']}")

# tool chat message finished
case "on_chain_end":
# input_messages = event['data']['input']['messages'] \
# if ('input' in event['data'] and 'messages' in event['data']['input']) \
Expand Down Expand Up @@ -234,4 +235,11 @@ async def stream_execute(self):
```
""", event=SocketEvents.MESSAGE, first=True, chunk_id=str(uuid.uuid4()),
timestamp=datetime.now().timestamp() * 1000, display_type="bubble")

if 'Error code: 400' in str(chunk_error):
# Terminate on some 400's (from openAI) such as not being on a high enough tier for a model, to prevent an infinite loop of error
self.send_to_socket(
event=SocketEvents.STOP_GENERATING,
chunk_id=str(uuid.uuid4()),
)
pass
33 changes: 26 additions & 7 deletions agent-backend/src/chat/agents/open_ai.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import uuid

from langchain_core.messages import HumanMessage, ToolMessage, AIMessage
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.constants import END
from langgraph.graph import StateGraph, MessagesState, START
from langgraph.prebuilt import ToolNode

from chat.agents.base import BaseChatAgent
from chat.mongo_db_saver import AsyncMongoDBSaver
from tools.global_tools import CustomHumanInput


Expand Down Expand Up @@ -66,11 +67,6 @@ def build_graph(self):
graph.add_node("tools", tools_node)
graph.add_node("human_input_invoker", self.invoke_human_input)

graph.add_edge(START, "human_input_invoker")
graph.add_edge("human_input_invoker", "human_input")
graph.add_edge("human_input", "chat_model")
graph.add_edge("tools", "chat_model")

def should_continue(state):
messages = state["messages"]
last_message = messages[-1]
Expand All @@ -84,6 +80,29 @@ def should_continue(state):
else:
return "continue"

def start_condition(state):
messages = state["messages"]
last_message = messages[-1]
if isinstance(last_message, AIMessage) and last_message.tool_calls:
if last_message.tool_calls[0]["name"] == "human_input":
return "human_input"
return "tools"
else:
return "human_input_invoker"

graph.add_conditional_edges(
START,
start_condition,
{
# If `tools`, then we call the tool node.
"tools": "tools",
"human_input": "human_input",
"human_input_invoker": "human_input_invoker"
},
)
graph.add_edge("human_input_invoker", "human_input")
graph.add_edge("human_input", "chat_model")
graph.add_edge("tools", "chat_model")
graph.add_conditional_edges(
"chat_model",
should_continue,
Expand All @@ -97,6 +116,6 @@ def should_continue(state):
)

# Set up memory
memory = MemorySaver()
memory = AsyncMongoDBSaver()

return graph.compile(checkpointer=memory)
41 changes: 20 additions & 21 deletions agent-backend/src/chat/chat_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,65 +3,65 @@
import re

from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import SystemMessage
from langchain_core.tools import BaseTool
from socketio import SimpleClient
from socketio.exceptions import ConnectionError

from chat.agents.base import BaseChatAgent
from chat.agents.factory import chat_agent_factory
from init.env_variables import SOCKET_URL, AGENT_BACKEND_SOCKET_TOKEN
from init.env_variables import SOCKET_URL, AGENT_BACKEND_SOCKET_TOKEN, MONGO_DB_NAME
from init.mongo_session import start_mongo_session
from lang_models import model_factory as language_model_factory
from models.mongo import App, Tool, Datasource, Model, ToolType, Agent
from tools import RagTool, GoogleCloudFunctionTool
from tools.builtin_tools import BuiltinTools


class ChatAssistant:
chat_model: BaseLanguageModel
tools: list[BaseTool]
chat_agent: BaseChatAgent
system_message: str
agent_name: str
max_messages: int
logger = logging.getLogger(__name__)


class ChatAssistant:
def __init__(self, session_id: str):
self.session_id = session_id
self.socket = SimpleClient()
self.mongo_client = start_mongo_session()
self.mongo_conn = start_mongo_session()
self.chat_model: BaseLanguageModel
self.tools: list[BaseTool]
self.system_message: str
self.agent_name: str
self.max_messages: int
self.init_socket()
self.init_app_state()
self.chat_agent = chat_agent_factory(chat_assistant_obj=self)

def init_socket(self):
try:
# Initialize the socket client and connect
logging.debug(f"Socket URL: {SOCKET_URL}")
logger.debug(f"Socket URL: {SOCKET_URL}")
custom_headers = {"x-agent-backend-socket-token": AGENT_BACKEND_SOCKET_TOKEN}
self.socket.connect(url=SOCKET_URL, headers=custom_headers)
self.socket.emit("join_room", f"_{self.session_id}")
except ConnectionError as ce:
logging.error(f"Connection error occurred: {ce}")
logger.error(f"Connection error occurred: {ce}")
raise

def init_app_state(self):
session = self.mongo_client.get_session(self.session_id)
session = self.mongo_conn.get_session(self.session_id)

app = self.mongo_client.get_single_model_by_id("apps", App, session.get('appId'))
app = self.mongo_conn.get_single_model_by_id("apps", App, session.get('appId'))

app_config = app.chatAppConfig
if not app_config:
raise

agentcloud_agent = self.mongo_client.get_single_model_by_id("agents", Agent, app_config.agentId)
agentcloud_agent = self.mongo_conn.get_single_model_by_id("agents", Agent, app_config.agentId)
self.agent_name = agentcloud_agent.name

agentcloud_tools = self.mongo_client.get_models_by_ids("tools", Tool, agentcloud_agent.toolIds)
agentcloud_tools = self.mongo_conn.get_models_by_ids("tools", Tool, agentcloud_agent.toolIds)

self.system_message = '\n'.join([agentcloud_agent.role, agentcloud_agent.goal, agentcloud_agent.backstory])

model = self.mongo_client.get_single_model_by_id("models", Model, agentcloud_agent.modelId)
model = self.mongo_conn.get_single_model_by_id("models", Model, agentcloud_agent.modelId)
self.chat_model = language_model_factory(model)

self.tools = list(map(self._make_langchain_tool, agentcloud_tools))
Expand All @@ -82,9 +82,9 @@ def _make_langchain_tool(self, agentcloud_tool: Tool):
agentcloud_tool.name = self._transform_tool_name(agentcloud_tool.name)

if agentcloud_tool.type == ToolType.RAG_TOOL:
datasource = self.mongo_client.get_single_model_by_id("datasources", Datasource,
datasource = self.mongo_conn.get_single_model_by_id("datasources", Datasource,
agentcloud_tool.datasourceId)
embedding_model = self.mongo_client.get_single_model_by_id("models", Model, datasource.modelId)
embedding_model = self.mongo_conn.get_single_model_by_id("models", Model, datasource.modelId)
embedding = language_model_factory(embedding_model)

return RagTool.factory(tool=agentcloud_tool,
Expand All @@ -96,8 +96,7 @@ def _make_langchain_tool(self, agentcloud_tool: Tool):
elif agentcloud_tool.type == ToolType.BUILTIN_TOOL:
tool_name = agentcloud_tool.data.name
if agentcloud_tool.linkedToolId:
linked_tool = self.mongo_client.get_tool(agentcloud_tool.linkedToolId)
print(f"linked_tool: {linked_tool}")
linked_tool = self.mongo_conn.get_tool(agentcloud_tool.linkedToolId)
if linked_tool:
tool_class = BuiltinTools.get_tool_class(linked_tool.data.name)
else:
Expand Down
Loading

0 comments on commit 71676c0

Please sign in to comment.