From dce667873dc1180972d6351ee1b13a8bf0ddeb2a Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 19 Mar 2024 15:15:01 +0100 Subject: [PATCH 01/49] add model draft --- docs/source/en/_toctree.yml | 6 +- docs/source/en/model_doc/video_llava.md | 43 + src/transformers/__init__.py | 24 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 3 + .../models/auto/image_processing_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 2 + .../models/auto/processing_auto.py | 1 + .../models/auto/tokenization_auto.py | 1 + .../models/video_llava/__init__.py | 72 ++ .../video_llava/configuration_video_llava.py | 128 +++ .../convert_video_llava_weights_to_hf.py | 156 +++ .../image_processing_video_llava.py | 399 ++++++++ .../video_llava/modeling_video_llava.py | 943 ++++++++++++++++++ .../video_llava/processing_video_llava.py | 137 +++ tests/models/video_llava/__init__.py | 0 .../video_llava/test_modeling_video_llava.py | 375 +++++++ 17 files changed, 2290 insertions(+), 2 deletions(-) create mode 100644 docs/source/en/model_doc/video_llava.md create mode 100644 src/transformers/models/video_llava/__init__.py create mode 100644 src/transformers/models/video_llava/configuration_video_llava.py create mode 100644 src/transformers/models/video_llava/convert_video_llava_weights_to_hf.py create mode 100644 src/transformers/models/video_llava/image_processing_video_llava.py create mode 100644 src/transformers/models/video_llava/modeling_video_llava.py create mode 100644 src/transformers/models/video_llava/processing_video_llava.py create mode 100644 tests/models/video_llava/__init__.py create mode 100644 tests/models/video_llava/test_modeling_video_llava.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 47c7ff6602c0e1..78422784e00b92 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -134,7 +134,7 @@ - local: custom_tools title: Custom Tools and Prompts - local: troubleshooting - title: Troubleshoot + title: Troubleshoot - local: hf_quantizer title: Contribute new quantization method title: Developer guides @@ -693,7 +693,7 @@ title: VideoMAE - local: model_doc/vivit title: ViViT - title: Video models + title: Video models - isExpanded: false sections: - local: model_doc/align @@ -780,6 +780,8 @@ title: TVP - local: model_doc/udop title: UDOP + - local: model_doc/video_llava + title: VideoLlava - local: model_doc/vilt title: ViLT - local: model_doc/vipllava diff --git a/docs/source/en/model_doc/video_llava.md b/docs/source/en/model_doc/video_llava.md new file mode 100644 index 00000000000000..d2f8bc47283154 --- /dev/null +++ b/docs/source/en/model_doc/video_llava.md @@ -0,0 +1,43 @@ + + +# video_llava + +## Overview + +The video_llava model was proposed in []() by . + + +The abstract from the paper is the following: + +** + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +The original code can be found [here](). + + +## VideoLlavaConfig + +[[autodoc]] VideoLlavaConfig + +## VideoLlavaForConditionalGeneration + +[[autodoc]] VideoLlavaForConditionalGeneration + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 631276d3c19b3b..80f9b68ae0bc18 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -885,6 +885,10 @@ "UnivNetFeatureExtractor", ], "models.upernet": ["UperNetConfig"], + "models.video_llava": [ + "VIDEO_LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", + "VideoLlavaConfig", + ], "models.videomae": ["VIDEOMAE_PRETRAINED_CONFIG_ARCHIVE_MAP", "VideoMAEConfig"], "models.vilt": [ "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP", @@ -3510,6 +3514,15 @@ "UperNetPreTrainedModel", ] ) + _import_structure["models.video_llava"].extend( + [ + "VIDEO_LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST", + "VideoLlavaForConditionalGeneration", + "VideoLlavaImageProcessor", + "VideoLlavaPreTrainedModel", + "VideoLlavaProcessor", + ] + ) _import_structure["models.videomae"].extend( [ "VIDEOMAE_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -5732,6 +5745,10 @@ UnivNetFeatureExtractor, ) from .models.upernet import UperNetConfig + from .models.video_llava import ( + VIDEO_LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, + VideoLlavaConfig, + ) from .models.videomae import VIDEOMAE_PRETRAINED_CONFIG_ARCHIVE_MAP, VideoMAEConfig from .models.vilt import ( VILT_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -7972,6 +7989,13 @@ UperNetForSemanticSegmentation, UperNetPreTrainedModel, ) + from .models.video_llava import ( + VIDEO_LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST, + VideoLlavaForConditionalGeneration, + VideoLlavaImageProcessor, + VideoLlavaPreTrainedModel, + VideoLlavaProcessor, + ) from .models.videomae import ( VIDEOMAE_PRETRAINED_MODEL_ARCHIVE_LIST, VideoMAEForPreTraining, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index d0f3ba2688e207..6affffa43aa277 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -230,6 +230,7 @@ unispeech_sat, univnet, upernet, + video_llava, videomae, vilt, vipllava, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index cd5be302c42986..05c99a729f8d27 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -242,6 +242,7 @@ ("univnet", "UnivNetConfig"), ("upernet", "UperNetConfig"), ("van", "VanConfig"), + ("video_llava", "VideoLlavaConfig"), ("videomae", "VideoMAEConfig"), ("vilt", "ViltConfig"), ("vipllava", "VipLlavaConfig"), @@ -468,6 +469,7 @@ ("unispeech-sat", "UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("univnet", "UNIVNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("van", "VAN_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("video_llava", "VIDEO_LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("videomae", "VIDEOMAE_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("vilt", "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("vipllava", "VIPLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -737,6 +739,7 @@ ("univnet", "UnivNet"), ("upernet", "UPerNet"), ("van", "VAN"), + ("video_llava", "VideoLlava"), ("videomae", "VideoMAE"), ("vilt", "ViLT"), ("vipllava", "VipLlava"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 618732f23a75a4..36cfed1535983b 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -112,6 +112,7 @@ ("udop", "LayoutLMv3ImageProcessor"), ("upernet", "SegformerImageProcessor"), ("van", "ConvNextImageProcessor"), + ("video_llava", "VideoLlavaImageProcessor"), ("videomae", "VideoMAEImageProcessor"), ("vilt", "ViltImageProcessor"), ("vipllava", "CLIPImageProcessor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index d2cb6882cd351c..42e83de45b0f84 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -314,6 +314,7 @@ ("tvlt", "TvltForPreTraining"), ("unispeech", "UniSpeechForPreTraining"), ("unispeech-sat", "UniSpeechSatForPreTraining"), + ("video_llava", "VideoLlavaForConditionalGeneration"), ("videomae", "VideoMAEForPreTraining"), ("vipllava", "VipLlavaForConditionalGeneration"), ("visual_bert", "VisualBertForPreTraining"), @@ -676,6 +677,7 @@ ("kosmos-2", "Kosmos2ForConditionalGeneration"), ("llava", "LlavaForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"), + ("video_llava", "VideoLlavaForConditionalGeneration"), ("vipllava", "VipLlavaForConditionalGeneration"), ("vision-encoder-decoder", "VisionEncoderDecoderModel"), ] diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index e41e39e56eeea2..d0cde95bec144e 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -87,6 +87,7 @@ ("tvp", "TvpProcessor"), ("unispeech", "Wav2Vec2Processor"), ("unispeech-sat", "Wav2Vec2Processor"), + ("video_llava", "VideoLlavaProcessor"), ("vilt", "ViltProcessor"), ("vipllava", "LlavaProcessor"), ("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index ca337ec0272e3a..f20d3dd836e151 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -435,6 +435,7 @@ "T5TokenizerFast" if is_tokenizers_available() else None, ), ), + ("video_llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("vilt", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("vipllava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/video_llava/__init__.py b/src/transformers/models/video_llava/__init__.py new file mode 100644 index 00000000000000..02e90214de2a1f --- /dev/null +++ b/src/transformers/models/video_llava/__init__.py @@ -0,0 +1,72 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available + + +_import_structure = { + "configuration_video_llava": [ + "VIDEO_LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", + "VideoLlavaConfig", + ], + "image_processing_video_llava": ["VideoLlavaImageProcessor"], + "processing_video_llava": ["VideoLlavaProcessor"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_video_llava"] = [ + "VIDEO_LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST", + "VideoLlavaVisionTransformer", + "VideoLlavaPreTrainedModel", + "VideoLlavaForConditionalGeneration", + ] + +if TYPE_CHECKING: + from .configuration_video_llava import ( + VIDEO_LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, + VideoLlavaConfig, + ) + from .image_processing_video_llava import VideoLlavaProcessor + + try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .image_processing_video_llava import VideoLlavaImageProcessor + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_video_llava import ( + VIDEO_LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST, + VideoLlavaForConditionalGeneration, + VideoLlavaPreTrainedModel, + VideoLlavaVisionTransformer, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/video_llava/configuration_video_llava.py b/src/transformers/models/video_llava/configuration_video_llava.py new file mode 100644 index 00000000000000..b184834136319c --- /dev/null +++ b/src/transformers/models/video_llava/configuration_video_llava.py @@ -0,0 +1,128 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" VideoLlava model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + +VideoLlava_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "LanguageBind/Video-LLaVA-7B": "https://huggingface.co/LanguageBind/Video-LLaVA-7B/resolve/main/config.json", +} + + +class VideoLlavaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`VideoLlavaForConditionalGeneration`]. It is used to instantiate an + VideoLlava model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the `LanguageBind/Video-LLaVA-7B`. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`VideoLlavaVisionConfig`, *optional*): + Custom vision config or dict + text_config (`Union[AutoConfig, dict]`, *optional*): + The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`. + ignore_index (`int`, *optional*, defaults to -100): + The ignore index for the loss function. + image_token_index (`int`, *optional*, defaults to 32000): + The image token index to encode the image prompt. + projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function used by the multimodal projector. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the CLIP backbone. + vision_feature_layer (`int`, *optional*, defaults to -2): + The index of the layer to select the vision feature. + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the VideoLlava model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`~VideoLlavaForConditionalGeneration`] + + Example: + + ```python + >>> from transformers import VideoLlavaForConditionalGeneration, VideoLlavaConfig, CLIPVisionConfig, LlamaConfig + + >>> # Initializing a CLIP-vision config + >>> vision_config = CLIPVisionConfig() + + >>> # Initializing a Llama config + >>> text_config = LlamaConfig() + + >>> # Initializing a VideoLlava video_llava-1.5-7b style configuration + >>> configuration = VideoLlavaConfig(vision_config, text_config) + + >>> # Initializing a model from the video_llava-1.5-7b style configuration + >>> model = VideoLlavaForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "video_llava" + is_composition = False + + def __init__( + self, + vision_config=None, + text_config=None, + ignore_index=-100, + image_token_index=32000, + projector_hidden_act="gelu", + vision_feature_select_strategy="default", + vision_feature_layer=-2, + vocab_size=32000, + **kwargs, + ): + self.ignore_index = ignore_index + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + self.vocab_size = vocab_size + + self.vision_config = vision_config + + if isinstance(self.vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" + ) + self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + self.vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=224, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + self.vocab_size = self.vocab_size + + self.text_config = text_config + + if isinstance(self.text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + self.vocab_size = self.text_config.vocab_size + elif text_config is None: + self.text_config = CONFIG_MAPPING["llama"]() + + super().__init__(**kwargs) diff --git a/src/transformers/models/video_llava/convert_video_llava_weights_to_hf.py b/src/transformers/models/video_llava/convert_video_llava_weights_to_hf.py new file mode 100644 index 00000000000000..bee642440329d8 --- /dev/null +++ b/src/transformers/models/video_llava/convert_video_llava_weights_to_hf.py @@ -0,0 +1,156 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +import torch +from huggingface_hub import hf_hub_download + +from transformers import ( + AddedToken, + AutoConfig, + AutoTokenizer, + CLIPImageProcessor, + LlavaProcessor, + VideoLlavaConfig, + VideoLlavaForConditionalGeneration, +) + + +EPILOG_TXT = """Example: + python transformers/src/transformers/models/video_llava/convert_video_llava_weights_to_hf.py --text_model_id lmsys/vicuna-7b-v1.5 --vision_model_id openai/clip-vit-large-patch14 --output_hub_path org/video_llava-7b --old_state_dict_id LanguageBind/Video-LLaVA-7B + +Example for creating the old state dict file with Python: + + import torch + from video_llava.model.language_model.video_llava_llama import VideoLlavaLlamaForCausalLM + + # load model + kwargs = {"device_map": "auto", "torch_dtype": torch.float16} + model = VideoLlavaLlamaForCausalLM.from_pretrained("LanguageBind/Video-LLaVA-7B", low_cpu_mem_usage=True, **kwargs) + + # load vision tower + model.get_vision_tower().load_model() + + # Save state dict + torch.save(model.state_dict(), "tmp/hf_models/video_llava-7b/model_state_dict.bin") +""" + +KEYS_TO_MODIFY_MAPPING = { + "model.video_tower.video_tower": "vision_tower", + "model.mm_projector": "multi_modal_projector", + "model": "language_model.model", + "lm_head": "language_model.lm_head", + "multi_modal_projector.0": "multi_modal_projector.linear_1", + "multi_modal_projector.2": "multi_modal_projector.linear_2", +} + +KEYS_TO_IGNORE = ["model.image_tower.image_tower."] + + +def convert_state_dict_to_hf(state_dict): + new_state_dict = {} + for key, value in state_dict.items(): + if key.endswith(".inv_freq") or key in KEYS_TO_IGNORE: + continue + for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + + new_state_dict[key] = value + return new_state_dict + + +def convert_video_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id): + torch.set_default_dtype(torch.float16) + text_config = AutoConfig.from_pretrained(text_model_id) + + tokenizer = AutoTokenizer.from_pretrained(text_model_id) + tokenizer.add_tokens(AddedToken("", special=True, normalized=False), special_tokens=True) + tokenizer.add_special_tokens({"pad_token": ""}) + + image_processor = CLIPImageProcessor.from_pretrained(vision_model_id) + + processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor) + + config = VideoLlavaConfig(text_config=text_config) + config.pad_token_id = 32001 + + with torch.device("meta"): + model = VideoLlavaForConditionalGeneration(config) + + model_state_dict = set(model.state_dict().keys()) + + # Pad to 64 for performance reasons + pad_shape = 64 + state_dict_temp = "pytorch_model-0000{i}-of-00002.bin" + for shard in range(1, 3): + state_dict_path = hf_hub_download(old_state_dict_id, state_dict_temp.format(i=shard)) + state_dict = torch.load(state_dict_path, map_location="cpu") + state_dict = convert_state_dict_to_hf(state_dict) + model.load_state_dict(state_dict, strict=False, assign=True) + model_state_dict -= set(state_dict.keys()) + + if len(model_state_dict) > 0: + raise RuntimeError(f"Missing keys in state dict: {model_state_dict}") + + pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data + mu = torch.mean(pre_expansion_embeddings, dim=0).float() + n = pre_expansion_embeddings.size()[0] + sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n + dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma) + + # We add an image token so we resize the model + model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape) + model.language_model.model.embed_tokens.weight.data[32000:] = torch.stack( + tuple((dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[32000:].shape[0]))), + dim=0, + ) + model.language_model.lm_head.weight.data[32000:] = torch.stack( + tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[32000:].shape[0]))), + dim=0, + ) + + model.push_to_hub(output_hub_path) + processor.push_to_hub(output_hub_path) + + +def main(): + parser = argparse.ArgumentParser( + epilog=EPILOG_TXT, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--text_model_id", + help="Hub location of the text model", + ) + parser.add_argument( + "--vision_model_id", + help="Hub location of the vision model", + ) + parser.add_argument( + "--output_hub_path", + help="Location on the hub of the converted model", + ) + parser.add_argument( + "--old_state_dict_id", + help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`", + ) + args = parser.parse_args() + convert_video_llava_llama_to_hf( + args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id + ) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/video_llava/image_processing_video_llava.py b/src/transformers/models/video_llava/image_processing_video_llava.py new file mode 100644 index 00000000000000..465b541a0054ca --- /dev/null +++ b/src/transformers/models/video_llava/image_processing_video_llava.py @@ -0,0 +1,399 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for CLIP.""" + +from typing import Dict, List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + is_valid_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_kwargs, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_vision_available, logging + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + + +def make_batched(videos) -> List[List[ImageInput]]: + if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): + return videos + + elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): + return [videos] + + elif is_valid_image(videos): + return [[videos]] + + raise ValueError(f"Could not make batched video from {videos}") + + +class VideoLlavaImageProcessor(BaseImageProcessor): + r""" + Constructs a CLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = True, + crop_size: Dict[str, int] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 224} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + self._valid_processor_keys = [ + "images", + "do_resize", + "size", + "resample", + "do_center_crop", + "crop_size", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_convert_rgb", + "return_tensors", + "data_format", + "input_data_format", + ] + + # for backwards compatibility of KOSMOS-2 + if "use_square_size" in kwargs: + self.size = {"height": size["shortest_edge"], "width": size["shortest_edge"]} + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + default_to_square = True + if "shortest_edge" in size: + size = size["shortest_edge"] + default_to_square = False + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size( + image, + size=size, + default_to_square=default_to_square, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + videos: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_center_crop: bool = None, + crop_size: int = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + if (images is None) ^ (videos is not None): + raise ValueError("VideoLlava currently does not support both images and videos as input") + + if images is not None: + inputs = [make_list_of_images(images)] + elif videos is not None: + inputs = make_batched(videos) + + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + if not valid_images(inputs): + raise ValueError( + "Invalid input type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + pixel_values = [ + [ + self._preprocess_image( + image=frame, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_convert_rgb=do_convert_rgb, + data_format=data_format, + input_data_format=input_data_format, + ) + for frame in video + ] + for video in inputs + ] + + encoded_outputs = BatchFeature(data={"pixel_values": pixel_values}, tensor_type=return_tensors) + + return encoded_outputs + + def _preprocess_image( + self, + image: ImageInput = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_center_crop: bool = None, + crop_size: int = None, + do_convert_rgb: bool = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + # PIL RGBA images are converted to RGB + if do_convert_rgb: + image = convert_to_rgb(image) + + # All transformations expect numpy arrays. + image = to_numpy_array(image) + + if is_scaled_image(image) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images/video frames. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(image) + + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + return image diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py new file mode 100644 index 00000000000000..676b4dd7915d67 --- /dev/null +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -0,0 +1,943 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch VideoLlava model.""" +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ... import PreTrainedModel +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ..auto import AutoModelForCausalLM +from .configuration_video_llava import VideoLlavaConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "VideoLlavaConfig" + +VIDEO_LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "LanguageBind/Video-LLaVA-7B", + # See all video_llava models at https://huggingface.co/models?filter=video_llava +] + + +@dataclass +# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->VideoLlava +class VideoLlavaCausalLMOutputWithPast(ModelOutput): + """ + Base class for VideoLlava causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + + image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->VideoLlava +class VideoLlavaMultiModalProjector(nn.Module): + def __init__(self, config: VideoLlavaConfig): + super().__init__() + + self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +VIDEO_LLAVA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`VideoLlavaConfig`] or [`VideoLlavaVisionConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +# Copied from transformers.models.clip.modeling_clip.ClipVisionEmbeddings with Clip->VideoLlava +class VideoLlavaVisionEmbeddings(nn.Module): + def __init__(self, config: VideoLlavaConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.ClipVisionEmbeddings with Clip->VideoLlava +class VideLlavaVisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +# Copied from transformers.models.clip.modeling_clip.ClipMLP with Clip->VideoLlava +class VideoLlavaVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.ClipEncoderLayer with Clip->VideoLlava +class VideoLlavaVisionEncoderLayer(nn.Module): + def __init__(self, config: VideoLlavaConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = VideLlavaVisionAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = VideoLlavaVisionMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.clip.modeling_clip.ClipEncoder with Clip->VideoLlava +class VideoLlavaVisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`CLIPEncoderLayer`]. + + Args: + config: CLIPConfig + """ + + def __init__(self, config: VideoLlavaConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([VideoLlavaVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +# Copied from transformers.models.clip.modeling_clip.ClipVisionTransformer with Clip->VideoLlava +class VideoLlavaVisionTransformer(nn.Module): + def __init__(self, config: VideoLlavaConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = VideoLlavaVisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = VideoLlavaVisionEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=VideoLlavaConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # if video input, merge batch with frames dim and later reshape back + if pixel_values.dim() == 5: + batch_size, num_frames, channels, height, width = pixel_values.shape + pixel_values = pixel_values.reshape(batch_size * num_frames, channels, height, width) + else: + batch_size, _, _, _ = pixel_values.shape + num_frames = 1 + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + pooled_output = pooled_output.reshape(batch_size, num_frames, -1).mean(1) + encoder_outputs.hidden_states = [ + hidden_state.reshape(batch_size, num_frames, -1, self.config.hidden_size) + for hidden_state in encoder_outputs.hidden_states + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + VIDEO_LLAVA_START_DOCSTRING, +) +# Copied from transformers.models.llava.modeling_llava.LlavaPreTrainedModel with Llava->VideoLlava,llava->video_llava +class VideoLlavaPreTrainedModel(PreTrainedModel): + config_class = VideoLlavaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["VideoLlavaVisionAttention"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def _init_weights(self, module): + # important: this ported version of VideoLlava isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed - the original codebase + # https://github.com/PKU-YuanGroup/Video-LLaVA/tree/main should serve for that purpose + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def _supports_sdpa(self): + """ + Retrieve language_model's attribute to check whether the model supports + SDPA or not. + """ + return self.language_model._supports_sdpa + + +VIDEO_LLAVA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses + [`CLIPImageProcessor`] for processing images). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """The VideoLlava model which consists of a vision backbone and a language model.""", + VIDEO_LLAVA_START_DOCSTRING, +) +# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration with LLAVA->VideoLlava,Llava->VideoLlava,llava-hf/llava-1.5-7b-hf->LanguageBind/Video-LLaVA-7B +class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel): + def __init__(self, config: VideoLlavaConfig): + super().__init__(config) + self.vision_tower = VideoLlavaVisionTransformer(config.vision_config) + + self.multi_modal_projector = VideoLlavaMultiModalProjector(config) + self.vocab_size = config.vocab_size + self.language_model = AutoModelForCausalLM.from_config( + config.text_config, attn_implementation=config._attn_implementation + ) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def tie_weights(self): + return self.language_model.tie_weights() + + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: + model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + # update vocab size + self.config.text_config.vocab_size = model_embeds.num_embeddings + self.config.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + return model_embeds + + def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): + batch_size, num_images, num_image_patches, embed_dim = image_features.shape + batch_size, sequence_length = input_ids.shape + left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) + # 1. Create a mask to know where special image tokens are + special_image_token_mask = input_ids == self.config.image_token_index + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) + # Compute the maximum embed dimension + max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length + batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) + + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged image-text sequence. + # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. + # `torch.cumsum` computes how each image token shifts subsequent text token positions. + # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. + new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 + nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] + if left_padding: + new_token_positions += nb_image_pad[:, None] # offset for left padding + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + final_attention_mask = torch.zeros( + batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device + ) + if labels is not None: + final_labels = torch.full( + (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device + ) + # In case the Vision model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = inputs_embeds.device + batch_indices, non_image_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_image_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + attention_mask = attention_mask.to(target_device) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] + # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] + if labels is not None: + final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] + + # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling + image_to_overwrite = torch.all(final_embedding == 0, dim=-1) + image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + + if image_to_overwrite.sum() != image_features.shape[:-1].numel(): + raise ValueError( + f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" + f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." + ) + + final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) + final_attention_mask |= image_to_overwrite + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + + if labels is None: + final_labels = None + + return final_embedding, final_attention_mask, final_labels, position_ids + + @add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=VideoLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, VideoLlavaCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, VideoLlavaForConditionalGeneration + + >>> model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B") + >>> processor = AutoProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B") + + >>> prompt = "\nUSER: What's the content of the image?\nASSISTANT:" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner" + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if inputs_embeds is None: + # 1. Extra the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text and images + if pixel_values is not None and input_ids.shape[1] != 1: + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, :, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError( + f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" + ) + + image_features = self.multi_modal_projector(selected_image_feature) + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels + ) + if labels is None: + labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long) + else: + # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of + # generation with cache + if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + target_seqlen = first_layer_past_key_value.shape[-1] + 1 + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Filter out only the tokens that can be un-attended, this can happen + # if one uses VideoLlava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + shift_attention_mask = attention_mask[..., 1:] + shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return VideoLlavaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + else: + cache_length = past_length = past_key_values[0][0].shape[2] + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + elif self.config.image_token_index in input_ids: + input_ids = input_ids[:, input_ids.shape[1] - 1 :] + # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the + # older attention values, as their corresponding values are not part of the input. + if cache_length < past_length and attention_mask is not None: + attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "pixel_values": pixel_values, + } + ) + return model_inputs + + def _reorder_cache(self, *args, **kwargs): + return self.language_model._reorder_cache(*args, **kwargs) diff --git a/src/transformers/models/video_llava/processing_video_llava.py b/src/transformers/models/video_llava/processing_video_llava.py new file mode 100644 index 00000000000000..8f58a89d72124d --- /dev/null +++ b/src/transformers/models/video_llava/processing_video_llava.py @@ -0,0 +1,137 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Llava. +""" + + +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from ...utils import TensorType + + +class VideoLlavaProcessor(ProcessorMixin): + r""" + Constructs a VideoLlava processor which wraps a VideoLlava image processor and a Llava tokenizer into a single processor. + + [`VideoLlavaProcessor`] offers all the functionalities of [`VideoLlavaImageProcessor`] and [`LlamaTokenizerFast`]. See the + [`~VideoLlavaProcessor.__call__`] and [`~VideoLlavaProcessor.decode`] for more information. + + Args: + image_processor ([`VideoLlavaImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerFast`], *optional*): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "VideoLlavaImageProcessor" + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None): + super().__init__(image_processor, tokenizer) + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + images: ImageInput = None, + videos: ImageInput = None, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length=None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + VideoLlavaImageProcessor's [`~VideoLlavaImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + if images is not None or videos is not None: + pixel_values = self.image_processor(images, videos=videos, return_tensors=return_tensors)["pixel_values"] + else: + pixel_values = None + text_inputs = self.tokenizer( + text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length + ) + + return BatchFeature(data={**text_inputs, "pixel_values": pixel_values}) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) diff --git a/tests/models/video_llava/__init__.py b/tests/models/video_llava/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/video_llava/test_modeling_video_llava.py b/tests/models/video_llava/test_modeling_video_llava.py new file mode 100644 index 00000000000000..76122f9ea911de --- /dev/null +++ b/tests/models/video_llava/test_modeling_video_llava.py @@ -0,0 +1,375 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch VideoLlava model. """ + +import gc +import unittest + +import requests + +from transformers import ( + AutoProcessor, + VideoLlavaConfig, + VideoLlavaForConditionalGeneration, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch +else: + is_torch_greater_or_equal_than_2_0 = False + +if is_vision_available(): + from PIL import Image + + +class VideoLlavaVisionText2TextModelTester: + def __init__( + self, + parent, + ignore_index=-100, + image_token_index=0, + projector_hidden_act="gelu", + seq_length=7, + vision_feature_select_strategy="default", + vision_feature_layer=-1, + text_config={ + "model_type": "llama", + "seq_length": 7, + "is_training": True, + "use_input_mask": True, + "use_token_type_ids": False, + "use_labels": True, + "vocab_size": 99, + "hidden_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 37, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "type_vocab_size": 16, + "type_sequence_label_size": 2, + "initializer_range": 0.02, + "num_labels": 3, + "num_choices": 4, + "pad_token_id": 0, + }, + is_training=True, + vision_config={ + "batch_size": 12, + "image_size": 30, + "patch_size": 2, + "num_channels": 3, + "is_training": True, + "hidden_size": 32, + "projection_dim": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 37, + "dropout": 0.1, + "attention_dropout": 0.1, + "initializer_range": 0.02, + }, + ): + self.parent = parent + self.ignore_index = ignore_index + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + self.text_config = text_config + self.vision_config = vision_config + self.seq_length = seq_length + + self.num_hidden_layers = text_config["num_hidden_layers"] + self.vocab_size = text_config["vocab_size"] + self.hidden_size = text_config["hidden_size"] + self.num_attention_heads = text_config["num_attention_heads"] + self.is_training = is_training + + self.batch_size = 3 + self.num_channels = 3 + self.image_size = 336 + self.encoder_seq_length = 231 + + def get_config(self): + return VideoLlavaConfig( + text_config=self.text_config, + vision_config=self.vision_config, + ignore_index=self.ignore_index, + image_token_index=self.image_token_index, + projector_hidden_act=self.projector_hidden_act, + vision_feature_select_strategy=self.vision_feature_select_strategy, + vision_feature_layer=self.vision_feature_layer, + ) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor( + [ + self.batch_size, + self.vision_config["num_channels"], + self.vision_config["image_size"], + self.vision_config["image_size"], + ] + ) + config = self.get_config() + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 + attention_mask = input_ids.ne(1).to(torch_device) + # we are giving 3 images let's make sure we pass in 3 image tokens + input_ids[:, 1] = config.image_token_index + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase): + """ + Model tester for `VideoLlavaForConditionalGeneration`. + """ + + all_model_classes = (VideoLlavaForConditionalGeneration,) if is_torch_available() else () + fx_compatible = False + test_pruning = False + test_resize_embeddings = True + test_head_masking = False + + def setUp(self): + self.model_tester = VideoLlavaVisionText2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=VideoLlavaConfig, has_text_modality=False) + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + +@require_torch +class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase): + def setUp(self): + self.processor = AutoProcessor.from_pretrained("video_llava-hf/bakVideoLlava-v1-hf") + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + @slow + @require_bitsandbytes + def test_small_model_integration_test(self): + # Let' s make sure we test the preprocessing to replace what is used + model = VideoLlavaForConditionalGeneration.from_pretrained( + "video_llava-hf/bakVideoLlava-v1-hf", load_in_4bit=True + ) + + prompt = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:" + image_file = "https://video_llava-vl.github.io/static/images/view.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = self.processor(prompt, raw_image, return_tensors="pt") + + EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip + self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS)) + + output = model.generate(**inputs, max_new_tokens=20) + EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip + + self.assertEqual( + self.processor.decode(output[0], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_llama(self): + # Let' s make sure we test the preprocessing to replace what is used + model_id = "LanguageBind/Video-LLaVA-7B" + + model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B", load_in_4bit=True) + processor = AutoProcessor.from_pretrained(model_id) + + prompt = "USER: \nWhat are the things I should be cautious about when I visit this place?\nASSISTANT:" + image_file = "https://video_llava-vl.github.io/static/images/view.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) + + output = model.generate(**inputs, max_new_tokens=900, do_sample=False) + EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the presence of wildlife, such as birds or fish, and avoid disturbing their natural habitats. Lastly, be aware of any local regulations or guidelines for the use of the pier, as some areas may be restricted or prohibited for certain activities." # fmt: skip + + self.assertEqual( + processor.decode(output[0], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_llama_batched(self): + # Let' s make sure we test the preprocessing to replace what is used + model_id = "LanguageBind/Video-LLaVA-7B" + + model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B", load_in_4bit=True) + processor = AutoProcessor.from_pretrained(model_id) + + prompts = [ + "USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", + "USER: \nWhat is this?\nASSISTANT:", + ] + image1 = Image.open(requests.get("https://video_llava-vl.github.io/static/images/view.jpg", stream=True).raw) + image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) + + inputs = processor(prompts, images=[image1, image2], return_tensors="pt", padding=True) + + output = model.generate(**inputs, max_new_tokens=20) + + EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip + + self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_batch(self): + # Let' s make sure we test the preprocessing to replace what is used + model = VideoLlavaForConditionalGeneration.from_pretrained( + "video_llava-hf/bakVideoLlava-v1-hf", load_in_4bit=True + ) + # The first batch is longer in terms of text, but only has 1 image. The second batch will be padded in text, but the first will be padded because images take more space!. + prompts = [ + "USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", + "USER: \nWhat is this?\nASSISTANT:", + ] + image1 = Image.open(requests.get("https://video_llava-vl.github.io/static/images/view.jpg", stream=True).raw) + image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) + + inputs = self.processor(prompts, images=[image1, image2], return_tensors="pt", padding=True) + + output = model.generate(**inputs, max_new_tokens=20) + + EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring along', 'USER: \nWhat is this?\nASSISTANT: Cats'] # fmt: skip + self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) + + @slow + @require_bitsandbytes + def test_small_model_integration_test_llama_batched_regression(self): + # Let' s make sure we test the preprocessing to replace what is used + model_id = "LanguageBind/Video-LLaVA-7B" + + # Multi-image & multi-prompt (e.g. 3 images and 2 prompts now fails with SDPA, this tests if "eager" works as before) + model = VideoLlavaForConditionalGeneration.from_pretrained( + "LanguageBind/Video-LLaVA-7B", load_in_4bit=True, attn_implementation="eager" + ) + processor = AutoProcessor.from_pretrained(model_id, pad_token="") + + prompts = [ + "USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", + "USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT:", + ] + image1 = Image.open(requests.get("https://video_llava-vl.github.io/static/images/view.jpg", stream=True).raw) + image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) + + inputs = processor(prompts, images=[image1, image2, image1], return_tensors="pt", padding=True) + + output = model.generate(**inputs, max_new_tokens=20) + + EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this serene location, one should be cautious about the weather conditions and potential', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip + + self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) + + @slow + @require_bitsandbytes + def test_video_llava_index_error_bug(self): + # This is a reproducer of https://github.com/huggingface/transformers/pull/28032 and makes sure it does not happen anymore + # Please refer to that PR, or specifically https://github.com/huggingface/transformers/pull/28032#issuecomment-1860650043 for + # more details + model_id = "LanguageBind/Video-LLaVA-7B" + model = VideoLlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) + + processor = AutoProcessor.from_pretrained(model_id) + + # Simulate a super long prompt + user_prompt = "Describe the image:?\n" * 200 + prompt = f"USER: \n{user_prompt}ASSISTANT:" + image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) + + # Make sure that `generate` works + _ = model.generate(**inputs, max_new_tokens=20) + + @slow + @require_torch_gpu + def test_video_llava_merge_inputs_error_bug(self): + # This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore + model_id = "LanguageBind/Video-LLaVA-7B" + model = VideoLlavaForConditionalGeneration.from_pretrained( + model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True + ).to(torch_device) + + # Simulate some user inputs + pixel_values = torch.randn( + (2, 3, 336, 336), + dtype=torch.float, + device=torch_device, + ) + input_ids = torch.tensor( + [ + [32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900], + [1, 15043, 7084, 29901, 29871, 32000, 29871, 13, 7900], + ], + dtype=torch.long, + device=torch_device, + ) + attention_mask = torch.tensor( + [[0, 0, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]], + dtype=torch.long, + device=torch_device, + ) + + # Make sure that the loss is properly computed + loss = model( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + labels=input_ids, + ).loss + loss.backward() From 72626dff85a1673e71751a018c41e0e6c08061d5 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 20 Mar 2024 14:03:14 +0100 Subject: [PATCH 02/49] update docstring --- .../image_processing_video_llava.py | 15 +++-- .../video_llava/modeling_video_llava.py | 66 +++++++++++-------- .../video_llava/processing_video_llava.py | 8 ++- 3 files changed, 53 insertions(+), 36 deletions(-) diff --git a/src/transformers/models/video_llava/image_processing_video_llava.py b/src/transformers/models/video_llava/image_processing_video_llava.py index 465b541a0054ca..e30ab110739e6a 100644 --- a/src/transformers/models/video_llava/image_processing_video_llava.py +++ b/src/transformers/models/video_llava/image_processing_video_llava.py @@ -50,15 +50,15 @@ import PIL -def make_batched(videos) -> List[List[ImageInput]]: +def make_batched_videos(videos) -> List[List[ImageInput]]: if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): return videos - elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): - return [videos] + elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]) and len(videos[0].shape) == 4: + return [list(video) for video in videos] - elif is_valid_image(videos): - return [[videos]] + elif is_valid_image(videos) and len(videos.shape) == 4: + return [list(videos)] raise ValueError(f"Could not make batched video from {videos}") @@ -234,6 +234,9 @@ def preprocess( images (`ImageInput`): Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. + videos (`ImageInput`): + Video frames to preprocess. Expects a single or batch of video frames with pixel values ranging from 0 + to 255. If passing in frames with pixel values between 0 and 1, set `do_rescale=False`. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): @@ -298,7 +301,7 @@ def preprocess( if images is not None: inputs = [make_list_of_images(images)] elif videos is not None: - inputs = make_batched(videos) + inputs = make_batched_videos(videos) validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 676b4dd7915d67..d5b55e3a209e12 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -72,11 +72,6 @@ class VideoLlavaCausalLMOutputWithPast(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): - Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, - sequence_length, hidden_size)`. - - image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver """ loss: Optional[torch.FloatTensor] = None @@ -84,7 +79,6 @@ class VideoLlavaCausalLMOutputWithPast(ModelOutput): past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None - image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None # Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->VideoLlava @@ -156,7 +150,7 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: return embeddings -# Copied from transformers.models.clip.modeling_clip.ClipVisionEmbeddings with Clip->VideoLlava +# Copied from transformers.models.clip.modeling_clip.ClipVisionAttention with Clip->VideoLlava class VideLlavaVisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -261,7 +255,7 @@ def forward( return attn_output, attn_weights_reshaped -# Copied from transformers.models.clip.modeling_clip.ClipMLP with Clip->VideoLlava +# Copied from transformers.models.clip.modeling_clip.ClipMLP with Clip -> VideoLlavaVision class VideoLlavaVisionMLP(nn.Module): def __init__(self, config): super().__init__() @@ -277,7 +271,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -# Copied from transformers.models.clip.modeling_clip.ClipEncoderLayer with Clip->VideoLlava +# Copied from transformers.models.clip.modeling_clip.ClipEncoderLayer with Clip->VideoLlavaVision class VideoLlavaVisionEncoderLayer(nn.Module): def __init__(self, config: VideoLlavaConfig): super().__init__() @@ -328,14 +322,14 @@ def forward( return outputs -# Copied from transformers.models.clip.modeling_clip.ClipEncoder with Clip->VideoLlava +# Copied from transformers.models.clip.modeling_clip.ClipEncoder with Clip->VideoLlavaVision class VideoLlavaVisionEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a - [`CLIPEncoderLayer`]. + [`VideoLlavaVisionEncoderLayer`]. Args: - config: CLIPConfig + config: VideoLlavaConfig """ def __init__(self, config: VideoLlavaConfig): @@ -426,7 +420,7 @@ def forward( ) -# Copied from transformers.models.clip.modeling_clip.ClipVisionTransformer with Clip->VideoLlava +# Copied from transformers.models.clip.modeling_clip.ClipVisionTransformer with Clip -> VideoLlava class VideoLlavaVisionTransformer(nn.Module): def __init__(self, config: VideoLlavaConfig): super().__init__() @@ -499,7 +493,6 @@ def forward( @add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", VIDEO_LLAVA_START_DOCSTRING, ) # Copied from transformers.models.llava.modeling_llava.LlavaPreTrainedModel with Llava->VideoLlava,llava->video_llava @@ -554,8 +547,8 @@ def _supports_sdpa(self): [What are input IDs?](../glossary#input-ids) pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses - [`CLIPImageProcessor`] for processing images). + [`AutoImageProcessor`]. See [`VideoLlavaImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses + [`VideoLlavaImageProcessor`] for processing images). attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: @@ -684,10 +677,6 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in final_attention_mask = torch.zeros( batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device ) - if labels is not None: - final_labels = torch.full( - (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device - ) # In case the Vision model or the Language model has been offloaded to CPU, we need to manually # set the corresponding tensors into their correct target device. target_device = inputs_embeds.device @@ -703,7 +692,12 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] if labels is not None: + final_labels = torch.full( + (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device + ) final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] + else: + final_labels = None # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling image_to_overwrite = torch.all(final_embedding == 0, dim=-1) @@ -719,9 +713,6 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in final_attention_mask |= image_to_overwrite position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) - if labels is None: - final_labels = None - return final_embedding, final_attention_mask, final_labels, position_ids @add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING) @@ -756,21 +747,38 @@ def forward( ```python >>> from PIL import Image >>> import requests + >>> from decord import VideoReader + >>> from huggingface_hub import hf_hub_download >>> from transformers import AutoProcessor, VideoLlavaForConditionalGeneration >>> model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B") >>> processor = AutoProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B") - >>> prompt = "\nUSER: What's the content of the image?\nASSISTANT:" - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) + >>> prompt = "USER: Why is this video funny? ASSISTANT:" + >>> video_path = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset") + >>> vr = VideoReader(uri=video_path, height=224, width=224) - >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + >>> # sample uniformly 8 frames from the video + >>> indices = np.arange(0, len(vr), len(vr) / 8).astype(int) + >>> frames = vr.get_batch(indices).asnumpy() + + >>> inputs = processor(text=prompt, videos=list(clip), return_tensors="pt") >>> # Generate - >>> generate_ids = model.generate(**inputs, max_length=30) + >>> generate_ids = model.generate(**inputs, max_length=80) >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner" + 'USER: Why is this video funny? ASSISTANT: The video is funny because the baby is sitting on the bed and reading a book, which is an unusual and amusing sight.Ъ' + + >>> # to generate from image + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> prompt = "USER: How many cats are there in the image? ASSISTANT:" + >>> inputs = processor(text=prompt, videos=list(clip), return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0] + 'USER: How many cats are there in the image? ASSISTANT: There are two cats in the image' ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions diff --git a/src/transformers/models/video_llava/processing_video_llava.py b/src/transformers/models/video_llava/processing_video_llava.py index 8f58a89d72124d..d285d41ef064d8 100644 --- a/src/transformers/models/video_llava/processing_video_llava.py +++ b/src/transformers/models/video_llava/processing_video_llava.py @@ -73,6 +73,10 @@ def __call__( The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. + videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + Video frames to preprocess. Expects a single or batch of video frames in NumPy array or PyTorch + tensor. Each video should be of shape (T, C, H, W), where T is number of frames, C is + number of channels, H and W are image height and width. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: @@ -104,7 +108,9 @@ def __call__( - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ if images is not None or videos is not None: - pixel_values = self.image_processor(images, videos=videos, return_tensors=return_tensors)["pixel_values"] + pixel_values = self.image_processor(images=images, videos=videos, return_tensors=return_tensors)[ + "pixel_values" + ] else: pixel_values = None text_inputs = self.tokenizer( From 8cca7318b00c7c70388397688ad2cb2ef6d247a0 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 20 Mar 2024 14:03:19 +0100 Subject: [PATCH 03/49] add tests --- .../video_llava/test_modeling_video_llava.py | 202 +++++++++++------- 1 file changed, 124 insertions(+), 78 deletions(-) diff --git a/tests/models/video_llava/test_modeling_video_llava.py b/tests/models/video_llava/test_modeling_video_llava.py index 76122f9ea911de..fde4397d21b2ab 100644 --- a/tests/models/video_llava/test_modeling_video_llava.py +++ b/tests/models/video_llava/test_modeling_video_llava.py @@ -17,17 +17,19 @@ import gc import unittest -import requests +import numpy as np +from huggingface_hub import hf_hub_download from transformers import ( - AutoProcessor, VideoLlavaConfig, VideoLlavaForConditionalGeneration, + VideoLlavaProcessor, is_torch_available, is_vision_available, ) from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device +from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -38,7 +40,7 @@ is_torch_greater_or_equal_than_2_0 = False if is_vision_available(): - from PIL import Image + pass class VideoLlavaVisionText2TextModelTester: @@ -49,6 +51,7 @@ def __init__( image_token_index=0, projector_hidden_act="gelu", seq_length=7, + num_frames=2, vision_feature_select_strategy="default", vision_feature_layer=-1, text_config={ @@ -100,6 +103,7 @@ def __init__( self.text_config = text_config self.vision_config = vision_config self.seq_length = seq_length + self.num_frames = num_frames self.num_hidden_layers = text_config["num_hidden_layers"] self.vocab_size = text_config["vocab_size"] @@ -110,7 +114,7 @@ def __init__( self.batch_size = 3 self.num_channels = 3 self.image_size = 336 - self.encoder_seq_length = 231 + self.encoder_seq_length = 455 def get_config(self): return VideoLlavaConfig( @@ -127,6 +131,7 @@ def prepare_config_and_inputs(self): pixel_values = floats_tensor( [ self.batch_size, + self.num_frames, self.vision_config["num_channels"], self.vision_config["image_size"], self.vision_config["image_size"], @@ -141,8 +146,8 @@ def prepare_config_and_inputs_for_common(self): config, pixel_values = config_and_inputs input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 attention_mask = input_ids.ne(1).to(torch_device) - # we are giving 3 images let's make sure we pass in 3 image tokens - input_ids[:, 1] = config.image_token_index + # we are giving 3 videos, each has 2 frames let's make sure we pass in 3 * 2 image tokens + input_ids[:, :2] = config.image_token_index inputs_dict = { "pixel_values": pixel_values, "input_ids": input_ids, @@ -152,7 +157,7 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase): +class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): """ Model tester for `VideoLlavaForConditionalGeneration`. """ @@ -189,7 +194,7 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): @require_torch class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase): def setUp(self): - self.processor = AutoProcessor.from_pretrained("video_llava-hf/bakVideoLlava-v1-hf") + self.processor = VideoLlavaProcessor.from_pretrained("RaushanTurganbay/video-llava-7b-hf") def tearDown(self): gc.collect() @@ -200,19 +205,21 @@ def tearDown(self): def test_small_model_integration_test(self): # Let' s make sure we test the preprocessing to replace what is used model = VideoLlavaForConditionalGeneration.from_pretrained( - "video_llava-hf/bakVideoLlava-v1-hf", load_in_4bit=True + "RaushanTurganbay/video-llava-7b-hf", load_in_4bit=True ) - prompt = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:" - image_file = "https://video_llava-vl.github.io/static/images/view.jpg" - raw_image = Image.open(requests.get(image_file, stream=True).raw) - inputs = self.processor(prompt, raw_image, return_tensors="pt") + prompt = "USER: Why is this video funny? ASSISTANT:" + video_file = hf_hub_download( + repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset" + ) + video_file = np.load(video_file) + inputs = self.processor(prompt, videos=video_file, return_tensors="pt") - EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip + EXPECTED_INPUT_IDS = torch.tensor([[1, 3148, 1001, 29901, 29871, 32000, 32000, 32000, 32000, 32000, 32000, 32000, 32000, 3750, 338, 445, 4863, 2090, 1460, 29973, 319, 1799, 9047, 13566, 29901]]) # fmt: skip self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS)) - output = model.generate(**inputs, max_new_tokens=20) - EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip + output = model.generate(**inputs, do_sample=False, max_new_tokens=20) + EXPECTED_DECODED_TEXT = "USER: Why is this video funny? ASSISTANT: The video is funny because the baby is playing with a Wii remote while sitting on a bed" # fmt: skip self.assertEqual( self.processor.decode(output[0], skip_special_tokens=True), @@ -223,18 +230,28 @@ def test_small_model_integration_test(self): @require_bitsandbytes def test_small_model_integration_test_llama(self): # Let' s make sure we test the preprocessing to replace what is used - model_id = "LanguageBind/Video-LLaVA-7B" + model_id = "RaushanTurganbay/video-llava-7b-hf" - model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B", load_in_4bit=True) - processor = AutoProcessor.from_pretrained(model_id) + model = VideoLlavaForConditionalGeneration.from_pretrained( + "RaushanTurganbay/video-llava-7b-hf", load_in_4bit=True + ) + processor = VideoLlavaProcessor.from_pretrained(model_id) - prompt = "USER: \nWhat are the things I should be cautious about when I visit this place?\nASSISTANT:" - image_file = "https://video_llava-vl.github.io/static/images/view.jpg" - raw_image = Image.open(requests.get(image_file, stream=True).raw) - inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) + prompt = ( + "USER: Describe the video in details. ASSISTANT:" + ) + video_file = hf_hub_download( + repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset" + ) + video_file = np.load(video_file) + inputs = self.processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16) output = model.generate(**inputs, max_new_tokens=900, do_sample=False) - EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the presence of wildlife, such as birds or fish, and avoid disturbing their natural habitats. Lastly, be aware of any local regulations or guidelines for the use of the pier, as some areas may be restricted or prohibited for certain activities." # fmt: skip + EXPECTED_DECODED_TEXT = "USER: Describe the video in details. ASSISTANT: The video features a young child sitting on a bed, holding a book and reading it. " \ + "The child appears to be enjoying the book, as they are fully engaged in the reading process. The bed is located in a bedroom, and there is a chair nearby. " \ + "The child is wearing a light blue shirt and pink pants, and they have glasses on. The room is well-lit, and there is a clock on the wall. The child seems " \ + "to be in a comfortable and relaxed environment, which is conducive to reading and learning. Overall, the video captures a heartwarming moment of a child " \ + "engaging in a simple yet essential activity, which is reading." # fmt: skip self.assertEqual( processor.decode(output[0], skip_special_tokens=True), @@ -245,72 +262,68 @@ def test_small_model_integration_test_llama(self): @require_bitsandbytes def test_small_model_integration_test_llama_batched(self): # Let' s make sure we test the preprocessing to replace what is used - model_id = "LanguageBind/Video-LLaVA-7B" + model_id = "RaushanTurganbay/video-llava-7b-hf" - model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B", load_in_4bit=True) - processor = AutoProcessor.from_pretrained(model_id) + model = VideoLlavaForConditionalGeneration.from_pretrained( + "RaushanTurganbay/video-llava-7b-hf", load_in_4bit=True + ) + processor = VideoLlavaProcessor.from_pretrained(model_id) + processor.tokenizer.padding_side = "left" prompts = [ - "USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", - "USER: \nWhat is this?\nASSISTANT:", + "USER: What is the baby doing? ASSISTANT:", + "USER: Who is sitting next to the woman? ASSISTANT:", ] - image1 = Image.open(requests.get("https://video_llava-vl.github.io/static/images/view.jpg", stream=True).raw) - image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) + video_1 = np.load( + hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset") + ) + video_2 = np.load( + hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo_2.npy", repo_type="dataset") + ) - inputs = processor(prompts, images=[image1, image2], return_tensors="pt", padding=True) + inputs = processor(prompts, videos=[video_1, video_2], return_tensors="pt", padding=True) output = model.generate(**inputs, max_new_tokens=20) - EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: The image features two cats lying down on a pink couch. One cat is located on'] # fmt: skip + EXPECTED_DECODED_TEXT = [ + 'USER: What is the baby doing? ASSISTANT: The baby is sitting on a bed and reading a book.Ъ', + 'USER: Who is sitting next to the woman? ASSISTANT: A small dog is sitting next to the woman.Ъ' + ] # fmt: skip self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) - @slow - @require_bitsandbytes - def test_small_model_integration_test_batch(self): - # Let' s make sure we test the preprocessing to replace what is used - model = VideoLlavaForConditionalGeneration.from_pretrained( - "video_llava-hf/bakVideoLlava-v1-hf", load_in_4bit=True - ) - # The first batch is longer in terms of text, but only has 1 image. The second batch will be padded in text, but the first will be padded because images take more space!. - prompts = [ - "USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", - "USER: \nWhat is this?\nASSISTANT:", - ] - image1 = Image.open(requests.get("https://video_llava-vl.github.io/static/images/view.jpg", stream=True).raw) - image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) - - inputs = self.processor(prompts, images=[image1, image2], return_tensors="pt", padding=True) - - output = model.generate(**inputs, max_new_tokens=20) - - EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring along', 'USER: \nWhat is this?\nASSISTANT: Cats'] # fmt: skip - self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) - @slow @require_bitsandbytes def test_small_model_integration_test_llama_batched_regression(self): # Let' s make sure we test the preprocessing to replace what is used - model_id = "LanguageBind/Video-LLaVA-7B" + model_id = "RaushanTurganbay/video-llava-7b-hf" # Multi-image & multi-prompt (e.g. 3 images and 2 prompts now fails with SDPA, this tests if "eager" works as before) model = VideoLlavaForConditionalGeneration.from_pretrained( - "LanguageBind/Video-LLaVA-7B", load_in_4bit=True, attn_implementation="eager" + "RaushanTurganbay/video-llava-7b-hf", load_in_4bit=True, attn_implementation="eager" ) - processor = AutoProcessor.from_pretrained(model_id, pad_token="") + processor = VideoLlavaProcessor.from_pretrained(model_id, pad_token="") + processor.tokenizer.padding_side = "left" prompts = [ - "USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:", - "USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT:", + "USER: What is the baby doing? ASSISTANT:", + "USER: Who is sitting next to the woman? ASSISTANT: A small dog is sitting next to the woman. USER: What about this video? ASSITANT:", ] - image1 = Image.open(requests.get("https://video_llava-vl.github.io/static/images/view.jpg", stream=True).raw) - image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) + video_1 = np.load( + hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset") + ) + video_2 = np.load( + hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo_2.npy", repo_type="dataset") + ) - inputs = processor(prompts, images=[image1, image2, image1], return_tensors="pt", padding=True) + inputs = processor(prompts, videos=[video_1, video_2, video_1], return_tensors="pt", padding=True) output = model.generate(**inputs, max_new_tokens=20) - EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this serene location, one should be cautious about the weather conditions and potential', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip + EXPECTED_DECODED_TEXT = [ + 'USER: What is the baby doing? ASSISTANT: The baby is sitting on a bed and reading a book.Ћ', + 'USER: Who is sitting next to the woman? ASSISTANT: A small dog is sitting next to the woman. USER: What about this video? ASSITANT: The video shows a baby sitting on a bed, reading a book. The baby is wearing glass' + ] # fmt: skip self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) @@ -320,18 +333,17 @@ def test_video_llava_index_error_bug(self): # This is a reproducer of https://github.com/huggingface/transformers/pull/28032 and makes sure it does not happen anymore # Please refer to that PR, or specifically https://github.com/huggingface/transformers/pull/28032#issuecomment-1860650043 for # more details - model_id = "LanguageBind/Video-LLaVA-7B" + model_id = "RaushanTurganbay/video-llava-7b-hf" model = VideoLlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) - processor = AutoProcessor.from_pretrained(model_id) - # Simulate a super long prompt - user_prompt = "Describe the image:?\n" * 200 - prompt = f"USER: \n{user_prompt}ASSISTANT:" - image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" - - raw_image = Image.open(requests.get(image_file, stream=True).raw) - inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) + user_prompt = "Describe the video:?\n" * 200 + prompt = f"USER: {user_prompt}ASSISTANT:" + video_file = hf_hub_download( + repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset" + ) + video_file = np.load(video_file) + inputs = self.processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16) # Make sure that `generate` works _ = model.generate(**inputs, max_new_tokens=20) @@ -340,27 +352,61 @@ def test_video_llava_index_error_bug(self): @require_torch_gpu def test_video_llava_merge_inputs_error_bug(self): # This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore - model_id = "LanguageBind/Video-LLaVA-7B" + model_id = "RaushanTurganbay/video-llava-7b-hf" model = VideoLlavaForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True ).to(torch_device) # Simulate some user inputs pixel_values = torch.randn( - (2, 3, 336, 336), + (2, 8, 3, 224, 224), dtype=torch.float, device=torch_device, ) input_ids = torch.tensor( [ - [32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900], - [1, 15043, 7084, 29901, 29871, 32000, 29871, 13, 7900], + [ + 32001, + 32001, + 1, + 15043, + 7084, + 32000, + 32000, + 32000, + 32000, + 32000, + 32000, + 32000, + 32000, + 29871, + 13, + 7900, + ], + [ + 1, + 15043, + 7084, + 29901, + 29871, + 32000, + 32000, + 32000, + 32000, + 32000, + 32000, + 32000, + 32000, + 29871, + 13, + 7900, + ], ], dtype=torch.long, device=torch_device, ) attention_mask = torch.tensor( - [[0, 0, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.long, device=torch_device, ) From 4ea4f70ed5353fdb43e194ea3e6d731815f43170 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 20 Mar 2024 16:58:53 +0100 Subject: [PATCH 04/49] support image and video as input --- .../convert_video_llava_weights_to_hf.py | 14 +- .../image_processing_video_llava.py | 101 ++++++++--- .../video_llava/modeling_video_llava.py | 158 +++++++++++++----- .../video_llava/processing_video_llava.py | 15 +- .../video_llava/test_modeling_video_llava.py | 14 +- 5 files changed, 218 insertions(+), 84 deletions(-) diff --git a/src/transformers/models/video_llava/convert_video_llava_weights_to_hf.py b/src/transformers/models/video_llava/convert_video_llava_weights_to_hf.py index bee642440329d8..ecd812fd843cd3 100644 --- a/src/transformers/models/video_llava/convert_video_llava_weights_to_hf.py +++ b/src/transformers/models/video_llava/convert_video_llava_weights_to_hf.py @@ -20,10 +20,10 @@ AddedToken, AutoConfig, AutoTokenizer, - CLIPImageProcessor, - LlavaProcessor, VideoLlavaConfig, VideoLlavaForConditionalGeneration, + VideoLlavaImageProcessor, + VideoLlavaProcessor, ) @@ -47,7 +47,8 @@ """ KEYS_TO_MODIFY_MAPPING = { - "model.video_tower.video_tower": "vision_tower", + "model.video_tower.video_tower": "video_tower", + "model.image_tower.image_tower": "image_tower", "model.mm_projector": "multi_modal_projector", "model": "language_model.model", "lm_head": "language_model.lm_head", @@ -55,7 +56,7 @@ "multi_modal_projector.2": "multi_modal_projector.linear_2", } -KEYS_TO_IGNORE = ["model.image_tower.image_tower."] +KEYS_TO_IGNORE = [] def convert_state_dict_to_hf(state_dict): @@ -78,10 +79,11 @@ def convert_video_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_p tokenizer = AutoTokenizer.from_pretrained(text_model_id) tokenizer.add_tokens(AddedToken("", special=True, normalized=False), special_tokens=True) tokenizer.add_special_tokens({"pad_token": ""}) + tokenizer.padding_side = "left" - image_processor = CLIPImageProcessor.from_pretrained(vision_model_id) + image_processor = VideoLlavaImageProcessor.from_pretrained(vision_model_id) - processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor) + processor = VideoLlavaProcessor(tokenizer=tokenizer, image_processor=image_processor) config = VideoLlavaConfig(text_config=text_config) config.pad_token_id = 32001 diff --git a/src/transformers/models/video_llava/image_processing_video_llava.py b/src/transformers/models/video_llava/image_processing_video_llava.py index e30ab110739e6a..419e48eb5e9ac1 100644 --- a/src/transformers/models/video_llava/image_processing_video_llava.py +++ b/src/transformers/models/video_llava/image_processing_video_llava.py @@ -40,7 +40,7 @@ validate_kwargs, validate_preprocess_arguments, ) -from ...utils import TensorType, is_vision_available, logging +from ...utils import TensorType, is_torch_available, is_vision_available, logging logger = logging.get_logger(__name__) @@ -49,6 +49,9 @@ if is_vision_available(): import PIL +if is_torch_available(): + import torch + def make_batched_videos(videos) -> List[List[ImageInput]]: if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): @@ -209,8 +212,8 @@ def resize( def preprocess( self, - images: ImageInput, - videos: ImageInput, + images, + visual_inputs: List[ImageInput], do_resize: bool = None, size: Dict[str, int] = None, resample: PILImageResampling = None, @@ -231,12 +234,9 @@ def preprocess( Preprocess an image or batch of images. Args: - images (`ImageInput`): - Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + visual_inputs (`ImageInput`): + List of images and/or videos to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. - videos (`ImageInput`): - Video frames to preprocess. Expects a single or batch of video frames with pixel values ranging from 0 - to 255. If passing in frames with pixel values between 0 and 1, set `do_rescale=False`. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`Dict[str, int]`, *optional*, defaults to `self.size`): @@ -295,26 +295,67 @@ def preprocess( image_std = image_std if image_std is not None else self.image_std do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - if (images is None) ^ (videos is not None): - raise ValueError("VideoLlava currently does not support both images and videos as input") + if not isinstance(visual_inputs, list): + visual_inputs = [visual_inputs] + + visual_positions = [-1] + img_count = 0 + images, videos = [], [] + for visual in visual_inputs: + if not isinstance(visual, PIL.Image.Image) and len(visual.shape) == 4: + visual_positions.extend(np.arange(1, 9) + max(visual_positions)) + videos.append(visual) + else: + visual_positions.append(-1) + img_count += 1 + images.append(visual) + img_positions = torch.arange( + max(visual_positions) + 1, max(visual_positions) + img_count + 1, dtype=torch.int32 + ) + visual_positions = torch.tensor(visual_positions[1:], dtype=torch.int32) + visual_positions[torch.where(visual_positions == -1)[0]] = img_positions - if images is not None: - inputs = [make_list_of_images(images)] - elif videos is not None: - inputs = make_batched_videos(videos) + if len(images) > 0: + images = make_list_of_images(images) + elif len(videos) > 0: + videos = make_batched_videos(videos) validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) - if not valid_images(inputs): + if not valid_images(videos) or not valid_images(images): raise ValueError( "Invalid input type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) - pixel_values = [ - [ + if len(videos) > 0: + pixel_values_videos = [ + [ + self._preprocess_image( + image=frame, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_convert_rgb=do_convert_rgb, + data_format=data_format, + input_data_format=input_data_format, + ) + for frame in video + ] + for video in videos + ] + + if len(images) > 0: + pixel_values_images = [ self._preprocess_image( - image=frame, + image=image, do_resize=do_resize, size=size, resample=resample, @@ -329,12 +370,28 @@ def preprocess( data_format=data_format, input_data_format=input_data_format, ) - for frame in video + for image in images ] - for video in inputs - ] - encoded_outputs = BatchFeature(data={"pixel_values": pixel_values}, tensor_type=return_tensors) + if len(images) > 0 and len(videos) > 0: + encoded_outputs = BatchFeature( + data={ + "pixel_values_videos": pixel_values_videos, + "pixel_values_images": pixel_values_images, + "visual_positions": visual_positions, + }, + tensor_type=return_tensors, + ) + elif len(images) > 0: + encoded_outputs = BatchFeature( + data={"pixel_values_images": pixel_values_images, "visual_positions": visual_positions}, + tensor_type=return_tensors, + ) + elif len(videos) > 0: + encoded_outputs = BatchFeature( + data={"pixel_values_videos": pixel_values_videos, "visual_positions": visual_positions}, + tensor_type=return_tensors, + ) return encoded_outputs diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index d5b55e3a209e12..0aabab63df813c 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -453,14 +453,6 @@ def forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - # if video input, merge batch with frames dim and later reshape back - if pixel_values.dim() == 5: - batch_size, num_frames, channels, height, width = pixel_values.shape - pixel_values = pixel_values.reshape(batch_size * num_frames, channels, height, width) - else: - batch_size, _, _, _ = pixel_values.shape - num_frames = 1 - hidden_states = self.embeddings(pixel_values) hidden_states = self.pre_layrnorm(hidden_states) @@ -475,12 +467,6 @@ def forward( pooled_output = last_hidden_state[:, 0, :] pooled_output = self.post_layernorm(pooled_output) - pooled_output = pooled_output.reshape(batch_size, num_frames, -1).mean(1) - encoder_outputs.hidden_states = [ - hidden_state.reshape(batch_size, num_frames, -1, self.config.hidden_size) - for hidden_state in encoder_outputs.hidden_states - ] - if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] @@ -609,7 +595,8 @@ def _supports_sdpa(self): class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel): def __init__(self, config: VideoLlavaConfig): super().__init__(config) - self.vision_tower = VideoLlavaVisionTransformer(config.vision_config) + self.video_tower = VideoLlavaVisionTransformer(config.vision_config) + self.image_tower = VideoLlavaVisionTransformer(config.vision_config) self.multi_modal_projector = VideoLlavaMultiModalProjector(config) self.vocab_size = config.vocab_size @@ -648,8 +635,8 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m self.vocab_size = model_embeds.num_embeddings return model_embeds - def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): - batch_size, num_images, num_image_patches, embed_dim = image_features.shape + def _merge_input_ids_with_image_features(self, visual_features, inputs_embeds, input_ids, attention_mask, labels): + num_images, num_image_patches, embed_dim = visual_features.shape batch_size, sequence_length = input_ids.shape left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) # 1. Create a mask to know where special image tokens are @@ -703,24 +690,95 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in image_to_overwrite = torch.all(final_embedding == 0, dim=-1) image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) - if image_to_overwrite.sum() != image_features.shape[:-1].numel(): + if image_to_overwrite.sum() != visual_features.shape[:-1].numel(): raise ValueError( f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." ) - final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) + final_embedding[image_to_overwrite] = visual_features.contiguous().reshape(-1, embed_dim).to(target_device) final_attention_mask |= image_to_overwrite position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) return final_embedding, final_attention_mask, final_labels, position_ids + def _get_vision_features( + self, + pixel_values_images: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + visual_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + if pixel_values_images is None and pixel_values_videos is None: + raise ValueError("You have to specify `pixel_values_images` or `pixel_values_videos`") + + if pixel_values_videos is not None: + batch_size, num_frames, channels, height, width = pixel_values_videos.shape + pixel_values = pixel_values_videos.reshape(batch_size * num_frames, channels, height, width) + video_outputs = self.video_tower( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if pixel_values_images is not None: + batch_size = pixel_values_images.shape[0] + image_outputs = self.image_tower( + pixel_values_images, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # return immediately if either one is None, otherwise have to merge them + if pixel_values_images is None: + return video_outputs + + if pixel_values_videos is None: + return image_outputs + + if visual_positions is None: + visual_positions = torch.arange( + pixel_values_images.shape[0] + pixel_values_videos.shape[0], + dtype=torch.int32, + device=pixel_values_images.device, + ) + + outputs = () + for idx in range(len(image_outputs)): + if idx == 1: + ordered_out = torch.cat([image_outputs[idx], video_outputs[idx]], dim=0) + elif isinstance(image_outputs[idx], torch.Tensor): + merged_out = torch.cat([image_outputs[idx], video_outputs[idx]], dim=0) + ordered_out = merged_out[visual_positions] + else: + ordered_out = [ + torch.cat([vid, img], dim=0)[visual_positions, ...] + for img, vid in zip(image_outputs[idx], video_outputs[idx]) + ] + outputs += (ordered_out,) + + if isinstance(image_outputs, tuple): + return outputs + + return BaseModelOutputWithPooling( + last_hidden_state=outputs[0], + pooler_output=outputs[1], + hidden_states=outputs[2] if output_hidden_states else None, + attentions=outputs[3] if output_attentions else None, + ) + @add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=VideoLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, + pixel_values_images: torch.FloatTensor = None, + pixel_values_videos: torch.FloatTensor = None, + visual_positions: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, @@ -747,12 +805,13 @@ def forward( ```python >>> from PIL import Image >>> import requests + >>> import numpy as np >>> from decord import VideoReader >>> from huggingface_hub import hf_hub_download - >>> from transformers import AutoProcessor, VideoLlavaForConditionalGeneration + >>> from transformers import VideoLlavaProcessor, VideoLlavaForConditionalGeneration >>> model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B") - >>> processor = AutoProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B") + >>> processor = VideoLlavaProcessor.from_pretrained("LanguageBind/Video-LLaVA-7B") >>> prompt = "USER: Why is this video funny? ASSISTANT:" >>> video_path = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset") @@ -762,23 +821,26 @@ def forward( >>> indices = np.arange(0, len(vr), len(vr) / 8).astype(int) >>> frames = vr.get_batch(indices).asnumpy() - >>> inputs = processor(text=prompt, videos=list(clip), return_tensors="pt") + >>> inputs = processor(text=prompt, visual_inputs=clip, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(**inputs, max_length=80) >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - 'USER: Why is this video funny? ASSISTANT: The video is funny because the baby is sitting on the bed and reading a book, which is an unusual and amusing sight.Ъ' + 'USER: Why is this video funny? ASSISTANT: The video is funny because the baby is sitting on the bed and reading a book, which is an unusual and amusing sight.Ъ' - >>> # to generate from image + >>> # to generate from image and video mix >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) - >>> prompt = "USER: How many cats are there in the image? ASSISTANT:" - >>> inputs = processor(text=prompt, videos=list(clip), return_tensors="pt") + >>> prompt = [ + "USER: How many cats are there in the image? ASSISTANT:", + "USER: Why is this video funny? ASSISTANT:" + ] + >>> inputs = processor(text=prompt, visual_inputs=[image, clip], padding=True, return_tensors="pt") >>> # Generate - >>> generate_ids = model.generate(**inputs, max_length=30) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0] - 'USER: How many cats are there in the image? ASSISTANT: There are two cats in the image' + >>> generate_ids = model.generate(**inputs, max_length=50) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) + ['USER: How many cats are there in the image? ASSISTANT: There are two cats in the image.\nHow many cats are sleeping on the couch?\nThere are', 'USER: Why is this video funny? ASSISTANT: The video is funny because the baby is sitting on the bed and reading a book, which is an unusual and amusing'] ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -800,30 +862,35 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) # 2. Merge text and images - if pixel_values is not None and input_ids.shape[1] != 1: - image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) + if (pixel_values_images is not None or pixel_values_videos is not None) and input_ids.shape[1] != 1: + vision_outputs = self._get_vision_features( + pixel_values_images=pixel_values_images, + pixel_values_videos=pixel_values_videos, + visual_positions=visual_positions, + output_hidden_states=True, + ) # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. - selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + selected_visual_feature = vision_outputs.hidden_states[vision_feature_layer].squeeze(1) if vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, :, 1:] + selected_visual_feature = selected_visual_feature[:, 1:] elif vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature + selected_visual_feature = selected_visual_feature else: raise ValueError( f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" ) - image_features = self.multi_modal_projector(selected_image_feature) + visual_features = self.multi_modal_projector(selected_visual_feature) inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( - image_features, inputs_embeds, input_ids, attention_mask, labels + visual_features, inputs_embeds, input_ids, attention_mask, labels ) if labels is None: labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long) else: - # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of + # In case input_ids.shape[1] == 1 & past_key_values != None, we are in the case of # generation with cache - if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: + if past_key_values is not None and input_ids.shape[1] == 1: # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] @@ -895,7 +962,14 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values_images=None, + pixel_values_videos=None, + attention_mask=None, + **kwargs, ): if past_key_values is not None: if isinstance(past_key_values, Cache): @@ -942,7 +1016,9 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, - "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "pixel_values_images": pixel_values_images, + "visual_positions": kwargs.get("visual_positions"), } ) return model_inputs diff --git a/src/transformers/models/video_llava/processing_video_llava.py b/src/transformers/models/video_llava/processing_video_llava.py index d285d41ef064d8..41bf352ee4e10d 100644 --- a/src/transformers/models/video_llava/processing_video_llava.py +++ b/src/transformers/models/video_llava/processing_video_llava.py @@ -50,8 +50,7 @@ def __init__(self, image_processor=None, tokenizer=None): def __call__( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - images: ImageInput = None, - videos: ImageInput = None, + visual_inputs: ImageInput = None, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = None, max_length=None, @@ -107,17 +106,17 @@ def __call__( `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ - if images is not None or videos is not None: - pixel_values = self.image_processor(images=images, videos=videos, return_tensors=return_tensors)[ - "pixel_values" - ] + if visual_inputs is not None: + image_kwargs = self.image_processor( + visual_inputs=visual_inputs, images=None, return_tensors=return_tensors + ) else: - pixel_values = None + image_kwargs = {} text_inputs = self.tokenizer( text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length ) - return BatchFeature(data={**text_inputs, "pixel_values": pixel_values}) + return BatchFeature(data={**text_inputs, **image_kwargs}) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama def batch_decode(self, *args, **kwargs): diff --git a/tests/models/video_llava/test_modeling_video_llava.py b/tests/models/video_llava/test_modeling_video_llava.py index fde4397d21b2ab..166ff3005a7bad 100644 --- a/tests/models/video_llava/test_modeling_video_llava.py +++ b/tests/models/video_llava/test_modeling_video_llava.py @@ -149,7 +149,7 @@ def prepare_config_and_inputs_for_common(self): # we are giving 3 videos, each has 2 frames let's make sure we pass in 3 * 2 image tokens input_ids[:, :2] = config.image_token_index inputs_dict = { - "pixel_values": pixel_values, + "pixel_values_videos": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask, } @@ -244,7 +244,7 @@ def test_small_model_integration_test_llama(self): repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset" ) video_file = np.load(video_file) - inputs = self.processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16) + inputs = self.processor(prompt, visual_inputs=video_file, return_tensors="pt").to(torch_device, torch.float16) output = model.generate(**inputs, max_new_tokens=900, do_sample=False) EXPECTED_DECODED_TEXT = "USER: Describe the video in details. ASSISTANT: The video features a young child sitting on a bed, holding a book and reading it. " \ @@ -281,7 +281,7 @@ def test_small_model_integration_test_llama_batched(self): hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo_2.npy", repo_type="dataset") ) - inputs = processor(prompts, videos=[video_1, video_2], return_tensors="pt", padding=True) + inputs = processor(prompts, visual_inputs=[video_1, video_2], return_tensors="pt", padding=True) output = model.generate(**inputs, max_new_tokens=20) @@ -316,7 +316,7 @@ def test_small_model_integration_test_llama_batched_regression(self): hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo_2.npy", repo_type="dataset") ) - inputs = processor(prompts, videos=[video_1, video_2, video_1], return_tensors="pt", padding=True) + inputs = processor(prompts, visual_inputs=[video_1, video_2, video_1], return_tensors="pt", padding=True) output = model.generate(**inputs, max_new_tokens=20) @@ -343,7 +343,7 @@ def test_video_llava_index_error_bug(self): repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset" ) video_file = np.load(video_file) - inputs = self.processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16) + inputs = self.processor(prompt, visual_inputs=video_file, return_tensors="pt").to(torch_device, torch.float16) # Make sure that `generate` works _ = model.generate(**inputs, max_new_tokens=20) @@ -358,7 +358,7 @@ def test_video_llava_merge_inputs_error_bug(self): ).to(torch_device) # Simulate some user inputs - pixel_values = torch.randn( + pixel_values_videos = torch.randn( (2, 8, 3, 224, 224), dtype=torch.float, device=torch_device, @@ -413,7 +413,7 @@ def test_video_llava_merge_inputs_error_bug(self): # Make sure that the loss is properly computed loss = model( - pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, input_ids=input_ids, attention_mask=attention_mask, labels=input_ids, From c36819d1ef95a5d1331ea6446502b0a87af26c24 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 21 Mar 2024 14:01:43 +0100 Subject: [PATCH 05/49] update for better handling of mixed input and clean-up a bit --- README.md | 1 + README_de.md | 1 + README_es.md | 1 + README_fr.md | 1 + README_hd.md | 1 + README_ja.md | 1 + README_ko.md | 1 + README_pt-br.md | 1 + README_ru.md | 1 + README_te.md | 1 + README_vi.md | 1 + README_zh-hans.md | 1 + README_zh-hant.md | 1 + docs/source/en/index.md | 1 + docs/source/en/model_doc/video_llava.md | 35 +- docs/source/en/perf_infer_gpu_one.md | 1 + .../models/video_llava/__init__.py | 11 +- .../video_llava/configuration_video_llava.py | 10 +- .../convert_video_llava_weights_to_hf.py | 13 +- .../image_processing_video_llava.py | 17 +- .../video_llava/modeling_video_llava.py | 554 +++--------------- src/transformers/utils/dummy_pt_objects.py | 31 + .../video_llava/test_modeling_video_llava.py | 149 ++++- 23 files changed, 334 insertions(+), 501 deletions(-) diff --git a/README.md b/README.md index fdbd59883dd541..bfe5381c69a837 100644 --- a/README.md +++ b/README.md @@ -523,6 +523,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[UnivNet](https://huggingface.co/docs/transformers/model_doc/univnet)** (from Kakao Corporation) released with the paper [UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation](https://arxiv.org/abs/2106.07889) by Won Jang, Dan Lim, Jaesam Yoon, Bongwan Kim, and Juntae Kim. 1. **[UPerNet](https://huggingface.co/docs/transformers/model_doc/upernet)** (from Peking University) released with the paper [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221) by Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun. 1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (from Tsinghua University and Nankai University) released with the paper [Visual Attention Network](https://arxiv.org/abs/2202.09741) by Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu. +1. **[VideoLlava](https://huggingface.co/docs/transformers/main/model_doc/video_llava)** (from YUAN Lab, Peking University) released with the paper []() by Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munang Ning, Peng Jin, Li Yuan. 1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (from Multimedia Computing Group, Nanjing University) released with the paper [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602) by Zhan Tong, Yibing Song, Jue Wang, Limin Wang. 1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim. 1. **[VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)** (from University of Wisconsin–Madison) released with the paper [Making Large Multimodal Models Understand Arbitrary Visual Prompts](https://arxiv.org/abs/2312.00784) by Mu Cai, Haotian Liu, Siva Karthik Mustikovela, Gregory P. Meyer, Yuning Chai, Dennis Park, Yong Jae Lee. diff --git a/README_de.md b/README_de.md index a1a016767194d3..a86aa0ae103e36 100644 --- a/README_de.md +++ b/README_de.md @@ -519,6 +519,7 @@ Aktuelle Anzahl der Checkpoints: ![](https://img.shields.io/endpoint?url=https:/ 1. **[UnivNet](https://huggingface.co/docs/transformers/model_doc/univnet)** (from Kakao Corporation) released with the paper [UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation](https://arxiv.org/abs/2106.07889) by Won Jang, Dan Lim, Jaesam Yoon, Bongwan Kim, and Juntae Kim. 1. **[UPerNet](https://huggingface.co/docs/transformers/model_doc/upernet)** (from Peking University) released with the paper [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221) by Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun. 1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (from Tsinghua University and Nankai University) released with the paper [Visual Attention Network](https://arxiv.org/abs/2202.09741) by Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu. +1. **[VideoLlava](https://huggingface.co/docs/transformers/main/model_doc/video_llava)** (from YUAN Lab, Peking University) released with the paper [Video-LLaVA: Learning United Visual Representation by Alignment Before Projection](https://arxiv.org/pdf/2311.10122.pdf) by Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munang Ning, Peng Jin, Li Yuan. 1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (from Multimedia Computing Group, Nanjing University) released with the paper [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602) by Zhan Tong, Yibing Song, Jue Wang, Limin Wang. 1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim. 1. **[VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)** (from University of Wisconsin–Madison) released with the paper [Making Large Multimodal Models Understand Arbitrary Visual Prompts](https://arxiv.org/abs/2312.00784) by Mu Cai, Haotian Liu, Siva Karthik Mustikovela, Gregory P. Meyer, Yuning Chai, Dennis Park, Yong Jae Lee. diff --git a/README_es.md b/README_es.md index 61a5f284b88e23..a263867eaf958a 100644 --- a/README_es.md +++ b/README_es.md @@ -496,6 +496,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt 1. **[UnivNet](https://huggingface.co/docs/transformers/model_doc/univnet)** (from Kakao Corporation) released with the paper [UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation](https://arxiv.org/abs/2106.07889) by Won Jang, Dan Lim, Jaesam Yoon, Bongwan Kim, and Juntae Kim. 1. **[UPerNet](https://huggingface.co/docs/transformers/model_doc/upernet)** (from Peking University) released with the paper [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221) by Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun. 1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (from Tsinghua University and Nankai University) released with the paper [Visual Attention Network](https://arxiv.org/abs/2202.09741) by Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu. +1. **[VideoLlava](https://huggingface.co/docs/transformers/main/model_doc/video_llava)** (from YUAN Lab, Peking University) released with the paper [Video-LLaVA: Learning United Visual Representation by Alignment Before Projection](https://arxiv.org/pdf/2311.10122.pdf) by Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munang Ning, Peng Jin, Li Yuan. 1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (from Multimedia Computing Group, Nanjing University) released with the paper [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602) by Zhan Tong, Yibing Song, Jue Wang, Limin Wang. 1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim. 1. **[VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)** (from University of Wisconsin–Madison) released with the paper [Making Large Multimodal Models Understand Arbitrary Visual Prompts](https://arxiv.org/abs/2312.00784) by Mu Cai, Haotian Liu, Siva Karthik Mustikovela, Gregory P. Meyer, Yuning Chai, Dennis Park, Yong Jae Lee. diff --git a/README_fr.md b/README_fr.md index ab0a790c284dfb..e5fda2b82288c7 100644 --- a/README_fr.md +++ b/README_fr.md @@ -517,6 +517,7 @@ Nombre actuel de points de contrôle : ![](https://img.shields.io/endpoint?url=h 1. **[UnivNet](https://huggingface.co/docs/transformers/model_doc/univnet)** (de Kakao Corporation) a été publié dans l'article [UnivNet : un vocodeur neuronal avec des discriminateurs de spectrogramme multi-résolution pour la génération de formes d'onde haute fidélité](https://arxiv.org/abs/2106.07889) par Won Jang, Dan Lim, Jaesam Yoon, Bongwan Kim et Juntae Kim. 1. **[UPerNet](https://huggingface.co/docs/transformers/model_doc/upernet)** (de l'Université de Pékin) a été publié dans l'article [Analyse perceptuelle unifiée pour la compréhension de scènes](https://arxiv.org/abs/1807.10221) par Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun. 1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (de l'Université Tsinghua et de l'Université Nankai) publié dans l'article [Visual Attention Network](https://arxiv.org/abs/2202.09741) par Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu. +1. **[VideoLlava](https://huggingface.co/docs/transformers/main/model_doc/video_llava)** (from YUAN Lab, Peking University) released with the paper [Video-LLaVA: Learning United Visual Representation by Alignment Before Projection](https://arxiv.org/pdf/2311.10122.pdf) by Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munang Ning, Peng Jin, Li Yuan. 1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (du groupe d'informatique multimédia, Université de Nankin) publié dans l'article [VideoMAE : Les autoencodeurs masqués sont des apprenants efficaces en données pour l'auto-pré-entraînement vidéo supervisé](https://arxiv.org/abs/2203.12602) par Zhan Tong, Yibing Song, Jue Wang, Limin Wang. 1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (du NAVER AI Lab/Kakao Enterprise/Kakao Brain) publié dans l'article [ViLT : Vision-and-Language Transformer sans convolution ni supervision de région](https://arxiv.org/abs/2102.03334) par Wonjae Kim, Bokyung Son, Ildoo Kim. 1. **[VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)** (de l'Université du Wisconsin–Madison) publié dans l'article [Rendre les grands modèles multimodaux comprenant des incitations visuelles arbitraires](https://arxiv.org/abs/2312.00784) par Mu Cai, Haotian Liu, Siva Karthik Mustikovela, Gregory P. Meyer, Yuning Chai, Dennis Park, Yong Jae Lee. diff --git a/README_hd.md b/README_hd.md index 4c1b6a419e448c..329bb7c4cee455 100644 --- a/README_hd.md +++ b/README_hd.md @@ -470,6 +470,7 @@ conda install conda-forge::transformers 1. **[UnivNet](https://huggingface.co/docs/transformers/model_doc/univnet)** (from Kakao Corporation) released with the paper [UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation](https://arxiv.org/abs/2106.07889) by Won Jang, Dan Lim, Jaesam Yoon, Bongwan Kim, and Juntae Kim. 1. **[UPerNet](https://huggingface.co/docs/transformers/model_doc/upernet)** (from Peking University) released with the paper [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221) by Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun. 1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (सिंघुआ यूनिवर्सिटी और ननकाई यूनिवर्सिटी से) साथ में पेपर [विजुअल अटेंशन नेटवर्क](https://arxiv.org/pdf/2202.09741.pdf) मेंग-हाओ गुओ, चेंग-ज़े लू, झेंग-निंग लियू, मिंग-मिंग चेंग, शि-मिन हू द्वारा। +1. **[VideoLlava](https://huggingface.co/docs/transformers/main/model_doc/video_llava)** (from YUAN Lab, Peking University) released with the paper [Video-LLaVA: Learning United Visual Representation by Alignment Before Projection](https://arxiv.org/pdf/2311.10122.pdf) by Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munang Ning, Peng Jin, Li Yuan. 1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (मल्टीमीडिया कम्प्यूटिंग ग्रुप, नानजिंग यूनिवर्सिटी से) साथ में पेपर [वीडियोएमएई: मास्क्ड ऑटोएन्कोडर स्व-पर्यवेक्षित वीडियो प्री-ट्रेनिंग के लिए डेटा-कुशल सीखने वाले हैं](https://arxiv.org/abs/2203.12602) ज़ान टोंग, यिबिंग सॉन्ग, जुए द्वारा वांग, लिमिन वांग द्वारा पोस्ट किया गया। 1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (NAVER AI Lab/Kakao Enterprise/Kakao Brain से) साथ में कागज [ViLT: Vision-and-Language Transformer बिना कनवल्शन या रीजन सुपरविजन](https://arxiv.org/abs/2102.03334) वोनजे किम, बोक्यूंग सोन, इल्डू किम द्वारा पोस्ट किया गया। 1. **[VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)** (University of Wisconsin–Madison से) Mu Cai, Haotian Liu, Siva Karthik Mustikovela, Gregory P. Meyer, Yuning Chai, Dennis Park, Yong Jae Lee. द्वाराअनुसंधान पत्र [Making Large Multimodal Models Understand Arbitrary Visual Prompts](https://arxiv.org/abs/2312.00784) के साथ जारी किया गया diff --git a/README_ja.md b/README_ja.md index 7efc8cd0570637..f067a892aab9fe 100644 --- a/README_ja.md +++ b/README_ja.md @@ -530,6 +530,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ 1. **[UnivNet](https://huggingface.co/docs/transformers/model_doc/univnet)** (from Kakao Corporation) released with the paper [UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation](https://arxiv.org/abs/2106.07889) by Won Jang, Dan Lim, Jaesam Yoon, Bongwan Kim, and Juntae Kim. 1. **[UPerNet](https://huggingface.co/docs/transformers/model_doc/upernet)** (Peking University から) Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun. から公開された研究論文 [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221) 1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (Tsinghua University and Nankai University から) Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu から公開された研究論文: [Visual Attention Network](https://arxiv.org/abs/2202.09741) +1. **[VideoLlava](https://huggingface.co/docs/transformers/main/model_doc/video_llava)** (from YUAN Lab, Peking University) released with the paper [Video-LLaVA: Learning United Visual Representation by Alignment Before Projection](https://arxiv.org/pdf/2311.10122.pdf) by Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munang Ning, Peng Jin, Li Yuan. 1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (Multimedia Computing Group, Nanjing University から) Zhan Tong, Yibing Song, Jue Wang, Limin Wang から公開された研究論文: [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602) 1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (NAVER AI Lab/Kakao Enterprise/Kakao Brain から) Wonjae Kim, Bokyung Son, Ildoo Kim から公開された研究論文: [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) 1. **[VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)** (University of Wisconsin–Madison から) Mu Cai, Haotian Liu, Siva Karthik Mustikovela, Gregory P. Meyer, Yuning Chai, Dennis Park, Yong Jae Lee. から公開された研究論文 [Making Large Multimodal Models Understand Arbitrary Visual Prompts](https://arxiv.org/abs/2312.00784) diff --git a/README_ko.md b/README_ko.md index 9004123d880cc0..9518bcf573c803 100644 --- a/README_ko.md +++ b/README_ko.md @@ -445,6 +445,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[UnivNet](https://huggingface.co/docs/transformers/model_doc/univnet)** (from Kakao Corporation) released with the paper [UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation](https://arxiv.org/abs/2106.07889) by Won Jang, Dan Lim, Jaesam Yoon, Bongwan Kim, and Juntae Kim. 1. **[UPerNet](https://huggingface.co/docs/transformers/model_doc/upernet)** (Peking University 에서 제공)은 Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun.의 [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221)논문과 함께 발표했습니다. 1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (Tsinghua University and Nankai University 에서) Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu 의 [Visual Attention Network](https://arxiv.org/pdf/2202.09741.pdf) 논문과 함께 발표했습니다. +1. **[VideoLlava](https://huggingface.co/docs/transformers/main/model_doc/video_llava)** (from YUAN Lab, Peking University) released with the paper [Video-LLaVA: Learning United Visual Representation by Alignment Before Projection](https://arxiv.org/pdf/2311.10122.pdf) by Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munang Ning, Peng Jin, Li Yuan. 1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (Multimedia Computing Group, Nanjing University 에서) Zhan Tong, Yibing Song, Jue Wang, Limin Wang 의 [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602) 논문과 함께 발표했습니다. 1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (NAVER AI Lab/Kakao Enterprise/Kakao Brain 에서) Wonjae Kim, Bokyung Son, Ildoo Kim 의 [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) 논문과 함께 발표했습니다. 1. **[VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)** (University of Wisconsin–Madison 에서 제공)은 Mu Cai, Haotian Liu, Siva Karthik Mustikovela, Gregory P. Meyer, Yuning Chai, Dennis Park, Yong Jae Lee.의 [Making Large Multimodal Models Understand Arbitrary Visual Prompts](https://arxiv.org/abs/2312.00784)논문과 함께 발표했습니다. diff --git a/README_pt-br.md b/README_pt-br.md index ef4c9b201b5f62..a76ce7b21bbac0 100644 --- a/README_pt-br.md +++ b/README_pt-br.md @@ -528,6 +528,7 @@ Número atual de pontos de verificação: ![](https://img.shields.io/endpoint?ur 1. **[UnivNet](https://huggingface.co/docs/transformers/model_doc/univnet)** (from Kakao Corporation) released with the paper [UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation](https://arxiv.org/abs/2106.07889) by Won Jang, Dan Lim, Jaesam Yoon, Bongwan Kim, and Juntae Kim. 1. **[UPerNet](https://huggingface.co/docs/transformers/model_doc/upernet)** (from Peking University) released with the paper [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221) by Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun. 1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (from Tsinghua University and Nankai University) released with the paper [Visual Attention Network](https://arxiv.org/abs/2202.09741) by Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu. +1. **[VideoLlava](https://huggingface.co/docs/transformers/main/model_doc/video_llava)** (from YUAN Lab, Peking University) released with the paper [Video-LLaVA: Learning United Visual Representation by Alignment Before Projection](https://arxiv.org/pdf/2311.10122.pdf) by Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munang Ning, Peng Jin, Li Yuan. 1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (from Multimedia Computing Group, Nanjing University) released with the paper [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602) by Zhan Tong, Yibing Song, Jue Wang, Limin Wang. 1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim. 1. **[VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)** (from University of Wisconsin–Madison) released with the paper [Making Large Multimodal Models Understand Arbitrary Visual Prompts](https://arxiv.org/abs/2312.00784) by Mu Cai, Haotian Liu, Siva Karthik Mustikovela, Gregory P. Meyer, Yuning Chai, Dennis Park, Yong Jae Lee. diff --git a/README_ru.md b/README_ru.md index 8139fa51b994a2..bfe86ef11f3644 100644 --- a/README_ru.md +++ b/README_ru.md @@ -518,6 +518,7 @@ conda install conda-forge::transformers 1. **[UnivNet](https://huggingface.co/docs/transformers/model_doc/univnet)** (from Kakao Corporation) released with the paper [UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation](https://arxiv.org/abs/2106.07889) by Won Jang, Dan Lim, Jaesam Yoon, Bongwan Kim, and Juntae Kim. 1. **[UPerNet](https://huggingface.co/docs/transformers/model_doc/upernet)** (from Peking University) released with the paper [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221) by Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun. 1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (from Tsinghua University and Nankai University) released with the paper [Visual Attention Network](https://arxiv.org/abs/2202.09741) by Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu. +1. **[VideoLlava](https://huggingface.co/docs/transformers/main/model_doc/video_llava)** (from YUAN Lab, Peking University) released with the paper [Video-LLaVA: Learning United Visual Representation by Alignment Before Projection](https://arxiv.org/pdf/2311.10122.pdf) by Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munang Ning, Peng Jin, Li Yuan. 1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (from Multimedia Computing Group, Nanjing University) released with the paper [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602) by Zhan Tong, Yibing Song, Jue Wang, Limin Wang. 1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim. 1. **[VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)** (from University of Wisconsin–Madison) released with the paper [Making Large Multimodal Models Understand Arbitrary Visual Prompts](https://arxiv.org/abs/2312.00784) by Mu Cai, Haotian Liu, Siva Karthik Mustikovela, Gregory P. Meyer, Yuning Chai, Dennis Park, Yong Jae Lee. diff --git a/README_te.md b/README_te.md index 00acdccd28c6fd..52d5b20da3a40e 100644 --- a/README_te.md +++ b/README_te.md @@ -520,6 +520,7 @@ Flax, PyTorch లేదా TensorFlow యొక్క ఇన్‌స్టా 1. **[UnivNet](https://huggingface.co/docs/transformers/model_doc/univnet)** (from Kakao Corporation) released with the paper [UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation](https://arxiv.org/abs/2106.07889) by Won Jang, Dan Lim, Jaesam Yoon, Bongwan Kim, and Juntae Kim. 1. **[UPerNet](https://huggingface.co/docs/transformers/model_doc/upernet)** (from Peking University) released with the paper [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221) by Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun. 1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (from Tsinghua University and Nankai University) released with the paper [Visual Attention Network](https://arxiv.org/abs/2202.09741) by Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu. +1. **[VideoLlava](https://huggingface.co/docs/transformers/main/model_doc/video_llava)** (from YUAN Lab, Peking University) released with the paper [Video-LLaVA: Learning United Visual Representation by Alignment Before Projection](https://arxiv.org/pdf/2311.10122.pdf) by Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munang Ning, Peng Jin, Li Yuan. 1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (from Multimedia Computing Group, Nanjing University) released with the paper [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602) by Zhan Tong, Yibing Song, Jue Wang, Limin Wang. 1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim. 1. **[VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)** (from University of Wisconsin–Madison) released with the paper [Making Large Multimodal Models Understand Arbitrary Visual Prompts](https://arxiv.org/abs/2312.00784) by Mu Cai, Haotian Liu, Siva Karthik Mustikovela, Gregory P. Meyer, Yuning Chai, Dennis Park, Yong Jae Lee. diff --git a/README_vi.md b/README_vi.md index 3a0a2f09db0227..e5aa42b7556ba4 100644 --- a/README_vi.md +++ b/README_vi.md @@ -519,6 +519,7 @@ Số lượng điểm kiểm tra hiện tại: ![](https://img.shields.io/endpoi 1. **[UnivNet](https://huggingface.co/docs/transformers/model_doc/univnet)** (từ Kakao Corporation) được phát hành với bài báo [UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation](https://arxiv.org/abs/2106.07889) by Won Jang, Dan Lim, Jaesam Yoon, Bongwan Kim, and Juntae Kim. 1. **[UPerNet](https://huggingface.co/docs/transformers/model_doc/upernet)** (từ Peking University) được phát hành với bài báo [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221) by Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun. 1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (từ Tsinghua University and Nankai University) được phát hành với bài báo [Visual Attention Network](https://arxiv.org/abs/2202.09741) by Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu. +1. **[VideoLlava](https://huggingface.co/docs/transformers/main/model_doc/video_llava)** (from YUAN Lab, Peking University) released with the paper [Video-LLaVA: Learning United Visual Representation by Alignment Before Projection](https://arxiv.org/pdf/2311.10122.pdf) by Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munang Ning, Peng Jin, Li Yuan. 1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (từ Multimedia Computing Group, Nanjing University) được phát hành với bài báo [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602) by Zhan Tong, Yibing Song, Jue Wang, Limin Wang. 1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (từ NAVER AI Lab/Kakao Enterprise/Kakao Brain) được phát hành với bài báo [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim. 1. **[VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)** (từ University of Wisconsin–Madison) được phát hành với bài báo [Making Large Multimodal Models Understand Arbitrary Visual Prompts](https://arxiv.org/abs/2312.00784) by Mu Cai, Haotian Liu, Siva Karthik Mustikovela, Gregory P. Meyer, Yuning Chai, Dennis Park, Yong Jae Lee. diff --git a/README_zh-hans.md b/README_zh-hans.md index 52a244bd8aa3eb..3838f90e174f10 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -469,6 +469,7 @@ conda install conda-forge::transformers 1. **[UnivNet](https://huggingface.co/docs/transformers/model_doc/univnet)** (from Kakao Corporation) released with the paper [UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation](https://arxiv.org/abs/2106.07889) by Won Jang, Dan Lim, Jaesam Yoon, Bongwan Kim, and Juntae Kim. 1. **[UPerNet](https://huggingface.co/docs/transformers/model_doc/upernet)** (来自 Peking University) 伴随论文 [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221) 由 Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun 发布。 1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (来自 Tsinghua University and Nankai University) 伴随论文 [Visual Attention Network](https://arxiv.org/pdf/2202.09741.pdf) 由 Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu 发布。 +1. **[VideoLlava](https://huggingface.co/docs/transformers/main/model_doc/video_llava)** (from YUAN Lab, Peking University) released with the paper [Video-LLaVA: Learning United Visual Representation by Alignment Before Projection](https://arxiv.org/pdf/2311.10122.pdf) by Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munang Ning, Peng Jin, Li Yuan. 1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (来自 Multimedia Computing Group, Nanjing University) 伴随论文 [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602) 由 Zhan Tong, Yibing Song, Jue Wang, Limin Wang 发布。 1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (来自 NAVER AI Lab/Kakao Enterprise/Kakao Brain) 伴随论文 [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) 由 Wonjae Kim, Bokyung Son, Ildoo Kim 发布。 1. **[VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)** (来自 University of Wisconsin–Madison) 伴随论文 [Making Large Multimodal Models Understand Arbitrary Visual Prompts](https://arxiv.org/abs/2312.00784) 由 Mu Cai, Haotian Liu, Siva Karthik Mustikovela, Gregory P. Meyer, Yuning Chai, Dennis Park, Yong Jae Lee 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index fd306c6176fa9d..86c82a1a21528e 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -481,6 +481,7 @@ conda install conda-forge::transformers 1. **[UnivNet](https://huggingface.co/docs/transformers/model_doc/univnet)** (from Kakao Corporation) released with the paper [UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation](https://arxiv.org/abs/2106.07889) by Won Jang, Dan Lim, Jaesam Yoon, Bongwan Kim, and Juntae Kim. 1. **[UPerNet](https://huggingface.co/docs/transformers/model_doc/upernet)** (from Peking University) released with the paper [Unified Perceptual Parsing for Scene Understanding](https://arxiv.org/abs/1807.10221) by Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun. 1. **[VAN](https://huggingface.co/docs/transformers/model_doc/van)** (from Tsinghua University and Nankai University) released with the paper [Visual Attention Network](https://arxiv.org/pdf/2202.09741.pdf) by Meng-Hao Guo, Cheng-Ze Lu, Zheng-Ning Liu, Ming-Ming Cheng, Shi-Min Hu. +1. **[VideoLlava](https://huggingface.co/docs/transformers/main/model_doc/video_llava)** (from YUAN Lab, Peking University) released with the paper [Video-LLaVA: Learning United Visual Representation by Alignment Before Projection](https://arxiv.org/pdf/2311.10122.pdf) by Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munang Ning, Peng Jin, Li Yuan. 1. **[VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae)** (from Multimedia Computing Group, Nanjing University) released with the paper [VideoMAE: Masked Autoencoders are Data-Efficient Learners for Self-Supervised Video Pre-Training](https://arxiv.org/abs/2203.12602) by Zhan Tong, Yibing Song, Jue Wang, Limin Wang. 1. **[ViLT](https://huggingface.co/docs/transformers/model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim. 1. **[VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)** (from University of Wisconsin–Madison) released with the paper [Making Large Multimodal Models Understand Arbitrary Visual Prompts](https://arxiv.org/abs/2312.00784) by Mu Cai, Haotian Liu, Siva Karthik Mustikovela, Gregory P. Meyer, Yuning Chai, Dennis Park, Yong Jae Lee. diff --git a/docs/source/en/index.md b/docs/source/en/index.md index ec8def2b2ef31b..faf9f336c32dec 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -291,6 +291,7 @@ Flax), PyTorch, and/or TensorFlow. | [UnivNet](model_doc/univnet) | ✅ | ❌ | ❌ | | [UPerNet](model_doc/upernet) | ✅ | ❌ | ❌ | | [VAN](model_doc/van) | ✅ | ❌ | ❌ | +| [VideoLlava](model_doc/video_llava) | ✅ | ❌ | ❌ | | [VideoMAE](model_doc/videomae) | ✅ | ❌ | ❌ | | [ViLT](model_doc/vilt) | ✅ | ❌ | ❌ | | [VipLlava](model_doc/vipllava) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/video_llava.md b/docs/source/en/model_doc/video_llava.md index d2f8bc47283154..89d0cacc42d716 100644 --- a/docs/source/en/model_doc/video_llava.md +++ b/docs/source/en/model_doc/video_llava.md @@ -18,25 +18,50 @@ rendered properly in your Markdown viewer. ## Overview -The video_llava model was proposed in []() by . - +Video-LLaVa is an open-source multimodal LLM trained by fine-tuning LlamA/Vicuna on multimodal instruction-following data generated by Llava1.5 and VideChat. It is an auto-regressive language model, based on the transformer architecture. Video-LLaVa unifies visual representations to the language feature space, and enables an LLM to perform visual reasoning capabilities on both images and videos simultaneously. + + +The Video-LLaVA model was proposed in [Video-LLaVA: Learning United Visual Representation by Alignment Before Projection](https://arxiv.org/abs/2311.10122) by Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munang Ning, Peng Jin, Li Yuan. The abstract from the paper is the following: -** +*The Large Vision-Language Model (LVLM) has enhanced the performance of various downstream tasks in +visual-language understanding. Most existing approaches +encode images and videos into separate feature spaces, +which are then fed as inputs to large language models. +However, due to the lack of unified tokenization for images and videos, namely misalignment before projection, it +becomes challenging for a Large Language Model (LLM) +to learn multi-modal interactions from several poor projection layers. In this work, we unify visual representation into the language feature space to advance the foundational LLM towards a unified LVLM. As a result, we establish a simple but robust LVLM baseline, Video-LLaVA, +which learns from a mixed dataset of images and videos, +mutually enhancing each other. Video-LLaVA achieves superior performances on a broad range of 9 image benchmarks across 5 image question-answering datasets and 4 +image benchmark toolkits. Additionally, our Video-LLaVA +also outperforms Video-ChatGPT by 5.8%, 9.9%, 18.6%, +and 10.1% on MSRVTT, MSVD, TGIF, and ActivityNet, respectively. Notably, extensive experiments demonstrate that +Video-LLaVA mutually benefits images and videos within +a unified visual representation, outperforming models designed specifically for images or videos. We aim for this +work to provide modest insights into the multi-modal inputs +for the LLM* Tips: -This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). -The original code can be found [here](). +This model was contributed by [RaushanTurganbay](https://huggingface.co/RaushanTurganbay). +The original code can be found [here](https://github.com/PKU-YuanGroup/Video-LLaVA). ## VideoLlavaConfig [[autodoc]] VideoLlavaConfig +## VideoLlavaImageProcessor + +[[autodoc]] VideoLlavaImageProcessor + +## VideoLlavaProcessor + +[[autodoc]] VideoLlavaProcessor + ## VideoLlavaForConditionalGeneration [[autodoc]] VideoLlavaForConditionalGeneration diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 5a8fbd6d9e66d7..0920ffe13ab31e 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -50,6 +50,7 @@ FlashAttention-2 is currently supported for the following architectures: * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) * [Llava](https://huggingface.co/docs/transformers/model_doc/llava) * [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava) +* [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava) * [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) diff --git a/src/transformers/models/video_llava/__init__.py b/src/transformers/models/video_llava/__init__.py index 02e90214de2a1f..971b130884e043 100644 --- a/src/transformers/models/video_llava/__init__.py +++ b/src/transformers/models/video_llava/__init__.py @@ -21,10 +21,17 @@ "VIDEO_LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", "VideoLlavaConfig", ], - "image_processing_video_llava": ["VideoLlavaImageProcessor"], "processing_video_llava": ["VideoLlavaProcessor"], } +try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["image_processing_video_llava"] = ["VideoLlavaImageProcessor"] + try: if not is_torch_available(): raise OptionalDependencyNotAvailable() @@ -33,7 +40,6 @@ else: _import_structure["modeling_video_llava"] = [ "VIDEO_LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST", - "VideoLlavaVisionTransformer", "VideoLlavaPreTrainedModel", "VideoLlavaForConditionalGeneration", ] @@ -63,7 +69,6 @@ VIDEO_LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST, VideoLlavaForConditionalGeneration, VideoLlavaPreTrainedModel, - VideoLlavaVisionTransformer, ) else: diff --git a/src/transformers/models/video_llava/configuration_video_llava.py b/src/transformers/models/video_llava/configuration_video_llava.py index b184834136319c..5e4588ad106bab 100644 --- a/src/transformers/models/video_llava/configuration_video_llava.py +++ b/src/transformers/models/video_llava/configuration_video_llava.py @@ -20,7 +20,7 @@ logger = logging.get_logger(__name__) -VideoLlava_PRETRAINED_CONFIG_ARCHIVE_MAP = { +VIDEO_LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = { "LanguageBind/Video-LLaVA-7B": "https://huggingface.co/LanguageBind/Video-LLaVA-7B/resolve/main/config.json", } @@ -29,7 +29,9 @@ class VideoLlavaConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`VideoLlavaForConditionalGeneration`]. It is used to instantiate an VideoLlava model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the `LanguageBind/Video-LLaVA-7B`. + with the defaults will yield a similar configuration to that of the like LanguageBind/Video-LLaVA-7B. + + e.g. [LanguageBind/Video-LLaVA-7B](https://huggingface.co/LanguageBind/Video-LLaVA-7B) Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -43,6 +45,8 @@ class VideoLlavaConfig(PretrainedConfig): The ignore index for the loss function. image_token_index (`int`, *optional*, defaults to 32000): The image token index to encode the image prompt. + video_token_index (`int`, *optional*, defaults to 32001): + The video token index to encode the image prompt. projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): The activation function used by the multimodal projector. vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): @@ -83,6 +87,7 @@ def __init__( text_config=None, ignore_index=-100, image_token_index=32000, + video_token_index=32001, projector_hidden_act="gelu", vision_feature_select_strategy="default", vision_feature_layer=-2, @@ -91,6 +96,7 @@ def __init__( ): self.ignore_index = ignore_index self.image_token_index = image_token_index + self.video_token_index = video_token_index self.projector_hidden_act = projector_hidden_act self.vision_feature_select_strategy = vision_feature_select_strategy self.vision_feature_layer = vision_feature_layer diff --git a/src/transformers/models/video_llava/convert_video_llava_weights_to_hf.py b/src/transformers/models/video_llava/convert_video_llava_weights_to_hf.py index ecd812fd843cd3..bdbff922f32098 100644 --- a/src/transformers/models/video_llava/convert_video_llava_weights_to_hf.py +++ b/src/transformers/models/video_llava/convert_video_llava_weights_to_hf.py @@ -52,17 +52,17 @@ "model.mm_projector": "multi_modal_projector", "model": "language_model.model", "lm_head": "language_model.lm_head", + "video_tower": "video_tower.vision_model", + "image_tower": "image_tower.vision_model", "multi_modal_projector.0": "multi_modal_projector.linear_1", "multi_modal_projector.2": "multi_modal_projector.linear_2", } -KEYS_TO_IGNORE = [] - def convert_state_dict_to_hf(state_dict): new_state_dict = {} for key, value in state_dict.items(): - if key.endswith(".inv_freq") or key in KEYS_TO_IGNORE: + if key.endswith(".inv_freq"): continue for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): if key_to_modify in key: @@ -78,6 +78,7 @@ def convert_video_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_p tokenizer = AutoTokenizer.from_pretrained(text_model_id) tokenizer.add_tokens(AddedToken("", special=True, normalized=False), special_tokens=True) + tokenizer.add_tokens(AddedToken("