Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langchain-google-vertexai: Update vision_models.py to allow kwargs usage #473

Merged
merged 4 commits into from
Oct 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions libs/vertexai/langchain_google_vertexai/vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.outputs.generation import Generation
from pydantic import BaseModel, ConfigDict, Field
from vertexai.preview.vision_models import ( # type: ignore[import-untyped]
from vertexai.vision_models import ( # type: ignore[import-untyped]
GeneratedImage,
Image,
ImageGenerationModel,
ImageTextModel,
)
from vertexai.vision_models import Image, ImageTextModel # type: ignore[import-untyped]

from langchain_google_vertexai._image_utils import (
ImageBytesLoader,
Expand Down Expand Up @@ -97,7 +98,7 @@ def _default_params(self) -> Dict[str, Any]:
def _prepare_params(self, **kwargs: Any) -> Dict[str, Any]:
params = self._default_params
for key, value in kwargs.items():
if key in params and value is not None:
if value is not None:
params[key] = value
return params

Expand All @@ -110,6 +111,7 @@ def _get_captions(
image: Image,
number_of_results: Optional[int] = None,
language: Optional[str] = None,
**kwargs,
) -> List[str]:
"""Uses the sdk methods to generate a list of captions.

Expand All @@ -123,7 +125,7 @@ def _get_captions(
"""
with telemetry.tool_context_manager(self._user_agent):
params = self._prepare_params(
number_of_results=number_of_results, language=language
number_of_results=number_of_results, language=language, **kwargs
)
captions = self.client.get_captions(image=image, **params)
return captions
Expand Down Expand Up @@ -224,7 +226,7 @@ def _generate(
"{'type': 'image_url', 'image_url': {'image': <image_str>}}"
)

captions = self._get_captions(image, **messages[0].additional_kwargs)
captions = self._get_captions(image, **messages[0].additional_kwargs, **kwargs)

generations = [
ChatGeneration(message=AIMessage(content=caption)) for caption in captions
Expand Down Expand Up @@ -287,7 +289,7 @@ def _generate(
)

answers = self._ask_questions(
image=image, query=user_question, **messages[0].additional_kwargs
image=image, query=user_question, **messages[0].additional_kwargs, **kwargs
)

generations = [
Expand Down Expand Up @@ -361,7 +363,7 @@ def _prepare_params(self, **kwargs: Any) -> Dict[str, Any]:
mapping = {"number_of_results": "number_of_images"}
for key, value in kwargs.items():
key = mapping.get(key, key)
if key in params and value is not None:
if value is not None:
params[key] = value
return {k: v for k, v in params.items() if v is not None}

Expand Down Expand Up @@ -477,7 +479,7 @@ def _generate(
)

image_str_list = self._generate_images(
prompt=user_query, **messages[0].additional_kwargs
prompt=user_query, **messages[0].additional_kwargs, **kwargs
)
image_content_part_list = [
create_image_content_part(image_str=image_str)
Expand Down Expand Up @@ -526,7 +528,10 @@ def _generate(
)

image_str_list = self._edit_images(
image_str=image_str, prompt=user_query, **messages[0].additional_kwargs
image_str=image_str,
prompt=user_query,
**messages[0].additional_kwargs,
**kwargs,
)
image_content_part_list = [
create_image_content_part(image_str=image_str)
Expand Down
Loading