From 9ebd8c11168952b9c33f6dbbf0262155796c905d Mon Sep 17 00:00:00 2001 From: TommasoPetrolito Date: Thu, 3 Oct 2024 11:04:22 +0200 Subject: [PATCH] langchain-google-vertexai: Update vision_models.py to allow kwargs usage (#473) --- .../vision_models.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/vision_models.py b/libs/vertexai/langchain_google_vertexai/vision_models.py index 12bcc0ec..616a1590 100644 --- a/libs/vertexai/langchain_google_vertexai/vision_models.py +++ b/libs/vertexai/langchain_google_vertexai/vision_models.py @@ -10,11 +10,12 @@ from langchain_core.outputs.chat_generation import ChatGeneration from langchain_core.outputs.generation import Generation from pydantic import BaseModel, ConfigDict, Field -from vertexai.preview.vision_models import ( # type: ignore[import-untyped] +from vertexai.vision_models import ( # type: ignore[import-untyped] GeneratedImage, + Image, ImageGenerationModel, + ImageTextModel, ) -from vertexai.vision_models import Image, ImageTextModel # type: ignore[import-untyped] from langchain_google_vertexai._image_utils import ( ImageBytesLoader, @@ -97,7 +98,7 @@ def _default_params(self) -> Dict[str, Any]: def _prepare_params(self, **kwargs: Any) -> Dict[str, Any]: params = self._default_params for key, value in kwargs.items(): - if key in params and value is not None: + if value is not None: params[key] = value return params @@ -110,6 +111,7 @@ def _get_captions( image: Image, number_of_results: Optional[int] = None, language: Optional[str] = None, + **kwargs, ) -> List[str]: """Uses the sdk methods to generate a list of captions. @@ -123,7 +125,7 @@ def _get_captions( """ with telemetry.tool_context_manager(self._user_agent): params = self._prepare_params( - number_of_results=number_of_results, language=language + number_of_results=number_of_results, language=language, **kwargs ) captions = self.client.get_captions(image=image, **params) return captions @@ -224,7 +226,7 @@ def _generate( "{'type': 'image_url', 'image_url': {'image': }}" ) - captions = self._get_captions(image, **messages[0].additional_kwargs) + captions = self._get_captions(image, **messages[0].additional_kwargs, **kwargs) generations = [ ChatGeneration(message=AIMessage(content=caption)) for caption in captions @@ -287,7 +289,7 @@ def _generate( ) answers = self._ask_questions( - image=image, query=user_question, **messages[0].additional_kwargs + image=image, query=user_question, **messages[0].additional_kwargs, **kwargs ) generations = [ @@ -361,7 +363,7 @@ def _prepare_params(self, **kwargs: Any) -> Dict[str, Any]: mapping = {"number_of_results": "number_of_images"} for key, value in kwargs.items(): key = mapping.get(key, key) - if key in params and value is not None: + if value is not None: params[key] = value return {k: v for k, v in params.items() if v is not None} @@ -477,7 +479,7 @@ def _generate( ) image_str_list = self._generate_images( - prompt=user_query, **messages[0].additional_kwargs + prompt=user_query, **messages[0].additional_kwargs, **kwargs ) image_content_part_list = [ create_image_content_part(image_str=image_str) @@ -526,7 +528,10 @@ def _generate( ) image_str_list = self._edit_images( - image_str=image_str, prompt=user_query, **messages[0].additional_kwargs + image_str=image_str, + prompt=user_query, + **messages[0].additional_kwargs, + **kwargs, ) image_content_part_list = [ create_image_content_part(image_str=image_str)