Skip to content

Commit

Permalink
feat: add tool calling for cohere (#15144)
Browse files Browse the repository at this point in the history
  • Loading branch information
Anirudh31415926535 committed Aug 16, 2024
1 parent 5236a21 commit b326407
Show file tree
Hide file tree
Showing 5 changed files with 510 additions and 78 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Any, Callable, Dict, Optional, Sequence
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

from llama_index.core.base.llms.types import (
ChatMessage,
Expand All @@ -18,21 +18,30 @@
llm_chat_callback,
llm_completion_callback,
)
from llama_index.core.llms.llm import LLM
import uuid
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.llms.llm import ToolSelection
from llama_index.core.types import BaseOutputParser, PydanticProgramMode
from llama_index.llms.cohere.utils import (
CHAT_MODELS,
_get_message_cohere_format,
_message_to_cohere_tool_results,
_messages_to_cohere_tool_results_curr_chat_turn,
acompletion_with_retry,
cohere_modelname_to_contextsize,
completion_with_retry,
messages_to_cohere_history,
is_cohere_function_calling_model,
remove_documents_from_messages,
format_to_cohere_tools,
)

from llama_index.core.tools.types import BaseTool
import cohere
from cohere.types import (
ToolCall,
)


class Cohere(LLM):
class Cohere(FunctionCallingLLM):
"""Cohere LLM.
Examples:
Expand Down Expand Up @@ -112,6 +121,7 @@ def metadata(self) -> LLMMetadata:
is_chat_model=True,
model_name=self.model,
system_role=MessageRole.CHATBOT,
is_function_calling_model=is_cohere_function_calling_model(self.model),
)

@property
Expand All @@ -131,13 +141,165 @@ def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
**kwargs,
}

def _prepare_chat_with_tools(
self,
tools: List["BaseTool"],
user_msg: Optional[Union[str, ChatMessage]] = None,
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
**kwargs: Any,
) -> Dict[str, Any]:
"""Prepare the chat with tools."""
chat_history = chat_history or []

if isinstance(user_msg, str):
user_msg = ChatMessage(role=MessageRole.USER, content=user_msg)

if user_msg is not None:
chat_history.append(user_msg)

tools_cohere_format = format_to_cohere_tools(tools)
return {
"messages": chat_history,
"tools": tools_cohere_format or [],
**kwargs,
}

def get_tool_calls_from_response(
self,
response: "ChatResponse",
error_on_no_tool_call: bool = False,
) -> List[ToolSelection]:
"""Predict and call the tool."""
tool_calls: List[ToolCall] = (
response.message.additional_kwargs.get("tool_calls", []) or []
)

if len(tool_calls) < 1 and error_on_no_tool_call:
raise ValueError(
f"Expected at least one tool call, but got {len(tool_calls)} tool calls."
)

tool_selections = []
for tool_call in tool_calls:
if not isinstance(tool_call, ToolCall):
raise ValueError("Invalid tool_call object")
tool_selections.append(
ToolSelection(
tool_id=uuid.uuid4().hex[:],
tool_name=tool_call.name,
tool_kwargs=tool_call.parameters,
)
)

return tool_selections

def get_cohere_chat_request(
self,
messages: List[ChatMessage],
*,
connectors: Optional[List[Dict[str, str]]] = None,
stop_sequences: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Get the request for the Cohere chat API.
Args:
messages: The messages.
connectors: The connectors.
**kwargs: The keyword arguments.
Returns:
The request for the Cohere chat API.
"""
additional_kwargs = messages[-1].additional_kwargs

# cohere SDK will fail loudly if both connectors and documents are provided
if additional_kwargs.get("documents", []) and documents and len(documents) > 0:
raise ValueError(
"Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option."
)

messages, documents = remove_documents_from_messages(messages)

tool_results: Optional[
List[Dict[str, Any]]
] = _messages_to_cohere_tool_results_curr_chat_turn(messages) or kwargs.get(
"tool_results"
)
if not tool_results:
tool_results = None

chat_history = []
temp_tool_results = []
# if force_single_step is set to False, then only message is empty in request if there is tool call
if not kwargs.get("force_single_step"):
for i, message in enumerate(messages[:-1]):
# If there are multiple tool messages, then we need to aggregate them into one single tool message to pass into chat history
if message.role == MessageRole.TOOL:
temp_tool_results += _message_to_cohere_tool_results(messages, i)

if (i == len(messages) - 1) or messages[
i + 1
].role != MessageRole.TOOL:
cohere_message = _get_message_cohere_format(
message, temp_tool_results
)
chat_history.append(cohere_message)
temp_tool_results = []
else:
chat_history.append(_get_message_cohere_format(message, None))

message_str = "" if tool_results else messages[-1].content

else:
message_str = ""
# if force_single_step is set to True, then message is the last human message in the conversation
for message in messages[:-1]:
if message.role in (
MessageRole.CHATBOT,
MessageRole.ASSISTANT,
) and message.additional_kwargs.get("tool_calls"):
continue

# If there are multiple tool messages, then we need to aggregate them into one single tool message to pass into chat history
if message.role == MessageRole.TOOL:
temp_tool_results += _message_to_cohere_tool_results(messages, i)

if (i == len(messages) - 1) or messages[
i + 1
].role != MessageRole.TOOL:
cohere_message = _get_message_cohere_format(
message, temp_tool_results
)
chat_history.append(cohere_message)
temp_tool_results = []
else:
chat_history.append(_get_message_cohere_format(message, None))
# Add the last human message in the conversation to the message string
for message in messages[::-1]:
if (message.role == MessageRole.USER) and (message.content):
message_str = message.content
break

req = {
"message": message_str,
"chat_history": chat_history,
"tool_results": tool_results,
"documents": documents,
"connectors": connectors,
"stop_sequences": stop_sequences,
**kwargs,
}
return {k: v for k, v in req.items() if v is not None}

@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
prompt = messages[-1].content
remaining, documents = remove_documents_from_messages(messages[:-1])
history = messages_to_cohere_history(remaining)

all_kwargs = self._get_all_kwargs(**kwargs)

chat_request = self.get_cohere_chat_request(messages=messages, **all_kwargs)

if all_kwargs["model"] not in CHAT_MODELS:
raise ValueError(f"{all_kwargs['model']} not supported for chat")

Expand All @@ -146,18 +308,27 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
"Parameter `stream` is not supported by the `chat` method."
"Use the `stream_chat` method instead"
)

