Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

genai: fix pydantic structured_output with array #469

Merged
merged 3 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 85 additions & 11 deletions libs/genai/langchain_google_genai/_function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,20 +270,97 @@ def _convert_pydantic_to_genai_function(
name=tool_name if tool_name else schema.get("title"),
description=tool_description if tool_description else schema.get("description"),
parameters={
"properties": {
k: {
"type_": _get_type_from_schema(v),
"description": v.get("description"),
}
for k, v in schema["properties"].items()
},
"properties": _get_properties_from_schema_any(
schema.get("properties")
), # TODO: use _dict_to_gapic_schema() if possible
# "items": _get_items_from_schema_any(
# schema
# ), # TODO: fix it https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/function-calling?hl#schema
"required": schema.get("required", []),
"type_": TYPE_ENUM[schema["type"]],
},
)
return function_declaration


def _get_properties_from_schema_any(schema: Any) -> Dict[str, Any]:
if isinstance(schema, Dict):
return _get_properties_from_schema(schema)
return {}


def _get_properties_from_schema(schema: Dict) -> Dict[str, Any]:
properties = {}
for k, v in schema.items():
if not isinstance(k, str):
logger.warning(f"Key '{k}' is not supported in schema, type={type(k)}")
continue
if not isinstance(v, Dict):
logger.warning(f"Value '{v}' is not supported in schema, ignoring v={v}")
continue
properties_item: Dict[str, Union[str, int, Dict, List]] = {}
if v.get("type") or v.get("anyOf"):
properties_item["type_"] = _get_type_from_schema(v)

if v.get("enum"):
properties_item["enum"] = v["enum"]

description = v.get("description")
if description and isinstance(description, str):
properties_item["description"] = description

if v.get("type") == "array" and v.get("items"):
properties_item["items"] = _get_items_from_schema_any(v.get("items"))

if v.get("type") == "object" and v.get("properties"):
properties_item["properties"] = _get_properties_from_schema_any(
v.get("properties")
)
if k == "title" and "description" not in properties_item:
properties_item["description"] = k + " is " + str(v)

properties[k] = properties_item

return properties


def _get_items_from_schema_any(schema: Any) -> Dict[str, Any]:
if isinstance(schema, Dict):
return _get_items_from_schema(schema)
if isinstance(schema, List):
return _get_items_from_schema(schema)
if isinstance(schema, str):
return _get_items_from_schema(schema)
return {}


def _get_items_from_schema(schema: Union[Dict, List, str]) -> Dict[str, Any]:
items: Dict = {}
if isinstance(schema, List):
for i, v in enumerate(schema):
items[f"item{i}"] = _get_properties_from_schema_any(v)
elif isinstance(schema, Dict):
item: Dict = {}
for k, v in schema.items():
item["type_"] = _get_type_from_schema(v)
if not isinstance(v, Dict):
logger.warning(
f"Value '{v}' is not supported in schema, ignoring v={v}"
)
continue
if v.get("type") == "object" and v.get("properties"):
item["properties"] = _get_properties_from_schema_any(
v.get("properties")
)
if k == "title" and "description" not in item:
item["description"] = v
items = item
else:
# str
items["type_"] = TYPE_ENUM.get(str(schema), glm.Type.STRING)
return items


def _get_type_from_schema(schema: Dict[str, Any]) -> int:
if "anyOf" in schema:
types = [_get_type_from_schema(sub_schema) for sub_schema in schema["anyOf"]]
Expand All @@ -294,10 +371,7 @@ def _get_type_from_schema(schema: Dict[str, Any]) -> int:
pass
elif "type" in schema:
stype = str(schema["type"])
if stype in TYPE_ENUM:
return TYPE_ENUM[stype]
else:
pass
return TYPE_ENUM.get(stype, glm.Type.STRING)
else:
pass
return TYPE_ENUM["string"] # Default to string if no valid types found
Expand Down
36 changes: 25 additions & 11 deletions libs/genai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,43 +335,55 @@ def _check_tool_calls(response: BaseMessage, expected_name: str) -> None:
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == ""

# function_call
function_call = response.additional_kwargs.get("function_call")
assert function_call
assert function_call["name"] == expected_name
arguments_str = function_call.get("arguments")
assert arguments_str
arguments = json.loads(arguments_str)
assert arguments == {
"name": "Erick",
"age": 27.0,
}
_check_tool_call_args(arguments)

# tool_calls
tool_calls = response.tool_calls
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert tool_call["name"] == expected_name
assert tool_call["args"] == {"age": 27.0, "name": "Erick"}
_check_tool_call_args(tool_call["args"])


def _check_tool_call_args(tool_call_args: dict) -> None:
assert tool_call_args == {
"age": 27.0,
"name": "Erick",
"likes": ["apple", "banana"],
}


@pytest.mark.extended
def test_chat_vertexai_gemini_function_calling() -> None:
class MyModel(BaseModel):
name: str
age: int
likes: list[str]

safety: Dict[HarmCategory, HarmBlockThreshold] = {
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH # type: ignore[dict-item]
}
# Test .bind_tools with BaseModel
message = HumanMessage(content="My name is Erick and I am 27 years old")
message = HumanMessage(
content="My name is Erick and I am 27 years old. I like apple and banana."
)
model = ChatGoogleGenerativeAI(model=_MODEL, safety_settings=safety).bind_tools(
[MyModel]
)
response = model.invoke([message])
_check_tool_calls(response, "MyModel")

# Test .bind_tools with function
def my_model(name: str, age: int) -> None:
"""Invoke this with names and ages."""
def my_model(name: str, age: int, likes: list[str]) -> None:
"""Invoke this with names and age and likes."""
pass

model = ChatGoogleGenerativeAI(model=_MODEL, safety_settings=safety).bind_tools(
Expand All @@ -382,8 +394,8 @@ def my_model(name: str, age: int) -> None:

# Test .bind_tools with tool
@tool
def my_tool(name: str, age: int) -> None:
"""Invoke this with names and ages."""
def my_tool(name: str, age: int, likes: list[str]) -> None:
"""Invoke this with names and age and likes."""
pass

model = ChatGoogleGenerativeAI(model=_MODEL, safety_settings=safety).bind_tools(
Expand All @@ -405,7 +417,9 @@ def my_tool(name: str, age: int) -> None:
assert len(gathered.tool_call_chunks) == 1
tool_call_chunk = gathered.tool_call_chunks[0]
assert tool_call_chunk["name"] == "my_tool"
assert tool_call_chunk["args"] == '{"age": 27.0, "name": "Erick"}'
arguments_str = tool_call_chunk["args"]
arguments = json.loads(str(arguments_str))
_check_tool_call_args(arguments)


# Test with model that supports tool choice (gemini 1.5) and one that doesn't
Expand Down
1 change: 1 addition & 0 deletions libs/genai/tests/unit_tests/test_function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def test_tool_to_dict_pydantic() -> None:
class MyModel(BaseModel):
name: str
age: int
likes: list[str]

gapic_tool = convert_to_genai_function_declarations([MyModel])
tool_dict = tool_to_dict(gapic_tool)
Expand Down
Loading