response = completion_with_retry(
client=self._client,
max_retries=self.max_retries,
chat=True,
message=prompt,
chat_history=history,
documents=documents,
**all_kwargs,
client=self._client, max_retries=self.max_retries, chat=True, **chat_request
)
if not isinstance(response, cohere.NonStreamedChatResponse):
tool_calls = response.get("tool_calls")
content = response.get("text")
response_raw = response

else:
tool_calls = response.tool_calls
content = response.text
response_raw = response.__dict__

return ChatResponse(
message=ChatMessage(role=MessageRole.ASSISTANT, content=response.text),
raw=response.__dict__,
message=ChatMessage(
role=MessageRole.ASSISTANT,
content=content,
additional_kwargs={"tool_calls": tool_calls},
),
raw=response_raw,
)

@llm_completion_callback()
Expand Down Expand Up @@ -188,22 +359,15 @@ def complete(
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
prompt = messages[-1].content
remaining, documents = remove_documents_from_messages(messages[:-1])
history = messages_to_cohere_history(remaining)

all_kwargs = self._get_all_kwargs(**kwargs)
all_kwargs["stream"] = True
if all_kwargs["model"] not in CHAT_MODELS:
raise ValueError(f"{all_kwargs['model']} not supported for chat")

chat_request = self.get_cohere_chat_request(messages=messages, **all_kwargs)

response = completion_with_retry(
client=self._client,
max_retries=self.max_retries,
chat=True,
message=prompt,
chat_history=history,
documents=documents,
**all_kwargs,
client=self._client, max_retries=self.max_retries, chat=True, **chat_request
)

def gen() -> ChatResponseGen:
Expand Down Expand Up @@ -253,8 +417,6 @@ def gen() -> CompletionResponseGen:
async def achat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponse:
history = messages_to_cohere_history(messages[:-1])
prompt = messages[-1].content
all_kwargs = self._get_all_kwargs(**kwargs)
if all_kwargs["model"] not in CHAT_MODELS:
raise ValueError(f"{all_kwargs['model']} not supported for chat")
Expand All @@ -264,13 +426,10 @@ async def achat(
"Use the `stream_chat` method instead"
)

chat_request = self.get_cohere_chat_request(messages, **all_kwargs)

response = await acompletion_with_retry(
aclient=self._aclient,
max_retries=self.max_retries,
chat=True,
message=prompt,
chat_history=history,
**all_kwargs,
client=self._client, max_retries=self.max_retries, chat=True, **chat_request
)

return ChatResponse(
Expand Down Expand Up @@ -306,19 +465,15 @@ async def acomplete(
async def astream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseAsyncGen:
history = messages_to_cohere_history(messages[:-1])
prompt = messages[-1].content
all_kwargs = self._get_all_kwargs(**kwargs)
all_kwargs["stream"] = True
if all_kwargs["model"] not in CHAT_MODELS:
raise ValueError(f"{all_kwargs['model']} not supported for chat")
response = await acompletion_with_retry(
aclient=self._aclient,
max_retries=self.max_retries,
chat=True,
message=prompt,
chat_history=history,
**all_kwargs,

chat_request = self.get_cohere_chat_request(messages, **all_kwargs)

response = completion_with_retry(
client=self._client, max_retries=self.max_retries, chat=True, **chat_request
)

async def gen() -> ChatResponseAsyncGen:
Expand Down
Loading

0 comments on commit b326407

Please sign in to comment.